Skip to content

Commit 2cd27d4

Browse files
committed
Add a TOSA specification class
Signed-off-by: Per Åstrand <[email protected]> Change-Id: I45374812daaa831f8084f6534b4903d14d6634b8
1 parent 39e5b91 commit 2cd27d4

File tree

2 files changed

+331
-0
lines changed

2 files changed

+331
-0
lines changed
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
from executorch.backends.arm.tosa_specification import (
10+
Tosa_0_80,
11+
Tosa_1_00,
12+
TosaSpecification,
13+
)
14+
15+
from executorch.exir.backend.compile_spec_schema import CompileSpec
16+
from parameterized import parameterized
17+
18+
test_valid_0_80_strings = [
19+
"TOSA-0.80.0+BI",
20+
"TOSA-0.80.0+MI+8k",
21+
"TOSA-0.80.0+BI+u55",
22+
]
23+
test_valid_1_00_strings = [
24+
"TOSA-1.00.0+INT+FP+fft",
25+
"TOSA-1.00.0+FP+bf16+fft",
26+
"TOSA-1.00.0+INT+int4+cf",
27+
"TOSA-1.00.0+FP+cf+bf16+8k",
28+
"TOSA-1.00.0+FP+INT+bf16+fft+int4+cf",
29+
"TOSA-1.00.0+FP+INT+fft+int4+cf+8k",
30+
]
31+
32+
test_valid_1_00_extensions = {
33+
"INT": ["int16", "int4", "var", "cf"],
34+
"FP": ["bf16", "fp8e4m3", "fp8e5m2", "fft", "var", "cf"],
35+
}
36+
37+
test_invalid_strings = [
38+
"TOSA-0.80.0+bi",
39+
"TOSA-0.80.0",
40+
"TOSA-0.80.0+8k",
41+
"TOSA-0.80.0+BI+MI",
42+
"TOSA-0.80.0+BI+U55",
43+
"TOSA-1.00.0+fft",
44+
"TOSA-1.00.0+fp+bf16+fft",
45+
"TOSA-1.00.0+INT+INT4+cf",
46+
"TOSA-1.00.0+BI",
47+
"TOSA-1.00.0+FP+FP+INT",
48+
"TOSA-1.00.0+FP+CF+bf16",
49+
"TOSA-1.00.0+BF16+fft+int4+cf+INT",
50+
]
51+
52+
test_compile_specs = [
53+
([CompileSpec("tosa_version", "TOSA-0.80.0+BI".encode())],),
54+
([CompileSpec("tosa_version", "TOSA-0.80.0+BI+u55".encode())],),
55+
([CompileSpec("tosa_version", "TOSA-1.00.0+INT".encode())],),
56+
]
57+
58+
test_compile_specs_no_version = [
59+
([CompileSpec("other_key", "TOSA-0.80.0+BI".encode())],),
60+
([CompileSpec("other_key", "some_value".encode())],),
61+
]
62+
63+
64+
class TestTosaSpecification(unittest.TestCase):
65+
"""Tests the TOSA specification class"""
66+
67+
@parameterized.expand(test_valid_0_80_strings)
68+
def test_version_string_0_80(self, version_string: str):
69+
tosa_spec = TosaSpecification.create_from_string(version_string)
70+
assert isinstance(tosa_spec, Tosa_0_80)
71+
assert tosa_spec.profile in ["BI", "MI"]
72+
73+
@parameterized.expand(test_valid_1_00_strings)
74+
def test_version_string_1_00(self, version_string: str):
75+
tosa_spec = TosaSpecification.create_from_string(version_string)
76+
assert isinstance(tosa_spec, Tosa_1_00)
77+
assert [profile in ["INT", "FP"] for profile in tosa_spec.profiles].count(
78+
True
79+
) > 0
80+
81+
for profile in tosa_spec.profiles:
82+
assert [
83+
e in test_valid_1_00_extensions[profile] for e in tosa_spec.extensions
84+
]
85+
86+
@parameterized.expand(test_invalid_strings)
87+
def test_invalid_version_strings(self, version_string: str):
88+
tosa_spec = None
89+
with self.assertRaises(ValueError):
90+
tosa_spec = TosaSpecification.create_from_string(version_string)
91+
92+
assert tosa_spec is None
93+
94+
@parameterized.expand(test_compile_specs)
95+
def test_create_from_compilespec(self, compile_specs: list[CompileSpec]):
96+
tosa_spec = TosaSpecification.create_from_compilespecs(compile_specs)
97+
assert isinstance(tosa_spec, TosaSpecification)
98+
99+
@parameterized.expand(test_compile_specs_no_version)
100+
def test_create_from_invalid_compilespec(self, compile_specs: list[CompileSpec]):
101+
tosa_spec = None
102+
with self.assertRaises(ValueError):
103+
tosa_spec = TosaSpecification.create_from_compilespecs(compile_specs)
104+
105+
assert tosa_spec is None

