14
14
15
15
from executorch .backends .apple .coreml .compiler import CoreMLBackend
16
16
from executorch .backends .apple .coreml .partition import CoreMLPartitioner
17
+ from executorch .exir .backend .utils import format_delegated_graph
18
+
19
+ from torchao .prototype .quantization .codebook_coreml import CodebookWeightOnlyConfig
17
20
from torchao .quantization import IntxWeightOnlyConfig , PerAxis , PerGroup , quantize_
18
21
19
22
@@ -164,6 +167,61 @@ def test_dequantize_affine_c8w_embedding_b4w_linear(self):
164
167
et_prog = delegated_program .to_executorch ()
165
168
self ._compare_outputs (et_prog , model , example_inputs )
166
169
170
+ def test_dequantize_codebook_linear (self ):
171
+ model , example_inputs = self ._get_test_model ()
172
+ quantize_ (
173
+ model ,
174
+ CodebookWeightOnlyConfig (dtype = torch .uint2 , block_size = [- 1 , 16 ]),
175
+ )
176
+ ep = torch .export .export (model , example_inputs )
177
+ assert "torch.ops.quant.dequantize_codebook.default" in ep .graph_module .code
178
+ delegated_program = executorch .exir .to_edge_transform_and_lower (
179
+ ep ,
180
+ partitioner = [self ._coreml_partitioner ()],
181
+ )
182
+ for node in delegated_program .exported_program ().graph .nodes :
183
+ if node .op == "call_function" :
184
+ assert node .target .__name__ in [
185
+ "executorch_call_delegate" ,
186
+ "getitem" ,
187
+ ], f"Got unexpected node target after delegation: { node .target .__name__ } "
188
+
189
+ assert (
190
+ "executorch.exir.dialects.edge._ops.quant.dequantize_codebook.default"
191
+ in format_delegated_graph (delegated_program .exported_program ().graph_module )
192
+ )
193
+
194
+ et_prog = delegated_program .to_executorch ()
195
+ self ._compare_outputs (et_prog , model , example_inputs )
196
+
197
+ def test_dequantize_codebook_embedding (self ):
198
+ model , example_inputs = self ._get_test_model ()
199
+ quantize_ (
200
+ model ,
201
+ CodebookWeightOnlyConfig (dtype = torch .uint3 , block_size = [- 1 , 16 ]),
202
+ lambda m , fqn : isinstance (m , torch .nn .Embedding ),
203
+ )
204
+ ep = torch .export .export (model , example_inputs )
205
+ assert "torch.ops.quant.dequantize_codebook.default" in ep .graph_module .code
206
+ delegated_program = executorch .exir .to_edge_transform_and_lower (
207
+ ep ,
208
+ partitioner = [self ._coreml_partitioner ()],
209
+ )
210
+ for node in delegated_program .exported_program ().graph .nodes :
211
+ if node .op == "call_function" :
212
+ assert node .target .__name__ in [
213
+ "executorch_call_delegate" ,
214
+ "getitem" ,
215
+ ], f"Got unexpected node target after delegation: { node .target .__name__ } "
216
+
217
+ assert (
218
+ "executorch.exir.dialects.edge._ops.quant.dequantize_codebook.default"
219
+ in format_delegated_graph (delegated_program .exported_program ().graph_module )
220
+ )
221
+
222
+ et_prog = delegated_program .to_executorch ()
223
+ self ._compare_outputs (et_prog , model , example_inputs )
224
+
167
225
168
226
if __name__ == "__main__" :
169
227
test_runner = TestTorchOps ()
@@ -172,3 +230,5 @@ def test_dequantize_affine_c8w_embedding_b4w_linear(self):
172
230
test_runner .test_dequantize_affine_c4w_embedding ()
173
231
test_runner .test_dequantize_affine_c4w_linear ()
174
232
test_runner .test_dequantize_affine_c8w_embedding_b4w_linear ()
233
+ test_runner .test_dequantize_codebook_linear ()
234
+ test_runner .test_dequantize_codebook_embedding ()
0 commit comments