Skip to content

Commit fa393a0

Browse files
committed
reverted unintended file change
Signed-off-by: Eran Geva <[email protected]>
1 parent e519caf commit fa393a0

File tree

1 file changed

+21
-21
lines changed

1 file changed

+21
-21
lines changed

tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -752,27 +752,27 @@ def _(input_list, group, num_lists):
752752
for i in range(0, len(input_list), num_ranks)
753753
]
754754

755-
# @torch.library.register_fake("trtllm::alltoall_helix_native")
756-
# def _(partial_o, softmax_stats, workspace, cp_rank, cp_size):
757-
# # Returns outputs with same shapes as inputs
758-
# return partial_o.new_empty(partial_o.shape), softmax_stats.new_empty(
759-
# softmax_stats.shape)
760-
761-
# @torch.library.register_fake("trtllm::initialize_helix_workspace")
762-
# def _(workspace, cp_rank, cp_size):
763-
# # This op initializes workspace in-place and returns nothing
764-
# return None
765-
766-
# @torch.library.register_fake("trtllm::helix_post_process")
767-
# def _(gathered_o, gathered_stats, scale):
768-
# return gathered_o.new_empty(*gathered_o.shape[1:])
769-
770-
# @torch.library.register_fake("trtllm::helix_post_process_native")
771-
# def _(gathered_o, gathered_stats, scale, cp_dim):
772-
# # Remove the dimension at cp_dim (context parallelism dimension)
773-
# out_shape = list(gathered_o.shape)
774-
# del out_shape[cp_dim]
775-
# return gathered_o.new_empty(*out_shape)
755+
@torch.library.register_fake("trtllm::alltoall_helix_native")
756+
def _(partial_o, softmax_stats, workspace, cp_rank, cp_size):
757+
# Returns outputs with same shapes as inputs
758+
return partial_o.new_empty(partial_o.shape), softmax_stats.new_empty(
759+
softmax_stats.shape)
760+
761+
@torch.library.register_fake("trtllm::initialize_helix_workspace")
762+
def _(workspace, cp_rank, cp_size):
763+
# This op initializes workspace in-place and returns nothing
764+
return None
765+
766+
@torch.library.register_fake("trtllm::helix_post_process")
767+
def _(gathered_o, gathered_stats, scale):
768+
return gathered_o.new_empty(*gathered_o.shape[1:])
769+
770+
@torch.library.register_fake("trtllm::helix_post_process_native")
771+
def _(gathered_o, gathered_stats, scale, cp_dim):
772+
# Remove the dimension at cp_dim (context parallelism dimension)
773+
out_shape = list(gathered_o.shape)
774+
del out_shape[cp_dim]
775+
return gathered_o.new_empty(*out_shape)
776776

777777
@torch.library.register_fake("trtllm::tinygemm2")
778778
def _(input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor):

0 commit comments

Comments
 (0)