1414
1515from executorch .backends .apple .coreml .compiler import CoreMLBackend
1616from executorch .backends .apple .coreml .partition import CoreMLPartitioner
17+ from executorch .exir .backend .utils import format_delegated_graph
18+
19+ from torchao .prototype .quantization .codebook_coreml import CodebookWeightOnlyConfig
1720from torchao .quantization import IntxWeightOnlyConfig , PerAxis , PerGroup , quantize_
1821
1922
@@ -164,6 +167,61 @@ def test_dequantize_affine_c8w_embedding_b4w_linear(self):
164167 et_prog = delegated_program .to_executorch ()
165168 self ._compare_outputs (et_prog , model , example_inputs )
166169
170+ def test_dequantize_codebook_linear (self ):
171+ model , example_inputs = self ._get_test_model ()
172+ quantize_ (
173+ model ,
174+ CodebookWeightOnlyConfig (dtype = torch .uint2 , block_size = [- 1 , 16 ]),
175+ )
176+ ep = torch .export .export (model , example_inputs )
177+ assert "torch.ops.quant.dequantize_codebook.default" in ep .graph_module .code
178+ delegated_program = executorch .exir .to_edge_transform_and_lower (
179+ ep ,
180+ partitioner = [self ._coreml_partitioner ()],
181+ )
182+ for node in delegated_program .exported_program ().graph .nodes :
183+ if node .op == "call_function" :
184+ assert node .target .__name__ in [
185+ "executorch_call_delegate" ,
186+ "getitem" ,
187+ ], f"Got unexpected node target after delegation: { node .target .__name__ } "
188+
189+ assert (
190+ "executorch.exir.dialects.edge._ops.quant.dequantize_codebook.default"
191+ in format_delegated_graph (delegated_program .exported_program ().graph_module )
192+ )
193+
194+ et_prog = delegated_program .to_executorch ()
195+ self ._compare_outputs (et_prog , model , example_inputs )
196+
197+ def test_dequantize_codebook_embedding (self ):
198+ model , example_inputs = self ._get_test_model ()
199+ quantize_ (
200+ model ,
201+ CodebookWeightOnlyConfig (dtype = torch .uint3 , block_size = [- 1 , 16 ]),
202+ lambda m , fqn : isinstance (m , torch .nn .Embedding ),
203+ )
204+ ep = torch .export .export (model , example_inputs )
205+ assert "torch.ops.quant.dequantize_codebook.default" in ep .graph_module .code
206+ delegated_program = executorch .exir .to_edge_transform_and_lower (
207+ ep ,
208+ partitioner = [self ._coreml_partitioner ()],
209+ )
210+ for node in delegated_program .exported_program ().graph .nodes :
211+ if node .op == "call_function" :
212+ assert node .target .__name__ in [
213+ "executorch_call_delegate" ,
214+ "getitem" ,
215+ ], f"Got unexpected node target after delegation: { node .target .__name__ } "
216+
217+ assert (
218+ "executorch.exir.dialects.edge._ops.quant.dequantize_codebook.default"
219+ in format_delegated_graph (delegated_program .exported_program ().graph_module )
220+ )
221+
222+ et_prog = delegated_program .to_executorch ()
223+ self ._compare_outputs (et_prog , model , example_inputs )
224+
167225
168226if __name__ == "__main__" :
169227 test_runner = TestTorchOps ()
@@ -172,3 +230,5 @@ def test_dequantize_affine_c8w_embedding_b4w_linear(self):
172230 test_runner .test_dequantize_affine_c4w_embedding ()
173231 test_runner .test_dequantize_affine_c4w_linear ()
174232 test_runner .test_dequantize_affine_c8w_embedding_b4w_linear ()
233+ test_runner .test_dequantize_codebook_linear ()
234+ test_runner .test_dequantize_codebook_embedding ()
0 commit comments