Skip to content

Commit 9574270

Browse files
authored
Add palletization/codebook support to CoreML backend (#13051)
This adds palletization support for embedding/linear layers in CoreML using TorchAO's quantize_ API. Note, this needs to wait for pytorch/ao#2648 to land in ao + a pin bump in ET before landing.
1 parent 0d4aec1 commit 9574270

File tree

4 files changed

+104
-2
lines changed

4 files changed

+104
-2
lines changed

backends/apple/coreml/compiler/torch_ops.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
# coremltools than is used by ExecuTorch. Each op registered here should have a link to a PR in coremltools that adds
99
# the op to the coremltools library.
1010

11+
import numpy as np
1112
import torch as _torch
1213
from coremltools import _logger
1314
from coremltools.converters.mil.frontend import _utils
@@ -21,7 +22,6 @@
2122
transpose,
2223
unbind,
2324
)
24-
2525
from coremltools.converters.mil.frontend.torch.torch_op_registry import (
2626
register_torch_op,
2727
)
@@ -132,3 +132,43 @@ def dequantize_affine(context, node):
132132
name=node.name,
133133
)
134134
context.add(output, node.name)
135+
136+
137+
@register_torch_op(
138+
torch_alias=["quant::dequantize_codebook", "quant.dequantize_codebook"],
139+
override=False,
140+
)
141+
def dequantize_codebook(context, node):
142+
inputs = _get_inputs(context, node, expected=[4, 5])
143+
codes = inputs[0].val
144+
codebook = inputs[1].val
145+
nbits = inputs[2].val
146+
147+
# information in block_size is redundant with codebook.shape
148+
block_size = inputs[3].val # noqa: F841
149+
150+
assert len(codes.shape) == 2, "Only rank 2 inputs are supported"
151+
152+
# Assert codebook is as expected. codebook.dim() = codes.dim() + 2
153+
assert len(codebook.shape) == 4, "Only rank 4 inputs are supported for codebook"
154+
assert codebook.shape[0] == 1, "Only grouped_channel granularity is supported"
155+
n_luts = codebook.shape[1]
156+
assert (
157+
codes.shape[1] % n_luts == 0
158+
), "codes.shape[1] must be divisible by codebook.shape[1]"
159+
assert codebook.shape[2] == 2**nbits
160+
assert codebook.shape[3] == 1, "Only scalar look up values are supported"
161+
162+
if len(inputs) > 4:
163+
output_dtype = inputs[4].val
164+
out_np_dtype = NUM_TO_NUMPY_DTYPE[output_dtype]
165+
_logger.warning(
166+
f"Core ML ignores output_dtype {out_np_dtype} on torchao.dequantize_affine and instead uses the native precision."
167+
)
168+
169+
output = _utils._construct_constexpr_lut_op(
170+
codes.astype(np.int8),
171+
codebook,
172+
name=node.name,
173+
)
174+
context.add(output, node.name)

backends/apple/coreml/test/test_torch_ops.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414

1515
from executorch.backends.apple.coreml.compiler import CoreMLBackend
1616
from executorch.backends.apple.coreml.partition import CoreMLPartitioner
17+
from executorch.exir.backend.utils import format_delegated_graph
18+
19+
from torchao.prototype.quantization.codebook_coreml import CodebookWeightOnlyConfig
1720
from torchao.quantization import IntxWeightOnlyConfig, PerAxis, PerGroup, quantize_
1821

1922

@@ -164,6 +167,61 @@ def test_dequantize_affine_c8w_embedding_b4w_linear(self):
164167
et_prog = delegated_program.to_executorch()
165168
self._compare_outputs(et_prog, model, example_inputs)
166169

170+
def test_dequantize_codebook_linear(self):
171+
model, example_inputs = self._get_test_model()
172+
quantize_(
173+
model,
174+
CodebookWeightOnlyConfig(dtype=torch.uint2, block_size=[-1, 16]),
175+
)
176+
ep = torch.export.export(model, example_inputs)
177+
assert "torch.ops.quant.dequantize_codebook.default" in ep.graph_module.code
178+
delegated_program = executorch.exir.to_edge_transform_and_lower(
179+
ep,
180+
partitioner=[self._coreml_partitioner()],
181+
)
182+
for node in delegated_program.exported_program().graph.nodes:
183+
if node.op == "call_function":
184+
assert node.target.__name__ in [
185+
"executorch_call_delegate",
186+
"getitem",
187+
], f"Got unexpected node target after delegation: {node.target.__name__}"
188+
189+
assert (
190+
"executorch.exir.dialects.edge._ops.quant.dequantize_codebook.default"
191+
in format_delegated_graph(delegated_program.exported_program().graph_module)
192+
)
193+
194+
et_prog = delegated_program.to_executorch()
195+
self._compare_outputs(et_prog, model, example_inputs)
196+
197+
def test_dequantize_codebook_embedding(self):
198+
model, example_inputs = self._get_test_model()
199+
quantize_(
200+
model,
201+
CodebookWeightOnlyConfig(dtype=torch.uint3, block_size=[-1, 16]),
202+
lambda m, fqn: isinstance(m, torch.nn.Embedding),
203+
)
204+
ep = torch.export.export(model, example_inputs)
205+
assert "torch.ops.quant.dequantize_codebook.default" in ep.graph_module.code
206+
delegated_program = executorch.exir.to_edge_transform_and_lower(
207+
ep,
208+
partitioner=[self._coreml_partitioner()],
209+
)
210+
for node in delegated_program.exported_program().graph.nodes:
211+
if node.op == "call_function":
212+
assert node.target.__name__ in [
213+
"executorch_call_delegate",
214+
"getitem",
215+
], f"Got unexpected node target after delegation: {node.target.__name__}"
216+
217+
assert (
218+
"executorch.exir.dialects.edge._ops.quant.dequantize_codebook.default"
219+
in format_delegated_graph(delegated_program.exported_program().graph_module)
220+
)
221+
222+
et_prog = delegated_program.to_executorch()
223+
self._compare_outputs(et_prog, model, example_inputs)
224+
167225

168226
if __name__ == "__main__":
169227
test_runner = TestTorchOps()
@@ -172,3 +230,5 @@ def test_dequantize_affine_c8w_embedding_b4w_linear(self):
172230
test_runner.test_dequantize_affine_c4w_embedding()
173231
test_runner.test_dequantize_affine_c4w_linear()
174232
test_runner.test_dequantize_affine_c8w_embedding_b4w_linear()
233+
test_runner.test_dequantize_codebook_linear()
234+
test_runner.test_dequantize_codebook_embedding()

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@ dependencies=[
7272
"typing-extensions>=4.10.0",
7373
# Keep this version in sync with: ./backends/apple/coreml/scripts/install_requirements.sh
7474
"coremltools==8.3; platform_system == 'Darwin' or platform_system == 'Linux'",
75+
# scikit-learn is used to support palettization in the coreml backend
76+
"scikit-learn==1.7.1",
7577
"hydra-core>=1.3.0",
7678
"omegaconf>=2.3.0",
7779
]

third-party/ao

Submodule ao updated 125 files

0 commit comments

Comments
 (0)