Skip to content

Commit 128cc86

Browse files
committed
Wrap in submodule
1 parent add4688 commit 128cc86

File tree

1 file changed

+12
-1
lines changed

1 file changed

+12
-1
lines changed

thunder/dynamo/splitter.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,18 @@ def is_thunder_supported_partition(node: torch.fx.Node) -> bool:
198198
)
199199
elif node.name.startswith("submod"): # For inductor
200200
graph_module = getattr(split_gm, node.name)
201-
jit_fn = torch_inductor(graph_module)
201+
202+
class Wrapped(torch.nn.Module):
203+
def __init__(self, gm):
204+
super().__init__()
205+
self.gm = gm
206+
207+
def forward(self, *a):
208+
return self.gm(*a)
209+
210+
# Make sure Inductor does not skip graph_module's compilation by wrapping it
211+
# See https://github.com/Lightning-AI/lightning-thunder/issues/2527#issuecomment-3345877210
212+
jit_fn = torch_inductor(Wrapped(graph_module))
202213
# Update the node name from "submod_*" to "inductor_*" for more user-friendly names
203214
update_node_and_submodule(split_gm, node, node.name.replace("submod", "inductor"), jit_fn)
204215
submodule_to_compiled_fns[getattr(original_split_gm, node_name)] = CompiledFunction(

0 commit comments

Comments
 (0)