Commit f68de58
[Inductor-FX] Support symbol and dynamic scalar graph inputs and outputs (pytorch#163596)
# Problems
This PR fixes a few edge cases that the FX converter missed related to dynamic shapes.
1. Inductor graphs can sometimes take `sympy.Symbol` inputs. We have logic to convert these to FX placeholder nodes. However, this logic did not update the `self.expr_to_proxy` table mapping symbols to proxy nodes. (There was existing logic to do this for `ir.TensorBox` inputs, but not `sympy.Symbol`.) This caused sympy tracing to fail when these symbol inputs were used in other expressions.
2. We lacked codegen for `ShapeAsConstantBuffer`. This IR node is seen when the graph input or output is a scalar computed from dynamic shapes.
# Fixes
a. Update `self.expr_to_proxy` when generating placeholders for `sympy.Symbol` inputs. Change `SymbolBuffer.get_example` to convert the symbol to a `torch.SymInt`, so we can populate `meta["val"]` correctly and use the value in other computations.
b. Support `ShapeAsConstantBuffer` by tracing the sympy expression.
c. Move output generation inside the metadata hook, allowing us to populate `meta["val"]` for the nodes computing `ShapeAsConstantBuffer`.
# Test plan
Added several new CI tests:
1. `torch.cond` with dynamic shapes. This exposes both issues, as the predicate is a `ShapeAsConstantBuffer` and one of the subgraphs uses a symbol input, due to the closure. Also tests when the parent and subgraphs have different input shapes.
2. Output dynamic shape scalar. This tests `ShapeAsConstantBuffer` as an output.
Pull Request resolved: pytorch#163596
Approved by: https://github.com/angelayi, https://github.com/jansel1 parent a8e9ed2 commit f68de58
File tree
3 files changed
+81
-14
lines changed- test/inductor
- torch/_inductor
- codegen
3 files changed
+81
-14
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
990 | 990 | | |
991 | 991 | | |
992 | 992 | | |
| 993 | + | |
| 994 | + | |
| 995 | + | |
| 996 | + | |
| 997 | + | |
| 998 | + | |
| 999 | + | |
| 1000 | + | |
| 1001 | + | |
| 1002 | + | |
| 1003 | + | |
| 1004 | + | |
| 1005 | + | |
| 1006 | + | |
| 1007 | + | |
| 1008 | + | |
| 1009 | + | |
| 1010 | + | |
| 1011 | + | |
| 1012 | + | |
| 1013 | + | |
| 1014 | + | |
| 1015 | + | |
| 1016 | + | |
| 1017 | + | |
| 1018 | + | |
| 1019 | + | |
| 1020 | + | |
| 1021 | + | |
| 1022 | + | |
| 1023 | + | |
| 1024 | + | |
| 1025 | + | |
| 1026 | + | |
| 1027 | + | |
| 1028 | + | |
| 1029 | + | |
| 1030 | + | |
| 1031 | + | |
| 1032 | + | |
| 1033 | + | |
| 1034 | + | |
993 | 1035 | | |
994 | 1036 | | |
995 | 1037 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
23 | 23 | | |
24 | 24 | | |
25 | 25 | | |
26 | | - | |
| 26 | + | |
| 27 | + | |
| 28 | + | |
| 29 | + | |
| 30 | + | |
27 | 31 | | |
28 | 32 | | |
29 | 33 | | |
| |||
89 | 93 | | |
90 | 94 | | |
91 | 95 | | |
92 | | - | |
93 | | - | |
| 96 | + | |
| 97 | + | |
| 98 | + | |
| 99 | + | |
94 | 100 | | |
95 | 101 | | |
96 | 102 | | |
| |||
386 | 392 | | |
387 | 393 | | |
388 | 394 | | |
| 395 | + | |
| 396 | + | |
| 397 | + | |
| 398 | + | |
| 399 | + | |
| 400 | + | |
| 401 | + | |
389 | 402 | | |
390 | 403 | | |
391 | 404 | | |
| |||
398 | 411 | | |
399 | 412 | | |
400 | 413 | | |
| 414 | + | |
401 | 415 | | |
402 | 416 | | |
403 | | - | |
| 417 | + | |
404 | 418 | | |
405 | 419 | | |
406 | 420 | | |
407 | | - | |
| 421 | + | |
| 422 | + | |
| 423 | + | |
408 | 424 | | |
409 | 425 | | |
| 426 | + | |
| 427 | + | |
| 428 | + | |
| 429 | + | |
410 | 430 | | |
411 | 431 | | |
412 | 432 | | |
| |||
421 | 441 | | |
422 | 442 | | |
423 | 443 | | |
424 | | - | |
425 | | - | |
| 444 | + | |
426 | 445 | | |
427 | 446 | | |
428 | 447 | | |
| |||
475 | 494 | | |
476 | 495 | | |
477 | 496 | | |
478 | | - | |
479 | | - | |
480 | | - | |
| 497 | + | |
481 | 498 | | |
482 | 499 | | |
483 | 500 | | |
| |||
504 | 521 | | |
505 | 522 | | |
506 | 523 | | |
| 524 | + | |
| 525 | + | |
| 526 | + | |
| 527 | + | |
507 | 528 | | |
508 | 529 | | |
509 | 530 | | |
| |||
539 | 560 | | |
540 | 561 | | |
541 | 562 | | |
542 | | - | |
| 563 | + | |
| 564 | + | |
| 565 | + | |
543 | 566 | | |
544 | 567 | | |
545 | 568 | | |
| |||
554 | 577 | | |
555 | 578 | | |
556 | 579 | | |
557 | | - | |
| 580 | + | |
558 | 581 | | |
559 | 582 | | |
560 | 583 | | |
| |||
614 | 637 | | |
615 | 638 | | |
616 | 639 | | |
617 | | - | |
| 640 | + | |
| 641 | + | |
| 642 | + | |
618 | 643 | | |
619 | 644 | | |
620 | 645 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
4220 | 4220 | | |
4221 | 4221 | | |
4222 | 4222 | | |
4223 | | - | |
| 4223 | + | |
4224 | 4224 | | |
4225 | 4225 | | |
4226 | 4226 | | |
| |||
0 commit comments