Skip to content

Commit ab39f2d

Browse files
committed
Fix CoreML torchao-quant for iOS16
1 parent 0d0769a commit ab39f2d

File tree

2 files changed

+32
-5
lines changed

2 files changed

+32
-5
lines changed

backends/apple/coreml/compiler/torch_ops.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,6 @@ def dequantize_affine(context, node):
152152
int_data.astype(quantized_np_dtype),
153153
zero_point,
154154
scale,
155-
axis=-1,
156155
name=node.name,
157156
)
158157
context.add(output, node.name)

backends/apple/coreml/test/test_torch_ops.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,15 @@
2727
class TestTorchOps(unittest.TestCase):
2828
edge_compile_config = executorch.exir.EdgeCompileConfig()
2929

30-
def _coreml_partitioner(self):
30+
def _coreml_partitioner(self, *, minimum_deployment_target=ct.target.iOS18):
3131
compile_specs = CoreMLBackend.generate_compile_specs(
32-
minimum_deployment_target=ct.target.iOS18
32+
minimum_deployment_target=minimum_deployment_target
3333
)
3434
return CoreMLPartitioner(compile_specs=compile_specs)
3535

3636
def _get_test_model(self):
3737
model = torch.nn.Sequential(
38-
torch.nn.Embedding(64, 128), torch.nn.Linear(128, 128), torch.nn.ReLU()
38+
torch.nn.Embedding(64, 128), torch.nn.Linear(128, 256), torch.nn.ReLU()
3939
)
4040
example_inputs = (torch.LongTensor([0]),)
4141
return model, example_inputs
@@ -117,7 +117,7 @@ def test_dequantize_affine_c4w_embedding(self):
117117
def test_dequantize_affine_c4w_linear(self):
118118
model, example_inputs = self._get_test_model()
119119
quantize_(
120-
model, IntxWeightOnlyConfig(weight_dtype=torch.int4, granularity=PerAxis(0))
120+
model, IntxWeightOnlyConfig(weight_dtype=torch.int8, granularity=PerAxis(0))
121121
)
122122
ep = torch.export.export(model, example_inputs)
123123
delegated_program = executorch.exir.to_edge_transform_and_lower(
@@ -158,6 +158,33 @@ def test_dequantize_affine_c8w_embedding_b4w_linear(self):
158158
et_prog = delegated_program.to_executorch()
159159
self._compare_outputs(et_prog, model, example_inputs)
160160

161+
def test_dequantize_affine_c8w_embedding_c8w_linear_ios16(self):
162+
model, example_inputs = self._get_test_model()
163+
quantize_(
164+
model,
165+
IntxWeightOnlyConfig(weight_dtype=torch.int8, granularity=PerAxis(0)),
166+
lambda m, fqn: isinstance(m, torch.nn.Embedding),
167+
)
168+
quantize_(
169+
model,
170+
IntxWeightOnlyConfig(weight_dtype=torch.int8, granularity=PerAxis(0)),
171+
)
172+
ep = torch.export.export(model, example_inputs)
173+
delegated_program = executorch.exir.to_edge_transform_and_lower(
174+
ep,
175+
partitioner=[
176+
self._coreml_partitioner(minimum_deployment_target=ct.target.iOS16)
177+
],
178+
)
179+
for node in delegated_program.exported_program().graph.nodes:
180+
if node.op == "call_function":
181+
assert node.target.__name__ in [
182+
"executorch_call_delegate",
183+
"getitem",
184+
], f"Got unexpected node target after delegation: {node.target.__name__}"
185+
et_prog = delegated_program.to_executorch()
186+
self._compare_outputs(et_prog, model, example_inputs)
187+
161188
def test_dequantize_codebook_linear(self):
162189
model, example_inputs = self._get_test_model()
163190
quantize_(
@@ -221,5 +248,6 @@ def test_dequantize_codebook_embedding(self):
221248
test_runner.test_dequantize_affine_c4w_embedding()
222249
test_runner.test_dequantize_affine_c4w_linear()
223250
test_runner.test_dequantize_affine_c8w_embedding_b4w_linear()
251+
test_runner.test_dequantize_affine_c8w_embedding_c8w_linear_ios16()
224252
test_runner.test_dequantize_codebook_linear()
225253
test_runner.test_dequantize_codebook_embedding()

0 commit comments

Comments
 (0)