4
4
# LICENSE file in the root directory of this source tree.
5
5
6
6
# pyre-unsafe
7
+ """Provide PyTorch-to-TOSA mapping helpers.
7
8
8
- #
9
- # PyTorch to Tosa mapping - simple mapping functions and multi-type extraction
10
- # of key information. These are used by the initial compile stage which captures
11
- # the standardised TOSA representation.
12
- #
9
+ Use these utilities to translate PyTorch dtypes and FX node metadata into
10
+ the TOSA serializer types and shapes used during initial compilation.
11
+
12
+ """
13
13
14
14
from typing import Any , Optional , Sequence
15
15
32
32
33
33
34
34
def map_dtype (data_type : torch .dtype , tosa_spec : TosaSpecification ) -> Any :
35
+ """Map a ``torch.dtype`` to a ``ts.DType``.
36
+
37
+ Args:
38
+ data_type (torch.dtype): PyTorch dtype to convert.
39
+ tosa_spec (TosaSpecification): Active spec (reserved for future checks).
40
+
41
+ Returns:
42
+ Any: Matching ``ts.DType`` enum value.
43
+
44
+ Raises:
45
+ ValueError: If the dtype is unsupported or unknown.
46
+
47
+ """
35
48
if data_type in UNSUPPORTED_DTYPES :
36
49
raise ValueError (f"Unsupported type: { data_type } " )
37
50
@@ -57,6 +70,20 @@ def map_dtype(data_type: torch.dtype, tosa_spec: TosaSpecification) -> Any:
57
70
# TODO: other types, can be
58
71
# SymInt, FakeTensor, a List[Union[FakeTensor, SymInt]], or None
59
72
def extract_tensor_meta (meta , tosa_spec : TosaSpecification ):
73
+ """Extract dtype, shape, and dimension order from FX metadata.
74
+
75
+ Args:
76
+ meta (dict): FX node ``meta`` containing a ``val`` FakeTensor (or tuple).
77
+ tosa_spec (TosaSpecification): Active TOSA spec for dtype mapping.
78
+
79
+ Returns:
80
+ tuple: ``(dtype, shape, dim_order)`` where ``dtype`` is ``ts.DType``,
81
+ ``shape`` is ``Tuple[int, ...]``, and ``dim_order`` is ``Tuple[int, ...]``.
82
+
83
+ Raises:
84
+ ValueError: If ``meta['val']`` is not a ``FakeTensor``.
85
+
86
+ """
60
87
assert meta .get ("val" ) is not None
61
88
val = meta ["val" ]
62
89
if type (val ) is tuple :
@@ -77,23 +104,66 @@ def extract_tensor_meta(meta, tosa_spec: TosaSpecification):
77
104
return (dtype , shape , dim_order )
78
105
79
106
80
- # Class to capture arguments and turn into tensor references for TOSA OPs
81
107
class TosaArg :
108
+ """Capture and normalize TOSA operator arguments.
109
+
110
+ Use this to convert FX nodes, sequences, and numeric literals into a
111
+ consistent structure suitable for TOSA serialization.
112
+
113
+ Attributes:
114
+ name (str): Node name when argument is a ``torch.fx.Node``; empty otherwise.
115
+ dtype (ts.DType | None): Inferred dtype when available.
116
+ shape (tuple[int, ...] | None): Inferred shape when available.
117
+ dim_order (tuple[int, ...] | None): Dimension order, defaulting to ``range(len(shape))``.
118
+ special (list | None): Captured list when the argument is a sequence.
119
+ number (float | int | None): Captured numeric value when given.
120
+ tosa_spec (TosaSpecification): Active specification used for mapping.
121
+
122
+ """
123
+
82
124
def __process_node (self , argument : torch .fx .Node ):
125
+ """Parse a ``torch.fx.Node`` and populate tensor attributes.
126
+
127
+ Args:
128
+ argument (torch.fx.Node): FX node to inspect.
129
+
130
+ """
83
131
self .name : str = argument .name
84
132
self .dtype , self .shape , self .dim_order = extract_tensor_meta (
85
133
argument .meta , self .tosa_spec
86
134
)
87
135
88
136
def __process_list (self , argument ):
137
+ """Capture a sequence argument as ``special``.
138
+
139
+ Args:
140
+ argument (Sequence): Sequence to store.
141
+
142
+ """
89
143
self .special : list = list (argument )
90
144
91
145
def __process_number (self , argument : float | int ):
146
+ """Capture a numeric argument as ``number``.
147
+
148
+ Args:
149
+ argument (float | int): Numeric value.
150
+
151
+ """
92
152
self .number : float | int = argument
93
153
94
154
def __init__ (
95
155
self , argument : Any , tosa_spec : Optional [TosaSpecification ] = None
96
156
) -> None :
157
+ """Initialize the argument wrapper and populate fields.
158
+
159
+ Args:
160
+ argument (Any): One of ``torch.fx.Node``, ``Sequence``, ``int``, ``float``, ``torch.dtype``, or ``None``.
161
+ tosa_spec (Optional[TosaSpecification]): Active specification; required.
162
+
163
+ Raises:
164
+ RuntimeError: If ``argument`` is of an unsupported type.
165
+
166
+ """
97
167
if tosa_spec is None :
98
168
raise ValueError ("tosa_spec is None" )
99
169
elif not isinstance (tosa_spec , TosaSpecification ):
@@ -127,6 +197,12 @@ def __init__(
127
197
)
128
198
129
199
def __repr__ (self ):
200
+ """Return a compact representation of populated attributes.
201
+
202
+ Returns:
203
+ str: Readable list of set attributes.
204
+
205
+ """
130
206
attrs = []
131
207
if hasattr (self , "name" ):
132
208
if self .name is not None :
0 commit comments