7
7
from vllm .attention import Attention , AttentionMetadata
8
8
from vllm .config import CacheConfig
9
9
from vllm .distributed import (get_tensor_model_parallel_rank ,
10
- get_tensor_model_parallel_world_size ,
11
- tensor_model_parallel_all_reduce )
12
- from vllm .model_executor .layers .fused_moe import fused_moe
10
+ get_tensor_model_parallel_world_size )
11
+ from vllm .model_executor .layers .fused_moe import FusedMoE
13
12
from vllm .model_executor .layers .linear import (QKVParallelLinear ,
14
13
ReplicatedLinear ,
15
14
RowParallelLinear )
22
21
DEFAULT_VOCAB_PADDING_SIZE , ParallelLMHead , VocabParallelEmbedding )
23
22
from vllm .model_executor .model_loader .weight_utils import default_weight_loader
24
23
from vllm .model_executor .sampling_metadata import SamplingMetadata
25
- from vllm .model_executor .utils import set_weight_attrs
26
24
from vllm .sequence import IntermediateTensors
27
25
from vllm .transformers_utils .configs .dbrx import DbrxConfig
28
26
@@ -54,63 +52,32 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
54
52
return router_logits
55
53
56
54
57
- class DbrxExperts (nn .Module ):
58
- """A tensor-parallel MoE implementation for DBRX.
59
-
60
- Each expert's weights are sharded across all ranks and a fused MoE
61
- kernel is used for the forward pass, and finally we reduce the outputs
62
- across ranks.
63
- """
55
+ class DbrxExperts (FusedMoE ):
64
56
65
57
def __init__ (
66
58
self ,
67
59
config : DbrxConfig ,
68
60
quant_config : Optional [QuantizationConfig ] = None ,
69
61
params_dtype : Optional [torch .dtype ] = None ,
70
62
):
71
- super ().__init__ ()
63
+ super ().__init__ (
64
+ num_experts = config .ffn_config .moe_num_experts ,
65
+ top_k = config .ffn_config .moe_top_k ,
66
+ hidden_size = config .d_model ,
67
+ intermediate_size = config .ffn_config .ffn_hidden_size ,
68
+ params_dtype = params_dtype ,
69
+ reduce_results = True ,
70
+ renormalize = True ,
71
+ quant_config = quant_config ,
72
+ tp_size = get_tensor_model_parallel_world_size (),
73
+ )
74
+ self .config = config
72
75
self .tp_size = get_tensor_model_parallel_world_size ()
73
- self .num_total_experts = config .ffn_config .moe_num_experts
74
- self .top_k = config .ffn_config .moe_top_k
75
76
self .d_model = config .d_model
76
- self .intermediate_size = (config .ffn_config .ffn_hidden_size //
77
+ self .intermediate_size = (self . config .ffn_config .ffn_hidden_size //
77
78
self .tp_size )
78
79
79
- if params_dtype is None :
80
- params_dtype = torch .get_default_dtype ()
81
- self .params_dtype = params_dtype
82
-
83
- self .router = DbrxRouter (config , self .params_dtype )
84
- self .ws = nn .Parameter (
85
- torch .empty (
86
- self .num_total_experts ,
87
- 2 * self .intermediate_size ,
88
- self .d_model ,
89
- device = "cuda" ,
90
- dtype = self .params_dtype ,
91
- ))
92
- self .w2s = nn .Parameter (
93
- torch .empty (
94
- self .num_total_experts ,
95
- self .d_model ,
96
- self .intermediate_size ,
97
- device = "cuda" ,
98
- dtype = self .params_dtype ,
99
- ))
100
-
101
- set_weight_attrs (
102
- self .ws ,
103
- {
104
- "weight_loader" : self .weight_loader ,
105
- },
106
- )
107
- set_weight_attrs (
108
- self .w2s ,
109
- {
110
- "weight_loader" : self .weight_loader ,
111
- },
112
- )
113
-
80
+ # Define custom weight loader for dbrx model
114
81
def weight_loader (self , param : nn .Parameter , loaded_weight : torch .Tensor ,
115
82
weight_name : str ):
116
83
tp_rank = get_tensor_model_parallel_rank ()
@@ -119,47 +86,61 @@ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
119
86
shard = slice (tp_rank * shard_size , (tp_rank + 1 ) * shard_size )
120
87
# DBRX uses GLU for each experts.
121
88
# GLU has 3 linear layers: w1, v1 and w2.
122
- if weight_name .endswith ("w1" ):
89
+ if weight_name .endswith ("w1. " ):
123
90
loaded_weight = torch .reshape (
124
91
loaded_weight ,
125
92
[- 1 , self .intermediate_size * self .tp_size , self .d_model ],
126
93
)
127
94
param_data [:, 0 :shard_size , :] = loaded_weight [:, shard , :]
128
- if weight_name .endswith ("v1" ):
95
+ if weight_name .endswith ("v1. " ):
129
96
loaded_weight = torch .reshape (
130
97
loaded_weight ,
131
98
[- 1 , self .intermediate_size * self .tp_size , self .d_model ],
132
99
)
133
100
param_data [:,
134
101
shard_size :2 * shard_size , :] = loaded_weight [:,
135
102
shard , :]
136
- if weight_name .endswith ("w2" ):
103
+ if weight_name .endswith ("w2. " ):
137
104
loaded_weight = torch .reshape (
138
105
loaded_weight ,
139
106
[- 1 , self .intermediate_size * self .tp_size , self .d_model ],
140
107
).transpose (1 , 2 )
141
108
param_data [:] = loaded_weight [:, :, shard ]
142
109
110
+
111
+ class DbrxMoE (nn .Module ):
112
+ """A tensor-parallel MoE implementation for DBRX.
113
+
114
+ Each expert's weights are sharded across all ranks and a fused MoE
115
+ kernel is used for the forward pass, and finally we reduce the outputs
116
+ across ranks.
117
+ """
118
+
119
+ def __init__ (
120
+ self ,
121
+ config : DbrxConfig ,
122
+ quant_config : Optional [QuantizationConfig ] = None ,
123
+ params_dtype : Optional [torch .dtype ] = None ,
124
+ ):
125
+ super ().__init__ ()
126
+ self .d_model = config .d_model
127
+ if params_dtype is None :
128
+ params_dtype = torch .get_default_dtype ()
129
+ self .params_dtype = params_dtype
130
+
131
+ self .router = DbrxRouter (config , self .params_dtype )
132
+
133
+ self .experts = DbrxExperts (config = config ,
134
+ quant_config = quant_config ,
135
+ params_dtype = self .params_dtype )
136
+
143
137
def forward (self , hidden_states : torch .Tensor ) -> torch .Tensor :
144
- num_tokens , hidden_size = hidden_states .shape
138
+ orig_shape = hidden_states .shape
145
139
hidden_states = hidden_states .view (- 1 , self .d_model )
146
140
# router_logits: (num_tokens, n_experts)
147
141
router_logits = self .router (hidden_states )
148
- final_hidden_states = fused_moe (
149
- hidden_states ,
150
- self .ws ,
151
- self .w2s ,
152
- router_logits ,
153
- self .top_k ,
154
- renormalize = True ,
155
- inplace = True ,
156
- )
157
-
158
- if self .tp_size > 1 :
159
- final_hidden_states = tensor_model_parallel_all_reduce (
160
- final_hidden_states )
161
-
162
- return final_hidden_states .view (num_tokens , hidden_size )
142
+ final_hidden_states = self .experts (hidden_states , router_logits )
143
+ return final_hidden_states .view (orig_shape )
163
144
164
145
165
146
class DbrxAttention (nn .Module ):
@@ -288,7 +269,7 @@ def __init__(
288
269
super ().__init__ ()
289
270
self .norm_attn_norm = DbrxFusedNormAttention (config , cache_config ,
290
271
quant_config )
291
- self .ffn = DbrxExperts (config , quant_config )
272
+ self .ffn = DbrxMoE (config , quant_config )
292
273
293
274
def forward (
294
275
self ,
@@ -409,12 +390,15 @@ def sample(
409
390
return next_tokens
410
391
411
392
def load_weights (self , weights : Iterable [Tuple [str , torch .Tensor ]]):
393
+
412
394
expert_params_mapping = [(
413
- "ws " if weight_name in ["w1" , "v1" ] else "w2s " ,
414
- f"experts. mlp.{ weight_name } " ,
395
+ "w13_ " if weight_name in ["w1" , "v1" ] else "w2_ " ,
396
+ f"mlp.{ weight_name } . " ,
415
397
) for weight_name in ["w1" , "v1" , "w2" ]]
416
398
params_dict = dict (self .named_parameters (remove_duplicate = False ))
417
399
for name , loaded_weight in weights :
400
+ if name .endswith (("w1" , "v1" , "w2" )):
401
+ name = name + ".weight"
418
402
for param_name , weight_name in expert_params_mapping :
419
403
if weight_name not in name :
420
404
continue
0 commit comments