-
Notifications
You must be signed in to change notification settings - Fork 39
Open
Description
Currently if you try to use torchcomms with dynamo or torch.compile it throws an error:
torch._dynamo.exc.Unsupported: Unsupported method call
Explanation: Dynamo does not know how to trace method `all_reduce` of class `TorchComm`
Hint: Avoid calling `TorchComm.all_reduce` in your code.
Hint: Please report an issue to PyTorch.
Developer debug context: call_method UserDefinedObjectVariable(TorchComm) all_reduce [LazyVariableTracker(unrealized: <class 'torch.Tensor'>), LazyVariableTracker(unrealized: <class 'pybind11_builtins.pybind11_static_property'>)] {'async_op': ConstantVariable(bool: False)}
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0156.html
from user code:
File "/home/tristanr/scripts/torchcomms_dynamo.py", line 15, in my_func
comm.all_reduce(t, torchcomms.ReduceOp.SUM, async_op=False)
We want to support this for at least graph capture cases.
"""
Invoke with:
torchrun --nnodes 1 --nproc_per_node=gpu ~/scripts/torchcomms_dynamo.py
"""
import torch
import torch.distributed as dist
import torchcomms
comm = torchcomms.new_comm('ncclx', torch.device('cuda'), store=None, name='1234')
t = torch.ones(10, device=comm.get_device())
def my_func(t):
comm.all_reduce(t, torchcomms.ReduceOp.SUM, async_op=False)
t *= 10
return t
try:
compiled_func = torch.compile(my_func, fullgraph=True)
compiled_func(t)
finally:
comm.finalize()In c10d::ProcessGroup we register one torch op per collective and always run them through the dispatcher to support this tracing. See https://github.com/pytorch/pytorch/blob/main/torch/csrc/distributed/c10d/ProcessGroup.hpp#L270-L295 for more details.
Metadata
Metadata
Assignees
Labels
No labels