Skip to content

Commit 2756820

Browse files
committed
Fix flash impl for "old" Falcon arch models (incl. falcon-7b)
1 parent 45842ad commit 2756820

File tree

1 file changed

+8
-7
lines changed

1 file changed

+8
-7
lines changed

server/text_generation_server/models/custom_modeling/flash_rw_modeling.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -107,10 +107,8 @@ def __init__(
107107

108108
if new_decoder_architecture is not None:
109109
self.new_decoder_architecture = new_decoder_architecture
110-
elif model_type == "RefinedWeb":
111-
self.new_decoder_architecture = True
112110
else:
113-
self.new_decoder_architecture = False
111+
self.new_decoder_architecture = (model_type == "RefinedWeb")
114112

115113
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
116114

@@ -205,7 +203,6 @@ def forward(
205203
query,
206204
layer_past[:, 0],
207205
layer_past[:, 1],
208-
torch.select(kv, dim=1, index=1),
209206
attn_output,
210207
cu_seqlens,
211208
max_s,
@@ -524,7 +521,7 @@ def __init__(self, config, weights):
524521
self.word_embeddings = TensorParallelEmbedding(
525522
prefix="transformer.word_embeddings", weights=weights
526523
)
527-
if config.new_decoder_architecture:
524+
if config.new_decoder_architecture: # "RefinedWeb"
528525
self.h = nn.ModuleList(
529526
[
530527
FlashRWLargeLayer(layer_id, config, weights)
@@ -536,14 +533,18 @@ def __init__(self, config, weights):
536533
2,
537534
self.h[0].self_attention.head_size,
538535
)
539-
else:
536+
else: # "RefinedWebModel"
540537
self.h = nn.ModuleList(
541538
[
542539
FlashRWLayer(layer_id, config, weights)
543540
for layer_id in range(config.num_hidden_layers)
544541
]
545542
)
546-
self.cache_size = self.h[0].self_attention.num_heads_kv
543+
self.cache_size = (
544+
2,
545+
self.h[0].self_attention.num_heads_kv,
546+
self.h[0].self_attention.head_size,
547+
)
547548

548549
self.ln_f = FastLayerNorm.load(
549550
prefix="transformer.ln_f",

0 commit comments

Comments
 (0)