Skip to content

Commit c04bc99

Browse files
[MPS] Add support for flatbuffer serialization > 4GB
Differential Revision: D60876064 Pull Request resolved: #4574
1 parent bbabd28 commit c04bc99

File tree

11 files changed

+336
-19
lines changed

11 files changed

+336
-19
lines changed

backends/apple/mps/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ add_library(mpsdelegate ${_mps_backend__srcs})
6969
find_library(FOUNDATION_FRAMEWORK Foundation)
7070
find_library(METAL_FRAMEWORK Metal)
7171
find_library(MPS_FRAMEWORK MetalPerformanceShaders)
72-
find_library(MPS_GRAPG_FRAMEWORK MetalPerformanceShadersGraph)
72+
find_library(MPS_GRAPH_FRAMEWORK MetalPerformanceShadersGraph)
7373

7474
target_link_libraries(
7575
mpsdelegate
@@ -79,7 +79,7 @@ target_link_libraries(
7979
${FOUNDATION_FRAMEWORK}
8080
${METAL_FRAMEWORK}
8181
${MPS_FRAMEWORK}
82-
${MPS_GRAPG_FRAMEWORK}
82+
${MPS_GRAPH_FRAMEWORK}
8383
)
8484

8585
target_link_options_shared_lib(mpsdelegate)

backends/apple/mps/mps_preprocess.py

Lines changed: 95 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,8 @@
22
# Copyright (c) 2023 Apple Inc. All rights reserved.
33
# Provided subject to the LICENSE file in the top level directory.
44
#
5-
65
import logging
7-
from typing import Dict, final, List
6+
from typing import ClassVar, Dict, final, List, Tuple
87

98
import torch
109

@@ -16,6 +15,8 @@
1615
)
1716

1817
from executorch.backends.apple.mps.serialization.mps_graph_schema import (
18+
Buffer,
19+
DataSegment,
1920
MPSGraph,
2021
MPSTensor,
2122
OpType,
@@ -25,6 +26,7 @@
2526
convert_to_flatbuffer,
2627
)
2728
from executorch.backends.apple.mps.utils.mps_utils import is_parameter
29+
from executorch.exir._serialize._program import Cord
2830

