Skip to content

Commit 7ceed47

Browse files
authored
rls2.6: Fuse moegate ops for deepseekv3 (#3488)
1 parent 6fc2ad8 commit 7ceed47

File tree

5 files changed

+195
-33
lines changed

5 files changed

+195
-33
lines changed

csrc/cpu/aten/MoE.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,8 @@ std::tuple<at::Tensor, at::Tensor> deepseek_moegate(
298298
const int64_t n_group,
299299
const int64_t topk_group,
300300
const int64_t n_routed_experts,
301-
const int64_t top_k) {
301+
const int64_t top_k,
302+
c10::optional<at::Tensor> e_score_cbias) {
302303
RECORD_FUNCTION("ipex::deepseek_moegate", c10::ArrayRef<c10::IValue>({}));
303304

304305
return deepseek_moegate_kernel_stub(
@@ -309,7 +310,8 @@ std::tuple<at::Tensor, at::Tensor> deepseek_moegate(
309310
n_group,
310311
topk_group,
311312
n_routed_experts,
312-
top_k);
313+
top_k,
314+
e_score_cbias);
313315
}
314316
} // namespace cpu
315317
} // namespace torch_ipex
@@ -374,7 +376,7 @@ TORCH_LIBRARY_FRAGMENT(torch_ipex, m) {
374376
c10::DispatchKey::CPU,
375377
torch_ipex::cpu::deepseek_moe_woq);
376378
m.def(
377-
"deepseek_moegate(Tensor hidden_states, Tensor scores, Tensor routed_scaling_factor, int n_group, int topk_group, int n_routed_experts, int top_k) -> (Tensor, Tensor)");
379+
"deepseek_moegate(Tensor hidden_states, Tensor scores, Tensor routed_scaling_factor, int n_group, int topk_group, int n_routed_experts, int top_k, Tensor? e_score_cbias=None) -> (Tensor, Tensor)");
378380
m.impl(
379381
"deepseek_moegate",
380382
c10::DispatchKey::CPU,

csrc/cpu/aten/MoE.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,8 @@ std::tuple<at::Tensor, at::Tensor> deepseek_moegate(
9797
const int64_t n_group,
9898
const int64_t topk_group,
9999
const int64_t n_routed_experts,
100-
const int64_t top_k);
100+
const int64_t top_k,
101+
c10::optional<at::Tensor> e_score_cbias);
101102
using mixtral_moe_tpp_kernel_fn = at::Tensor (*)(
102103
const at::Tensor& hidden_states,
103104
const at::Tensor& top_x,
@@ -179,7 +180,8 @@ using deepseek_moegate_kernel_fn = std::tuple<at::Tensor, at::Tensor> (*)(
179180
const int64_t n_group,
180181
const int64_t topk_group,
181182
const int64_t n_routed_experts,
182-
const int64_t top_k);
183+
const int64_t top_k,
184+
c10::optional<at::Tensor> e_score_cbias);
183185
IPEX_DECLARE_DISPATCH(mixtral_moe_tpp_kernel_fn, mixtral_moe_tpp_kernel_stub);
184186
IPEX_DECLARE_DISPATCH(deepseek_moe_tpp_kernel_fn, deepseek_moe_tpp_kernel_stub);
185187
IPEX_DECLARE_DISPATCH(mixtral_moe_woq_kernel_fn, mixtral_moe_woq_kernel_stub);

csrc/cpu/aten/kernels/MoEKrnl.cpp

