Skip to content

Commit a65fa49

Browse files
committed
fix(examples): te_llama compatibility with HuggingFace transformers >= 4.57
The te_llama.py example was failing with HuggingFace transformers 4.57+ due to API changes in how decoder layer outputs are handled. Changes: - Handle case where hidden_states is passed as a tuple (older HF versions) - Return tensor directly instead of wrapped in tuple (HF 4.57+ expects this) - Fix regex pattern to use raw string (fixes SyntaxWarning) Error fixed: AttributeError: 'tuple' object has no attribute 'contiguous' Tested with: - transformer_engine 2.5.0 - transformers 4.57.3 - PyTorch container nvcr.io/nvidia/pytorch:25.08-py3 Signed-off-by: Santosh Bhavani <[email protected]>
1 parent 702fc5e commit a65fa49

File tree

1 file changed

+10
-5
lines changed

1 file changed

+10
-5
lines changed

docs/examples/te_llama/te_llama.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -72,10 +72,15 @@ def forward(self, hidden_states, *args, attention_mask, **kwargs):
7272
forward pass of the `TransformerLayer`. Also, make sure the output
7373
format matches the output of the HF's `LlamaDecoderLayer`.
7474
"""
75-
return (
76-
super().forward(
77-
hidden_states, attention_mask=attention_mask, rotary_pos_emb=self.te_rope_emb
78-
),
75+
# Handle case where hidden_states might be a tuple (from previous layer output)
76+
# This can happen with older versions of HuggingFace transformers
77+
if isinstance(hidden_states, tuple):
78+
hidden_states = hidden_states[0]
79+
80+
# Return tensor directly for HuggingFace transformers >= 4.57
81+
# (older versions wrapped output in tuple and extracted with layer_outputs[0])
82+
return super().forward(
83+
hidden_states, attention_mask=attention_mask, rotary_pos_emb=self.te_rope_emb
7984
)
8085

8186

@@ -162,7 +167,7 @@ def replace_params(hf_state_dict, te_state_dict, config):
162167
# collect all layer prefixes to update
163168
all_layer_prefixes = set()
164169
for param_key in hf_state_dict.keys():
165-
layer_prefix_pat = "model.layers.\d+."
170+
layer_prefix_pat = r"model.layers.\d+."
166171
m = re.match(layer_prefix_pat, param_key)
167172
if m is not None:
168173
all_layer_prefixes.add(m.group())

0 commit comments

Comments
 (0)