1717
1818from torchao .prototype .quantization .codebook_coreml import CodebookWeightOnlyConfig
1919from torchao .quantization import IntxWeightOnlyConfig , PerAxis , PerGroup , quantize_
20-
20+ from executorch . exir . backend . utils import format_delegated_graph
2121
2222def is_fbcode ():
2323 return not hasattr (torch .version , "git_version" )
@@ -173,7 +173,7 @@ def test_dequantize_codebook_linear(self):
173173 CodebookWeightOnlyConfig (dtype = torch .uint2 , block_size = [- 1 , 16 ]),
174174 )
175175 ep = torch .export .export (model , example_inputs )
176- print ( "ORIGINAL MODEL" , ep )
176+ assert "torch.ops.quant.dequantize_codebook.default" in ep . graph_module . code
177177 delegated_program = executorch .exir .to_edge_transform_and_lower (
178178 ep ,
179179 partitioner = [self ._coreml_partitioner ()],
@@ -184,8 +184,11 @@ def test_dequantize_codebook_linear(self):
184184 "executorch_call_delegate" ,
185185 "getitem" ,
186186 ], f"Got unexpected node target after delegation: { node .target .__name__ } "
187+
188+ print (format_delegated_graph (delegated_program .exported_program ().graph_module ))
189+
190+
187191 et_prog = delegated_program .to_executorch ()
188- print (et_prog .exported_program ())
189192 self ._compare_outputs (et_prog , model , example_inputs )
190193
191194 def test_dequantize_codebook_embedding (self ):
@@ -196,6 +199,7 @@ def test_dequantize_codebook_embedding(self):
196199 lambda m , fqn : isinstance (m , torch .nn .Embedding ),
197200 )
198201 ep = torch .export .export (model , example_inputs )
202+ assert "torch.ops.quant.dequantize_codebook.default" in ep .graph_module .code
199203 delegated_program = executorch .exir .to_edge_transform_and_lower (
200204 ep ,
201205 partitioner = [self ._coreml_partitioner ()],
@@ -207,16 +211,15 @@ def test_dequantize_codebook_embedding(self):
207211 "getitem" ,
208212 ], f"Got unexpected node target after delegation: { node .target .__name__ } "
209213 et_prog = delegated_program .to_executorch ()
210- print (et_prog .exported_program ())
211214 self ._compare_outputs (et_prog , model , example_inputs )
212215
213216
214217if __name__ == "__main__" :
215218 test_runner = TestTorchOps ()
216- test_runner .test_dequantize_affine_b4w_embedding ()
217- test_runner .test_dequantize_affine_b4w_linear ()
218- test_runner .test_dequantize_affine_c4w_embedding ()
219- test_runner .test_dequantize_affine_c4w_linear ()
220- test_runner .test_dequantize_affine_c8w_embedding_b4w_linear ()
219+ # test_runner.test_dequantize_affine_b4w_embedding()
220+ # test_runner.test_dequantize_affine_b4w_linear()
221+ # test_runner.test_dequantize_affine_c4w_embedding()
222+ # test_runner.test_dequantize_affine_c4w_linear()
223+ # test_runner.test_dequantize_affine_c8w_embedding_b4w_linear()
221224 test_runner .test_dequantize_codebook_linear ()
222225 test_runner .test_dequantize_codebook_embedding ()
0 commit comments