Skip to content

Commit f6f0496

Browse files
kiya00crcrparpre-commit-ci[bot]
authored
Support CUDA stream operators in ThunderFX (#2761)
Co-authored-by: Masaki <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 9e12768 commit f6f0496

File tree

3 files changed

+45
-0
lines changed

3 files changed

+45
-0
lines changed

thunder/dynamo/splitter.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,23 @@
2828
if TYPE_CHECKING:
2929
from collections.abc import Callable
3030
from typing import Any
31+
from torch.fx import GraphModule
32+
33+
34+
# TODO: investigate and see if there's a cleaner way to prevent the error.
35+
# ``cannot extract sympy expressions from <torch.cuda.Stream device=cuda:0 cuda_stream=0x0> <class 'torch.cuda.streams.Stream'>``
36+
def _preprocess_cuda_stream_objects(gm: GraphModule) -> None:
37+
"""Preprocess the graph to handle :class:`torch.cuda.Stream` objects.
38+
39+
Since :class:`torch.cuda.Stream` does not have sympy expression apparently,
40+
manually setting its metadata to :obj:`None` to avoid an error such as
41+
``cannot extract sympy expressions from <torch.cuda.Stream device=cuda:0 cuda_stream=0x0> <class 'torch.cuda.streams.Stream'>``
42+
"""
43+
for node in gm.graph.nodes:
44+
if hasattr(node, "meta") and "example_value" in node.meta:
45+
example_value = node.meta["example_value"]
46+
if isinstance(example_value, torch.cuda.Stream):
47+
node.meta["example_value"] = None
3148

3249

3350
def _splitter(
@@ -166,6 +183,7 @@ def callback(node) -> int:
166183
for n in functionctx_nodes_to_del:
167184
gm.graph.erase_node(n)
168185
gm.recompile()
186+
_preprocess_cuda_stream_objects(gm)
169187

170188
# `split_module` iterates over nodes and determines the partition to place them based on the callback.
171189
split_gm: torch.fx.GraphModule = split_module(

thunder/dynamo/utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,11 @@ def forward(self, *args):
220220
self.graph_module.graph = original_graph
221221
self.graph_module.recompile()
222222
self.compiled_fn = self.graph_module
223+
except (NotImplementedError, AssertionError) as e:
224+
warnings.warn(f"torch._inductor.compile failed: {e}. Falling back to eager.")
225+
self.graph_module.graph = original_graph
226+
self.graph_module.recompile()
227+
self.compiled_fn = self.graph_module
223228

224229
return self.compiled_fn(*args)
225230

@@ -495,6 +500,12 @@ def is_node_supported_by_thunder(
495500
target = node.target # Target is the function to call.
496501
if node.op == "call_method":
497502
target = getattr(torch.Tensor, node.target, None)
503+
if target is None and hasattr(torch.cuda.Stream, node.target):
504+
split_reason = SplitReason(
505+
SplitReasonType.MISSING_OP_SUPPORT,
506+
f"node with name {node.name} and target {node.target} is a `torch.cuda.Stream` method which is not supported by Thunder.",
507+
)
508+
return False, split_reason
498509
assert target is not None, f"Failed to find method {node.target}"
499510

500511
# If the operation has automatic registration, we mark it as unsupported as `inductor` might be

thunder/tests/test_dynamo.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1919,3 +1919,19 @@ def fake_compile(*args, **kwargs):
19191919
with patch("torch._inductor.compile", side_effect=fake_compile):
19201920
cfunc = thunderfx(func)
19211921
cfunc(x)
1922+
1923+
1924+
@requiresCUDA
1925+
def test_stream_op():
1926+
def fn():
1927+
cuda = torch.device("cuda")
1928+
s = torch.cuda.streams.Stream()
1929+
s.wait_stream(torch.cuda.current_stream(cuda))
1930+
1931+
cfunc = thunderfx(fn)
1932+
cfunc()
1933+
split_reasons = cfunc._backend.subgraph_infos[0].split_reasons
1934+
assert any(
1935+
"is a `torch.cuda.Stream` method which is not supported by Thunder" in getattr(reason, "info", "")
1936+
for reason in split_reasons
1937+
)

0 commit comments

Comments
 (0)