@@ -752,27 +752,27 @@ def _(input_list, group, num_lists):
752752 for i in range (0 , len (input_list ), num_ranks )
753753 ]
754754
755- # @torch.library.register_fake("trtllm::alltoall_helix_native")
756- # def _(partial_o, softmax_stats, workspace, cp_rank, cp_size):
757- # # Returns outputs with same shapes as inputs
758- # return partial_o.new_empty(partial_o.shape), softmax_stats.new_empty(
759- # softmax_stats.shape)
760-
761- # @torch.library.register_fake("trtllm::initialize_helix_workspace")
762- # def _(workspace, cp_rank, cp_size):
763- # # This op initializes workspace in-place and returns nothing
764- # return None
765-
766- # @torch.library.register_fake("trtllm::helix_post_process")
767- # def _(gathered_o, gathered_stats, scale):
768- # return gathered_o.new_empty(*gathered_o.shape[1:])
769-
770- # @torch.library.register_fake("trtllm::helix_post_process_native")
771- # def _(gathered_o, gathered_stats, scale, cp_dim):
772- # # Remove the dimension at cp_dim (context parallelism dimension)
773- # out_shape = list(gathered_o.shape)
774- # del out_shape[cp_dim]
775- # return gathered_o.new_empty(*out_shape)
755+ @torch .library .register_fake ("trtllm::alltoall_helix_native" )
756+ def _ (partial_o , softmax_stats , workspace , cp_rank , cp_size ):
757+ # Returns outputs with same shapes as inputs
758+ return partial_o .new_empty (partial_o .shape ), softmax_stats .new_empty (
759+ softmax_stats .shape )
760+
761+ @torch .library .register_fake ("trtllm::initialize_helix_workspace" )
762+ def _ (workspace , cp_rank , cp_size ):
763+ # This op initializes workspace in-place and returns nothing
764+ return None
765+
766+ @torch .library .register_fake ("trtllm::helix_post_process" )
767+ def _ (gathered_o , gathered_stats , scale ):
768+ return gathered_o .new_empty (* gathered_o .shape [1 :])
769+
770+ @torch .library .register_fake ("trtllm::helix_post_process_native" )
771+ def _ (gathered_o , gathered_stats , scale , cp_dim ):
772+ # Remove the dimension at cp_dim (context parallelism dimension)
773+ out_shape = list (gathered_o .shape )
774+ del out_shape [cp_dim ]
775+ return gathered_o .new_empty (* out_shape )
776776
777777 @torch .library .register_fake ("trtllm::tinygemm2" )
778778 def _ (input : torch .Tensor , weight : torch .Tensor , bias : torch .Tensor ):
0 commit comments