11# ==-------------------------------------------------------------------------==
22# VLLM-HPU-EXT PATCH Start
33# ==-------------------------------------------------------------------------==
4+ import logging
5+ import os
46import torch
57from typing import Callable , Optional , Tuple
68import habana_frameworks .torch as htorch
79
10+ logging .basicConfig (level = logging .INFO )
11+
812
913class MoeFP8Matmul (torch .nn .Module ):
1014 def __init__ (
@@ -66,7 +70,11 @@ def get_dequant_weights_func(
6670
6771class VllmMixtureOfExpertsOpFP8 (torch .nn .Module ):
6872 def __init__ (
69- self , num_experts : int , experts_min : int = 0 , experts_max : int = 8
73+ self ,
74+ num_experts : int ,
75+ global_num_experts : int = 0 ,
76+ experts_min : int = 0 ,
77+ experts_max : int = 8 ,
7078 ):
7179 super ().__init__ ()
7280 self .w13_list = torch .nn .ModuleList (
@@ -75,10 +83,52 @@ def __init__(
7583 self .w2_list = torch .nn .ModuleList (
7684 [MoeFP8Matmul () for _ in range (num_experts )]
7785 )
86+ self .enable_moe_chunk = (
87+ os .environ .get ("VLLM_SUPPORT_MOE_CHUNK" , "false" ).lower () == "true"
88+ )
89+ self .chunk_size_list = [
90+ int (x )
91+ for x in os .environ .get (
92+ "PT_HPU_MOE_CHUNK" , "64,128,512,1024,1536,2048,4096"
93+ ).split ("," )
94+ if x .strip ()
95+ ]
96+ self .token_boundary_list = [
97+ int (x )
98+ for x in os .environ .get (
99+ "PT_HPU_MOE_TOKEN_BOUNDARY" , "64,128,1536,1736,2048,3072,4096"
100+ ).split ("," )
101+ if x .strip ()
102+ ]
103+ assert len (self .chunk_size_list ) == len (self .token_boundary_list ), (
104+ f"chunk_size_list({ len (self .chunk_size_list )} ) and "
105+ f"token_boundary_list({ len (self .token_boundary_list )} ) must be the same length"
106+ )
107+ logger = logging .getLogger ()
108+ if self .enable_moe_chunk :
109+ logger .info ("token_boundary_list is:%s" ,self .token_boundary_list )
110+ logger .info ("chunk_size_list is:%s" ,self .chunk_size_list )
111+
78112 self .num_experts = num_experts
113+ self .global_num_experts = global_num_experts
79114 self .experts_min = experts_min
80115 self .experts_max = experts_max
81116
117+ def _get_extra_kwargs (self , tokens_num : int ):
118+ if self .enable_moe_chunk :
119+ chunk_size = self .chunk_size_list [- 1 ]
120+ for idx , threshold in enumerate (self .token_boundary_list ):
121+ if tokens_num <= threshold :
122+ chunk_size = self .chunk_size_list [idx ]
123+ break
124+ kwargs = {
125+ "chunk_size" : chunk_size ,
126+ "total_experts" : self .global_num_experts ,
127+ }
128+ else :
129+ kwargs = {}
130+ return kwargs
131+
82132 def forward (
83133 self ,
84134 x ,
@@ -89,6 +139,8 @@ def forward(
89139 max_expert = self .experts_max
90140 w13_list_slice = []
91141 w2_list_slice = []
142+ tokens_num , _ = x .shape
143+ kwargs = self ._get_extra_kwargs (tokens_num )
92144 for j in range (self .num_experts ):
93145 w13_list_slice .append (self .w13_list [j ].get_dequant_weight ())
94146 w2_list_slice .append (self .w2_list [j ].get_dequant_weight ())
@@ -103,6 +155,7 @@ def forward(
103155 activation = "silu" ,
104156 experts_min = min_expert ,
105157 experts_max = max_expert ,
158+ ** kwargs ,
106159 )
107160 htorch .core .mark_step ()
108161 return final_hidden_states
0 commit comments