Skip to content

Commit 18eeefa

Browse files
authored
Optimize deepseek (#3437) (#3441)
1 parent 9d58cd2 commit 18eeefa

File tree

14 files changed

+720
-142
lines changed

14 files changed

+720
-142
lines changed

csrc/cpu/aten/MoE.cpp

Lines changed: 103 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ IPEX_DEFINE_DISPATCH(mixtral_moe_tpp_kernel_stub);
99
IPEX_DEFINE_DISPATCH(mixtral_moe_woq_kernel_stub);
1010
IPEX_DEFINE_DISPATCH(deepseek_moe_woq_kernel_stub);
1111
IPEX_DEFINE_DISPATCH(mixtral_moe_kernel_stub);
12+
IPEX_DEFINE_DISPATCH(deepseek_moegate_kernel_stub);
1213

1314
at::Tensor mixtral_moe_tpp(
1415
const at::Tensor& hidden_states,
@@ -39,25 +40,53 @@ at::Tensor mixtral_moe_tpp(
3940
is_distributed);
4041
}
4142

43+
inline std::tuple<
44+
std::vector<long>,
45+
std::vector<std::vector<long>>,
46+
std::vector<std::vector<long>>>
47+
get_expert_topx_idx(const at::Tensor& topk_ids, const int num_experts) {
48+
auto token_num = topk_ids.size(0);
49+
auto topk = topk_ids.size(1);
50+
std::vector<long> expert_selected(num_experts, 0);
51+
std::vector<std::vector<long>> expert_idx(num_experts);
52+
std::vector<std::vector<long>> expert_top_x(num_experts);
53+
auto topk_ids_ptr = topk_ids.data_ptr<long>();
54+
auto topk_ids_stride0 = topk_ids.stride(0);
55+
for (auto i = 0; i < token_num; i++) {
56+
for (auto j = 0; j < topk; j++) {
57+
auto expert_id = topk_ids_ptr[i * topk_ids_stride0 + j];
58+
expert_selected[expert_id] += 1;
59+
expert_top_x[expert_id].push_back(i);
60+
expert_idx[expert_id].push_back(j);
61+
}
62+
}
63+
return std::make_tuple(expert_selected, expert_idx, expert_top_x);
64+
}
65+
4266
at::Tensor deepseek_moe_tpp(
4367
const at::Tensor& hidden_states,
44-
const at::Tensor& expert_mask,
68+
const at::Tensor& topk_ids,
4569
const std::vector<at::Tensor>& gate_wei,
4670
const std::vector<at::Tensor>& up_wei,
4771
const std::vector<at::Tensor>& down_wei,
4872
bool tpp_fallback,
4973
const at::Tensor& routing_weights,
50-
at::Tensor& output,
5174
bool is_distributed) {
5275
RECORD_FUNCTION("ipex::deepseek_moe_tpp", c10::ArrayRef<c10::IValue>({}));
5376

77+
auto output = at::zeros_like(hidden_states);
5478
int num_experts = gate_wei.size();
79+
std::vector<long> expert_selected;
80+
std::vector<std::vector<long>> expert_idx, expert_top_x;
81+
std::tie(expert_selected, expert_idx, expert_top_x) =
82+
get_expert_topx_idx(topk_ids, num_experts);
5583
for (auto i = 0; i < num_experts; i++) {
56-
auto non_zero = expert_mask[i].nonzero();
57-
if (non_zero.sizes()[0] == 0)
84+
if (expert_selected[i] == 0)
5885
continue;
59-
auto idx = non_zero.select(1, 0);
60-
auto top_x = non_zero.select(1, 1);
86+
auto idx =
87+
torch::from_blob(expert_idx[i].data(), {expert_selected[i]}, at::kLong);
88+
auto top_x = torch::from_blob(
89+
expert_top_x[i].data(), {expert_selected[i]}, at::kLong);
6190
output = mixtral_moe_tpp_kernel_stub(
6291
kCPU,
6392
hidden_states,
@@ -111,26 +140,30 @@ at::Tensor mixtral_moe(
111140

112141
at::Tensor deepseek_moe(
113142
const at::Tensor& hidden_states,
114-
const at::Tensor& expert_mask,
143+
const at::Tensor& topk_ids,
115144
const std::vector<at::Tensor>& gate_wei,
116145
const std::vector<c10::intrusive_ptr<LinearOpContext>>& gate_op_ctx,
117146
const std::vector<at::Tensor>& up_wei,
118147
const std::vector<c10::intrusive_ptr<LinearOpContext>>& up_op_ctx,
119148
const std::vector<at::Tensor>& down_wei,
120149
const std::vector<c10::intrusive_ptr<LinearOpContext>>& down_op_ctx,
121150
const at::Tensor& routing_weights,
122-
at::Tensor& output,
123151
bool is_distributed) {
124152
RECORD_FUNCTION("ipex::deepseek_moe", c10::ArrayRef<c10::IValue>({}));
125153

154+
auto output = at::zeros_like(hidden_states);
126155
int num_experts = gate_wei.size();
156+
std::vector<long> expert_selected;
157+
std::vector<std::vector<long>> expert_idx, expert_top_x;
158+
std::tie(expert_selected, expert_idx, expert_top_x) =
159+
get_expert_topx_idx(topk_ids, num_experts);
127160
for (auto i = 0; i < num_experts; i++) {
128-
auto non_zero = expert_mask[i].nonzero();
129-
if (non_zero.sizes()[0] == 0)
161+
if (expert_selected[i] == 0)
130162
continue;
131-
auto idx = non_zero.select(1, 0);
132-
auto top_x = non_zero.select(1, 1);
133-
163+
auto idx =
164+
torch::from_blob(expert_idx[i].data(), {expert_selected[i]}, at::kLong);
165+
auto top_x = torch::from_blob(
166+
expert_top_x[i].data(), {expert_selected[i]}, at::kLong);
134167
output = mixtral_moe_kernel_stub(
135168
kCPU,
136169
hidden_states,
@@ -152,25 +185,30 @@ at::Tensor deepseek_moe(
152185

153186
at::Tensor deepseek_moe_mkl(
154187
const at::Tensor& hidden_states,
155-
const at::Tensor& expert_mask,
188+
const at::Tensor& topk_ids,
156189
const std::vector<at::Tensor>& gate_wei,
157190
const std::vector<c10::intrusive_ptr<MKLOpContext>>& gate_op_ctx,
158191
const std::vector<at::Tensor>& up_wei,
159192
const std::vector<c10::intrusive_ptr<MKLOpContext>>& up_op_ctx,
160193
const std::vector<at::Tensor>& down_wei,
161194
const std::vector<c10::intrusive_ptr<MKLOpContext>>& down_op_ctx,
162195
const at::Tensor& routing_weights,
163-
at::Tensor& output,
164196
bool is_distributed) {
165197
RECORD_FUNCTION("ipex::deepseek_moe_mkl", c10::ArrayRef<c10::IValue>({}));
166198

199+
auto output = at::zeros_like(hidden_states);
167200
int num_experts = gate_wei.size();
201+
std::vector<long> expert_selected;
202+
std::vector<std::vector<long>> expert_idx, expert_top_x;
203+
std::tie(expert_selected, expert_idx, expert_top_x) =
204+
get_expert_topx_idx(topk_ids, num_experts);
168205
for (auto i = 0; i < num_experts; i++) {
169-
auto non_zero = expert_mask[i].nonzero();
170-
if (non_zero.sizes()[0] == 0)
206+
if (expert_selected[i] == 0)
171207
continue;
172-
auto idx = non_zero.select(1, 0);
173-
auto top_x = non_zero.select(1, 1);
208+
auto idx =
209+
torch::from_blob(expert_idx[i].data(), {expert_selected[i]}, at::kLong);
210+
auto top_x = torch::from_blob(
211+
expert_top_x[i].data(), {expert_selected[i]}, at::kLong);
174212
output = mixtral_moe_kernel_stub(
175213
kCPU,
176214
hidden_states,
@@ -217,22 +255,27 @@ at::Tensor mixtral_moe_woq(
217255
}
218256
at::Tensor deepseek_moe_woq(
219257
const at::Tensor& hidden_states,
220-
const at::Tensor& expert_mask,
258+
const at::Tensor& topk_ids,
221259
const std::vector<c10::intrusive_ptr<WoqLinearOpContext>>& gate_ctx,
222260
const std::vector<c10::intrusive_ptr<WoqLinearOpContext>>& up_ctx,
223261
const std::vector<c10::intrusive_ptr<WoqLinearOpContext>>& down_ctx,
224262
const at::Tensor& routing_weights,
225-
at::Tensor& output,
226263
bool is_distributed) {
227264
RECORD_FUNCTION("ipex::deepseek_moe_woq", c10::ArrayRef<c10::IValue>({}));
228265

266+
auto output = at::zeros_like(hidden_states);
229267
int num_experts = gate_ctx.size();
268+
std::vector<long> expert_selected;
269+
std::vector<std::vector<long>> expert_idx, expert_top_x;
270+
std::tie(expert_selected, expert_idx, expert_top_x) =
271+
get_expert_topx_idx(topk_ids, num_experts);
230272
for (auto i = 0; i < num_experts; i++) {
231-
auto non_zero = expert_mask[i].nonzero();
232-
if (non_zero.sizes()[0] == 0)
273+
if (expert_selected[i] == 0)
233274
continue;
234-
auto idx = non_zero.select(1, 0);
235-
auto top_x = non_zero.select(1, 1);
275+
auto idx =
276+
torch::from_blob(expert_idx[i].data(), {expert_selected[i]}, at::kLong);
277+
auto top_x = torch::from_blob(
278+
expert_top_x[i].data(), {expert_selected[i]}, at::kLong);
236279
output = mixtral_moe_woq_kernel_stub(
237280
kCPU,
238281
hidden_states,
@@ -247,6 +290,27 @@ at::Tensor deepseek_moe_woq(
247290
}
248291
return output;
249292
}
293+
294+
std::tuple<at::Tensor, at::Tensor> deepseek_moegate(
295+
const at::Tensor& hidden_states,
296+
const at::Tensor& scores,
297+
const at::Tensor& routed_scaling_factor,
298+
const int64_t n_group,
299+
const int64_t topk_group,
300+
const int64_t n_routed_experts,
301+
const int64_t top_k) {
302+
RECORD_FUNCTION("ipex::deepseek_moegate", c10::ArrayRef<c10::IValue>({}));
303+
304+
return deepseek_moegate_kernel_stub(
305+
kCPU,
306+
hidden_states,
307+
scores,
308+
routed_scaling_factor,
309+
n_group,
310+
topk_group,
311+
n_routed_experts,
312+
top_k);
313+
}
250314
} // namespace cpu
251315
} // namespace torch_ipex
252316

@@ -262,9 +326,9 @@ TORCH_LIBRARY_FRAGMENT(torch_ipex, m) {
262326
c10::DispatchKey::CPU,
263327
torch_ipex::cpu::mixtral_moe_tpp);
264328
m.def(
265-
"deepseek_moe_tpp(Tensor hidden_states, Tensor expert_mask, Tensor[] gate_wei, \
329+
"deepseek_moe_tpp(Tensor hidden_states, Tensor topk_ids, Tensor[] gate_wei, \
266330
Tensor[] up_wei, Tensor[] down_wei, bool tpp_fallback, Tensor routing_weights, \
267-
Tensor output, bool is_distributed) -> Tensor");
331+
bool is_distributed) -> Tensor");
268332
m.impl(
269333
"deepseek_moe_tpp",
270334
c10::DispatchKey::CPU,
@@ -275,18 +339,18 @@ TORCH_LIBRARY_FRAGMENT(torch_ipex, m) {
275339
Tensor down_op_ctx, bool use_dnnl, Tensor routing_weights, Tensor output, bool is_distributed) -> Tensor");
276340
m.impl("mixtral_moe", c10::DispatchKey::CPU, torch_ipex::cpu::mixtral_moe);
277341
m.def(
278-
"deepseek_moe(Tensor hidden_states, Tensor expert_mask, Tensor[] gate_wei, \
342+
"deepseek_moe(Tensor hidden_states, Tensor topk_ids, Tensor[] gate_wei, \
279343
__torch__.torch.classes.ipex_prepack.LinearOpContext[] gate_op_ctx, Tensor[] up_wei, \
280344
__torch__.torch.classes.ipex_prepack.LinearOpContext[] up_op_ctx, Tensor[] down_wei, \
281345
__torch__.torch.classes.ipex_prepack.LinearOpContext[] down_op_ctx, Tensor routing_weights, \
282-
Tensor output, bool is_distributed) -> Tensor");
346+
bool is_distributed) -> Tensor");
283347
m.impl("deepseek_moe", c10::DispatchKey::CPU, torch_ipex::cpu::deepseek_moe);
284348
m.def(
285-
"deepseek_moe_mkl(Tensor hidden_states, Tensor expert_mask, Tensor[] gate_wei, \
349+
"deepseek_moe_mkl(Tensor hidden_states, Tensor topk_ids, Tensor[] gate_wei, \
286350
__torch__.torch.classes.ipex_prepack.MKLOpContext[] gate_op_ctx, Tensor[] up_wei, \
287351
__torch__.torch.classes.ipex_prepack.MKLOpContext[] up_op_ctx, \
288352
Tensor[] down_wei, __torch__.torch.classes.ipex_prepack.MKLOpContext[] down_op_ctx, \
289-
Tensor routing_weights, Tensor output, bool is_distributed) -> Tensor");
353+
Tensor routing_weights, bool is_distributed) -> Tensor");
290354
m.impl(
291355
"deepseek_moe_mkl",
292356
c10::DispatchKey::CPU,
@@ -299,15 +363,21 @@ TORCH_LIBRARY_FRAGMENT(torch_ipex, m) {
299363
c10::DispatchKey::CPU,
300364
torch_ipex::cpu::mixtral_moe_woq);
301365
m.def(
302-
"deepseek_moe_woq(Tensor hidden_states, Tensor expert_mask, \
366+
"deepseek_moe_woq(Tensor hidden_states, Tensor topk_ids, \
303367
__torch__.torch.classes.ipex_prepack.WoqLinearOpContext[] gate_ctx, \
304368
__torch__.torch.classes.ipex_prepack.WoqLinearOpContext[] up_ctx, \
305369
__torch__.torch.classes.ipex_prepack.WoqLinearOpContext[] down_ctx, \
306-
Tensor routing_weights, Tensor output, bool is_distributed) -> Tensor");
370+
Tensor routing_weights, bool is_distributed) -> Tensor");
307371

308372
m.impl(
309373
"deepseek_moe_woq",
310374
c10::DispatchKey::CPU,
311375
torch_ipex::cpu::deepseek_moe_woq);
376+
m.def(
377+
"deepseek_moegate(Tensor hidden_states, Tensor scores, Tensor routed_scaling_factor, int n_group, int topk_group, int n_routed_experts, int top_k) -> (Tensor, Tensor)");
378+
m.impl(
379+
"deepseek_moegate",
380+
c10::DispatchKey::CPU,
381+
torch_ipex::cpu::deepseek_moegate);
312382
}
313383
} // namespace

csrc/cpu/aten/MoE.h

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ at::Tensor deepseek_moe_tpp(
2525
const std::vector<at::Tensor>&,
2626
bool,
2727
const at::Tensor&,
28-
at::Tensor&,
2928
bool);
3029
at::Tensor mixtral_moe_woq(
3130
const at::Tensor&,
@@ -44,7 +43,6 @@ at::Tensor deepseek_moe_woq(
4443
const std::vector<c10::intrusive_ptr<WoqLinearOpContext>>&,
4544
const std::vector<c10::intrusive_ptr<WoqLinearOpContext>>&,
4645
const at::Tensor&,
47-
at::Tensor&,
4846
bool);
4947
at::Tensor mixtral_moe_woq(
5048
const at::Tensor&,
@@ -80,7 +78,6 @@ at::Tensor deepseek_moe(
8078
const std::vector<at::Tensor>&,
8179
const std::vector<c10::intrusive_ptr<LinearOpContext>>&,
8280
const at::Tensor&,
83-
at::Tensor&,
8481
bool);
8582
at::Tensor deepseek_moe_mkl(
8683
const at::Tensor&,
@@ -92,8 +89,15 @@ at::Tensor deepseek_moe_mkl(
9289
const std::vector<at::Tensor>&,
9390
const std::vector<c10::intrusive_ptr<MKLOpContext>>&,
9491
const at::Tensor&,
95-
at::Tensor&,
9692
bool);
93+
std::tuple<at::Tensor, at::Tensor> deepseek_moegate(
94+
const at::Tensor& hidden_states,
95+
const at::Tensor& scores,
96+
const at::Tensor& routed_scaling_factor,
97+
const int64_t n_group,
98+
const int64_t topk_group,
99+
const int64_t n_routed_experts,
100+
const int64_t top_k);
97101
using mixtral_moe_tpp_kernel_fn = at::Tensor (*)(
98102
const at::Tensor& hidden_states,
99103
const at::Tensor& top_x,
@@ -107,13 +111,12 @@ using mixtral_moe_tpp_kernel_fn = at::Tensor (*)(
107111
bool is_distributed);
108112
using deepseek_moe_tpp_kernel_fn = at::Tensor (*)(
109113
const at::Tensor& hidden_states,
110-
const at::Tensor& expert_mask,
114+
const at::Tensor& topk_ids,
111115
const std::vector<at::Tensor>& gate_wei,
112116
const std::vector<at::Tensor>& up_wei,
113117
const std::vector<at::Tensor>& down_wei,
114118
bool tpp_fallback,
115119
const at::Tensor& routing_weights,
116-
at::Tensor& output,
117120
bool is_distributed);
118121
using mixtral_moe_woq_kernel_fn = at::Tensor (*)(
119122
const at::Tensor& hidden_states,
@@ -127,12 +130,11 @@ using mixtral_moe_woq_kernel_fn = at::Tensor (*)(
127130
bool is_distributed);
128131
using deepseek_moe_woq_kernel_fn = at::Tensor (*)(
129132
const at::Tensor& hidden_states,
130-
const at::Tensor& expert_mask,
133+
const at::Tensor& topk_ids,
131134
const std::vector<c10::intrusive_ptr<WoqLinearOpContext>>& gate_ctx,
132135
const std::vector<c10::intrusive_ptr<WoqLinearOpContext>>& up_ctx,
133136
const std::vector<c10::intrusive_ptr<WoqLinearOpContext>>& down_ctx,
134137
const at::Tensor& routing_weights,
135-
at::Tensor& output,
136138
bool is_distributed);
137139
using mixtral_moe_kernel_fn = at::Tensor (*)(
138140
const at::Tensor& hidden_states,
@@ -150,34 +152,41 @@ using mixtral_moe_kernel_fn = at::Tensor (*)(
150152
bool is_distributed);
151153
using deepseek_moe_kernel_fn = at::Tensor (*)(
152154
const at::Tensor& hidden_states,
153-
const at::Tensor& expert_mask,
155+
const at::Tensor& topk_ids,
154156
const std::vector<at::Tensor>& gate_wei,
155157
const std::vector<c10::intrusive_ptr<LinearOpContext>>& gate_op_ctx,
156158
const std::vector<at::Tensor>& up_wei,
157159
const std::vector<c10::intrusive_ptr<LinearOpContext>>& up_op_ctx,
158160
const std::vector<at::Tensor>& down_wei,
159161
const std::vector<c10::intrusive_ptr<LinearOpContext>>& down_op_ctx,
160162
const at::Tensor& routing_weights,
161-
at::Tensor& output,
162163
bool is_distributed);
163164
using deepseek_moe_mkl_kernel_fn = at::Tensor (*)(
164165
const at::Tensor& hidden_states,
165-
const at::Tensor& expert_mask,
166+
const at::Tensor& topk_ids,
166167
const std::vector<at::Tensor>& gate_wei,
167168
const std::vector<c10::intrusive_ptr<MKLOpContext>>& gate_op_ctx,
168169
const std::vector<at::Tensor>& up_wei,
169170
const std::vector<c10::intrusive_ptr<MKLOpContext>>& up_op_ctx,
170171
const std::vector<at::Tensor>& down_wei,
171172
const std::vector<c10::intrusive_ptr<MKLOpContext>>& down_op_ctx,
172173
const at::Tensor& routing_weights,
173-
at::Tensor& output,
174174
bool is_distributed);
175+
using deepseek_moegate_kernel_fn = std::tuple<at::Tensor, at::Tensor> (*)(
176+
const at::Tensor& hidden_states,
177+
const at::Tensor& scores,
178+
const at::Tensor& routed_scaling_factor,
179+
const int64_t n_group,
180+
const int64_t topk_group,
181+
const int64_t n_routed_experts,
182+
const int64_t top_k);
175183
IPEX_DECLARE_DISPATCH(mixtral_moe_tpp_kernel_fn, mixtral_moe_tpp_kernel_stub);
176184
IPEX_DECLARE_DISPATCH(deepseek_moe_tpp_kernel_fn, deepseek_moe_tpp_kernel_stub);
177185
IPEX_DECLARE_DISPATCH(mixtral_moe_woq_kernel_fn, mixtral_moe_woq_kernel_stub);
178186
IPEX_DECLARE_DISPATCH(deepseek_moe_woq_kernel_fn, deepseek_moe_woq_kernel_stub);
179187
IPEX_DECLARE_DISPATCH(mixtral_moe_kernel_fn, mixtral_moe_kernel_stub);
180188
IPEX_DECLARE_DISPATCH(deepseek_moe_kernel_fn, deepseek_moe_kernel_stub);
181189
IPEX_DECLARE_DISPATCH(deepseek_moe_mkl_kernel_fn, deepseek_moe_mkl_kernel_stub);
190+
IPEX_DECLARE_DISPATCH(deepseek_moegate_kernel_fn, deepseek_moegate_kernel_stub);
182191
} // namespace cpu
183192
} // namespace torch_ipex

0 commit comments

Comments
 (0)