backends/arm/tosa_specification.py

Lines changed: 226 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
8+
#
9+
# Main implementation of AoT flow to partition and preprocess for Arm target
10+
# backends. Converts via TOSA as an intermediate form supported by AoT and
11+
# JIT compiler flows.
12+
#
13+
14+
import re
15+
from typing import List
16+
17+
from executorch.exir.backend.compile_spec_schema import CompileSpec
18+
from packaging.version import Version
19+
20+
21+
class TosaSpecification:
22+
"""
23+
This class implements a representation of TOSA specification
24+
(https://www.mlplatform.org/tosa/tosa_spec.html) with a version, a profile
25+
(with extension) and a level (8k).
26+
For 0.80 releases the profile is BI or MI, with u55 handled as an inofficial extension
27+
For 1.00 releases the profile is INT or FP, and the extensions are for
28+
INT: int16, int4, var, cf
29+
FP: bf16, fp8e4m3, fp8e5m2, fft, var, cf
30+
31+
The TOSA specification is encoded in the string represenatation
32+
TOSA-major.minor.patch+profile[+level][+extensions]
33+
34+
For 0.80 MI implies BI, while for 1.0 the profiles has to explicitely be specified.
35+
36+
Profiles are uppercase letters and extensions and level is lowercase.
37+
"""
38+
39+
version: Version
40+
41+
def support_integer(self) -> bool:
42+
"""
43+
Returns true if any integer operations are supported for the specification.
44+
"""
45+
raise NotImplementedError
46+
47+
def support_float(self) -> bool:
48+
"""
49+
Returns true if any float operations are supported for the specification.
50+
"""
51+
raise NotImplementedError
52+
53+
def __init__(self, version: Version):
54+
self.version = version
55+
56+
@staticmethod
57+
def create_from_compilespecs(
58+
compile_specs: List[CompileSpec],
59+
) -> "TosaSpecification":
60+
"""
61+
Search the CompileSpec list for 'tosa_version' and instantiate a
62+
class from the found value or return None on failure.
63+
"""
64+
for spec in compile_specs:
65+
if spec.key == "tosa_version":
66+
return TosaSpecification.create_from_string(spec.value.decode())
67+
raise ValueError(
68+
"No TOSA version key found in any of the supplied CompileSpecs"
69+
)
70+
71+
@staticmethod
72+
def create_from_string(repr: str) -> "TosaSpecification":
73+
"""
74+
Creates a TOSA specification class from a string representation:
75+
TOSA-0.80.0+MI
76+
TOSA-0.80.0+BI+8k
77+
TOSA-0.80.0+BI+u55 # Ethos-U55 extension to handle TOSA subset
78+
TOSA-0.90.0+MI
79+
TOSA-1.00.0+INT+FP+int4+cf
80+
"""
81+
82+
pattern = r"^(TOSA)-([\d.]+)\+(.+)$"
83+
match = re.match(pattern, repr)
84+
if match:
85+
name = match.group(1)
86+
version = Version(match.group(2))
87+
extras = match.group(3).split("+")
88+
if name != "TOSA":
89+
raise ValueError(f"Malformed TOSA specification representation: {repr}")
90+
match version:
91+
case _ if version.major == 0 and version.minor == 80:
92+
return Tosa_0_80(version, extras)
93+
case _ if version.major == 1 and version.minor == 0:
94+
return Tosa_1_00(version, extras)
95+
case _:
96+
raise ValueError(f"Wrong TOSA version: {version} from {repr}")
97+
98+
raise ValueError(f"Failed to parse TOSA specification representation: {repr}")
99+
100+
101+
class Tosa_0_80(TosaSpecification):
102+
profile: str
103+
level_8k: bool
104+
is_U55_subset: bool
105+
available_profiles = ["BI", "MI"] # MT is not defined
106+
107+
def __init__(self, version: Version, extras: List[str]):
108+
super().__init__(version)
109+
assert version >= Version("0.80") and version < Version("0.90")
110+
111+
# Check that we only have one profile in the extensions list
112+
if [e in Tosa_0_80.available_profiles for e in extras].count(True) != 1:
113+
raise ValueError(
114+
f"Bad combination of extras: {extras}, more than one of {Tosa_0_80.available_profiles} found."
115+
)
116+
117+
# The list contains one profile at most, so pick it
118+
self.profile = [e for e in extras if e in Tosa_0_80.available_profiles][0]
119+
extras.remove(self.profile)
120+
121+
self.level_8k = "8k" in extras
122+
if self.level_8k:
123+
extras.remove("8k")
124+
self.is_U55_subset = "u55" in extras
125+
if self.is_U55_subset:
126+
extras.remove("u55")
127+
128+
if len(extras) > 0:
129+
raise ValueError(f"Unhandled extras found: {extras}")
130+
131+
def __repr__(self):
132+
extensions = ""
133+
if self.level_8k:
134+
extensions += "+8K"
135+
if self.is_U55_subset:
136+
extensions += "+u55"
137+
return f"TOSA-{str(self.version)}+{self.profile}{extensions}"
138+
139+
def __hash__(self) -> int:
140+
return hash(str(self.version) + self.profile)
141+
142+
def __eq__(self, other: object) -> bool:
143+
if isinstance(other, Tosa_0_80):
144+
return (self.version == other.version) and (self.profile == other.profile)
145+
return False
146+
147+
def support_integer(self):
148+
return True
149+
150+
def support_float(self):
151+
return self.profile == "MI"
152+
153+
154+
class Tosa_1_00(TosaSpecification):
155+
profiles: List[str]
156+
level_8k: bool
157+
extensions: List[str]
158+
159+
available_profiles = ["INT", "FP"]
160+
valid_extensions = {
161+
"INT": ["int16", "int4", "var", "cf"],
162+
"FP": ["bf16", "fp8e4m3", "fp8e5m2", "fft", "var", "cf"],
163+
}
164+
165+
def __init__(self, version: Version, extras: List[str]):
166+
super().__init__(version)
167+
168+
# Check that we have at least one profile in the extensions list
169+
if [e in Tosa_1_00.available_profiles for e in extras].count(True) == 0:
170+
raise ValueError(
171+
f"No profile ({Tosa_1_00.available_profiles}) found in: {extras}."
172+
)
173+
174+
# and not more than number of available profiles
175+
if [e in Tosa_1_00.available_profiles for e in extras].count(True) > len(
176+
Tosa_1_00.available_profiles
177+
):
178+
raise ValueError(
179+
f"Too many profiles ({Tosa_1_00.available_profiles}) found in: {extras}."
180+
)
181+
182+
# The list contains one profile at least, so pick them
183+
self.profiles = [e for e in extras if e in Tosa_1_00.available_profiles]
184+
for p in self.profiles:
185+
extras.remove(p)
186+
187+
self.level_8k = "8k" in extras
188+
if self.level_8k:
189+
extras.remove("8k")
190+
191+
combined_extensions = []
192+
for p in self.profiles:
193+
combined_extensions += Tosa_1_00.valid_extensions[p]
194+
195+
if not all(e in combined_extensions for e in extras):
196+
raise ValueError(
197+
f"Bad extensions for TOSA-{version}{self._get_profiles_string()}: {extras}"
198+
)
199+
200+
# all the rest of the extras are handled extensions
201+
self.extensions = extras
202+
203+
def _get_profiles_string(self) -> str:
204+
return "".join(["+" + p for p in self.profiles])
205+
206+
def _get_extensions_string(self) -> str:
207+
return "".join(["+" + e for e in self.extensions])
208+
209+
def __repr__(self):
210+
return f"TOSA-{self.version}{self._get_profiles_string()}{self._get_profiles_string()}"
211+
212+
def __hash__(self) -> int:
213+
return hash(str(self.version) + self._get_profiles_string())
214+
215+
def __eq__(self, other: object) -> bool:
216+
if isinstance(other, Tosa_1_00):
217+
return (self.version == other.version) and (
218+
self._get_profiles_string() == other._get_profiles_string()
219+
)
220+
return False
221+
222+
def support_integer(self):
223+
return "INT" in self.profiles
224+
225+
def support_float(self):
226+
return "FP" in self.profiles

0 commit comments

Comments
 (0)