Skip to content

Commit 17ed16f

Browse files
committed
Revert "nccl op changes"
This reverts commit ee4a9c8.
1 parent ee4a9c8 commit 17ed16f

File tree

4 files changed

+12
-24
lines changed

4 files changed

+12
-24
lines changed

py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import logging
44
from typing import Dict, Sequence, Tuple, Union
55

6-
import tensorrt as trt
76
from torch.fx.node import Argument, Target
87
from torch_tensorrt.dynamo._SourceIR import SourceIR
98
from torch_tensorrt.dynamo.conversion import impl
@@ -17,6 +16,8 @@
1716
tensorrt_fused_nccl_reduce_scatter_op,
1817
)
1918

19+
import tensorrt as trt
20+
2021
_LOGGER: logging.Logger = logging.getLogger(__name__)
2122

2223
if load_tensorrt_llm():
@@ -29,7 +30,7 @@ def fused_nccl_gather(
2930
kwargs: Dict[str, Argument],
3031
name: str,
3132
) -> Union[trt.ITensor, Sequence[trt.ITensor]]:
32-
return impl.nccl_ops.nccl_gather(
33+
return impl.distributed.nccl_gather(
3334
ctx,
3435
target,
3536
SourceIR.ATEN,
@@ -45,14 +46,15 @@ def fused_nccl_reduce_scatter(
4546
kwargs: Dict[str, Argument],
4647
name: str,
4748
) -> Union[trt.ITensor, Sequence[trt.ITensor]]:
48-
return impl.nccl_ops.nccl_reduce_scatter(
49+
return impl.distributed.nccl_reduce_scatter(
4950
ctx,
5051
target,
5152
SourceIR.ATEN,
5253
name,
5354
[args[0]],
5455
)
5556

57+
breakpoint()
5658
else:
5759
_LOGGER.debug(
5860
"Did not load torch.distributed converters since TensorRT-LLM is not available"

py/torch_tensorrt/dynamo/lowering/passes/_replace_complex_placeholder_to_tuple.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -106,12 +106,7 @@ def update_node_meta(node: torch.fx.Node, fake_mode: FakeTensorMode) -> None:
106106

107107
if op_target in shape_inference_funcs:
108108
new_shape = shape_inference_funcs[op_target](node)
109-
new_node_dtype = None
110-
if node.meta["val"].dtype == torch.complex64:
111-
new_node_dtype = torch.float32
112-
else:
113-
new_node_dtype = torch.float64
114-
real_tensor = torch.empty(new_shape, dtype=new_node_dtype)
109+
real_tensor = torch.empty(new_shape, dtype=node.meta["val"].dtype)
115110
node.meta["val"] = fake_mode.from_tensor(real_tensor)
116111
else:
117112
print("No shape for the inference function", {op_name})

py/torch_tensorrt/dynamo/lowering/passes/fuse_distributed_ops.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def fuse_distributed_ops(
4949
== torch.ops._c10d_functional.wait_tensor.default
5050
):
5151
wait_tensor_node = list(node.users)[0]
52+
fused_op = None
5253
if node.target == torch.ops._c10d_functional.all_gather_into_tensor.default:
5354
with gm.graph.inserting_after(wait_tensor_node):
5455
fused_node = gm.graph.create_node(
@@ -57,12 +58,11 @@ def fuse_distributed_ops(
5758
args=(node.args[0], node.args[1], node.args[2]),
5859
)
5960
else:
60-
with gm.graph.inserting_after(wait_tensor_node):
61-
fused_node = gm.graph.create_node(
62-
op="call_function",
63-
target=tensorrt_fused_nccl_reduce_scatter_op, # Define your custom fused function
64-
args=(node.args[0], node.args[1], node.args[2], node.args[3]),
65-
)
61+
fused_node = gm.graph.create_node(
62+
op="call_function",
63+
target=tensorrt_fused_nccl_reduce_scatter_op, # Define your custom fused function
64+
args=(node.args[0], node.args[1], node.args[2], node.args[3]),
65+
)
6666

6767
wait_tensor_node.replace_all_uses_with(fused_node)
6868
fused_node.meta.update(node.meta)

py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -364,15 +364,6 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
364364
(i.contiguous() if isinstance(i, torch.Tensor) else torch.tensor(i).cuda())
365365
for i in inputs
366366
]
367-
368-
for i, contiguous_input in enumerate(contiguous_inputs):
369-
if contiguous_input.dtype == torch.complex64:
370-
contiguous_input_real = contiguous_input.real
371-
contiguous_input_imag = contiguous_input.imag
372-
contiguous_inputs[i] = torch.stack(
373-
(contiguous_input_real, contiguous_input_imag), dim=-1
374-
)
375-
376367
with (
377368
torch.autograd.profiler.record_function("PythonTorchTensorRTModule:Forward")
378369
if self.profiling_enabled

0 commit comments

Comments
 (0)