22# Copyright (c) 2023 Apple Inc. All rights reserved.
33# Provided subject to the LICENSE file in the top level directory.
44#
5-
65import logging
7- from typing import Dict , final , List
6+ from typing import ClassVar , Dict , final , List , Tuple
87
98import torch
109
1615)
1716
1817from executorch .backends .apple .mps .serialization .mps_graph_schema import (
18+ Buffer ,
19+ DataSegment ,
1920 MPSGraph ,
2021 MPSTensor ,
2122 OpType ,
2526 convert_to_flatbuffer ,
2627)
2728from executorch .backends .apple .mps .utils .mps_utils import is_parameter
29+ from executorch .exir ._serialize ._program import Cord
2830
2931from executorch .exir .backend .backend_details import (
3032 BackendDetails ,
3941
4042@final
4143class MPSBackend (BackendDetails ):
44+ @staticmethod
45+ def slice_len_max (s ):
46+ assert s .start is not None
47+ assert s .stop is not None
48+ step = 1
49+ if s .step is not None :
50+ step = s .step
51+ return max ((s .stop - s .start ) // step , 1 )
52+
53+ MAGIC_IX : ClassVar [slice ] = slice (4 , 8 )
54+ DATA_SEGMENT_OFFSET_IX : ClassVar [slice ] = slice (8 , 16 )
55+ DATA_SEGMENT_SIZE_IX : ClassVar [slice ] = slice (16 , 24 )
56+
57+ # magic bytes that should be at the beginning of the header
58+ EXPECTED_MAGIC : ClassVar [bytes ] = b"MP00"
59+ # The length of the header in bytes
60+ EXPECTED_LENGTH : ClassVar [int ] = (
61+ 4
62+ + slice_len_max (MAGIC_IX )
63+ + slice_len_max (DATA_SEGMENT_OFFSET_IX )
64+ + slice_len_max (DATA_SEGMENT_SIZE_IX )
65+ )
66+
4267 @staticmethod
4368 def preprocess (
4469 edge_program : ExportedProgram ,
@@ -67,6 +92,7 @@ def preprocess(
6792 output_ids = [],
6893 constant_ids = [],
6994 graph_type = OpType .mps_graph ,
95+ constant_segment = DataSegment (0 , 0 ),
7096 )
7197
7298 convert_model_to_fp16 = True
@@ -100,10 +126,44 @@ def preprocess(
100126 else :
101127 op_handler [node .op ](edge_program , node_visitors , node , mps_graph )
102128
129+ segment_data , mps_graph = _extract_constant_segment (mps_graph )
130+
131+ # Add to aggregate segments cord with padding.
132+ padding_length = _padding_required (len (segment_data ), 16 )
133+ if padding_length > 0 :
134+ segment_data .append (b"\x00 " * padding_length )
135+
136+ # Combine mps_graph with segment data
137+ combined = Cord ()
138+ graph_bytes = convert_to_flatbuffer (mps_graph )
139+
140+ data_segment_offset : int = MPSBackend .EXPECTED_LENGTH
141+ data_segment_offset = data_segment_offset + len (graph_bytes )
142+
143+ graph_padding_length = _padding_required (data_segment_offset , 16 )
144+ data_segment_offset = data_segment_offset + graph_padding_length
145+ data_segment_size = len (segment_data )
146+
147+ data : bytes = (
148+ b"\x00 \x00 \x00 \x00 "
149+ + MPSBackend .EXPECTED_MAGIC
150+ + data_segment_offset .to_bytes (8 , byteorder = "little" )
151+ + data_segment_size .to_bytes (8 , byteorder = "little" )
152+ )
153+ assert len (data ) == MPSBackend .EXPECTED_LENGTH
154+
155+ combined .append (data )
156+ combined .append (graph_bytes )
157+
158+ if graph_padding_length > 0 :
159+ combined .append (b"\x00 " * graph_padding_length )
160+ # Append the segment data to the end of the mps graph
161+ combined .append (segment_data )
162+
103163 if logging .DEBUG >= logging .root .level :
104164 pretty_print (mps_graph )
105165
106- return PreprocessResult (processed_bytes = convert_to_flatbuffer ( mps_graph ))
166+ return PreprocessResult (processed_bytes = bytes ( combined ))
107167
108168 @staticmethod
109169 def handle_call_function (
@@ -164,12 +224,42 @@ def handle_get_attr(
164224 pass
165225
166226
227+ def _padding_required (offset : int , alignment : int ) -> int :
228+ """Returns the padding required to align `offset` to `alignment`."""
229+ remainder : int = offset % alignment
230+ if remainder != 0 :
231+ return alignment - remainder
232+ return 0
233+
234+
235+ def _extract_constant_segment (mps_graph : MPSGraph ) -> Tuple [Cord , MPSGraph ]:
236+ """Extracts the constant segment from the MPSGraph and returns the updated MPSGraph along with the segment data."""
237+ # Note that the beginning of the segment data is not aligned. Need to handle out of this call.
238+ segment_data = Cord ()
239+ offset = 0
240+ for i in range (len (mps_graph .mps_values )):
241+ tensor = mps_graph .mps_values [i ]
242+ if tensor .constant_buffer_size > 0 :
243+ # Notice that buffer is already force aligned so we don't need to pad it
244+ segment_data .append (tensor .constant_buffer .storage )
245+
246+ # Reset buffer to empty
247+ tensor .constant_buffer = Buffer (storage = b"" )
248+ # Update segment offset
249+ tensor .segment_offset = offset
250+ offset += tensor .constant_buffer_size
251+
252+ return segment_data , mps_graph
253+
254+
167255def tensor_to_str (mps_tensor : MPSTensor ):
168256 tensor_str = "MPSTensor("
169257 tensor_str += "datatype=" + str (mps_tensor .datatype ) + ", "
170258 tensor_str += "num_dims=" + str (mps_tensor .num_dims ) + ", "
171259 tensor_str += "dims=" + str (mps_tensor .dims ) + ", "
172- tensor_str += "constant_buffer_size=" + str (mps_tensor .constant_buffer_size )
260+ tensor_str += "constant_buffer=" + str (mps_tensor .constant_buffer ) + ", "
261+ tensor_str += "constant_buffer_size=" + str (mps_tensor .constant_buffer_size ) + ", "
262+ tensor_str += "segment_offset=" + str (mps_tensor .segment_offset )
173263 tensor_str += ")"
174264
175265 return tensor_str
@@ -193,3 +283,4 @@ def pretty_print(mps_graph: MPSGraph):
193283 logging .info (" Output ids:" )
194284 for out_id in mps_graph .output_ids :
195285 logging .info (f" { out_id } " )
286+ logging .info (f" Constant segment: { mps_graph .constant_segment } " )
0 commit comments