Skip to content

Commit 4fc602b

Browse files
authored
Enable per-row/per-col grouping in CoreML LUT ops (#13674)
This PR enables per-row grouping in CoreML LUT ops.
1 parent 17cad55 commit 4fc602b

File tree

3 files changed

+79
-11
lines changed

3 files changed

+79
-11
lines changed

backends/apple/coreml/compiler/torch_ops.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -175,11 +175,22 @@ def dequantize_codebook(context, node):
175175

176176
# Assert codebook is as expected. codebook.dim() = codes.dim() + 2
177177
assert len(codebook.shape) == 4, "Only rank 4 inputs are supported for codebook"
178-
assert codebook.shape[0] == 1, "Only grouped_channel granularity is supported"
179-
n_luts = codebook.shape[1]
180-
assert (
181-
codes.shape[1] % n_luts == 0
182-
), "codes.shape[1] must be divisible by codebook.shape[1]"
178+
assert (codebook.shape[0] == 1) or (
179+
codebook.shape[1] == 1
180+
), "Only grouped_channel granularity is supported"
181+
if codebook.shape[0] == 1:
182+
# LUT is per column group
183+
n_luts = codebook.shape[1]
184+
assert (
185+
codes.shape[1] % n_luts == 0
186+
), "codes.shape[1] must be divisible by codebook.shape[1]"
187+
else:
188+
# LUT is per row group
189+
n_luts = codebook.shape[0]
190+
assert (
191+
codes.shape[0] % n_luts == 0
192+
), "codes.shape[0] must be divisible by codebook.shape[0]"
193+
183194
assert codebook.shape[2] == 2**nbits
184195
assert codebook.shape[3] == 1, "Only scalar look up values are supported"
185196

backends/apple/coreml/test/test_torch_ops.py

Lines changed: 62 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def _coreml_partitioner(self):
3535

3636
def _get_test_model(self):
3737
model = torch.nn.Sequential(
38-
torch.nn.Embedding(64, 128), torch.nn.Linear(128, 128), torch.nn.ReLU()
38+
torch.nn.Embedding(64, 128), torch.nn.Linear(128, 256), torch.nn.ReLU()
3939
)
4040
example_inputs = (torch.LongTensor([0]),)
4141
return model, example_inputs
@@ -158,7 +158,7 @@ def test_dequantize_affine_c8w_embedding_b4w_linear(self):
158158
et_prog = delegated_program.to_executorch()
159159
self._compare_outputs(et_prog, model, example_inputs)
160160

161-
def test_dequantize_codebook_linear(self):
161+
def test_dequantize_codebook_linear_per_grouped_col(self):
162162
model, example_inputs = self._get_test_model()
163163
quantize_(
164164
model,
@@ -185,7 +185,34 @@ def test_dequantize_codebook_linear(self):
185185
et_prog = delegated_program.to_executorch()
186186
self._compare_outputs(et_prog, model, example_inputs)
187187

188-
def test_dequantize_codebook_embedding(self):
188+
def test_dequantize_codebook_linear_per_grouped_row(self):
189+
model, example_inputs = self._get_test_model()
190+
quantize_(
191+
model,
192+
CodebookWeightOnlyConfig(dtype=torch.uint2, block_size=[16, -1]),
193+
)
194+
ep = torch.export.export(model, example_inputs)
195+
assert "torch.ops.quant.dequantize_codebook.default" in ep.graph_module.code
196+
delegated_program = executorch.exir.to_edge_transform_and_lower(
197+
ep,
198+
partitioner=[self._coreml_partitioner()],
199+
)
200+
for node in delegated_program.exported_program().graph.nodes:
201+
if node.op == "call_function":
202+
assert node.target.__name__ in [
203+
"executorch_call_delegate",
204+
"getitem",
205+
], f"Got unexpected node target after delegation: {node.target.__name__}"
206+
207+
assert (
208+
"executorch.exir.dialects.edge._ops.quant.dequantize_codebook.default"
209+
in format_delegated_graph(delegated_program.exported_program().graph_module)
210+
)
211+
212+
et_prog = delegated_program.to_executorch()
213+
self._compare_outputs(et_prog, model, example_inputs)
214+
215+
def test_dequantize_codebook_embedding_per_grouped_col(self):
189216
model, example_inputs = self._get_test_model()
190217
quantize_(
191218
model,
@@ -213,6 +240,34 @@ def test_dequantize_codebook_embedding(self):
213240
et_prog = delegated_program.to_executorch()
214241
self._compare_outputs(et_prog, model, example_inputs)
215242

243+
def test_dequantize_codebook_embedding_per_grouped_row(self):
244+
model, example_inputs = self._get_test_model()
245+
quantize_(
246+
model,
247+
CodebookWeightOnlyConfig(dtype=torch.uint3, block_size=[16, -1]),
248+
lambda m, fqn: isinstance(m, torch.nn.Embedding),
249+
)
250+
ep = torch.export.export(model, example_inputs)
251+
assert "torch.ops.quant.dequantize_codebook.default" in ep.graph_module.code
252+
delegated_program = executorch.exir.to_edge_transform_and_lower(
253+
ep,
254+
partitioner=[self._coreml_partitioner()],
255+
)
256+
for node in delegated_program.exported_program().graph.nodes:
257+
if node.op == "call_function":
258+
assert node.target.__name__ in [
259+
"executorch_call_delegate",
260+
"getitem",
261+
], f"Got unexpected node target after delegation: {node.target.__name__}"
262+
263+
assert (
264+
"executorch.exir.dialects.edge._ops.quant.dequantize_codebook.default"
265+
in format_delegated_graph(delegated_program.exported_program().graph_module)
266+
)
267+
268+
et_prog = delegated_program.to_executorch()
269+
self._compare_outputs(et_prog, model, example_inputs)
270+
216271

217272
if __name__ == "__main__":
218273
test_runner = TestTorchOps()
@@ -221,5 +276,7 @@ def test_dequantize_codebook_embedding(self):
221276
test_runner.test_dequantize_affine_c4w_embedding()
222277
test_runner.test_dequantize_affine_c4w_linear()
223278
test_runner.test_dequantize_affine_c8w_embedding_b4w_linear()
224-
test_runner.test_dequantize_codebook_linear()
225-
test_runner.test_dequantize_codebook_embedding()
279+
test_runner.test_dequantize_codebook_linear_per_grouped_col()
280+
test_runner.test_dequantize_codebook_linear_per_grouped_row()
281+
test_runner.test_dequantize_codebook_embedding_per_grouped_col()
282+
test_runner.test_dequantize_codebook_embedding_per_grouped_row()

third-party/ao

Submodule ao updated 289 files

0 commit comments

Comments
 (0)