Skip to content

Commit 145e3be

Browse files
metascroyyifan_shen3
andauthored
Fix dequantize_affine before iOS18 (#2589)
* Fix dequantize_affine before iOS18 * Update test_torch_quantization_ops.py * fix ci --------- Co-authored-by: yifan_shen3 <[email protected]>
1 parent 428d4b2 commit 145e3be

File tree

2 files changed

+45
-1
lines changed

2 files changed

+45
-1
lines changed

coremltools/converters/mil/frontend/torch/quantization_ops.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -803,7 +803,6 @@ def dequantize_affine(context, node):
803803
int_data.astype(quantized_np_dtype),
804804
zero_point,
805805
scale,
806-
axis=-1,
807806
name=node.name,
808807
)
809808
context.add(output, node.name)

coremltools/converters/mil/frontend/torch/test/test_torch_quantization_ops.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,51 @@ def forward(self, x):
272272
prog = res[1]._mil_program
273273
assert get_op_types_in_program(prog) == ["constexpr_blockwise_shift_scale", "linear"]
274274

275+
@pytest.mark.skipif(not _HAS_TORCHAO, reason=MSG_TORCHAO_NOT_FOUND)
276+
@pytest.mark.parametrize(
277+
"compute_unit, has_zeros, minimum_deployment_target",
278+
itertools.product(compute_units, [True, False], [ct.target.iOS16, ct.target.iOS17]),
279+
)
280+
def test_dequantize_affine_before_ios18(self, compute_unit, has_zeros, minimum_deployment_target):
281+
quant_min = -128
282+
quant_max = 127
283+
284+
n = 4
285+
k = 128
286+
input_dtype = torch.int8
287+
int_data = torch.randint(low=quant_min, high=quant_max, size=(n, k)).to(input_dtype)
288+
scale = torch.rand(n, 1)
289+
290+
zero_point = None
291+
if has_zeros:
292+
zero_point = torch.randint(low=quant_min, high=quant_max, size=(n, 1)).to(input_dtype)
293+
294+
class Model(torch.nn.Module):
295+
def __init__(self):
296+
super().__init__()
297+
self.register_buffer("int_data", int_data)
298+
self.register_buffer("scale", scale)
299+
self.register_buffer("zero_point", zero_point)
300+
301+
def forward(self, x):
302+
w = torchao_quant.dequantize_affine(self.int_data, [1, k], self.scale, self.zero_point, input_dtype, quant_min, quant_max)
303+
return torch.nn.functional.linear(x, w)
304+
305+
306+
model = Model()
307+
model = model.to(torch.device("cpu"))
308+
309+
input_shape = [(3, k)]
310+
res = self.run_compare_torch(
311+
input_shape,
312+
model,
313+
minimum_deployment_target=minimum_deployment_target,
314+
compute_unit=compute_unit,
315+
rtol=0.1,
316+
frontend=TorchFrontend.TORCHEXPORT,
317+
)
318+
prog = res[1]._mil_program
319+
assert get_op_types_in_program(prog) == ["constexpr_affine_dequantize", "linear"]
275320

276321

277322
# TODO(rdar://108463675): refactor torch op tests later to parametrize quantized vs standard ops

0 commit comments

Comments
 (0)