|
13 | 13 | import executorch.backends.cadence.aot.ops_registrations # noqa |
14 | 14 | import torch |
15 | 15 | from executorch.backends.cadence.aot import compiler |
16 | | -from executorch.backends.cadence.aot.compiler import ( |
17 | | - export_to_edge, |
18 | | - quantize_and_export_to_edge, |
19 | | -) |
20 | 16 | from executorch.backends.cadence.aot.fuse_ops import ( |
21 | 17 | FuseFullThenReshapePass, |
22 | 18 | FuseMulScalarIntoDequantPass, |
@@ -336,94 +332,144 @@ def test_replace_quant_view_dequant_with_requantize(self): |
336 | 332 | ) |
337 | 333 |
|
338 | 334 | def test_replace_dequant_quant_with_requantize(self): |
339 | | - class M(torch.nn.Module): |
340 | | - def __init__(self): |
341 | | - super().__init__() |
342 | | - |
343 | | - def forward(self, x): |
344 | | - x = torch.ops.quantized_decomposed.dequantize_per_tensor( |
345 | | - x, 1.2, 3, 0, 127, torch.int8 |
346 | | - ) |
347 | | - x = torch.permute(x, [2, 0, 1, 3]) |
348 | | - x = torch.ops.quantized_decomposed.quantize_per_tensor( |
349 | | - x, 4.5, 6, 0, 127, torch.int8 |
350 | | - ) |
351 | | - return x |
352 | | - |
353 | | - inputs = torch.randn(2, 12, 1, 6).to(torch.int8) |
354 | | - model = M() |
355 | | - graph_module = export_to_edge(model, (inputs,)).exported_program().graph_module |
356 | | - graph_module = FuseQuantDequantToRequantizePass()(graph_module).graph_module |
| 335 | + builder = GraphBuilder() |
| 336 | + x = builder.placeholder("x", torch.randn(2, 12, 1, 6, dtype=torch.float32)) |
| 337 | + dequant = builder.call_operator( |
| 338 | + op=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, |
| 339 | + args=(x, 1.2, 3, 0, 127, torch.int8), |
| 340 | + ) |
| 341 | + quant = builder.call_operator( |
| 342 | + op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, |
| 343 | + args=(dequant, 4.5, 6, 0, 127, torch.int8), |
| 344 | + ) |
| 345 | + builder.output(quant) |
| 346 | + graph_module = FuseQuantDequantToRequantizePass()( |
| 347 | + builder.get_graph_module() |
| 348 | + ).graph_module |
357 | 349 |
|
358 | 350 | self.check_op_counts( |
359 | 351 | graph_module, |
360 | 352 | expected_op_counts={ |
361 | | - # Verify that dequant -> permute -> quant was replaced with permute -> requantize. |
| 353 | + # Verify that dequant -> quant was replaced with requantize. |
362 | 354 | exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 0, |
363 | 355 | exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: 0, |
364 | 356 | exir_ops.edge.cadence.requantize.default: 1, |
365 | 357 | }, |
366 | 358 | ) |
367 | 359 |
|
368 | 360 | def test_replace_dequant_permute_quant_with_requantize(self): |
369 | | - class M(torch.nn.Module): |
370 | | - def __init__(self): |
371 | | - super().__init__() |
372 | | - |
373 | | - def forward(self, x): |
374 | | - x = torch.ops.quantized_decomposed.dequantize_per_tensor( |
375 | | - x, 1.2, 3, 0, 127, torch.int8 |
376 | | - ) |
377 | | - x = torch.permute(x, [2, 0, 1, 3]) |
378 | | - x = torch.ops.quantized_decomposed.quantize_per_tensor( |
379 | | - x, 4.5, 6, 0, 127, torch.int8 |
380 | | - ) |
381 | | - return x |
382 | | - |
383 | | - inputs = torch.randn(2, 12, 1, 6).to(torch.int8) |
384 | | - model = M() |
385 | | - graph_module = export_to_edge(model, (inputs,)).exported_program().graph_module |
386 | | - graph_module = FuseQuantDequantToRequantizePass()(graph_module).graph_module |
| 361 | + builder = GraphBuilder() |
| 362 | + x = builder.placeholder("x", torch.randn(2, 12, 1, 6, dtype=torch.float32)) |
| 363 | + dequant = builder.call_operator( |
| 364 | + op=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, |
| 365 | + args=(x, 1.2, 3, 0, 127, torch.int8), |
| 366 | + ) |
| 367 | + permute = builder.call_operator( |
| 368 | + op=exir_ops.edge.aten.permute_copy.default, args=(dequant, [2, 0, 1, 3]) |
| 369 | + ) |
| 370 | + quant = builder.call_operator( |
| 371 | + op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, |
| 372 | + args=(permute, 4.5, 6, 0, 127, torch.int8), |
| 373 | + ) |
| 374 | + builder.output(quant) |
| 375 | + graph_module = FuseQuantDequantToRequantizePass()( |
| 376 | + builder.get_graph_module() |
| 377 | + ).graph_module |
387 | 378 |
|
388 | 379 | self.check_op_counts( |
389 | 380 | graph_module, |
390 | 381 | expected_op_counts={ |
391 | 382 | # Verify that dequant -> permute -> quant was replaced with permute -> requantize. |
392 | 383 | exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 0, |
393 | 384 | exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: 0, |
| 385 | + exir_ops.edge.aten.permute_copy.default: 1, |
394 | 386 | exir_ops.edge.cadence.requantize.default: 1, |
395 | 387 | }, |
396 | 388 | ) |
397 | 389 |
|
398 | 390 | def test_remove_nop_dequant_quant(self): |
399 | | - class M(torch.nn.Module): |
400 | | - def __init__(self): |
401 | | - super(M, self).__init__() |
402 | | - self.lin1 = torch.nn.Linear(6, 12, bias=False) |
403 | | - self.lin2 = torch.nn.Linear(12, 24, bias=False) |
| 391 | + LEADING_DIMS: Final[int] = 12 |
| 392 | + IN_DIM: Final[int] = 6 |
| 393 | + OUT_DIM: Final[int] = 12 |
404 | 394 |
|
405 | | - def forward(self, x): |
406 | | - x = self.lin1(x) |
407 | | - # redundant dequant+quant will be created around this permute |
408 | | - x = torch.permute(x, [0, 2, 1, 3]) |
409 | | - x = self.lin2(x) |
410 | | - return x |
411 | | - |
412 | | - inputs = torch.randn(2, 12, 1, 6) |
413 | | - model = M() |
414 | | - graph_module = ( |
415 | | - quantize_and_export_to_edge(model, (inputs,)) |
416 | | - .exported_program() |
417 | | - .graph_module |
| 395 | + builder = GraphBuilder() |
| 396 | + x = builder.placeholder( |
| 397 | + "x", torch.randn(LEADING_DIMS, IN_DIM, dtype=torch.float32) |
| 398 | + ) |
| 399 | + quant1 = builder.call_operator( |
| 400 | + op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, |
| 401 | + args=(x, 4.5, 6, 0, 127, torch.int8), |
| 402 | + ) |
| 403 | + weights = builder.call_operator( |
| 404 | + op=exir_ops.edge.aten.full.default, args=([OUT_DIM, IN_DIM], 1) |
| 405 | + ) |
| 406 | + bias = builder.call_operator( |
| 407 | + op=exir_ops.edge.aten.full.default, args=([OUT_DIM], 1) |
| 408 | + ) |
| 409 | + weight_zero_point = builder.call_operator( |
| 410 | + op=exir_ops.edge.aten.full.default, args=([IN_DIM], 0) |
| 411 | + ) |
| 412 | + out_multiplier = builder.call_operator( |
| 413 | + op=exir_ops.edge.aten.full.default, args=([OUT_DIM], 1) |
| 414 | + ) |
| 415 | + out_shift = builder.call_operator( |
| 416 | + op=exir_ops.edge.aten.full.default, args=([OUT_DIM], 0) |
418 | 417 | ) |
419 | | - graph_module = FuseQuantDequantToRequantizePass()(graph_module).graph_module |
| 418 | + linear1 = builder.call_operator( |
| 419 | + op=exir_ops.edge.cadence.quantized_linear.default, |
| 420 | + args=( |
| 421 | + quant1, |
| 422 | + weights, |
| 423 | + bias, |
| 424 | + 0, # src_zero_point |
| 425 | + weight_zero_point, |
| 426 | + out_multiplier, |
| 427 | + out_shift, |
| 428 | + 0, # out_zero_point |
| 429 | + None, |
| 430 | + ), |
| 431 | + ) |
| 432 | + dequant1 = builder.call_operator( |
| 433 | + op=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, |
| 434 | + args=(linear1, 1.2, 3, 0, 127, torch.int8), |
| 435 | + ) |
| 436 | + permute = builder.call_operator( |
| 437 | + op=exir_ops.edge.aten.permute_copy.default, args=(dequant1, [1, 0]) |
| 438 | + ) |
| 439 | + quant2 = builder.call_operator( |
| 440 | + op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, |
| 441 | + args=(permute, 4.5, 6, 0, 127, torch.int8), |
| 442 | + ) |
| 443 | + linear2 = builder.call_operator( |
| 444 | + op=exir_ops.edge.cadence.quantized_linear.default, |
| 445 | + args=( |
| 446 | + quant2, |
| 447 | + weights, |
| 448 | + bias, |
| 449 | + 0, # src_zero_point |
| 450 | + weight_zero_point, |
| 451 | + out_multiplier, |
| 452 | + out_shift, |
| 453 | + 0, # out_zero_point |
| 454 | + None, |
| 455 | + ), |
| 456 | + ) |
| 457 | + dequant2 = builder.call_operator( |
| 458 | + op=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, |
| 459 | + args=(linear2, 1.2, 3, 0, 127, torch.int8), |
| 460 | + ) |
| 461 | + builder.output(dequant2) |
| 462 | + graph_module = FuseQuantDequantToRequantizePass()( |
| 463 | + builder.get_graph_module() |
| 464 | + ).graph_module |
420 | 465 | self.check_op_counts( |
421 | 466 | graph_module, |
422 | 467 | expected_op_counts={ |
423 | | - # Verify that one dequant/quant pair was removed |
424 | | - # Expect 1 quantize ops: 1 input |
| 468 | + # Verify that one dequant/quant pair was removed from chain: |
| 469 | + # quant->linear->dequant->permute->quant->linear->dequant |
| 470 | + # gets converted to: |
| 471 | + # quant->linear->permute->linear->dequant |
425 | 472 | exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 1, |
426 | | - # Expect 1 dequant op at the end (output of second linear) |
427 | 473 | exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: 1, |
428 | 474 | }, |
429 | 475 | ) |
|
0 commit comments