Lines changed: 112 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,6 @@ at::Tensor mixtral_moe_woq_kernl_impl(
292292

293293
template <typename T>
294294
std::tuple<at::Tensor, at::Tensor> deepseek_moegate_kernel(
295-
const at::Tensor& hidden_states,
296295
const at::Tensor& scores,
297296
const at::Tensor& routed_scaling_factor,
298297
const int64_t n_group,
@@ -302,7 +301,7 @@ std::tuple<at::Tensor, at::Tensor> deepseek_moegate_kernel(
302301
auto group_size = n_routed_experts / n_group;
303302
auto n = scores.size(0);
304303
auto h = scores.size(1);
305-
auto group_scores = at::empty({n, n_group}, hidden_states.options());
304+
auto group_scores = at::empty({n, n_group}, scores.options());
306305
auto group_scores_ptr = group_scores.data_ptr<T>();
307306
auto scores_ptr = scores.data_ptr<T>();
308307
#pragma omp parallel for collapse(2)
@@ -319,7 +318,7 @@ std::tuple<at::Tensor, at::Tensor> deepseek_moegate_kernel(
319318
}
320319

321320
auto group_idx = std::get<1>(group_scores.topk(topk_group, -1, true, false));
322-
auto tmp_scores = at::zeros_like(scores, hidden_states.options());
321+
auto tmp_scores = at::zeros_like(scores, scores.options());
323322
auto group_idx_ptr = group_idx.data_ptr<int64_t>();
324323
auto tmp_scores_ptr = tmp_scores.data_ptr<T>();
325324
T scale = routed_scaling_factor.item<T>();
@@ -339,17 +338,117 @@ std::tuple<at::Tensor, at::Tensor> deepseek_moegate_kernel(
339338
return std::make_tuple(topk, topk_weight);
340339
}
341340

341+
template <typename T>
342+
std::tuple<at::Tensor, at::Tensor> deepseekv3_moegate_kernel(
343+
const at::Tensor& scores,
344+
const at::Tensor& routed_scaling_factor,
345+
const int64_t n_group,
346+
const int64_t topk_group,
347+
const int64_t n_routed_experts,
348+
const int64_t top_k,
349+
const at::Tensor& e_score_cbias) {
350+
auto group_size = n_routed_experts / n_group;
351+
auto n = scores.size(0);
352+
auto h = scores.size(1);
353+
auto scores_for_choice = at::empty({n, n_group, group_size}, at::kFloat);
354+
auto scores_ptr = scores.data_ptr<T>();
355+
auto scores_for_choice_ptr = scores_for_choice.data_ptr<float>();
356+
auto scores_for_choice_stride0 = scores_for_choice.stride(0);
357+
auto e_score_cbias_ptr = e_score_cbias.data_ptr<float>();
358+
#pragma omp parallel for collapse(2)
359+
for (auto i = 0; i < n; i++) {
360+
for (auto j = 0; j < n_group; j++) {
361+
auto k_start = j * group_size;
362+
auto k_end = k_start + group_size;
363+
for (auto k = k_start; k < k_end; k++) {
364+
scores_for_choice_ptr[i * scores_for_choice_stride0 + k] =
365+
scores_ptr[i * h + k] + e_score_cbias_ptr[k];
366+
}
367+
}
368+
}
369+
auto group_scores =
370+
std::get<0>(scores_for_choice.topk(2, -1, true, false)).sum(-1);
371+
auto group_idx = std::get<1>(group_scores.topk(topk_group, -1, true, false));
372+
auto tmp_scores = at::zeros_like(scores, at::kFloat);
373+
auto group_idx_ptr = group_idx.data_ptr<int64_t>();
374+
auto tmp_scores_ptr = tmp_scores.data_ptr<float>();
375+
#pragma omp parallel for collapse(2)
376+
for (auto i = 0; i < n; i++) {
377+
for (auto j = 0; j < topk_group; j++) {
378+
auto selected_idx = group_idx_ptr[i * topk_group + j];
379+
auto k_start = selected_idx * group_size;
380+
auto k_end = k_start + group_size;
381+
for (auto k = k_start; k < k_end; k++) {
382+
tmp_scores_ptr[i * h + k] =
383+
scores_for_choice_ptr[i * scores_for_choice_stride0 + k];
384+
}
385+
}
386+
}
387+
auto topk = std::get<1>(tmp_scores.topk(top_k, -1, true, false));
388+
auto topk_weight = scores.gather(1, topk);
389+
return std::make_tuple(topk, topk_weight);
390+
}
391+
342392
std::tuple<at::Tensor, at::Tensor> deepseek_moegate_kernel_impl(
343393
const at::Tensor& hidden_states,
344394
const at::Tensor& scores,
345395
const at::Tensor& routed_scaling_factor,
346396
const int64_t n_group,
347397
const int64_t topk_group,
348398
const int64_t n_routed_experts,
349-
const int64_t top_k) {
399+
const int64_t top_k,
400+
c10::optional<at::Tensor> e_score_cbias) {
401+
if (e_score_cbias.has_value()) { // deepseekv3
402+
if (hidden_states.scalar_type() == at::ScalarType::Float) {
403+
return deepseekv3_moegate_kernel<float>(
404+
scores,
405+
routed_scaling_factor,
406+
n_group,
407+
topk_group,
408+
n_routed_experts,
409+
top_k,
410+
e_score_cbias.value());
411+
} else if (hidden_states.scalar_type() == at::ScalarType::BFloat16) {
412+
return deepseekv3_moegate_kernel<at::BFloat16>(
413+
scores,
414+
routed_scaling_factor,
415+
n_group,
416+
topk_group,
417+
n_routed_experts,
418+
top_k,
419+
e_score_cbias.value());
420+
} else if (hidden_states.scalar_type() == at::ScalarType::Half) {
421+
return deepseekv3_moegate_kernel<at::Half>(
422+
scores,
423+
routed_scaling_factor,
424+
n_group,
425+
topk_group,
426+
n_routed_experts,
427+
top_k,
428+
e_score_cbias.value());
429+
}
430+
auto n = hidden_states.size(0);
431+
auto group_size = n_routed_experts / n_group;
432+
auto scores_for_choice =
433+
scores.view({n, -1}) + e_score_cbias.value().unsqueeze(0);
434+
auto group_scores = std::get<0>(
435+
scores_for_choice.view({n, n_group, -1}).topk(2, -1, true, false));
436+
group_scores = group_scores.sum(-1);
437+
auto group_idx =
438+
std::get<1>(group_scores.topk(topk_group, -1, true, false));
439+
auto group_mask = at::zeros_like(group_scores);
440+
group_mask.scatter_(1, group_idx, 1);
441+
auto score_mask = group_mask.unsqueeze(-1)
442+
.expand({n, n_group, group_size})
443+
.reshape({n, -1});
444+
auto tmp_scores =
445+
scores_for_choice.masked_fill(~score_mask.to(at::kBool), 0.0);
446+
auto topk = std::get<1>(tmp_scores.topk(top_k, -1, true, false));
447+
auto topk_weight = scores.gather(1, topk);
448+
return std::make_tuple(topk, topk_weight.to(hidden_states.scalar_type()));
449+
}
350450
if (hidden_states.scalar_type() == at::ScalarType::Float) {
351451
return deepseek_moegate_kernel<float>(
352-
hidden_states,
353452
scores,
354453
routed_scaling_factor,
355454
n_group,
@@ -358,7 +457,14 @@ std::tuple<at::Tensor, at::Tensor> deepseek_moegate_kernel_impl(
358457
top_k);
359458
} else if (hidden_states.scalar_type() == at::ScalarType::BFloat16) {
360459
return deepseek_moegate_kernel<at::BFloat16>(
361-
hidden_states,
460+
scores,
461+
routed_scaling_factor,
462+
n_group,
463+
topk_group,
464+
n_routed_experts,
465+
top_k);
466+
} else if (hidden_states.scalar_type() == at::ScalarType::Half) {
467+
return deepseek_moegate_kernel<at::Half>(
362468
scores,
363469
routed_scaling_factor,
364470
n_group,

intel_extension_for_pytorch/transformers/models/reference/models.py

Lines changed: 10 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -5875,27 +5875,16 @@ def Deepseek_MoEGate_forward(self, hidden_states):
58755875
self.top_k,
58765876
)
58775877
elif self.topk_method == "noaux_tc":
5878-
# TODO: fuse the following ops.
5879-
n = hidden_states.size(0)
5880-
scores_for_choice = scores.view(n, -1) + self.e_score_correction_bias.unsqueeze(
5881-
0
5882-
)
5883-
group_scores = (
5884-
scores_for_choice.view(n, self.n_group, -1).topk(2, dim=-1)[0].sum(dim=-1)
5885-
) # [n, n_group]
5886-
group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[
5887-
1
5888-
] # [n, top_k_group]
5889-
group_mask = torch.zeros_like(group_scores) # [n, n_group]
5890-
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
5891-
score_mask = (
5892-
group_mask.unsqueeze(-1)
5893-
.expand(n, self.n_group, self.n_routed_experts // self.n_group)
5894-
.reshape(n, -1)
5895-
) # [n, e]
5896-
tmp_scores = scores_for_choice.masked_fill(~score_mask.bool(), 0.0) # [n, e]
5897-
_, topk_idx = torch.topk(tmp_scores, k=self.top_k, dim=-1, sorted=False)
5898-
topk_weight = scores.gather(1, topk_idx)
5878+
topk_idx, topk_weight = torch.ops.torch_ipex.deepseek_moegate(
5879+
hidden_states,
5880+
scores,
5881+
torch.tensor(self.routed_scaling_factor),
5882+
self.n_group,
5883+
self.topk_group,
5884+
self.n_routed_experts,
5885+
self.top_k,
5886+
torch.tensor(self.e_score_correction_bias),
5887+
)
58995888

59005889
# norm gate to sum 1
59015890
if self.top_k > 1 and self.norm_topk_prob:

tests/cpu/test_cpu_ops.py

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2245,7 +2245,10 @@ def moe_gate(scores):
22452245
topk_weight = topk_weight * routed_scaling_factor
22462246
return topk_idx, topk_weight
22472247

2248-
for dtype in [torch.float32, torch.bfloat16]:
2248+
dtypes = [torch.float32, torch.bfloat16]
2249+
if core.onednn_has_fp16_support():
2250+
dtypes.append(torch.float16)
2251+
for dtype in dtypes:
22492252
hidden_states = torch.rand(10, 2560, dtype=dtype)
22502253
weight = torch.rand(16, 2560, dtype=dtype)
22512254
logits = torch.nn.functional.linear(
@@ -2267,6 +2270,66 @@ def moe_gate(scores):
22672270
self.assertEqual(topk_idx_ref, topk_idx_ipex)
22682271
self.assertEqual(topk_weight_ref, topk_weight_ipex)
22692272

2273+
def test_deepseekv3_moegate(self):
2274+
n_group = 8
2275+
topk_group = 3
2276+
n_routed_experts = 16
2277+
top_k = 6
2278+
routed_scaling_factor = 16.0
2279+
e_score_correction_bias = torch.rand(n_routed_experts)
2280+
2281+
def moe_gate(scores):
2282+
n, h = scores.shape
2283+
scores_for_choice = scores.view(n, -1) + e_score_correction_bias.unsqueeze(
2284+
0
2285+
)
2286+
group_scores = (
2287+
scores_for_choice.view(n, n_group, -1).topk(2, dim=-1)[0].sum(dim=-1)
2288+
) # [n, n_group]
2289+
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[
2290+
1
2291+
] # [n, top_k_group]
2292+
group_mask = torch.zeros_like(group_scores) # [n, n_group]
2293+
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
2294+
score_mask = (
2295+
group_mask.unsqueeze(-1)
2296+
.expand(n, n_group, n_routed_experts // n_group)
2297+
.reshape(n, -1)
2298+
) # [n, e]
2299+
tmp_scores = scores_for_choice.masked_fill(
2300+
~score_mask.bool(), 0.0
2301+
) # [n, e]
2302+
_, topk_idx = torch.topk(tmp_scores, k=top_k, dim=-1, sorted=False)
2303+
topk_weight = scores.gather(1, topk_idx)
2304+
2305+
return topk_idx, topk_weight
2306+
2307+
dtypes = [torch.float32, torch.bfloat16]
2308+
if core.onednn_has_fp16_support():
2309+
dtypes.append(torch.float16)
2310+
for dtype in dtypes:
2311+
hidden_states = torch.rand(10, 2560, dtype=dtype)
2312+
weight = torch.rand(16, 2560, dtype=dtype)
2313+
logits = torch.nn.functional.linear(
2314+
hidden_states.type(torch.float32), weight.type(torch.float32), None
2315+
)
2316+
scores = logits.sigmoid()
2317+
enable_autocast = dtype == torch.bfloat16
2318+
with torch.no_grad(), torch.cpu.amp.autocast(enabled=enable_autocast):
2319+
topk_idx_ref, topk_weight_ref = moe_gate(scores)
2320+
topk_idx_ipex, topk_weight_ipex = torch.ops.torch_ipex.deepseek_moegate(
2321+
hidden_states,
2322+
scores.to(dtype),
2323+
torch.tensor(routed_scaling_factor),
2324+
n_group,
2325+
topk_group,
2326+
n_routed_experts,
2327+
top_k,
2328+
torch.tensor(e_score_correction_bias),
2329+
)
2330+
self.assertEqual(topk_idx_ref, topk_idx_ipex)
2331+
self.assertEqual(topk_weight_ref, topk_weight_ipex)
2332+
22702333

22712334
if __name__ == "__main__":
22722335
test = unittest.main()

0 commit comments

Comments
 (0)