|  | 
|  | 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. | 
|  | 2 | +# All rights reserved. | 
|  | 3 | +# | 
|  | 4 | +# This source code is licensed under the BSD-style license found in the | 
|  | 5 | +# LICENSE file in the root directory of this source tree. | 
|  | 6 | + | 
|  | 7 | +# pyre-strict | 
|  | 8 | + | 
|  | 9 | +import json | 
|  | 10 | +import os | 
|  | 11 | +import tempfile | 
|  | 12 | +from dataclasses import dataclass | 
|  | 13 | +from typing import ClassVar, Dict, List, Literal, Optional | 
|  | 14 | + | 
|  | 15 | +import pkg_resources | 
|  | 16 | +from executorch.exir._serialize._cord import Cord | 
|  | 17 | +from executorch.exir._serialize._dataclass import _DataclassEncoder | 
|  | 18 | + | 
|  | 19 | +from executorch.exir._serialize._flatbuffer import _flatc_compile | 
|  | 20 | +from executorch.exir._serialize.data_serializer import DataPayload, DataSerializer | 
|  | 21 | + | 
|  | 22 | +from executorch.exir._serialize.padding import aligned_size, pad_to, padding_required | 
|  | 23 | + | 
|  | 24 | +# Byte order of numbers written to flat tensor headers. Always little-endian | 
|  | 25 | +# regardless of the host system, since all commonly-used modern CPUs are little | 
|  | 26 | +# endian. | 
|  | 27 | +_HEADER_BYTEORDER: Literal["little"] = "little" | 
|  | 28 | + | 
|  | 29 | +from executorch.extension.flat_tensor.serialize.flat_tensor_schema import ( | 
|  | 30 | +    DataSegment, | 
|  | 31 | +    FlatTensor, | 
|  | 32 | +    TensorMetadata, | 
|  | 33 | +) | 
|  | 34 | + | 
|  | 35 | + | 
|  | 36 | +def _convert_to_flatbuffer(flat_tensor: FlatTensor) -> Cord: | 
|  | 37 | +    """Converts a FlatTensor to a flatbuffer and returns the serialized data.""" | 
|  | 38 | +    flat_tensor_json = json.dumps(flat_tensor, cls=_DataclassEncoder) | 
|  | 39 | +    with tempfile.TemporaryDirectory() as d: | 
|  | 40 | +        schema_path = os.path.join(d, "flat_tensor.fbs") | 
|  | 41 | +        with open(schema_path, "wb") as schema_file: | 
|  | 42 | +            schema_file.write( | 
|  | 43 | +                pkg_resources.resource_string(__name__, "flat_tensor.fbs") | 
|  | 44 | +            ) | 
|  | 45 | +        scalar_type_path = os.path.join(d, "scalar_type.fbs") | 
|  | 46 | +        with open(scalar_type_path, "wb") as scalar_type_file: | 
|  | 47 | +            scalar_type_file.write( | 
|  | 48 | +                pkg_resources.resource_string(__name__, "scalar_type.fbs") | 
|  | 49 | +            ) | 
|  | 50 | +        json_path = os.path.join(d, "flat_tensor.json") | 
|  | 51 | +        with open(json_path, "wb") as json_file: | 
|  | 52 | +            json_file.write(flat_tensor_json.encode("ascii")) | 
|  | 53 | + | 
|  | 54 | +        _flatc_compile(d, schema_path, json_path) | 
|  | 55 | +        output_path = os.path.join(d, "flat_tensor.ptd") | 
|  | 56 | +        with open(output_path, "rb") as output_file: | 
|  | 57 | +            return Cord(output_file.read()) | 
|  | 58 | + | 
|  | 59 | + | 
|  | 60 | +@dataclass | 
|  | 61 | +class FlatTensorConfig: | 
|  | 62 | +    tensor_alignment: int = 16 | 
|  | 63 | +    segment_alignment: int = 16 | 
|  | 64 | + | 
|  | 65 | + | 
|  | 66 | +@dataclass | 
|  | 67 | +class FlatTensorHeader: | 
|  | 68 | +    # Class constants. | 
|  | 69 | +    # The magic bytes that should be at the beginning of the header. | 
|  | 70 | +    EXPECTED_MAGIC: ClassVar[bytes] = b"FH01" | 
|  | 71 | +    EXPECTED_LENGTH: ClassVar[int] = ( | 
|  | 72 | +        # Header magic | 
|  | 73 | +        4 | 
|  | 74 | +        # Header length | 
|  | 75 | +        + 4 | 
|  | 76 | +        # Flatbuffer offset | 
|  | 77 | +        + 8 | 
|  | 78 | +        # Flatbuffer data size | 
|  | 79 | +        + 8 | 
|  | 80 | +        # Segment base offset | 
|  | 81 | +        + 8 | 
|  | 82 | +        # Data size | 
|  | 83 | +        + 8 | 
|  | 84 | +    ) | 
|  | 85 | + | 
|  | 86 | +    # Instance attributes. @dataclass will turn these into ctor args. | 
|  | 87 | + | 
|  | 88 | +    # Offset to the start of the flatbuffer data, in bytes. | 
|  | 89 | +    flatbuffer_offset: int | 
|  | 90 | +    # The size of the serialized data in bytes. | 
|  | 91 | +    flatbuffer_size: int | 
|  | 92 | +    # Offset to the start of the first segment, or zero if there | 
|  | 93 | +    # are no segments. | 
|  | 94 | +    segment_base_offset: int | 
|  | 95 | +    # Size of all the segment data, in bytes. | 
|  | 96 | +    segment_data_size: int | 
|  | 97 | + | 
|  | 98 | +    # The magic bytes read from or to be written to the binary header. | 
|  | 99 | +    magic: bytes = EXPECTED_MAGIC | 
|  | 100 | +    # The header length, in bytes, read from or to be written to the binary | 
|  | 101 | +    # header. | 
|  | 102 | +    length: int = EXPECTED_LENGTH | 
|  | 103 | + | 
|  | 104 | +    @staticmethod | 
|  | 105 | +    def from_bytes(data: bytes) -> "FlatTensorHeader": | 
|  | 106 | +        """Tries to read an flat_tensor header from the provided data. | 
|  | 107 | +
 | 
|  | 108 | +        Does not validate that the header is well-formed. Callers should | 
|  | 109 | +        use is_valid(). | 
|  | 110 | +
 | 
|  | 111 | +        Args: | 
|  | 112 | +            data: The data to read from. | 
|  | 113 | +        Returns: | 
|  | 114 | +            The contents of the flat_tensor header. | 
|  | 115 | +        Raises: | 
|  | 116 | +            ValueError: If not enough data is provided. | 
|  | 117 | +        """ | 
|  | 118 | +        if len(data) < FlatTensorHeader.EXPECTED_LENGTH: | 
|  | 119 | +            raise ValueError( | 
|  | 120 | +                f"Not enough data for flat_tensor header: {len(data)} " | 
|  | 121 | +                + f"< {FlatTensorHeader.EXPECTED_LENGTH}" | 
|  | 122 | +            ) | 
|  | 123 | + | 
|  | 124 | +        return FlatTensorHeader( | 
|  | 125 | +            magic=data[0:4], | 
|  | 126 | +            length=int.from_bytes(data[4:8], byteorder=_HEADER_BYTEORDER), | 
|  | 127 | +            flatbuffer_offset=int.from_bytes(data[8:16], byteorder=_HEADER_BYTEORDER), | 
|  | 128 | +            flatbuffer_size=int.from_bytes(data[16:24], byteorder=_HEADER_BYTEORDER), | 
|  | 129 | +            segment_base_offset=int.from_bytes( | 
|  | 130 | +                data[24:32], byteorder=_HEADER_BYTEORDER | 
|  | 131 | +            ), | 
|  | 132 | +            segment_data_size=int.from_bytes(data[32:40], byteorder=_HEADER_BYTEORDER), | 
|  | 133 | +        ) | 
|  | 134 | + | 
|  | 135 | +    def is_valid(self) -> bool: | 
|  | 136 | +        """Returns true if the flat_tensor header appears to be well-formed.""" | 
|  | 137 | +        return ( | 
|  | 138 | +            self.magic == FlatTensorHeader.EXPECTED_MAGIC | 
|  | 139 | +            and self.length >= FlatTensorHeader.EXPECTED_LENGTH | 
|  | 140 | +        ) | 
|  | 141 | + | 
|  | 142 | +    def to_bytes(self) -> bytes: | 
|  | 143 | +        """Returns the binary representation of the flat_tensor header. | 
|  | 144 | +
 | 
|  | 145 | +        Note that this will ignore self.magic and self.length and will always | 
|  | 146 | +        write the proper magic/length. | 
|  | 147 | +        """ | 
|  | 148 | +        data: bytes = ( | 
|  | 149 | +            # Extended header magic. This lets consumers detect whether the | 
|  | 150 | +            # header was inserted or not. Always use the proper magic value | 
|  | 151 | +            # (i.e., ignore self.magic) since there's no reason to create an | 
|  | 152 | +            # invalid header. | 
|  | 153 | +            self.EXPECTED_MAGIC | 
|  | 154 | +            # uint32_t: Size of this header. This makes it easier to add new | 
|  | 155 | +            # fields to this header in the future. Always use the proper size | 
|  | 156 | +            # (i.e., ignore self.length) since there's no reason to create an | 
|  | 157 | +            # invalid header. | 
|  | 158 | +            + self.EXPECTED_LENGTH.to_bytes(4, byteorder=_HEADER_BYTEORDER) | 
|  | 159 | +            # uint64_t: Offset to the start of the flatbuffer data, in bytes. | 
|  | 160 | +            + self.flatbuffer_offset.to_bytes(8, byteorder=_HEADER_BYTEORDER) | 
|  | 161 | +            # uint64_t: Size of the serialized data in bytes. | 
|  | 162 | +            + self.flatbuffer_size.to_bytes(8, byteorder=_HEADER_BYTEORDER) | 
|  | 163 | +            # uint64_t: Offset to the start of the first segment, or zero if | 
|  | 164 | +            # there are no segments. | 
|  | 165 | +            + self.segment_base_offset.to_bytes(8, byteorder=_HEADER_BYTEORDER) | 
|  | 166 | +            # uint64_t: Size of all the segment data, in bytes. | 
|  | 167 | +            + self.segment_data_size.to_bytes(8, byteorder=_HEADER_BYTEORDER) | 
|  | 168 | +        ) | 
|  | 169 | +        return data | 
|  | 170 | + | 
|  | 171 | + | 
|  | 172 | +class FlatTensorSerializer(DataSerializer): | 
|  | 173 | +    """A concrete implementation of the DataSerializer interface that | 
|  | 174 | +    serializes and deserializes data to/from the FlatTensor format. | 
|  | 175 | +    """ | 
|  | 176 | + | 
|  | 177 | +    def __init__(self, config: Optional[FlatTensorConfig] = None) -> None: | 
|  | 178 | +        """FlatTensorConfig holds information required for serialization, | 
|  | 179 | +        eg. alignment. | 
|  | 180 | +        """ | 
|  | 181 | +        if config is None: | 
|  | 182 | +            self.config: FlatTensorConfig = FlatTensorConfig() | 
|  | 183 | +        else: | 
|  | 184 | +            self.config: FlatTensorConfig = config | 
|  | 185 | + | 
|  | 186 | +    def serialize( | 
|  | 187 | +        self, | 
|  | 188 | +        data: DataPayload, | 
|  | 189 | +    ) -> Cord: | 
|  | 190 | +        """Serializes a list of tensor metadata and tensors into a blob.""" | 
|  | 191 | + | 
|  | 192 | +        flat_tensor_metadata: List[TensorMetadata] = [] | 
|  | 193 | +        flat_tensor_data: Cord = Cord() | 
|  | 194 | + | 
|  | 195 | +        # {idx, offset} | 
|  | 196 | +        saved_offsets: Dict[int, int] = {} | 
|  | 197 | + | 
|  | 198 | +        for fqn, tensor_entry in data.fqn_to_tensor.items(): | 
|  | 199 | +            assert tensor_entry.layout is not None | 
|  | 200 | +            # Check index into the tensor buffers is valid. | 
|  | 201 | +            assert tensor_entry.buffer_index < len( | 
|  | 202 | +                data.buffers | 
|  | 203 | +            ), f"Invalid index {tensor_entry.buffer_index} is greater than tensor buffer size {len(data.buffers)}." | 
|  | 204 | + | 
|  | 205 | +            # Check if the tensor has already been appended to the flat_tensor_data. | 
|  | 206 | +            offset = saved_offsets.get(tensor_entry.buffer_index, -1) | 
|  | 207 | +            if offset == -1: | 
|  | 208 | +                if len(flat_tensor_data) > 0: | 
|  | 209 | +                    # Add padding to round off the previous tensor offset. | 
|  | 210 | +                    pad_length = padding_required( | 
|  | 211 | +                        len(flat_tensor_data), self.config.tensor_alignment | 
|  | 212 | +                    ) | 
|  | 213 | +                    flat_tensor_data.append(b"\x00" * pad_length) | 
|  | 214 | +                # Add to saved offsets. | 
|  | 215 | +                offset = len(flat_tensor_data) | 
|  | 216 | +                saved_offsets[tensor_entry.buffer_index] = offset | 
|  | 217 | +                # Append to flat_tensor_data at the offset. | 
|  | 218 | +                flat_tensor_data.append(data.buffers[tensor_entry.buffer_index]) | 
|  | 219 | + | 
|  | 220 | +            flat_tensor_metadata.append( | 
|  | 221 | +                TensorMetadata( | 
|  | 222 | +                    fully_qualified_name=fqn, | 
|  | 223 | +                    scalar_type=tensor_entry.layout.scalar_type, | 
|  | 224 | +                    sizes=tensor_entry.layout.sizes, | 
|  | 225 | +                    dim_order=tensor_entry.layout.dim_order, | 
|  | 226 | +                    segment_index=0, | 
|  | 227 | +                    offset=offset, | 
|  | 228 | +                ) | 
|  | 229 | +            ) | 
|  | 230 | + | 
|  | 231 | +        # Pad flat_tensor_data to segment alignment. | 
|  | 232 | +        segment_pad_length = padding_required( | 
|  | 233 | +            len(flat_tensor_data), self.config.segment_alignment | 
|  | 234 | +        ) | 
|  | 235 | +        if segment_pad_length > 0: | 
|  | 236 | +            flat_tensor_data.append(b"\x00" * segment_pad_length) | 
|  | 237 | + | 
|  | 238 | +        # Create FlatTensor, which describes of the contents of the file and | 
|  | 239 | +        # points to all the data segments. It will be serialized to flatbuffer. | 
|  | 240 | +        flat_tensor = FlatTensor( | 
|  | 241 | +            version=0, | 
|  | 242 | +            tensor_alignment=self.config.tensor_alignment, | 
|  | 243 | +            tensors=flat_tensor_metadata, | 
|  | 244 | +            segments=[DataSegment(offset=0, size=len(flat_tensor_data))], | 
|  | 245 | +        ) | 
|  | 246 | + | 
|  | 247 | +        flatbuffer_payload = _convert_to_flatbuffer(flat_tensor) | 
|  | 248 | +        padded_flatbuffer_length: int = aligned_size( | 
|  | 249 | +            input_size=len(flatbuffer_payload), | 
|  | 250 | +            alignment=self.config.tensor_alignment, | 
|  | 251 | +        ) | 
|  | 252 | + | 
|  | 253 | +        padded_header_length: int = aligned_size( | 
|  | 254 | +            input_size=FlatTensorHeader.EXPECTED_LENGTH, | 
|  | 255 | +            alignment=self.config.tensor_alignment, | 
|  | 256 | +        ) | 
|  | 257 | + | 
|  | 258 | +        segment_base_offset = aligned_size( | 
|  | 259 | +            padded_flatbuffer_length + padded_header_length, | 
|  | 260 | +            self.config.segment_alignment, | 
|  | 261 | +        ) | 
|  | 262 | + | 
|  | 263 | +        # Create FlatTensorHeader, which stores the offsets and sizes of the | 
|  | 264 | +        # FlatTensor flatbuffer and the segment data. | 
|  | 265 | +        header_data: bytes = FlatTensorHeader( | 
|  | 266 | +            flatbuffer_offset=padded_header_length, | 
|  | 267 | +            flatbuffer_size=len(flatbuffer_payload), | 
|  | 268 | +            segment_base_offset=segment_base_offset, | 
|  | 269 | +            segment_data_size=len(flat_tensor_data), | 
|  | 270 | +        ).to_bytes() | 
|  | 271 | + | 
|  | 272 | +        # Pad header and payload to segment alignment. | 
|  | 273 | +        header_data = pad_to(header_data, padded_header_length) | 
|  | 274 | +        flatbuffer_payload.append( | 
|  | 275 | +            b"\x00" * (padded_flatbuffer_length - len(flatbuffer_payload)) | 
|  | 276 | +        ) | 
|  | 277 | + | 
|  | 278 | +        # Place everything into one segment. | 
|  | 279 | +        payload = Cord() | 
|  | 280 | +        payload.append(header_data) | 
|  | 281 | +        payload.append(flatbuffer_payload) | 
|  | 282 | +        payload.append(flat_tensor_data) | 
|  | 283 | + | 
|  | 284 | +        return payload | 
|  | 285 | + | 
|  | 286 | +    def deserialize(self, blob: Cord) -> DataPayload: | 
|  | 287 | +        """ | 
|  | 288 | +        Deserializes a flat_tensor blob into a list of tensor metadata and tensors. | 
|  | 289 | +        """ | 
|  | 290 | +        raise NotImplementedError("deserialize_data") | 
0 commit comments