Skip to content

Commit b91987e

Browse files
Arm backend: Move support_extension to base class (#15909)
Move support_extension to TosaSpecification base class to avoid having to check whether the TosaSpecification is an instance of TosaSpecification_1_00. cc @freddan80 @per @zingo @digantdesai Signed-off-by: Oscar Andersson <[email protected]>
1 parent 3305927 commit b91987e

File tree

2 files changed

+14
-4
lines changed

2 files changed

+14
-4
lines changed

backends/arm/_passes/decompose_int16_activation_conv2d_pass.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from executorch.backends.arm._passes.arm_pass import ArmPass
1111
from executorch.backends.arm._passes.quant_args import QuantArgs
1212

13-
from executorch.backends.arm.tosa.specification import get_context_spec, Tosa_1_00
13+
from executorch.backends.arm.tosa.specification import get_context_spec
1414
from executorch.exir.dialects._ops import ops as exir_ops
1515
from executorch.exir.pass_base import ExportPass
1616

@@ -40,9 +40,7 @@ def call_operator(self, op, args, kwargs, meta):
4040
if args[0].data.dtype == torch.int8:
4141
return super().call_operator(op, args, kwargs, meta)
4242
elif args[0].data.dtype == torch.int16:
43-
if isinstance(tosa_spec, Tosa_1_00) and not tosa_spec.support_extension(
44-
"int16"
45-
):
43+
if not tosa_spec.support_extension("int16"):
4644
raise ValueError(
4745
"int16 activation for convolution requires TOSA int16 extension"
4846
)

backends/arm/tosa/specification.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,18 @@ def support_float(self) -> bool:
105105
"""Return True if floating-point operations are supported."""
106106
raise NotImplementedError
107107

108+
def support_extension(self, extension: str) -> bool:
109+
"""Return True if an extension is supported and enabled.
110+
111+
Args:
112+
extension (str): Extension name (for example ``int4``, ``bf16``).
113+
114+
Returns:
115+
bool: True if the extension is valid for the active profiles and selected.
116+
117+
"""
118+
raise NotImplementedError
119+
108120
def __init__(self, version: Version, extras: List[str]):
109121
"""Initialize the base specification.
110122

0 commit comments

Comments
 (0)