@@ -107,10 +107,8 @@ def __init__(
107
107
108
108
if new_decoder_architecture is not None :
109
109
self .new_decoder_architecture = new_decoder_architecture
110
- elif model_type == "RefinedWeb" :
111
- self .new_decoder_architecture = True
112
110
else :
113
- self .new_decoder_architecture = False
111
+ self .new_decoder_architecture = ( model_type == "RefinedWeb" )
114
112
115
113
super ().__init__ (bos_token_id = bos_token_id , eos_token_id = eos_token_id , ** kwargs )
116
114
@@ -205,7 +203,6 @@ def forward(
205
203
query ,
206
204
layer_past [:, 0 ],
207
205
layer_past [:, 1 ],
208
- torch .select (kv , dim = 1 , index = 1 ),
209
206
attn_output ,
210
207
cu_seqlens ,
211
208
max_s ,
@@ -524,7 +521,7 @@ def __init__(self, config, weights):
524
521
self .word_embeddings = TensorParallelEmbedding (
525
522
prefix = "transformer.word_embeddings" , weights = weights
526
523
)
527
- if config .new_decoder_architecture :
524
+ if config .new_decoder_architecture : # "RefinedWeb"
528
525
self .h = nn .ModuleList (
529
526
[
530
527
FlashRWLargeLayer (layer_id , config , weights )
@@ -536,14 +533,18 @@ def __init__(self, config, weights):
536
533
2 ,
537
534
self .h [0 ].self_attention .head_size ,
538
535
)
539
- else :
536
+ else : # "RefinedWebModel"
540
537
self .h = nn .ModuleList (
541
538
[
542
539
FlashRWLayer (layer_id , config , weights )
543
540
for layer_id in range (config .num_hidden_layers )
544
541
]
545
542
)
546
- self .cache_size = self .h [0 ].self_attention .num_heads_kv
543
+ self .cache_size = (
544
+ 2 ,
545
+ self .h [0 ].self_attention .num_heads_kv ,
546
+ self .h [0 ].self_attention .head_size ,
547
+ )
547
548
548
549
self .ln_f = FastLayerNorm .load (
549
550
prefix = "transformer.ln_f" ,
0 commit comments