@@ -35,7 +35,7 @@ def _coreml_partitioner(self):
35
35
36
36
def _get_test_model (self ):
37
37
model = torch .nn .Sequential (
38
- torch .nn .Embedding (64 , 128 ), torch .nn .Linear (128 , 128 ), torch .nn .ReLU ()
38
+ torch .nn .Embedding (64 , 128 ), torch .nn .Linear (128 , 256 ), torch .nn .ReLU ()
39
39
)
40
40
example_inputs = (torch .LongTensor ([0 ]),)
41
41
return model , example_inputs
@@ -158,7 +158,7 @@ def test_dequantize_affine_c8w_embedding_b4w_linear(self):
158
158
et_prog = delegated_program .to_executorch ()
159
159
self ._compare_outputs (et_prog , model , example_inputs )
160
160
161
- def test_dequantize_codebook_linear (self ):
161
+ def test_dequantize_codebook_linear_per_grouped_col (self ):
162
162
model , example_inputs = self ._get_test_model ()
163
163
quantize_ (
164
164
model ,
@@ -185,7 +185,34 @@ def test_dequantize_codebook_linear(self):
185
185
et_prog = delegated_program .to_executorch ()
186
186
self ._compare_outputs (et_prog , model , example_inputs )
187
187
188
- def test_dequantize_codebook_embedding (self ):
188
+ def test_dequantize_codebook_linear_per_grouped_row (self ):
189
+ model , example_inputs = self ._get_test_model ()
190
+ quantize_ (
191
+ model ,
192
+ CodebookWeightOnlyConfig (dtype = torch .uint2 , block_size = [16 , - 1 ]),
193
+ )
194
+ ep = torch .export .export (model , example_inputs )
195
+ assert "torch.ops.quant.dequantize_codebook.default" in ep .graph_module .code
196
+ delegated_program = executorch .exir .to_edge_transform_and_lower (
197
+ ep ,
198
+ partitioner = [self ._coreml_partitioner ()],
199
+ )
200
+ for node in delegated_program .exported_program ().graph .nodes :
201
+ if node .op == "call_function" :
202
+ assert node .target .__name__ in [
203
+ "executorch_call_delegate" ,
204
+ "getitem" ,
205
+ ], f"Got unexpected node target after delegation: { node .target .__name__ } "
206
+
207
+ assert (
208
+ "executorch.exir.dialects.edge._ops.quant.dequantize_codebook.default"
209
+ in format_delegated_graph (delegated_program .exported_program ().graph_module )
210
+ )
211
+
212
+ et_prog = delegated_program .to_executorch ()
213
+ self ._compare_outputs (et_prog , model , example_inputs )
214
+
215
+ def test_dequantize_codebook_embedding_per_grouped_col (self ):
189
216
model , example_inputs = self ._get_test_model ()
190
217
quantize_ (
191
218
model ,
@@ -213,6 +240,34 @@ def test_dequantize_codebook_embedding(self):
213
240
et_prog = delegated_program .to_executorch ()
214
241
self ._compare_outputs (et_prog , model , example_inputs )
215
242
243
+ def test_dequantize_codebook_embedding_per_grouped_row (self ):
244
+ model , example_inputs = self ._get_test_model ()
245
+ quantize_ (
246
+ model ,
247
+ CodebookWeightOnlyConfig (dtype = torch .uint3 , block_size = [16 , - 1 ]),
248
+ lambda m , fqn : isinstance (m , torch .nn .Embedding ),
249
+ )
250
+ ep = torch .export .export (model , example_inputs )
251
+ assert "torch.ops.quant.dequantize_codebook.default" in ep .graph_module .code
252
+ delegated_program = executorch .exir .to_edge_transform_and_lower (
253
+ ep ,
254
+ partitioner = [self ._coreml_partitioner ()],
255
+ )
256
+ for node in delegated_program .exported_program ().graph .nodes :
257
+ if node .op == "call_function" :
258
+ assert node .target .__name__ in [
259
+ "executorch_call_delegate" ,
260
+ "getitem" ,
261
+ ], f"Got unexpected node target after delegation: { node .target .__name__ } "
262
+
263
+ assert (
264
+ "executorch.exir.dialects.edge._ops.quant.dequantize_codebook.default"
265
+ in format_delegated_graph (delegated_program .exported_program ().graph_module )
266
+ )
267
+
268
+ et_prog = delegated_program .to_executorch ()
269
+ self ._compare_outputs (et_prog , model , example_inputs )
270
+
216
271
217
272
if __name__ == "__main__" :
218
273
test_runner = TestTorchOps ()
@@ -221,5 +276,7 @@ def test_dequantize_codebook_embedding(self):
221
276
test_runner .test_dequantize_affine_c4w_embedding ()
222
277
test_runner .test_dequantize_affine_c4w_linear ()
223
278
test_runner .test_dequantize_affine_c8w_embedding_b4w_linear ()
224
- test_runner .test_dequantize_codebook_linear ()
225
- test_runner .test_dequantize_codebook_embedding ()
279
+ test_runner .test_dequantize_codebook_linear_per_grouped_col ()
280
+ test_runner .test_dequantize_codebook_linear_per_grouped_row ()
281
+ test_runner .test_dequantize_codebook_embedding_per_grouped_col ()
282
+ test_runner .test_dequantize_codebook_embedding_per_grouped_row ()
0 commit comments