@@ -9,6 +9,7 @@ IPEX_DEFINE_DISPATCH(mixtral_moe_tpp_kernel_stub);
99IPEX_DEFINE_DISPATCH (mixtral_moe_woq_kernel_stub);
1010IPEX_DEFINE_DISPATCH (deepseek_moe_woq_kernel_stub);
1111IPEX_DEFINE_DISPATCH (mixtral_moe_kernel_stub);
12+ IPEX_DEFINE_DISPATCH (deepseek_moegate_kernel_stub);
1213
1314at::Tensor mixtral_moe_tpp (
1415 const at::Tensor& hidden_states,
@@ -39,25 +40,53 @@ at::Tensor mixtral_moe_tpp(
3940 is_distributed);
4041}
4142
43+ inline std::tuple<
44+ std::vector<long >,
45+ std::vector<std::vector<long >>,
46+ std::vector<std::vector<long >>>
47+ get_expert_topx_idx (const at::Tensor& topk_ids, const int num_experts) {
48+ auto token_num = topk_ids.size (0 );
49+ auto topk = topk_ids.size (1 );
50+ std::vector<long > expert_selected (num_experts, 0 );
51+ std::vector<std::vector<long >> expert_idx (num_experts);
52+ std::vector<std::vector<long >> expert_top_x (num_experts);
53+ auto topk_ids_ptr = topk_ids.data_ptr <long >();
54+ auto topk_ids_stride0 = topk_ids.stride (0 );
55+ for (auto i = 0 ; i < token_num; i++) {
56+ for (auto j = 0 ; j < topk; j++) {
57+ auto expert_id = topk_ids_ptr[i * topk_ids_stride0 + j];
58+ expert_selected[expert_id] += 1 ;
59+ expert_top_x[expert_id].push_back (i);
60+ expert_idx[expert_id].push_back (j);
61+ }
62+ }
63+ return std::make_tuple (expert_selected, expert_idx, expert_top_x);
64+ }
65+
4266at::Tensor deepseek_moe_tpp (
4367 const at::Tensor& hidden_states,
44- const at::Tensor& expert_mask ,
68+ const at::Tensor& topk_ids ,
4569 const std::vector<at::Tensor>& gate_wei,
4670 const std::vector<at::Tensor>& up_wei,
4771 const std::vector<at::Tensor>& down_wei,
4872 bool tpp_fallback,
4973 const at::Tensor& routing_weights,
50- at::Tensor& output,
5174 bool is_distributed) {
5275 RECORD_FUNCTION (" ipex::deepseek_moe_tpp" , c10::ArrayRef<c10::IValue>({}));
5376
77+ auto output = at::zeros_like (hidden_states);
5478 int num_experts = gate_wei.size ();
79+ std::vector<long > expert_selected;
80+ std::vector<std::vector<long >> expert_idx, expert_top_x;
81+ std::tie (expert_selected, expert_idx, expert_top_x) =
82+ get_expert_topx_idx (topk_ids, num_experts);
5583 for (auto i = 0 ; i < num_experts; i++) {
56- auto non_zero = expert_mask[i].nonzero ();
57- if (non_zero.sizes ()[0 ] == 0 )
84+ if (expert_selected[i] == 0 )
5885 continue ;
59- auto idx = non_zero.select (1 , 0 );
60- auto top_x = non_zero.select (1 , 1 );
86+ auto idx =
87+ torch::from_blob (expert_idx[i].data (), {expert_selected[i]}, at::kLong );
88+ auto top_x = torch::from_blob (
89+ expert_top_x[i].data (), {expert_selected[i]}, at::kLong );
6190 output = mixtral_moe_tpp_kernel_stub (
6291 kCPU ,
6392 hidden_states,
@@ -111,26 +140,30 @@ at::Tensor mixtral_moe(
111140
112141at::Tensor deepseek_moe (
113142 const at::Tensor& hidden_states,
114- const at::Tensor& expert_mask ,
143+ const at::Tensor& topk_ids ,
115144 const std::vector<at::Tensor>& gate_wei,
116145 const std::vector<c10::intrusive_ptr<LinearOpContext>>& gate_op_ctx,
117146 const std::vector<at::Tensor>& up_wei,
118147 const std::vector<c10::intrusive_ptr<LinearOpContext>>& up_op_ctx,
119148 const std::vector<at::Tensor>& down_wei,
120149 const std::vector<c10::intrusive_ptr<LinearOpContext>>& down_op_ctx,
121150 const at::Tensor& routing_weights,
122- at::Tensor& output,
123151 bool is_distributed) {
124152 RECORD_FUNCTION (" ipex::deepseek_moe" , c10::ArrayRef<c10::IValue>({}));
125153
154+ auto output = at::zeros_like (hidden_states);
126155 int num_experts = gate_wei.size ();
156+ std::vector<long > expert_selected;
157+ std::vector<std::vector<long >> expert_idx, expert_top_x;
158+ std::tie (expert_selected, expert_idx, expert_top_x) =
159+ get_expert_topx_idx (topk_ids, num_experts);
127160 for (auto i = 0 ; i < num_experts; i++) {
128- auto non_zero = expert_mask[i].nonzero ();
129- if (non_zero.sizes ()[0 ] == 0 )
161+ if (expert_selected[i] == 0 )
130162 continue ;
131- auto idx = non_zero.select (1 , 0 );
132- auto top_x = non_zero.select (1 , 1 );
133-
163+ auto idx =
164+ torch::from_blob (expert_idx[i].data (), {expert_selected[i]}, at::kLong );
165+ auto top_x = torch::from_blob (
166+ expert_top_x[i].data (), {expert_selected[i]}, at::kLong );
134167 output = mixtral_moe_kernel_stub (
135168 kCPU ,
136169 hidden_states,
@@ -152,25 +185,30 @@ at::Tensor deepseek_moe(
152185
153186at::Tensor deepseek_moe_mkl (
154187 const at::Tensor& hidden_states,
155- const at::Tensor& expert_mask ,
188+ const at::Tensor& topk_ids ,
156189 const std::vector<at::Tensor>& gate_wei,
157190 const std::vector<c10::intrusive_ptr<MKLOpContext>>& gate_op_ctx,
158191 const std::vector<at::Tensor>& up_wei,
159192 const std::vector<c10::intrusive_ptr<MKLOpContext>>& up_op_ctx,
160193 const std::vector<at::Tensor>& down_wei,
161194 const std::vector<c10::intrusive_ptr<MKLOpContext>>& down_op_ctx,
162195 const at::Tensor& routing_weights,
163- at::Tensor& output,
164196 bool is_distributed) {
165197 RECORD_FUNCTION (" ipex::deepseek_moe_mkl" , c10::ArrayRef<c10::IValue>({}));
166198
199+ auto output = at::zeros_like (hidden_states);
167200 int num_experts = gate_wei.size ();
201+ std::vector<long > expert_selected;
202+ std::vector<std::vector<long >> expert_idx, expert_top_x;
203+ std::tie (expert_selected, expert_idx, expert_top_x) =
204+ get_expert_topx_idx (topk_ids, num_experts);
168205 for (auto i = 0 ; i < num_experts; i++) {
169- auto non_zero = expert_mask[i].nonzero ();
170- if (non_zero.sizes ()[0 ] == 0 )
206+ if (expert_selected[i] == 0 )
171207 continue ;
172- auto idx = non_zero.select (1 , 0 );
173- auto top_x = non_zero.select (1 , 1 );
208+ auto idx =
209+ torch::from_blob (expert_idx[i].data (), {expert_selected[i]}, at::kLong );
210+ auto top_x = torch::from_blob (
211+ expert_top_x[i].data (), {expert_selected[i]}, at::kLong );
174212 output = mixtral_moe_kernel_stub (
175213 kCPU ,
176214 hidden_states,
@@ -217,22 +255,27 @@ at::Tensor mixtral_moe_woq(
217255}
218256at::Tensor deepseek_moe_woq (
219257 const at::Tensor& hidden_states,
220- const at::Tensor& expert_mask ,
258+ const at::Tensor& topk_ids ,
221259 const std::vector<c10::intrusive_ptr<WoqLinearOpContext>>& gate_ctx,
222260 const std::vector<c10::intrusive_ptr<WoqLinearOpContext>>& up_ctx,
223261 const std::vector<c10::intrusive_ptr<WoqLinearOpContext>>& down_ctx,
224262 const at::Tensor& routing_weights,
225- at::Tensor& output,
226263 bool is_distributed) {
227264 RECORD_FUNCTION (" ipex::deepseek_moe_woq" , c10::ArrayRef<c10::IValue>({}));
228265
266+ auto output = at::zeros_like (hidden_states);
229267 int num_experts = gate_ctx.size ();
268+ std::vector<long > expert_selected;
269+ std::vector<std::vector<long >> expert_idx, expert_top_x;
270+ std::tie (expert_selected, expert_idx, expert_top_x) =
271+ get_expert_topx_idx (topk_ids, num_experts);
230272 for (auto i = 0 ; i < num_experts; i++) {
231- auto non_zero = expert_mask[i].nonzero ();
232- if (non_zero.sizes ()[0 ] == 0 )
273+ if (expert_selected[i] == 0 )
233274 continue ;
234- auto idx = non_zero.select (1 , 0 );
235- auto top_x = non_zero.select (1 , 1 );
275+ auto idx =
276+ torch::from_blob (expert_idx[i].data (), {expert_selected[i]}, at::kLong );
277+ auto top_x = torch::from_blob (
278+ expert_top_x[i].data (), {expert_selected[i]}, at::kLong );
236279 output = mixtral_moe_woq_kernel_stub (
237280 kCPU ,
238281 hidden_states,
@@ -247,6 +290,27 @@ at::Tensor deepseek_moe_woq(
247290 }
248291 return output;
249292}
293+
294+ std::tuple<at::Tensor, at::Tensor> deepseek_moegate (
295+ const at::Tensor& hidden_states,
296+ const at::Tensor& scores,
297+ const at::Tensor& routed_scaling_factor,
298+ const int64_t n_group,
299+ const int64_t topk_group,
300+ const int64_t n_routed_experts,
301+ const int64_t top_k) {
302+ RECORD_FUNCTION (" ipex::deepseek_moegate" , c10::ArrayRef<c10::IValue>({}));
303+
304+ return deepseek_moegate_kernel_stub (
305+ kCPU ,
306+ hidden_states,
307+ scores,
308+ routed_scaling_factor,
309+ n_group,
310+ topk_group,
311+ n_routed_experts,
312+ top_k);
313+ }
250314} // namespace cpu
251315} // namespace torch_ipex
252316
@@ -262,9 +326,9 @@ TORCH_LIBRARY_FRAGMENT(torch_ipex, m) {
262326 c10::DispatchKey::CPU,
263327 torch_ipex::cpu::mixtral_moe_tpp);
264328 m.def (
265- " deepseek_moe_tpp(Tensor hidden_states, Tensor expert_mask , Tensor[] gate_wei, \
329+ " deepseek_moe_tpp(Tensor hidden_states, Tensor topk_ids , Tensor[] gate_wei, \
266330 Tensor[] up_wei, Tensor[] down_wei, bool tpp_fallback, Tensor routing_weights, \
267- Tensor output, bool is_distributed) -> Tensor" );
331+ bool is_distributed) -> Tensor" );
268332 m.impl (
269333 " deepseek_moe_tpp" ,
270334 c10::DispatchKey::CPU,
@@ -275,18 +339,18 @@ TORCH_LIBRARY_FRAGMENT(torch_ipex, m) {
275339 Tensor down_op_ctx, bool use_dnnl, Tensor routing_weights, Tensor output, bool is_distributed) -> Tensor" );
276340 m.impl (" mixtral_moe" , c10::DispatchKey::CPU, torch_ipex::cpu::mixtral_moe);
277341 m.def (
278- " deepseek_moe(Tensor hidden_states, Tensor expert_mask , Tensor[] gate_wei, \
342+ " deepseek_moe(Tensor hidden_states, Tensor topk_ids , Tensor[] gate_wei, \
279343 __torch__.torch.classes.ipex_prepack.LinearOpContext[] gate_op_ctx, Tensor[] up_wei, \
280344 __torch__.torch.classes.ipex_prepack.LinearOpContext[] up_op_ctx, Tensor[] down_wei, \
281345 __torch__.torch.classes.ipex_prepack.LinearOpContext[] down_op_ctx, Tensor routing_weights, \
282- Tensor output, bool is_distributed) -> Tensor" );
346+ bool is_distributed) -> Tensor" );
283347 m.impl (" deepseek_moe" , c10::DispatchKey::CPU, torch_ipex::cpu::deepseek_moe);
284348 m.def (
285- " deepseek_moe_mkl(Tensor hidden_states, Tensor expert_mask , Tensor[] gate_wei, \
349+ " deepseek_moe_mkl(Tensor hidden_states, Tensor topk_ids , Tensor[] gate_wei, \
286350 __torch__.torch.classes.ipex_prepack.MKLOpContext[] gate_op_ctx, Tensor[] up_wei, \
287351 __torch__.torch.classes.ipex_prepack.MKLOpContext[] up_op_ctx, \
288352 Tensor[] down_wei, __torch__.torch.classes.ipex_prepack.MKLOpContext[] down_op_ctx, \
289- Tensor routing_weights, Tensor output, bool is_distributed) -> Tensor" );
353+ Tensor routing_weights, bool is_distributed) -> Tensor" );
290354 m.impl (
291355 " deepseek_moe_mkl" ,
292356 c10::DispatchKey::CPU,
@@ -299,15 +363,21 @@ TORCH_LIBRARY_FRAGMENT(torch_ipex, m) {
299363 c10::DispatchKey::CPU,
300364 torch_ipex::cpu::mixtral_moe_woq);
301365 m.def (
302- " deepseek_moe_woq(Tensor hidden_states, Tensor expert_mask , \
366+ " deepseek_moe_woq(Tensor hidden_states, Tensor topk_ids , \
303367 __torch__.torch.classes.ipex_prepack.WoqLinearOpContext[] gate_ctx, \
304368 __torch__.torch.classes.ipex_prepack.WoqLinearOpContext[] up_ctx, \
305369 __torch__.torch.classes.ipex_prepack.WoqLinearOpContext[] down_ctx, \
306- Tensor routing_weights, Tensor output, bool is_distributed) -> Tensor" );
370+ Tensor routing_weights, bool is_distributed) -> Tensor" );
307371
308372 m.impl (
309373 " deepseek_moe_woq" ,
310374 c10::DispatchKey::CPU,
311375 torch_ipex::cpu::deepseek_moe_woq);
376+ m.def (
377+ " deepseek_moegate(Tensor hidden_states, Tensor scores, Tensor routed_scaling_factor, int n_group, int topk_group, int n_routed_experts, int top_k) -> (Tensor, Tensor)" );
378+ m.impl (
379+ " deepseek_moegate" ,
380+ c10::DispatchKey::CPU,
381+ torch_ipex::cpu::deepseek_moegate);
312382}
313383} // namespace
0 commit comments