@@ -35,7 +35,7 @@ def _coreml_partitioner(self):
3535
3636 def _get_test_model (self ):
3737 model = torch .nn .Sequential (
38- torch .nn .Embedding (64 , 128 ), torch .nn .Linear (128 , 128 ), torch .nn .ReLU ()
38+ torch .nn .Embedding (64 , 128 ), torch .nn .Linear (128 , 256 ), torch .nn .ReLU ()
3939 )
4040 example_inputs = (torch .LongTensor ([0 ]),)
4141 return model , example_inputs
@@ -158,7 +158,7 @@ def test_dequantize_affine_c8w_embedding_b4w_linear(self):
158158 et_prog = delegated_program .to_executorch ()
159159 self ._compare_outputs (et_prog , model , example_inputs )
160160
161- def test_dequantize_codebook_linear (self ):
161+ def test_dequantize_codebook_linear_per_grouped_col (self ):
162162 model , example_inputs = self ._get_test_model ()
163163 quantize_ (
164164 model ,
@@ -185,7 +185,34 @@ def test_dequantize_codebook_linear(self):
185185 et_prog = delegated_program .to_executorch ()
186186 self ._compare_outputs (et_prog , model , example_inputs )
187187
188- def test_dequantize_codebook_embedding (self ):
188+ def test_dequantize_codebook_linear_per_grouped_row (self ):
189+ model , example_inputs = self ._get_test_model ()
190+ quantize_ (
191+ model ,
192+ CodebookWeightOnlyConfig (dtype = torch .uint2 , block_size = [16 , - 1 ]),
193+ )
194+ ep = torch .export .export (model , example_inputs )
195+ assert "torch.ops.quant.dequantize_codebook.default" in ep .graph_module .code
196+ delegated_program = executorch .exir .to_edge_transform_and_lower (
197+ ep ,
198+ partitioner = [self ._coreml_partitioner ()],
199+ )
200+ for node in delegated_program .exported_program ().graph .nodes :
201+ if node .op == "call_function" :
202+ assert node .target .__name__ in [
203+ "executorch_call_delegate" ,
204+ "getitem" ,
205+ ], f"Got unexpected node target after delegation: { node .target .__name__ } "
206+
207+ assert (
208+ "executorch.exir.dialects.edge._ops.quant.dequantize_codebook.default"
209+ in format_delegated_graph (delegated_program .exported_program ().graph_module )
210+ )
211+
212+ et_prog = delegated_program .to_executorch ()
213+ self ._compare_outputs (et_prog , model , example_inputs )
214+
215+ def test_dequantize_codebook_embedding_per_grouped_col (self ):
189216 model , example_inputs = self ._get_test_model ()
190217 quantize_ (
191218 model ,
@@ -213,6 +240,34 @@ def test_dequantize_codebook_embedding(self):
213240 et_prog = delegated_program .to_executorch ()
214241 self ._compare_outputs (et_prog , model , example_inputs )
215242
243+ def test_dequantize_codebook_embedding_per_grouped_row (self ):
244+ model , example_inputs = self ._get_test_model ()
245+ quantize_ (
246+ model ,
247+ CodebookWeightOnlyConfig (dtype = torch .uint3 , block_size = [16 , - 1 ]),
248+ lambda m , fqn : isinstance (m , torch .nn .Embedding ),
249+ )
250+ ep = torch .export .export (model , example_inputs )
251+ assert "torch.ops.quant.dequantize_codebook.default" in ep .graph_module .code
252+ delegated_program = executorch .exir .to_edge_transform_and_lower (
253+ ep ,
254+ partitioner = [self ._coreml_partitioner ()],
255+ )
256+ for node in delegated_program .exported_program ().graph .nodes :
257+ if node .op == "call_function" :
258+ assert node .target .__name__ in [
259+ "executorch_call_delegate" ,
260+ "getitem" ,
261+ ], f"Got unexpected node target after delegation: { node .target .__name__ } "
262+
263+ assert (
264+ "executorch.exir.dialects.edge._ops.quant.dequantize_codebook.default"
265+ in format_delegated_graph (delegated_program .exported_program ().graph_module )
266+ )
267+
268+ et_prog = delegated_program .to_executorch ()
269+ self ._compare_outputs (et_prog , model , example_inputs )
270+
216271
217272if __name__ == "__main__" :
218273 test_runner = TestTorchOps ()
@@ -221,5 +276,7 @@ def test_dequantize_codebook_embedding(self):
221276 test_runner .test_dequantize_affine_c4w_embedding ()
222277 test_runner .test_dequantize_affine_c4w_linear ()
223278 test_runner .test_dequantize_affine_c8w_embedding_b4w_linear ()
224- test_runner .test_dequantize_codebook_linear ()
225- test_runner .test_dequantize_codebook_embedding ()
279+ test_runner .test_dequantize_codebook_linear_per_grouped_col ()
280+ test_runner .test_dequantize_codebook_linear_per_grouped_row ()
281+ test_runner .test_dequantize_codebook_embedding_per_grouped_col ()
282+ test_runner .test_dequantize_codebook_embedding_per_grouped_row ()
0 commit comments