-
Notifications
You must be signed in to change notification settings - Fork 107
Use torch._inductor.compile
for ThunderFX fallback entrypoint
#2600
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Note on (C) Lack of |
Tests are failing on Windows with:
I'm trying to prepare a Windows environment for investigation. |
Thank you, @shino16. It seems that this should solve the torch compile |
Thank you for your comments! As @mattteochen says, we should gain performance benefit from Inductor for fallback submodules, but latency will increase for the same reason. |
IIUC, this would happen only for the first run, right? |
Fixes #2539 and other issues it caused.
What this fixes
torch.compile
skips the lowering ofGraphModule
that Thunder's splitter emits. This PR passes theGraphModule
totorch._inductor.compile
instead, which then callstorch._inductor.compile_fx.compile_fx
, to make sure Inductor lowers it.Although it is found in #2527 (comment) that wrapping the
GraphModule
in a newModule
is sufficient, it was found in #2551 (comment) that Dynamo cannot trace the enter/exit oftorch.autocast()
region.Why this is more than just replacing the
torch.compile
call(A) Graph Splitting
Inductor assumes that the output node of the
GraphModule
has the formreturn (t0, ..., tN)
. Usuallytorch.fx.passes.split_module.split_module
(the underlying algorithm of the graph splitter) makes sure this holds for the submodules it creates, but when the splitter cuts out a node that returns a tuple, the output value will look likereturn tuple_value
. This breaks Inductor.Example
This is a complaint for
return tag_activation_checkpoint
, whichcompile_fx.make_graph_return_tuple
converts toreturn [tag_activation_checkpoint]
.We work around this by grouping the subsequent
getitem
node together into the node that was cut out. This hack is applied to nodes whoseexample_value
is a tuple, which includestorch.ops.higher_order.tag_activation_checkpoint
andtorch.ops.higher_order.autograd_function_apply
.(B) Graph break inside Inductor's fallback submodule
Inductor is not responsible for breaking the graph, so it raises an exception when it encounters something it cannot compile.
Example
We must catch the exception and let Dynamo handle it by falling further back to
torch.compile
. I except there are a lot more exceptions thattorch._inductor.compiler
may raise. There are no common supertype for such exceptions, so we will have to cover each of them when it pops up.(C) Lack of
example_value
Inductor requires instances of
FakeTensor
s representing the inputs. This is sometimes unavailable: see #2429. We will simply fall back totorch.compile
in such cases.Concern
This PR does not cover cases where a submodule containing
torch.autocast()
region falls back totorch.compile
(when (B) or (C) applies).AssertionError
is raised in such cases (see the new test which is xfailed).Should we catch such errors and resort to eager mode? I wonder how big the impact will be.