|  | 
|  | 1 | +from abc import ABC, abstractmethod | 
|  | 2 | +from dataclasses import dataclass | 
|  | 3 | +from typing import Dict, List, Sequence | 
|  | 4 | + | 
|  | 5 | +from executorch.exir._serialize._cord import Cord | 
|  | 6 | + | 
|  | 7 | +from executorch.exir.schema import ScalarType | 
|  | 8 | + | 
|  | 9 | + | 
|  | 10 | +@dataclass | 
|  | 11 | +class TensorLayout: | 
|  | 12 | +    """ | 
|  | 13 | +    Tensor layout information for externally-serialized tensors. | 
|  | 14 | +    """ | 
|  | 15 | + | 
|  | 16 | +    scalar_type: ScalarType | 
|  | 17 | +    sizes: List[int] | 
|  | 18 | +    dim_order: List[bytes] | 
|  | 19 | + | 
|  | 20 | + | 
|  | 21 | +@dataclass | 
|  | 22 | +class SerializationInfo: | 
|  | 23 | +    # A sequence of tensor data buffers. | 
|  | 24 | +    tensor_buffers: Sequence[bytes] | 
|  | 25 | + | 
|  | 26 | +    # A map from tensor name (fqn) to tensor index inside `tensor_buffers`. | 
|  | 27 | +    # Note: multiple tensor names may map to the same index as `tensor_buffers` | 
|  | 28 | +    # is likely deduplicated. | 
|  | 29 | +    fqn_to_buffer_index: Dict[str, int] | 
|  | 30 | + | 
|  | 31 | +    # A map from tensor name (fqn) to TensorLayout. | 
|  | 32 | +    fqn_to_tensor_layout: Dict[str, TensorLayout] | 
|  | 33 | + | 
|  | 34 | + | 
|  | 35 | +class DataSerializer(ABC): | 
|  | 36 | +    """Serializes and deserializes FQN-tagged tensor data. | 
|  | 37 | +
 | 
|  | 38 | +    This base class enables serialization into different formats. See | 
|  | 39 | +    executorch/extension/flat_tensor/ for an example. | 
|  | 40 | +    """ | 
|  | 41 | + | 
|  | 42 | +    @abstractmethod | 
|  | 43 | +    def __init__(self) -> None: | 
|  | 44 | +        """ | 
|  | 45 | +        This initializer may be overridden in derived classes to hold | 
|  | 46 | +        the data required for serialization, eg. configurations. | 
|  | 47 | +        """ | 
|  | 48 | +        pass | 
|  | 49 | + | 
|  | 50 | +    @abstractmethod | 
|  | 51 | +    def serialize_tensors( | 
|  | 52 | +        self, | 
|  | 53 | +        serialization_info: SerializationInfo, | 
|  | 54 | +    ) -> Cord: | 
|  | 55 | +        """ | 
|  | 56 | +        Serializes a list of tensors emitted by ExecuTorch into a binary blob. | 
|  | 57 | +
 | 
|  | 58 | +        Args: | 
|  | 59 | +            serialization_info: the tensor buffers and tensor layout | 
|  | 60 | +            information required for serialization. | 
|  | 61 | +
 | 
|  | 62 | +        Returns: | 
|  | 63 | +            A binary blob that contains the serialized data. | 
|  | 64 | +        """ | 
|  | 65 | +        raise NotImplementedError("serialize_data") | 
|  | 66 | + | 
|  | 67 | +    @abstractmethod | 
|  | 68 | +    def deserialize_tensors(self, blob: Cord) -> SerializationInfo: | 
|  | 69 | +        """ | 
|  | 70 | +        Deserializes a blob into a list of tensors. Reverses the effect of | 
|  | 71 | +        serialize_tensors. | 
|  | 72 | +
 | 
|  | 73 | +        Args: | 
|  | 74 | +            blob: A binary blob that contains the serialized data. | 
|  | 75 | +
 | 
|  | 76 | +        Returns: | 
|  | 77 | +            SerializationInfo: tensor buffers and tensor layout information | 
|  | 78 | +            deserialized from `blob`. | 
|  | 79 | +        """ | 
|  | 80 | +        raise NotImplementedError("deserialize_data") | 
0 commit comments