Commit 92f41cc
[Inductor] Support precomputed size args in the FX backend. (pytorch#157758)
# Feature
If a Triton kernel has a complicated indexing expression, Inductor may decide to precompute it on the host and pass it to the kernel as an argument. This happens in situations like broadcasts with dynamic shapes.
This PR adds support for this feature to Inductor's FX IR backend.
We generate FX IR for precomputed size args in 3 steps:
1. In `PythonWrapperCodegen`, this PR refactors the relevant code to use a `SymbolicCallArgLine` instead of raw Python strings. This stores a (symbol, expr) pair. (Prior to this PR, it was (str, expr), but changing this to a symbol makes it easier to do substitutions later on.)
2. In `WrapperFxCodegen`, keep a dict of {symbol: expr} arg defs which gets updated whenever we see a `SymbolicCallArgLine`.
3. When the FX backend sees a `KernelCallLine`, it uses this dict to replace symbolic call args with their definitions.
In the longer run, it might be desirable to emit FX nodes defining these symbolic call args. That way, we could reuse the size computation when the same kernel is called multiple times. However, I wasn't sure if there was an existing way to generate FX nodes from a sympy expression, and implementing that seemed like overkill for the present purposes.
# Test plan
Added a new CI test exercising this feature.
Pull Request resolved: pytorch#157758
Approved by: https://github.com/jansel1 parent 95bc3da commit 92f41cc
File tree
3 files changed
+39
-13
lines changed- test/inductor
- torch/_inductor/codegen
3 files changed
+39
-13
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
393 | 393 | | |
394 | 394 | | |
395 | 395 | | |
| 396 | + | |
| 397 | + | |
| 398 | + | |
| 399 | + | |
| 400 | + | |
| 401 | + | |
| 402 | + | |
| 403 | + | |
| 404 | + | |
| 405 | + | |
| 406 | + | |
| 407 | + | |
| 408 | + | |
| 409 | + | |
| 410 | + | |
| 411 | + | |
396 | 412 | | |
397 | 413 | | |
398 | 414 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
322 | 322 | | |
323 | 323 | | |
324 | 324 | | |
325 | | - | |
| 325 | + | |
326 | 326 | | |
327 | 327 | | |
328 | 328 | | |
| |||
1726 | 1726 | | |
1727 | 1727 | | |
1728 | 1728 | | |
1729 | | - | |
| 1729 | + | |
| 1730 | + | |
1730 | 1731 | | |
1731 | 1732 | | |
1732 | 1733 | | |
| |||
2257 | 2258 | | |
2258 | 2259 | | |
2259 | 2260 | | |
2260 | | - | |
| 2261 | + | |
2261 | 2262 | | |
2262 | | - | |
| 2263 | + | |
| 2264 | + | |
2263 | 2265 | | |
2264 | 2266 | | |
2265 | 2267 | | |
| |||
2268 | 2270 | | |
2269 | 2271 | | |
2270 | 2272 | | |
2271 | | - | |
| 2273 | + | |
2272 | 2274 | | |
2273 | 2275 | | |
2274 | 2276 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
17 | 17 | | |
18 | 18 | | |
19 | 19 | | |
20 | | - | |
| 20 | + | |
21 | 21 | | |
22 | 22 | | |
23 | 23 | | |
| |||
155 | 155 | | |
156 | 156 | | |
157 | 157 | | |
| 158 | + | |
| 159 | + | |
| 160 | + | |
158 | 161 | | |
159 | 162 | | |
160 | 163 | | |
| |||
576 | 579 | | |
577 | 580 | | |
578 | 581 | | |
579 | | - | |
580 | | - | |
581 | | - | |
582 | | - | |
583 | | - | |
584 | | - | |
| 582 | + | |
| 583 | + | |
| 584 | + | |
| 585 | + | |
| 586 | + | |
| 587 | + | |
| 588 | + | |
| 589 | + | |
| 590 | + | |
585 | 591 | | |
586 | 592 | | |
587 | 593 | | |
| |||
691 | 697 | | |
692 | 698 | | |
693 | 699 | | |
694 | | - | |
| 700 | + | |
| 701 | + | |
| 702 | + | |
0 commit comments