@@ -58,11 +58,19 @@ def forward(ctx,
5858 input_act : torch .Tensor ,
5959 indices : torch .Tensor ,
6060 max_token_num : int ):
61+ '''
62+ indices: for topK=1, indices in a 1-d tensor of shape [num_tokens],
63+ otherwise, it's a 2-d tensor of shape [num_tokens, topK]
64+ '''
6165 nvtx .range_push ("permute_topK forward" )
6266 # Empty input check
6367 if not input_act .numel ():
6468 return input_act , None
6569
70+ # For top1 case, view the indices as 2D tensor to unify the shape for topk>=2 cases.
71+ if indices .dim () == 1 :
72+ indices = indices .view (- 1 , 1 )
73+
6674 # Device check
6775 if input_act .is_cpu :
6876 raise RuntimeError ("[Error] The input `input_act` of permute_topK op is on the device: CPU!" )
@@ -108,7 +116,7 @@ def forward(ctx,
108116
109117 ctx .row_id_map = row_id_map
110118 ctx .num_tokens = indices .size (0 )
111- ctx .num_topK = indices . size ( 1 )
119+ ctx .num_topK = num_topK
112120 nvtx .range_pop ()
113121 return permuted_act , row_id_map
114122
@@ -148,7 +156,7 @@ class UnpermuteMoE_topK(torch.autograd.Function):
148156 def forward (ctx ,
149157 input_act : torch .Tensor ,
150158 row_id_map : torch .Tensor ,
151- probs : torch .Tensor ):
159+ probs : torch .Tensor = None ):
152160 nvtx .range_push ("unpermute_topK forward" )
153161 # Empty input check
154162 if not input_act .numel ():
@@ -161,15 +169,15 @@ def forward(ctx,
161169 if row_id_map .is_cpu :
162170 warnings .warn ("The input `row_id_map` of unpermute_topK op is on the device: CPU!" )
163171 row_id_map = row_id_map .cuda ()
164- if probs .is_cpu :
172+ if probs is not None and probs .is_cpu :
165173 warnings .warn ("The input `probs` of unpermute_topK op is on the device: CPU!" )
166174 probs = probs .cuda ()
167175
168176 # Shape check
169177 if row_id_map .size (0 ) != input_act .size (0 ):
170178 raise RuntimeError (f"[Error] unpermute_topK op input `row_id_map` shape mismatch! "
171179 f"Expect { input_act .size (0 )} , but got { row_id_map .size (0 )} ." )
172- if input_act .size (0 ) != probs .size (0 ) * probs .size (1 ):
180+ if probs is not None and input_act .size (0 ) != probs .size (0 ) * probs .size (1 ):
173181 raise RuntimeError (f"[Error] unpermute_topK op input `probs` shape mismatch! "
174182 f"Expect { input_act .size (0 )} , but got { probs .size (0 ) * probs .size (1 )} ." )
175183
@@ -178,7 +186,7 @@ def forward(ctx,
178186 warnings .warn (f"The data type of the input `row_id_map` of unpermute_topK op is { row_id_map .dtype } ! "
179187 "The recommended type is torch.int32." )
180188 row_id_map = row_id_map .to (torch .int32 )
181- if probs .dtype != torch .float32 :
189+ if probs is not None and probs .dtype != torch .float32 :
182190 warnings .warn (f"The data type of the input `probs` of unpermute_topK op is { probs .dtype } ! "
183191 "The recommended type is torch.float32." )
184192 probs = probs .to (torch .float32 )
@@ -190,17 +198,17 @@ def forward(ctx,
190198 if not row_id_map .is_contiguous ():
191199 warnings .warn ("The input `row_id_map` of unpermute_topK op is discontiguous!" )
192200 row_id_map = row_id_map .contiguous ()
193- if not probs .is_contiguous ():
201+ if probs is not None and not probs .is_contiguous ():
194202 warnings .warn ("The input `probs` of unpermute_topK op is discontiguous!" )
195203 probs = probs .contiguous ()
196204
197- num_tokens = probs .size (0 )
198- num_topK = probs .size (1 )
205+ num_tokens = probs .size (0 ) if probs is not None else input_act . size ( 0 )
206+ num_topK = probs .size (1 ) if probs is not None else 1
199207
200208 unpermuted_output = backend .unpermute (
201209 input_act ,
202210 row_id_map ,
203- probs ,
211+ probs if probs is not None else torch . tensor ([]) ,
204212 num_tokens ,
205213 num_topK )
206214
@@ -236,5 +244,5 @@ def backward(ctx, unpermuted_act_grad):
236244def permute (input_act , indices , max_token_num = 0 ):
237245 return PermuteMoE_topK .apply (input_act , indices , max_token_num )
238246
239- def unpermute (input_act , row_id_map , probs ):
247+ def unpermute (input_act , row_id_map , probs = None ):
240248 return UnpermuteMoE_topK .apply (input_act , row_id_map , probs )
0 commit comments