Skip to content

Commit d5dff72

Browse files
Arm backend: Add docstrings to tosa/mapping.py (#14374)
Signed-off-by: Sebastian Larsson <[email protected]>
1 parent d43cde5 commit d5dff72

File tree

1 file changed

+82
-6
lines changed

1 file changed

+82
-6
lines changed

backends/arm/tosa/mapping.py

Lines changed: 82 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@
44
# LICENSE file in the root directory of this source tree.
55

66
# pyre-unsafe
7+
"""Provide PyTorch-to-TOSA mapping helpers.
78
8-
#
9-
# PyTorch to Tosa mapping - simple mapping functions and multi-type extraction
10-
# of key information. These are used by the initial compile stage which captures
11-
# the standardised TOSA representation.
12-
#
9+
Use these utilities to translate PyTorch dtypes and FX node metadata into
10+
the TOSA serializer types and shapes used during initial compilation.
11+
12+
"""
1313

1414
from typing import Any, Optional, Sequence
1515

@@ -32,6 +32,19 @@
3232

3333

3434
def map_dtype(data_type: torch.dtype, tosa_spec: TosaSpecification) -> Any:
35+
"""Map a ``torch.dtype`` to a ``ts.DType``.
36+
37+
Args:
38+
data_type (torch.dtype): PyTorch dtype to convert.
39+
tosa_spec (TosaSpecification): Active spec (reserved for future checks).
40+
41+
Returns:
42+
Any: Matching ``ts.DType`` enum value.
43+
44+
Raises:
45+
ValueError: If the dtype is unsupported or unknown.
46+
47+
"""
3548
if data_type in UNSUPPORTED_DTYPES:
3649
raise ValueError(f"Unsupported type: {data_type}")
3750

@@ -57,6 +70,20 @@ def map_dtype(data_type: torch.dtype, tosa_spec: TosaSpecification) -> Any:
5770
# TODO: other types, can be
5871
# SymInt, FakeTensor, a List[Union[FakeTensor, SymInt]], or None
5972
def extract_tensor_meta(meta, tosa_spec: TosaSpecification):
73+
"""Extract dtype, shape, and dimension order from FX metadata.
74+
75+
Args:
76+
meta (dict): FX node ``meta`` containing a ``val`` FakeTensor (or tuple).
77+
tosa_spec (TosaSpecification): Active TOSA spec for dtype mapping.
78+
79+
Returns:
80+
tuple: ``(dtype, shape, dim_order)`` where ``dtype`` is ``ts.DType``,
81+
``shape`` is ``Tuple[int, ...]``, and ``dim_order`` is ``Tuple[int, ...]``.
82+
83+
Raises:
84+
ValueError: If ``meta['val']`` is not a ``FakeTensor``.
85+
86+
"""
6087
assert meta.get("val") is not None
6188
val = meta["val"]
6289
if type(val) is tuple:
@@ -77,23 +104,66 @@ def extract_tensor_meta(meta, tosa_spec: TosaSpecification):
77104
return (dtype, shape, dim_order)
78105

79106

80-
# Class to capture arguments and turn into tensor references for TOSA OPs
81107
class TosaArg:
108+
"""Capture and normalize TOSA operator arguments.
109+
110+
Use this to convert FX nodes, sequences, and numeric literals into a
111+
consistent structure suitable for TOSA serialization.
112+
113+
Attributes:
114+
name (str): Node name when argument is a ``torch.fx.Node``; empty otherwise.
115+
dtype (ts.DType | None): Inferred dtype when available.
116+
shape (tuple[int, ...] | None): Inferred shape when available.
117+
dim_order (tuple[int, ...] | None): Dimension order, defaulting to ``range(len(shape))``.
118+
special (list | None): Captured list when the argument is a sequence.
119+
number (float | int | None): Captured numeric value when given.
120+
tosa_spec (TosaSpecification): Active specification used for mapping.
121+
122+
"""
123+
82124
def __process_node(self, argument: torch.fx.Node):
125+
"""Parse a ``torch.fx.Node`` and populate tensor attributes.
126+
127+
Args:
128+
argument (torch.fx.Node): FX node to inspect.
129+
130+
"""
83131
self.name: str = argument.name
84132
self.dtype, self.shape, self.dim_order = extract_tensor_meta(
85133
argument.meta, self.tosa_spec
86134
)
87135

88136
def __process_list(self, argument):
137+
"""Capture a sequence argument as ``special``.
138+
139+
Args:
140+
argument (Sequence): Sequence to store.
141+
142+
"""
89143
self.special: list = list(argument)
90144

91145
def __process_number(self, argument: float | int):
146+
"""Capture a numeric argument as ``number``.
147+
148+
Args:
149+
argument (float | int): Numeric value.
150+
151+
"""
92152
self.number: float | int = argument
93153

94154
def __init__(
95155
self, argument: Any, tosa_spec: Optional[TosaSpecification] = None
96156
) -> None:
157+
"""Initialize the argument wrapper and populate fields.
158+
159+
Args:
160+
argument (Any): One of ``torch.fx.Node``, ``Sequence``, ``int``, ``float``, ``torch.dtype``, or ``None``.
161+
tosa_spec (Optional[TosaSpecification]): Active specification; required.
162+
163+
Raises:
164+
RuntimeError: If ``argument`` is of an unsupported type.
165+
166+
"""
97167
if tosa_spec is None:
98168
raise ValueError("tosa_spec is None")
99169
elif not isinstance(tosa_spec, TosaSpecification):
@@ -127,6 +197,12 @@ def __init__(
127197
)
128198

129199
def __repr__(self):
200+
"""Return a compact representation of populated attributes.
201+
202+
Returns:
203+
str: Readable list of set attributes.
204+
205+
"""
130206
attrs = []
131207
if hasattr(self, "name"):
132208
if self.name is not None:

0 commit comments

Comments
 (0)