Skip to content

Commit d8ee14f

Browse files
Sync with transformers 4.57.1 (#1016)
# What does this PR do? As per title. This PR checks the difference between `4.55.4` and `4.57.1` and updates ON to account for the changes. cc @dacorvo @tengomucho who are OOO at the time of this PR --------- Co-authored-by: JingyaHuang <[email protected]>
1 parent 454c634 commit d8ee14f

File tree

18 files changed

+135
-111
lines changed

18 files changed

+135
-111
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def main():
115115
model = NeuronModelForCausalLM.from_pretrained(
116116
model_id,
117117
training_args.trn_config,
118-
torch_dtype=torch.bfloat16,
118+
dtype=torch.bfloat16,
119119
attn_implementation="flash_attention_2", # Enable flash attention
120120
)
121121

docs/source/contribute/contribute_for_training.mdx

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ class YourModelEmbeddings(nn.Module):
8080
self.embed_tokens = ParallelEmbedding(
8181
config.vocab_size,
8282
config.hidden_size,
83-
dtype=config.torch_dtype,
83+
dtype=config.dtype,
8484
sequence_parallel_enabled=trn_config.sequence_parallel_enabled,
8585
)
8686
```
@@ -105,7 +105,7 @@ class YourModelMLP(nn.Module, CustomModule):
105105
bias=False,
106106
gather_output=False,
107107
sequence_parallel_enabled=trn_config.sequence_parallel_enabled,
108-
dtype=config.torch_dtype,
108+
dtype=config.dtype,
109109
)
110110

111111
self.down_proj = RowParallelLinear(
@@ -114,7 +114,7 @@ class YourModelMLP(nn.Module, CustomModule):
114114
bias=False,
115115
input_is_parallel=True,
116116
sequence_parallel_enabled=trn_config.sequence_parallel_enabled,
117-
dtype=config.torch_dtype,
117+
dtype=config.dtype,
118118
)
119119

120120
# Define transformation specs
@@ -151,23 +151,23 @@ class YourModelAttention(nn.Module, CustomModule):
151151
bias=False,
152152
gather_output=False,
153153
sequence_parallel_enabled=trn_config.sequence_parallel_enabled,
154-
dtype=config.torch_dtype,
154+
dtype=config.dtype,
155155
)
156156
self.k_proj = ColumnParallelLinear(
157157
config.hidden_size,
158158
self.num_key_value_heads * self.head_dim,
159159
bias=False,
160160
gather_output=False,
161161
sequence_parallel_enabled=trn_config.sequence_parallel_enabled,
162-
dtype=config.torch_dtype,
162+
dtype=config.dtype,
163163
)
164164
self.v_proj = ColumnParallelLinear(
165165
config.hidden_size,
166166
self.num_key_value_heads * self.head_dim,
167167
bias=False,
168168
gather_output=False,
169169
sequence_parallel_enabled=trn_config.sequence_parallel_enabled,
170-
dtype=config.torch_dtype,
170+
dtype=config.dtype,
171171
)
172172

173173
self.o_proj = RowParallelLinear(
@@ -176,7 +176,7 @@ class YourModelAttention(nn.Module, CustomModule):
176176
bias=False,
177177
input_is_parallel=True,
178178
sequence_parallel_enabled=trn_config.sequence_parallel_enabled,
179-
dtype=config.torch_dtype,
179+
dtype=config.dtype,
180180
)
181181

182182
# No transformation specs needed - regular parallel layers
@@ -201,7 +201,7 @@ class YourModelAttention(nn.Module, CustomModule):
201201
bias=False,
202202
gather_output=False,
203203
sequence_parallel_enabled=trn_config.sequence_parallel_enabled,
204-
dtype=config.torch_dtype,
204+
dtype=config.dtype,
205205
)
206206

207207
# Define transformation specs for fused QKV
@@ -246,7 +246,7 @@ class YourModelAttention(nn.Module, CustomModule):
246246
sequence_parallel_enabled=trn_config.sequence_parallel_enabled,
247247
kv_size_multiplier=self.kv_size_multiplier,
248248
fuse_qkv=trn_config.fuse_qkv,
249-
dtype=config.torch_dtype,
249+
dtype=config.dtype,
250250
)
251251

252252
# Define transformation specs for GQA QKV
@@ -336,7 +336,7 @@ class YourModelForCausalLM(NeuronModelMixin, YourPreTrainedModel):
336336
config.vocab_size,
337337
bias=False,
338338
gather_output=False,
339-
dtype=config.torch_dtype,
339+
dtype=config.dtype,
340340
)
341341

342342
self.post_init()
@@ -473,7 +473,7 @@ Update `tests/training/test_modeling_auto.py`:
473473
@is_trainium_test
474474
def test_auto_model_with_supported_architecture(from_pretrained):
475475
trn_config = TrainingNeuronConfig()
476-
kwargs = {"torch_dtype": torch.bfloat16}
476+
kwargs = {"dtype": torch.bfloat16}
477477
for model_name_or_path in [
478478
"michaelbenayoun/llama-2-tiny-4kv-heads-4layers-random",
479479
"michaelbenayoun/granite-tiny-4kv-heads-4layers-random",
@@ -487,7 +487,7 @@ def test_auto_model_with_supported_architecture(from_pretrained):
487487
@is_trainium_test
488488
def test_auto_model_for_causal_lm_with_supported_architecture(from_pretrained):
489489
trn_config = TrainingNeuronConfig()
490-
kwargs = {"torch_dtype": torch.bfloat16}
490+
kwargs = {"dtype": torch.bfloat16}
491491
for model_name_or_path in [
492492
"michaelbenayoun/llama-2-tiny-4kv-heads-4layers-random",
493493
"michaelbenayoun/granite-tiny-4kv-heads-4layers-random",

docs/source/quickstart.mdx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def main():
7979
model = NeuronModelForCausalLM.from_pretrained(
8080
model_id,
8181
training_args.trn_config,
82-
torch_dtype=torch.bfloat16,
82+
dtype=torch.bfloat16,
8383
attn_implementation="flash_attention_2", # Enable flash attention
8484
)
8585

docs/source/training_tutorials/finetune_llama.mdx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ dtype = torch.bfloat16 if training_args.bf16 else torch.float32
138138
model = NeuronModelForCausalLM.from_pretrained(
139139
model_id,
140140
trn_config,
141-
torch_dtype=dtype,
141+
dtype=dtype,
142142
# Use FlashAttention2 for better performance and to be able to use larger sequence lengths.
143143
attn_implementation="flash_attention_2",
144144
)

docs/source/training_tutorials/finetune_qwen3.mdx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ dtype = torch.bfloat16 if training_args.bf16 else torch.float32
137137
model = NeuronModelForCausalLM.from_pretrained(
138138
model_id,
139139
trn_config,
140-
torch_dtype=dtype,
140+
dtype=dtype,
141141
# Use FlashAttention2 for better performance and to be able to use larger sequence lengths.
142142
attn_implementation="flash_attention_2",
143143
)

optimum/neuron/modeling_diffusion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1193,7 +1193,7 @@ def forward(
11931193

11941194
outputs = self.model(*inputs)
11951195
if self.config.model_type == "t5" and isinstance(outputs, dict): # Flux text encoder 2
1196-
return [outputs["last_hidden_state"].to(self.config.torch_dtype)]
1196+
return [outputs["last_hidden_state"].to(self.config.dtype)]
11971197

11981198
if return_dict and not isinstance(outputs, dict):
11991199
outputs = ModelOutput(dict(zip(self.neuron_config.outputs, outputs)))

optimum/neuron/modeling_traced.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -612,7 +612,7 @@ def neuron_padding_manager(self, inputs: dict[str, "torch.Tensor"]):
612612

613613
@staticmethod
614614
def remove_padding(
615-
outputs: list[torch.Tensor],
615+
outputs: list[torch.Tensor] | dict,
616616
dims: list[int],
617617
indices: list[int],
618618
padding_side: Literal["right", "left"] = "right",
@@ -633,6 +633,8 @@ def remove_padding(
633633
if len(dims) != len(indices):
634634
raise ValueError(f"The size of `dims`({len(dims)}) and indices`({len(indices)}) must be equal.")
635635

636+
if isinstance(outputs, dict):
637+
outputs = list(outputs.values())
636638
for dim, indice in zip(dims, indices):
637639
if padding_side == "right":
638640
outputs = [

optimum/neuron/models/inference/backend/modules/generation/generation_utils.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from typing import Any
1818

1919
import torch
20-
from transformers import GenerationConfig
20+
from transformers import GenerationConfig, PreTrainedModel
2121
from transformers.generation import GenerationMixin, SampleDecoderOnlyOutput
2222
from transformers.generation.logits_process import LogitsProcessorList
2323
from transformers.generation.stopping_criteria import StoppingCriteriaList
@@ -270,14 +270,13 @@ def _update_model_kwargs_for_generation(
270270
def _assisted_decoding(
271271
self,
272272
input_ids: torch.LongTensor,
273-
candidate_generator: "CandidateGenerator", # noqa
274273
stopping_criteria: StoppingCriteriaList,
275274
generation_config: GenerationConfig,
275+
assistant_model: "PreTrainedModel | None" = None,
276276
**model_kwargs,
277277
):
278278
pad_token_id = generation_config.pad_token_id
279279
eos_token_id = generation_config.eos_token_id
280-
assistant_model = candidate_generator.assistant_model
281280

282281
if assistant_model.neuron_config.on_device_sampling:
283282
raise ValueError("Assistant model must not use on-device sampling")

optimum/neuron/models/inference/modeling_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def get_neuron_config(
138138
batch_size=batch_size,
139139
sequence_length=sequence_length,
140140
tensor_parallel_size=tensor_parallel_size,
141-
dtype=DTYPE_MAPPER.pt(config.torch_dtype),
141+
dtype=DTYPE_MAPPER.pt(config.dtype),
142142
)
143143

144144
@classmethod

optimum/neuron/models/inference/t5/modeling_t5.py

Lines changed: 22 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from torch import nn
2828
from transformers import T5Config
2929
from transformers.activations import ACT2FN
30+
from transformers.cache_utils import EncoderDecoderCache
3031
from transformers.models.t5.modeling_t5 import (
3132
T5Attention,
3233
T5DenseActDense,
@@ -154,7 +155,7 @@ def forward(
154155
mask=None,
155156
key_value_states=None,
156157
position_bias=None,
157-
past_key_value=None,
158+
past_key_values=None,
158159
layer_head_mask=None,
159160
query_length=None,
160161
use_cache=False,
@@ -177,38 +178,38 @@ def forward(
177178
batch_size, -1, self.num_attention_heads_per_partition, self.key_value_proj_dim
178179
).transpose(1, 2)
179180

180-
if past_key_value is not None:
181-
is_updated = past_key_value.is_updated.get(self.layer_idx)
181+
# Check is encoder-decoder model is being used. Otherwise we'll get `DynamicCache`
182+
is_updated = False
183+
if isinstance(past_key_values, EncoderDecoderCache):
184+
is_updated = past_key_values.is_updated.get(self.layer_idx)
182185
if is_cross_attention:
183186
# after the first generated id, we can subsequently re-use all key/value_states from cache
184-
curr_past_key_value = past_key_value.cross_attention_cache
187+
curr_past_key_values = past_key_values.cross_attention_cache
185188
else:
186-
curr_past_key_value = past_key_value.self_attention_cache
189+
curr_past_key_values = past_key_values.self_attention_cache
190+
else:
191+
curr_past_key_values = past_key_values
187192

188193
current_states = key_value_states if is_cross_attention else hidden_states
189-
if is_cross_attention and past_key_value is not None and is_updated:
194+
if is_cross_attention and past_key_values is not None and is_updated:
190195
# reuse k,v, cross_attentions
191-
key_states = curr_past_key_value.key_cache[self.layer_idx]
192-
value_states = curr_past_key_value.value_cache[self.layer_idx]
196+
key_states = curr_past_key_values.layers[self.layer_idx].keys
197+
value_states = curr_past_key_values.layers[self.layer_idx].values
193198
else:
194199
key_states = self.k(current_states)
195200
value_states = self.v(current_states)
196-
key_states = key_states.view(
197-
batch_size, -1, self.num_attention_heads_per_partition, self.key_value_proj_dim
198-
).transpose(1, 2)
199-
value_states = value_states.view(
200-
batch_size, -1, self.num_attention_heads_per_partition, self.key_value_proj_dim
201-
).transpose(1, 2)
202-
203-
if past_key_value is not None:
201+
key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
202+
value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
203+
204+
if past_key_values is not None:
204205
# save all key/value_states to cache to be re-used for fast auto-regressive generation
205206
cache_position = cache_position if not is_cross_attention else None
206-
key_states, value_states = curr_past_key_value.update(
207+
key_states, value_states = curr_past_key_values.update(
207208
key_states, value_states, self.layer_idx, {"cache_position": cache_position}
208209
)
209210
# set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
210-
if is_cross_attention:
211-
past_key_value.is_updated[self.layer_idx] = True
211+
if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache):
212+
past_key_values.is_updated[self.layer_idx] = True
212213

213214
# compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
214215
scores = torch.matmul(query_states, key_states.transpose(3, 2))
@@ -235,14 +236,9 @@ def forward(
235236
causal_mask = mask[:, :, :, : key_states.shape[-2]]
236237
position_bias = position_bias + causal_mask
237238

238-
if self.pruned_heads:
239-
mask = torch.ones(position_bias.shape[1])
240-
mask[list(self.pruned_heads)] = 0
241-
position_bias_masked = position_bias[:, mask.bool()]
242-
else:
243-
position_bias_masked = position_bias
244-
239+
position_bias_masked = position_bias
245240
scores += position_bias_masked
241+
246242
# (batch_size, n_heads, seq_length, key_length)
247243
attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores)
248244
attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)

0 commit comments

Comments
 (0)