2931
from executorch.exir.backend.backend_details import (
3032
BackendDetails,
@@ -39,6 +41,29 @@
3941

4042
@final
4143
class MPSBackend(BackendDetails):
44+
@staticmethod
45+
def slice_len_max(s):
46+
assert s.start is not None
47+
assert s.stop is not None
48+
step = 1
49+
if s.step is not None:
50+
step = s.step
51+
return max((s.stop - s.start) // step, 1)
52+
53+
MAGIC_IX: ClassVar[slice] = slice(4, 8)
54+
DATA_SEGMENT_OFFSET_IX: ClassVar[slice] = slice(8, 16)
55+
DATA_SEGMENT_SIZE_IX: ClassVar[slice] = slice(16, 24)
56+
57+
# magic bytes that should be at the beginning of the header
58+
EXPECTED_MAGIC: ClassVar[bytes] = b"MP00"
59+
# The length of the header in bytes
60+
EXPECTED_LENGTH: ClassVar[int] = (
61+
4
62+
+ slice_len_max(MAGIC_IX)
63+
+ slice_len_max(DATA_SEGMENT_OFFSET_IX)
64+
+ slice_len_max(DATA_SEGMENT_SIZE_IX)
65+
)
66+
4267
@staticmethod
4368
def preprocess(
4469
edge_program: ExportedProgram,
@@ -67,6 +92,7 @@ def preprocess(
6792
output_ids=[],
6893
constant_ids=[],
6994
graph_type=OpType.mps_graph,
95+
constant_segment=DataSegment(0, 0),
7096
)
7197

7298
convert_model_to_fp16 = True
@@ -100,10 +126,44 @@ def preprocess(
100126
else:
101127
op_handler[node.op](edge_program, node_visitors, node, mps_graph)
102128

129+
segment_data, mps_graph = _extract_constant_segment(mps_graph)
130+
131+
# Add to aggregate segments cord with padding.
132+
padding_length = _padding_required(len(segment_data), 16)
133+
if padding_length > 0:
134+
segment_data.append(b"\x00" * padding_length)
135+
136+
# Combine mps_graph with segment data
137+
combined = Cord()
138+
graph_bytes = convert_to_flatbuffer(mps_graph)
139+
140+
data_segment_offset: int = MPSBackend.EXPECTED_LENGTH
141+
data_segment_offset = data_segment_offset + len(graph_bytes)
142+
143+
graph_padding_length = _padding_required(data_segment_offset, 16)
144+
data_segment_offset = data_segment_offset + graph_padding_length
145+
data_segment_size = len(segment_data)
146+
147+
data: bytes = (
148+
b"\x00\x00\x00\x00"
149+
+ MPSBackend.EXPECTED_MAGIC
150+
+ data_segment_offset.to_bytes(8, byteorder="little")
151+
+ data_segment_size.to_bytes(8, byteorder="little")
152+
)
153+
assert len(data) == MPSBackend.EXPECTED_LENGTH
154+
155+
combined.append(data)
156+
combined.append(graph_bytes)
157+
158+
if graph_padding_length > 0:
159+
combined.append(b"\x00" * graph_padding_length)
160+
# Append the segment data to the end of the mps graph
161+
combined.append(segment_data)
162+
103163
if logging.DEBUG >= logging.root.level:
104164
pretty_print(mps_graph)
105165

106-
return PreprocessResult(processed_bytes=convert_to_flatbuffer(mps_graph))
166+
return PreprocessResult(processed_bytes=bytes(combined))
107167

108168
@staticmethod
109169
def handle_call_function(
@@ -164,12 +224,42 @@ def handle_get_attr(
164224
pass
165225

166226

227+
def _padding_required(offset: int, alignment: int) -> int:
228+
"""Returns the padding required to align `offset` to `alignment`."""
229+
remainder: int = offset % alignment
230+
if remainder != 0:
231+
return alignment - remainder
232+
return 0
233+
234+
235+
def _extract_constant_segment(mps_graph: MPSGraph) -> Tuple[Cord, MPSGraph]:
236+
"""Extracts the constant segment from the MPSGraph and returns the updated MPSGraph along with the segment data."""
237+
# Note that the beginning of the segment data is not aligned. Need to handle out of this call.
238+
segment_data = Cord()
239+
offset = 0
240+
for i in range(len(mps_graph.mps_values)):
241+
tensor = mps_graph.mps_values[i]
242+
if tensor.constant_buffer_size > 0:
243+
# Notice that buffer is already force aligned so we don't need to pad it
244+
segment_data.append(tensor.constant_buffer.storage)
245+
246+
# Reset buffer to empty
247+
tensor.constant_buffer = Buffer(storage=b"")
248+
# Update segment offset
249+
tensor.segment_offset = offset
250+
offset += tensor.constant_buffer_size
251+
252+
return segment_data, mps_graph
253+
254+
167255
def tensor_to_str(mps_tensor: MPSTensor):
168256
tensor_str = "MPSTensor("
169257
tensor_str += "datatype=" + str(mps_tensor.datatype) + ", "
170258
tensor_str += "num_dims=" + str(mps_tensor.num_dims) + ", "
171259
tensor_str += "dims=" + str(mps_tensor.dims) + ", "
172-
tensor_str += "constant_buffer_size=" + str(mps_tensor.constant_buffer_size)
260+
tensor_str += "constant_buffer=" + str(mps_tensor.constant_buffer) + ", "
261+
tensor_str += "constant_buffer_size=" + str(mps_tensor.constant_buffer_size) + ", "
262+
tensor_str += "segment_offset=" + str(mps_tensor.segment_offset)
173263
tensor_str += ")"
174264

175265
return tensor_str
@@ -193,3 +283,4 @@ def pretty_print(mps_graph: MPSGraph):
193283
logging.info(" Output ids:")
194284
for out_id in mps_graph.output_ids:
195285
logging.info(f" {out_id}")
286+
logging.info(f" Constant segment: {mps_graph.constant_segment}")

backends/apple/mps/runtime/MPSCompiler.mm

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
Error err = Error::Ok;
4444

4545
std::unique_ptr<MPSGraphBuilder> mpsGraphBuilder(
46-
new MPSGraphBuilder(buffer_pointer, executor->_mpsGraphTensorToId));
46+
new MPSGraphBuilder(buffer_pointer, num_bytes, executor->_mpsGraphTensorToId));
4747
err = mpsGraphBuilder->compileModel();
4848
ET_CHECK_OR_RETURN_ERROR(
4949
err == Error::Ok, Internal, "Failed to construct the MPS graph object");
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
//
2+
// Copyright (c) 2024 Apple Inc. All rights reserved.
3+
// Provided subject to the LICENSE file in the top level directory.
4+
//
5+
6+
#pragma once
7+
8+
#include <executorch/runtime/core/result.h>
9+
10+
namespace torch {
11+
namespace executor {
12+
namespace mps {
13+
namespace delegate {
14+
15+
/**
16+
* MPS-header that is embedded before the flatbuffer payload
17+
*
18+
*/
19+
struct MPSDelegateHeader {
20+
/**
21+
* The minimum size of the MPSDelegateHeader. The caller should provide at
22+
* least this many bytes of the head of the serialized MPS Data
23+
*/
24+
static constexpr size_t kMinSize = 30;
25+
26+
/**
27+
* The magic offset. This offset is the same as the offset for flatbuffer
28+
* header so we will be able to check if the header is is either the
29+
* flatbuffer head or the wrapper header we introduce here
30+
*/
31+
static constexpr size_t kMagicOffset = 4;
32+
33+
/**
34+
* The magic bytes that identify the header.
35+
*
36+
* This is the canonical definition of the expected value. If the header
37+
* layout ever changes in a compatibility-breaking way, increment the digits
38+
* in the magic. But, doing so will prevent older binaries from recognizing
39+
* the presence of the header. The compatibility-preserving way to make
40+
* changes is to increase the header's length field and add new fields at the
41+
* end.
42+
*/
43+
static constexpr size_t kMagicSize = 4;
44+
static constexpr char kMagic[kMagicSize] = {'M', 'P', '0', '0'};
45+
46+
/**
47+
* The size in bytes of the header length. We store 2 bytes for the header
48+
* length
49+
*/
50+
static constexpr size_t kHeaderLengthSize = 2;
51+
52+
/**
53+
* The expected location of the header length field relative to the beginning
54+
* of the header.
55+
*/
56+
static constexpr size_t kHeaderLengthOffset =
57+
MPSDelegateHeader::kMagicOffset + MPSDelegateHeader::kMagicSize;
58+
59+
/*
60+
* The expected location of the constant data offset field relative to the
61+
* beginning of the header.
62+
*/
63+
static constexpr size_t kConstantDataSegmentOffset = kHeaderLengthOffset;
64+
65+
/*
66+
* The expected location of the constant data size field relative to the
67+
* beginning of the header.
68+
*/
69+
static constexpr size_t kConstantDataSizeOffset =
70+
kConstantDataSegmentOffset + sizeof(uint64_t);
71+
72+
/**
73+
* The expected location of the flatbuffer data offset field relative to the
74+
* beginning of the header.
75+
*/
76+
static constexpr size_t kFlatbufferDataOffsetOffset =
77+
kConstantDataSizeOffset + sizeof(uint64_t);
78+
79+
/**
80+
* Look for and parse an ExtendedHeader in the provided data.
81+
*
82+
* @param[in] data The contents of the beginning of the serialized binary
83+
* Program data, starting at offset 0 (i.e., the head of the file).
84+
* @param[in] size Length of `data` in bytes.
85+
*
86+
* @returns an MPSHeader if the header was found and is valid. Returns an
87+
* error if size was too short, if the header was not found, or if the
88+
* header appeared to be corrupt.
89+
*/
90+
static Result<MPSDelegateHeader> Parse(const void* data, size_t size);
91+
92+
/**
93+
* The offset in bytes to the beginning of the constant data.
94+
*/
95+
uint64_t constant_data_offset;
96+
/**
97+
* The size in bytes of the constant data.
98+
*/
99+
uint64_t constant_data_size;
100+
/**
101+
* The offset in bytes to the beginning of the flatbuffer data.
102+
*/
103+
uint64_t flatbuffer_offset;
104+
/**
105+
* The size in bytes of the flatbuffer data.
106+
*/
107+
uint64_t flatbuffer_size;
108+
};
109+
110+
} // namespace delegate
111+
} // namespace mps
112+
} // namespace executor
113+
} // namespace torch
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
//
2+
// Copyright (c) 2024 Apple Inc. All rights reserved.
3+
// Provided subject to the LICENSE file in the top level directory.
4+
//
5+
6+
#include <executorch/backends/apple/mps/runtime/MPSDelegateHeader.h>
7+
8+
#include <cstring>
9+
10+
#include <executorch/runtime/core/error.h>
11+
#include <executorch/runtime/core/result.h>
12+
13+
namespace torch {
14+
namespace executor {
15+
namespace mps {
16+
namespace delegate {
17+
18+
/// Interprets the 8 bytes at `data` as a little-endian uint64_t.
19+
uint64_t getUInt64LE(const uint8_t* data) {
20+
return (uint64_t)data[0] | ((uint64_t)data[1] << 8) |
21+
((uint64_t)data[2] << 16) | ((uint64_t)data[3] << 24) |
22+
((uint64_t)data[4] << 32) | ((uint64_t)data[5] << 40) |
23+
((uint64_t)data[6] << 48) | ((uint64_t)data[7] << 56);
24+
}
25+
26+
Result<MPSDelegateHeader> MPSDelegateHeader::Parse(const void* data, size_t size) {
27+
const uint8_t* header_data = (const uint8_t*)data;
28+
29+
if (size < MPSDelegateHeader::kMinSize) {
30+
return Error::InvalidArgument;
31+
}
32+
33+
const uint8_t* magic_start = header_data + MPSDelegateHeader::kMagicOffset;
34+
if (std::memcmp(magic_start, MPSDelegateHeader::kMagic, MPSDelegateHeader::kMagicSize) != 0) {
35+
return Error::NotFound;
36+
}
37+
38+
uint64_t constant_data_offset = getUInt64LE(header_data + MPSDelegateHeader::kConstantDataSegmentOffset);
39+
uint64_t constant_data_size = getUInt64LE(header_data + MPSDelegateHeader::kConstantDataSizeOffset);
40+
uint64_t flatbuffer_offset = MPSDelegateHeader::kFlatbufferDataOffsetOffset;
41+
uint64_t flatbuffer_size = size - flatbuffer_offset;
42+
43+
return MPSDelegateHeader{
44+
constant_data_offset,
45+
constant_data_size,
46+
flatbuffer_offset,
47+
flatbuffer_size};
48+
}
49+
50+
} // namespace delegate
51+
} // namespace mps
52+
} // namespace executor
53+
} // namespace torch

backends/apple/mps/runtime/MPSGraphBuilder.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@ using NodePtr = const mpsgraph::MPSNode *;
4040
*/
4141
class MPSGraphBuilder {
4242
public:
43-
MPSGraphBuilder(const void *buffer_pointer, std::unordered_map<MPSGraphTensor *, int32_t> &mpsGraphTensorToId);
43+
MPSGraphBuilder(const void *buffer_pointer, size_t num_bytes,
44+
std::unordered_map<MPSGraphTensor *, int32_t> &mpsGraphTensorToId);
4445
~MPSGraphBuilder() = default;
4546

4647
Error compileModel();
@@ -178,12 +179,15 @@ class MPSGraphBuilder {
178179
const mpsgraph::MPSGraph *_flatBufferGraph;
179180
// FlatBuffer raw bytes of the serialized MPS model.
180181
const void *_buffer_pointer;
182+
size_t _num_bytes;
181183

182184
bool _metal_kernel;
183185
MPSGraph *_mpsGraph;
184186
MPSGraphExecutable *_mpsGraphExecutable;
185187
NSMutableDictionary<MPSGraphTensor *, MPSGraphShapedType *> *_feeds;
186188
NSMutableArray<MPSGraphTensor *> *_targetTensors;
189+
190+
const uint8_t *_constant_data_ptr;
187191
};
188192

189193
#undef _DEFINE_MPS_OP

0 commit comments

Comments
 (0)