Skip to content

Commit a85d473

Browse files
committed
minor: read attn impl from config json
Signed-off-by: h-guo18 <[email protected]>
1 parent 75a4dc1 commit a85d473

File tree

2 files changed

+2
-6
lines changed

2 files changed

+2
-6
lines changed

examples/speculative_decoding/eagle_config.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,6 @@
66
"original_max_position_embeddings": 8192,
77
"rope_type": "llama3"
88
},
9-
"initializer_range": 0.02
9+
"initializer_range": 0.02,
10+
"attn_implementation": "flex_attention"
1011
}

modelopt/torch/speculative/plugins/transformers.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -184,10 +184,6 @@ def __init__(self, config, decoder_layer_cls, bias=False):
184184
super().__init__()
185185
self.config = config
186186

187-
# Use flex attention for efficient TTT
188-
# config._attn_implementation = "flex_attention"
189-
config.attn_implementation = "sdpa"
190-
191187
self.layers = nn.ModuleList(
192188
[decoder_layer_cls(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
193189
)
@@ -446,7 +442,6 @@ def modify(
446442
eagle_architecture_config=eagle_architecture_config,
447443
)
448444
self.eagle_config = PretrainedConfig.from_dict(eagle_architecture_config)
449-
self.eagle_config._attn_implementation = "sdpa"
450445
decoder_cls = (
451446
type(self.model.layers[-1]) if self.eagle_reuse_base_decoder else LlamaDecoderLayer
452447
)

0 commit comments

Comments
 (0)