Skip to content

Commit fa17d40

Browse files
committed
Unify unpermute api for top1 and topk.
1 parent 17fb8ba commit fa17d40

File tree

2 files changed

+21
-10
lines changed

2 files changed

+21
-10
lines changed

grouped_gemm/backend.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,4 +37,7 @@ def unpermute(input, row_id_map, prob, max_tokens, num_topK):
3737
return backend.unpermute(input, row_id_map, prob, max_tokens, num_topK)
3838

3939
def unpermute_bwd(input_bwd, input_fwd, row_id_map, prob):
40+
# TODO: @Jiang fix the case in kernel to allow None probs
41+
if prob is None:
42+
prob = torch.ones([input_bwd.size(0), 1], dtype=torch.float32, device=input_bwd.device)
4043
return backend.unpermute_bwd(input_bwd, input_fwd, row_id_map, prob)

grouped_gemm/ops.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -58,11 +58,19 @@ def forward(ctx,
5858
input_act: torch.Tensor,
5959
indices: torch.Tensor,
6060
max_token_num: int):
61+
'''
62+
indices: for topK=1, indices in a 1-d tensor of shape [num_tokens],
63+
otherwise, it's a 2-d tensor of shape [num_tokens, topK]
64+
'''
6165
nvtx.range_push("permute_topK forward")
6266
# Empty input check
6367
if not input_act.numel():
6468
return input_act, None
6569

70+
# For top1 case, view the indices as 2D tensor to unify the shape for topk>=2 cases.
71+
if indices.dim() == 1:
72+
indices = indices.view(-1, 1)
73+
6674
# Device check
6775
if input_act.is_cpu:
6876
raise RuntimeError("[Error] The input `input_act` of permute_topK op is on the device: CPU!")
@@ -108,7 +116,7 @@ def forward(ctx,
108116

109117
ctx.row_id_map = row_id_map
110118
ctx.num_tokens = indices.size(0)
111-
ctx.num_topK = indices.size(1)
119+
ctx.num_topK = num_topK
112120
nvtx.range_pop()
113121
return permuted_act, row_id_map
114122

@@ -148,7 +156,7 @@ class UnpermuteMoE_topK(torch.autograd.Function):
148156
def forward(ctx,
149157
input_act: torch.Tensor,
150158
row_id_map: torch.Tensor,
151-
probs: torch.Tensor):
159+
probs: torch.Tensor = None):
152160
nvtx.range_push("unpermute_topK forward")
153161
# Empty input check
154162
if not input_act.numel():
@@ -161,15 +169,15 @@ def forward(ctx,
161169
if row_id_map.is_cpu:
162170
warnings.warn("The input `row_id_map` of unpermute_topK op is on the device: CPU!")
163171
row_id_map = row_id_map.cuda()
164-
if probs.is_cpu:
172+
if probs is not None and probs.is_cpu:
165173
warnings.warn("The input `probs` of unpermute_topK op is on the device: CPU!")
166174
probs = probs.cuda()
167175

168176
# Shape check
169177
if row_id_map.size(0) != input_act.size(0):
170178
raise RuntimeError(f"[Error] unpermute_topK op input `row_id_map` shape mismatch! "
171179
f"Expect {input_act.size(0)}, but got {row_id_map.size(0)}.")
172-
if input_act.size(0) != probs.size(0) * probs.size(1):
180+
if probs is not None and input_act.size(0) != probs.size(0) * probs.size(1):
173181
raise RuntimeError(f"[Error] unpermute_topK op input `probs` shape mismatch! "
174182
f"Expect {input_act.size(0)}, but got {probs.size(0) * probs.size(1)}.")
175183

@@ -178,7 +186,7 @@ def forward(ctx,
178186
warnings.warn(f"The data type of the input `row_id_map` of unpermute_topK op is {row_id_map.dtype}! "
179187
"The recommended type is torch.int32.")
180188
row_id_map = row_id_map.to(torch.int32)
181-
if probs.dtype != torch.float32:
189+
if probs is not None and probs.dtype != torch.float32:
182190
warnings.warn(f"The data type of the input `probs` of unpermute_topK op is {probs.dtype}! "
183191
"The recommended type is torch.float32.")
184192
probs = probs.to(torch.float32)
@@ -190,17 +198,17 @@ def forward(ctx,
190198
if not row_id_map.is_contiguous():
191199
warnings.warn("The input `row_id_map` of unpermute_topK op is discontiguous!")
192200
row_id_map = row_id_map.contiguous()
193-
if not probs.is_contiguous():
201+
if probs is not None and not probs.is_contiguous():
194202
warnings.warn("The input `probs` of unpermute_topK op is discontiguous!")
195203
probs = probs.contiguous()
196204

197-
num_tokens = probs.size(0)
198-
num_topK = probs.size(1)
205+
num_tokens = probs.size(0) if probs is not None else input_act.size(0)
206+
num_topK = probs.size(1) if probs is not None else 1
199207

200208
unpermuted_output = backend.unpermute(
201209
input_act,
202210
row_id_map,
203-
probs,
211+
probs if probs is not None else torch.tensor([]),
204212
num_tokens,
205213
num_topK)
206214

@@ -236,5 +244,5 @@ def backward(ctx, unpermuted_act_grad):
236244
def permute(input_act, indices, max_token_num=0):
237245
return PermuteMoE_topK.apply(input_act, indices, max_token_num)
238246

239-
def unpermute(input_act, row_id_map, probs):
247+
def unpermute(input_act, row_id_map, probs=None):
240248
return UnpermuteMoE_topK.apply(input_act, row_id_map, probs)

0 commit comments

Comments
 (0)