Skip to content

Commit 76ed145

Browse files
authored
Add CoreML support for torchao quantize_ (#12664)
This is CoreML support for blockwise/channelwise quantization of linear/embedding layers using torchao APIs. This PR is based on top of #12648, and will be rebased once that lands to show the changes isolated to quantization.
1 parent 413dee4 commit 76ed145

File tree

2 files changed

+245
-1
lines changed

2 files changed

+245
-1
lines changed

backends/apple/coreml/compiler/torch_ops.py

Lines changed: 79 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,21 @@
88
# coremltools than is used by ExecuTorch. Each op registered here should have a link to a PR in coremltools that adds
99
# the op to the coremltools library.
1010

11-
from coremltools.converters.mil.frontend.torch.ops import transpose, unbind
11+
import torch as _torch
12+
from coremltools import _logger as logger
13+
from coremltools.converters.mil.frontend import _utils
14+
from coremltools.converters.mil.frontend.torch.ops import (
15+
_get_inputs,
16+
NUM_TO_NUMPY_DTYPE,
17+
NUM_TO_TORCH_DTYPE,
18+
transpose,
19+
unbind,
20+
)
21+
1222
from coremltools.converters.mil.frontend.torch.torch_op_registry import (
1323
register_torch_op,
1424
)
25+
from coremltools.converters.mil.mil import types
1526

1627

1728
# https://github.com/apple/coremltools/pull/2556
@@ -24,3 +35,70 @@ def transpose_copy(context, node):
2435
@register_torch_op(override=False)
2536
def unbind_copy(context, node):
2637
unbind(context, node)
38+
39+
40+
# https://github.com/apple/coremltools/pull/2558
41+
@register_torch_op(
42+
torch_alias=["torchao::dequantize_affine", "torchao.dequantize_affine"],
43+
override=False,
44+
)
45+
def dequantize_affine(context, node):
46+
inputs = _get_inputs(context, node, expected=[7, 8])
47+
int_data = inputs[0].val
48+
block_size = inputs[1].val
49+
scale = inputs[2].val
50+
zero_point = (
51+
inputs[3].val if inputs[3] is not None and inputs[3].val is not None else None
52+
)
53+
# I do not think we need to worry about input_dtype b/c it gets cast to int4/int8
54+
# For now, we just check that it is int8 or int32
55+
input_dtype = inputs[4].val # noqa: F841
56+
assert NUM_TO_TORCH_DTYPE[input_dtype] in [
57+
_torch.int8,
58+
_torch.int32,
59+
], "input_dtype should be int8 or int32"
60+
61+
quant_min = inputs[5].val
62+
quant_max = inputs[6].val
63+
64+
assert len(int_data.shape) == 2, "dequantize_affine only supports rank 2 inputs"
65+
66+
assert len(int_data.shape) == len(
67+
block_size
68+
), "block_size must have the same length as int_data.shape"
69+
assert block_size[0] == 1, "block_size[0] must be 1"
70+
group_size = block_size[1]
71+
k = int_data.shape[1]
72+
assert k % group_size == 0, "k must be divisible by group_size"
73+
scales_per_row = k // group_size
74+
scale = scale.reshape(-1, scales_per_row)
75+
if zero_point is not None:
76+
zero_point = zero_point.reshape(-1, scales_per_row)
77+
78+
# TODO: I don't know if CoreML can make use of this
79+
# We could add a cast op to the output, but I'm pretty CoreML will remove this during a later pass
80+
# For now, we just log a warning
81+
out_np_dtype = None
82+
if len(inputs) > 7:
83+
out_np_dtype = NUM_TO_NUMPY_DTYPE[inputs[7].val]
84+
logger.warning(
85+
f"Core ML ignores output_dtype {out_np_dtype} on torchao.dequantize_affine and instead uses the native precision."
86+
)
87+
88+
if quant_min == -8 and quant_max == 7:
89+
quantized_np_dtype = types.nptype_from_builtin(types.string_to_builtin("int4"))
90+
elif quant_min == -128 and quant_max == 127:
91+
quantized_np_dtype = types.nptype_from_builtin(types.string_to_builtin("int8"))
92+
else:
93+
raise ValueError(
94+
f"Unsupported quantization range: {quant_min} to {quant_max}. CoreML only supports 4-bit and 8-bit quantization."
95+
)
96+
97+
output = _utils._construct_constexpr_dequant_op(
98+
int_data.astype(quantized_np_dtype),
99+
zero_point,
100+
scale,
101+
axis=-1,
102+
name=node.name,
103+
)
104+
context.add(output, node.name)
Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
# Copyright © 2023 Apple Inc. All rights reserved.
2+
#
3+
# Please refer to the license found in the LICENSE file in the root directory of the source tree.
4+
5+
import platform
6+
import sys
7+
import unittest
8+
9+
import coremltools as ct
10+
11+
import executorch.exir
12+
13+
import torch
14+
15+
from executorch.backends.apple.coreml.compiler import CoreMLBackend
16+
from executorch.backends.apple.coreml.partition import CoreMLPartitioner
17+
from executorch.runtime import Runtime
18+
from torchao.quantization import IntxWeightOnlyConfig, PerAxis, PerGroup, quantize_
19+
20+
_TEST_RUNTIME = sys.platform == "darwin" and tuple(
21+
map(int, platform.mac_ver()[0].split("."))
22+
) >= (15, 0)
23+
24+
25+
class TestTorchOps(unittest.TestCase):
26+
edge_compile_config = executorch.exir.EdgeCompileConfig()
27+
28+
def _coreml_partitioner(self):
29+
compile_specs = CoreMLBackend.generate_compile_specs(
30+
minimum_deployment_target=ct.target.iOS18
31+
)
32+
return CoreMLPartitioner(compile_specs=compile_specs)
33+
34+
def _get_test_model(self):
35+
model = torch.nn.Sequential(
36+
torch.nn.Embedding(64, 128), torch.nn.Linear(128, 128), torch.nn.ReLU()
37+
)
38+
example_inputs = (torch.LongTensor([0]),)
39+
return model, example_inputs
40+
41+
def _compare_outputs(self, executorch_program, eager_program, example_inputs):
42+
if not _TEST_RUNTIME:
43+
return
44+
runtime = Runtime.get()
45+
program = runtime.load_program(executorch_program.buffer)
46+
method = program.load_method("forward")
47+
et_outputs = method.execute(example_inputs)[0]
48+
eager_outputs = eager_program(*example_inputs)
49+
self.assertTrue(
50+
torch.allclose(et_outputs, eager_outputs, atol=1e-02, rtol=1e-02)
51+
)
52+
53+
def test_dequantize_affine_b4w_embedding(self):
54+
model, example_inputs = self._get_test_model()
55+
quantize_(
56+
model,
57+
IntxWeightOnlyConfig(weight_dtype=torch.int4, granularity=PerGroup(32)),
58+
lambda m, fqn: isinstance(m, torch.nn.Embedding),
59+
)
60+
ep = torch.export.export(model, example_inputs)
61+
delegated_program = executorch.exir.to_edge_transform_and_lower(
62+
ep,
63+
partitioner=[self._coreml_partitioner()],
64+
)
65+
for node in delegated_program.exported_program().graph.nodes:
66+
if node.op == "call_function":
67+
assert node.target.__name__ in [
68+
"executorch_call_delegate",
69+
"getitem",
70+
], f"Got unexpected node target after delegation: {node.target.__name__}"
71+
et_prog = delegated_program.to_executorch()
72+
self._compare_outputs(et_prog, model, example_inputs)
73+
74+
def test_dequantize_affine_b4w_linear(self):
75+
model, example_inputs = self._get_test_model()
76+
quantize_(
77+
model,
78+
IntxWeightOnlyConfig(weight_dtype=torch.int4, granularity=PerGroup(32)),
79+
)
80+
ep = torch.export.export(model, example_inputs)
81+
delegated_program = executorch.exir.to_edge_transform_and_lower(
82+
ep,
83+
partitioner=[self._coreml_partitioner()],
84+
)
85+
for node in delegated_program.exported_program().graph.nodes:
86+
if node.op == "call_function":
87+
assert node.target.__name__ in [
88+
"executorch_call_delegate",
89+
"getitem",
90+
], f"Got unexpected node target after delegation: {node.target.__name__}"
91+
et_prog = delegated_program.to_executorch()
92+
self._compare_outputs(et_prog, model, example_inputs)
93+
94+
def test_dequantize_affine_c4w_embedding(self):
95+
model, example_inputs = self._get_test_model()
96+
quantize_(
97+
model,
98+
IntxWeightOnlyConfig(weight_dtype=torch.int4, granularity=PerAxis(0)),
99+
lambda m, fqn: isinstance(m, torch.nn.Embedding),
100+
)
101+
ep = torch.export.export(model, example_inputs)
102+
delegated_program = executorch.exir.to_edge_transform_and_lower(
103+
ep,
104+
partitioner=[self._coreml_partitioner()],
105+
)
106+
for node in delegated_program.exported_program().graph.nodes:
107+
if node.op == "call_function":
108+
assert node.target.__name__ in [
109+
"executorch_call_delegate",
110+
"getitem",
111+
], f"Got unexpected node target after delegation: {node.target.__name__}"
112+
et_prog = delegated_program.to_executorch()
113+
self._compare_outputs(et_prog, model, example_inputs)
114+
115+
def test_dequantize_affine_c4w_linear(self):
116+
model, example_inputs = self._get_test_model()
117+
quantize_(
118+
model, IntxWeightOnlyConfig(weight_dtype=torch.int4, granularity=PerAxis(0))
119+
)
120+
ep = torch.export.export(model, example_inputs)
121+
delegated_program = executorch.exir.to_edge_transform_and_lower(
122+
ep,
123+
partitioner=[self._coreml_partitioner()],
124+
)
125+
for node in delegated_program.exported_program().graph.nodes:
126+
if node.op == "call_function":
127+
assert node.target.__name__ in [
128+
"executorch_call_delegate",
129+
"getitem",
130+
], f"Got unexpected node target after delegation: {node.target.__name__}"
131+
et_prog = delegated_program.to_executorch()
132+
self._compare_outputs(et_prog, model, example_inputs)
133+
134+
def test_dequantize_affine_c8w_embedding_b4w_linear(self):
135+
model, example_inputs = self._get_test_model()
136+
quantize_(
137+
model,
138+
IntxWeightOnlyConfig(weight_dtype=torch.int8, granularity=PerAxis(0)),
139+
lambda m, fqn: isinstance(m, torch.nn.Embedding),
140+
)
141+
quantize_(
142+
model,
143+
IntxWeightOnlyConfig(weight_dtype=torch.int4, granularity=PerGroup(32)),
144+
)
145+
ep = torch.export.export(model, example_inputs)
146+
delegated_program = executorch.exir.to_edge_transform_and_lower(
147+
ep,
148+
partitioner=[self._coreml_partitioner()],
149+
)
150+
for node in delegated_program.exported_program().graph.nodes:
151+
if node.op == "call_function":
152+
assert node.target.__name__ in [
153+
"executorch_call_delegate",
154+
"getitem",
155+
], f"Got unexpected node target after delegation: {node.target.__name__}"
156+
et_prog = delegated_program.to_executorch()
157+
self._compare_outputs(et_prog, model, example_inputs)
158+
159+
160+
if __name__ == "__main__":
161+
test_runner = TestTorchOps()
162+
test_runner.test_dequantize_affine_b4w_embedding()
163+
test_runner.test_dequantize_affine_b4w_linear()
164+
test_runner.test_dequantize_affine_c4w_embedding()
165+
test_runner.test_dequantize_affine_c4w_linear()
166+
test_runner.test_dequantize_affine_c8w_embedding_b4w_linear()

0 commit comments

Comments
 (0)