44# LICENSE file in the root directory of this source tree.
55
66# pyre-unsafe
7+ """Provide TOSA specification parsing and context utilities.
78
8- #
9- # Main implementation of AoT flow to partition and preprocess for Arm target
10- # backends. Converts via TOSA as an intermediate form supported by AoT and
11- # JIT compiler flows.
12- #
9+ Use these helpers to parse and validate TOSA profile/extension strings and to
10+ manage a lowering-time context for the active specification.
11+
12+ """
1313
1414import contextvars
1515import re
1919
2020
2121class TosaSpecification :
22- """
23- This class implements a representation of TOSA specification
24- (https://www.mlplatform.org/tosa/tosa_spec.html) with a version, a profile
25- (with extension) and a level (8k).
26- For 1.00 releases the profile is INT or FP, and the extensions are for
27- INT: int16, int4, var, cf
28- FP: bf16, fp8e4m3, fp8e5m2, fft, var, cf
22+ """Represent a TOSA specification.
2923
30- The TOSA specification is encoded in the string represenatation
31- TOSA-major.minor.patch+profile[+level][+extensions]
24+ A specification includes a semantic version, one or more profiles, and
25+ optional extensions and levels (for example ``8k``).
26+ The encoded form follows ``TOSA-<major>.<minor>.<patch>+<PROFILE>[+<LEVEL>][+<EXT>...]``.
27+ Profiles use uppercase (for example ``INT``, ``FP``); levels and extensions
28+ use lowercase.
29+
30+ Attributes:
31+ version (Version): Parsed TOSA semantic version.
32+ is_U55_subset (bool): True if the ``u55`` subset is requested.
3233
33- Profiles are uppercase letters and extensions and level is lowercase.
3434 """
3535
3636 version : Version
3737 is_U55_subset : bool
3838
3939 def support_integer (self ) -> bool :
40- """
41- Returns true if any integer operations are supported for the specification.
42- """
40+ """Return True if integer operations are supported."""
4341 raise NotImplementedError
4442
4543 def support_float (self ) -> bool :
46- """
47- Returns true if any float operations are supported for the specification.
48- """
44+ """Return True if floating-point operations are supported."""
4945 raise NotImplementedError
5046
5147 def __init__ (self , version : Version , extras : List [str ]):
48+ """Initialize the base specification.
49+
50+ Args:
51+ version (Version): Parsed TOSA semantic version.
52+ extras (List[str]): Remaining tokens such as profiles, levels, and extensions.
53+
54+ """
5255 self .version = version
5356
5457 self .is_U55_subset = "u55" in extras
@@ -57,11 +60,20 @@ def __init__(self, version: Version, extras: List[str]):
5760
5861 @staticmethod
5962 def create_from_string (repr : str ) -> "TosaSpecification" :
60- """
61- Creates a TOSA specification class from a string representation:
62- TOSA-1.00.0+INT+FP+int4+cf
63- """
63+ """Create a specification from a standard string format.
64+
65+ Example: ``TOSA-1.00.0+INT+FP+int4+cf``.
6466
67+ Args:
68+ repr (str): Standard representation string.
69+
70+ Returns:
71+ TosaSpecification: Parsed specification instance.
72+
73+ Raises:
74+ ValueError: If the representation is malformed or version is unsupported.
75+
76+ """
6577 pattern = r"^(TOSA)-([\d.]+)\+(.+)$"
6678 match = re .match (pattern , repr )
6779 if match :
@@ -80,6 +92,18 @@ def create_from_string(repr: str) -> "TosaSpecification":
8092
8193
8294class Tosa_1_00 (TosaSpecification ):
95+ """Provide TOSA 1.00 profile and extension semantics.
96+
97+ This variant validates profiles (``INT``, ``FP``), the optional ``8k`` level,
98+ and allowed extensions based on the selected profiles.
99+
100+ Attributes:
101+ profiles (List[str]): Selected profiles, e.g., ``["INT"]`` or ``["INT", "FP"]``.
102+ level_8k (bool): True if the ``8k`` level is enabled.
103+ extensions (List[str]): Enabled extensions valid for the chosen profiles.
104+
105+ """
106+
83107 profiles : List [str ]
84108 level_8k : bool
85109 extensions : List [str ]
@@ -91,6 +115,16 @@ class Tosa_1_00(TosaSpecification):
91115 }
92116
93117 def __init__ (self , version : Version , extras : List [str ]):
118+ """Initialize the 1.00 specification and validate extras.
119+
120+ Args:
121+ version (Version): Semantic version (major=1, minor=0).
122+ extras (List[str]): Tokens including profiles, level, and extensions.
123+
124+ Raises:
125+ ValueError: If no/too many profiles are provided or extensions are invalid.
126+
127+ """
94128 super ().__init__ (version , extras )
95129
96130 # Check that we have at least one profile in the extensions list
@@ -129,12 +163,20 @@ def __init__(self, version: Version, extras: List[str]):
129163 self .extensions = extras
130164
131165 def _get_profiles_string (self ) -> str :
166+ """Return the ``+``-joined profile segment (e.g., ``+INT+FP``)."""
132167 return "" .join (["+" + p for p in self .profiles ])
133168
134169 def _get_extensions_string (self ) -> str :
170+ """Return the ``+``-joined extensions segment (e.g., ``+int4+cf``)."""
135171 return "" .join (["+" + e for e in self .extensions ])
136172
137173 def __repr__ (self ):
174+ """Return the standard specification string format.
175+
176+ Returns:
177+ str: Standard form like ``TOSA-1.00.0+INT+8k+int4``.
178+
179+ """
138180 extensions = self ._get_extensions_string ()
139181 if self .level_8k :
140182 extensions += "+8k"
@@ -143,22 +185,48 @@ def __repr__(self):
143185 return f"TOSA-{ self .version } { self ._get_profiles_string ()} { extensions } "
144186
145187 def __hash__ (self ) -> int :
188+ """Return a stable hash for use in sets and dict keys.
189+
190+ Returns:
191+ int: Hash value derived from version and profiles.
192+
193+ """
146194 return hash (str (self .version ) + self ._get_profiles_string ())
147195
148196 def __eq__ (self , other : object ) -> bool :
197+ """Return True if another instance represents the same spec.
198+
199+ Args:
200+ other (object): Object to compare.
201+
202+ Returns:
203+ bool: True if versions and profiles match.
204+
205+ """
149206 if isinstance (other , Tosa_1_00 ):
150207 return (self .version == other .version ) and (
151208 self ._get_profiles_string () == other ._get_profiles_string ()
152209 )
153210 return False
154211
155212 def support_integer (self ):
213+ """Return True if the ``INT`` profile is present."""
156214 return "INT" in self .profiles
157215
158216 def support_float (self ):
217+ """Return True if the ``FP`` profile is present."""
159218 return "FP" in self .profiles
160219
161220 def support_extension (self , extension : str ) -> bool :
221+ """Return True if an extension is supported and enabled.
222+
223+ Args:
224+ extension (str): Extension name (for example ``int4``, ``bf16``).
225+
226+ Returns:
227+ bool: True if the extension is valid for the active profiles and selected.
228+
229+ """
162230 for p in self .profiles :
163231 if extension in self .valid_extensions [p ] and extension in self .extensions :
164232 return True
@@ -167,30 +235,63 @@ def support_extension(self, extension: str) -> bool:
167235
168236
169237class TosaLoweringContext :
170- """
171- A context manager to handle the TOSA specific aspects of the lowering process.
172- For now it only handles the TOSA specification context, but it can be extended
173- to include other policies or configurations.
238+ """Manage the TOSA specification context for lowering.
239+
240+ For now, only the active ``TosaSpecification`` is tracked, but this can be
241+ extended to carry additional lowering policies or configuration.
242+
243+ Attributes:
244+ tosa_spec_var (contextvars.ContextVar): Context variable storing the active spec.
245+ spec (TosaSpecification): Specification passed to the context manager.
246+
174247 """
175248
176249 # Define a context variable for the spec
177250 tosa_spec_var : contextvars .ContextVar = contextvars .ContextVar ("tosa_spec" )
178251
179252 def __init__ (self , spec : TosaSpecification ):
253+ """Initialize the lowering context with a specification.
254+
255+ Args:
256+ spec (TosaSpecification): Active specification to put into context.
257+
258+ """
180259 self .spec = spec
181260
182261 def __enter__ (self ):
262+ """Set the context variable and return self.
263+
264+ Returns:
265+ TosaLoweringContext: This context manager instance.
266+
267+ """
183268 # Set the spec in the context variable and store the token for later reset
184269 self .token = TosaLoweringContext .tosa_spec_var .set (self .spec )
185270 return self
186271
187272 def __exit__ (self , exc_type , exc_value , traceback ):
273+ """Reset the context variable to its previous state.
274+
275+ Args:
276+ exc_type (type | None): Exception type, if any.
277+ exc_value (BaseException | None): Exception instance, if any.
278+ traceback (TracebackType | None): Traceback, if any.
279+
280+ """
188281 # Reset the context variable to its previous state
189282 TosaLoweringContext .tosa_spec_var .reset (self .token )
190283
191284
192- # A helper function to retrieve the current spec anywhere in your code
193285def get_context_spec () -> TosaSpecification :
286+ """Get the current ``TosaSpecification`` from the lowering context.
287+
288+ Returns:
289+ TosaSpecification: Active specification retrieved from the context var.
290+
291+ Raises:
292+ RuntimeError: If called outside a ``TosaLoweringContext``.
293+
294+ """
194295 try :
195296 return TosaLoweringContext .tosa_spec_var .get ()
196297 except LookupError :
0 commit comments