44# LICENSE file in the root directory of this source tree.
55
66# pyre-unsafe
7+ """Provide PyTorch-to-TOSA mapping helpers.
78
8- #
9- # PyTorch to Tosa mapping - simple mapping functions and multi-type extraction
10- # of key information. These are used by the initial compile stage which captures
11- # the standardised TOSA representation.
12- #
9+ Use these utilities to translate PyTorch dtypes and FX node metadata into
10+ the TOSA serializer types and shapes used during initial compilation.
11+
12+ """
1313
1414from typing import Any , Optional , Sequence
1515
3232
3333
3434def map_dtype (data_type : torch .dtype , tosa_spec : TosaSpecification ) -> Any :
35+ """Map a ``torch.dtype`` to a ``ts.DType``.
36+
37+ Args:
38+ data_type (torch.dtype): PyTorch dtype to convert.
39+ tosa_spec (TosaSpecification): Active spec (reserved for future checks).
40+
41+ Returns:
42+ Any: Matching ``ts.DType`` enum value.
43+
44+ Raises:
45+ ValueError: If the dtype is unsupported or unknown.
46+
47+ """
3548 if data_type in UNSUPPORTED_DTYPES :
3649 raise ValueError (f"Unsupported type: { data_type } " )
3750
@@ -57,6 +70,20 @@ def map_dtype(data_type: torch.dtype, tosa_spec: TosaSpecification) -> Any:
5770# TODO: other types, can be
5871# SymInt, FakeTensor, a List[Union[FakeTensor, SymInt]], or None
5972def extract_tensor_meta (meta , tosa_spec : TosaSpecification ):
73+ """Extract dtype, shape, and dimension order from FX metadata.
74+
75+ Args:
76+ meta (dict): FX node ``meta`` containing a ``val`` FakeTensor (or tuple).
77+ tosa_spec (TosaSpecification): Active TOSA spec for dtype mapping.
78+
79+ Returns:
80+ tuple: ``(dtype, shape, dim_order)`` where ``dtype`` is ``ts.DType``,
81+ ``shape`` is ``Tuple[int, ...]``, and ``dim_order`` is ``Tuple[int, ...]``.
82+
83+ Raises:
84+ ValueError: If ``meta['val']`` is not a ``FakeTensor``.
85+
86+ """
6087 assert meta .get ("val" ) is not None
6188 val = meta ["val" ]
6289 if type (val ) is tuple :
@@ -77,23 +104,66 @@ def extract_tensor_meta(meta, tosa_spec: TosaSpecification):
77104 return (dtype , shape , dim_order )
78105
79106
80- # Class to capture arguments and turn into tensor references for TOSA OPs
81107class TosaArg :
108+ """Capture and normalize TOSA operator arguments.
109+
110+ Use this to convert FX nodes, sequences, and numeric literals into a
111+ consistent structure suitable for TOSA serialization.
112+
113+ Attributes:
114+ name (str): Node name when argument is a ``torch.fx.Node``; empty otherwise.
115+ dtype (ts.DType | None): Inferred dtype when available.
116+ shape (tuple[int, ...] | None): Inferred shape when available.
117+ dim_order (tuple[int, ...] | None): Dimension order, defaulting to ``range(len(shape))``.
118+ special (list | None): Captured list when the argument is a sequence.
119+ number (float | int | None): Captured numeric value when given.
120+ tosa_spec (TosaSpecification): Active specification used for mapping.
121+
122+ """
123+
82124 def __process_node (self , argument : torch .fx .Node ):
125+ """Parse a ``torch.fx.Node`` and populate tensor attributes.
126+
127+ Args:
128+ argument (torch.fx.Node): FX node to inspect.
129+
130+ """
83131 self .name : str = argument .name
84132 self .dtype , self .shape , self .dim_order = extract_tensor_meta (
85133 argument .meta , self .tosa_spec
86134 )
87135
88136 def __process_list (self , argument ):
137+ """Capture a sequence argument as ``special``.
138+
139+ Args:
140+ argument (Sequence): Sequence to store.
141+
142+ """
89143 self .special : list = list (argument )
90144
91145 def __process_number (self , argument : float | int ):
146+ """Capture a numeric argument as ``number``.
147+
148+ Args:
149+ argument (float | int): Numeric value.
150+
151+ """
92152 self .number : float | int = argument
93153
94154 def __init__ (
95155 self , argument : Any , tosa_spec : Optional [TosaSpecification ] = None
96156 ) -> None :
157+ """Initialize the argument wrapper and populate fields.
158+
159+ Args:
160+ argument (Any): One of ``torch.fx.Node``, ``Sequence``, ``int``, ``float``, ``torch.dtype``, or ``None``.
161+ tosa_spec (Optional[TosaSpecification]): Active specification; required.
162+
163+ Raises:
164+ RuntimeError: If ``argument`` is of an unsupported type.
165+
166+ """
97167 if tosa_spec is None :
98168 raise ValueError ("tosa_spec is None" )
99169 elif not isinstance (tosa_spec , TosaSpecification ):
@@ -127,6 +197,12 @@ def __init__(
127197 )
128198
129199 def __repr__ (self ):
200+ """Return a compact representation of populated attributes.
201+
202+ Returns:
203+ str: Readable list of set attributes.
204+
205+ """
130206 attrs = []
131207 if hasattr (self , "name" ):
132208 if self .name is not None :
0 commit comments