Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def main():
model = NeuronModelForCausalLM.from_pretrained(
model_id,
training_args.trn_config,
torch_dtype=torch.bfloat16,
dtype=torch.bfloat16,
attn_implementation="flash_attention_2", # Enable flash attention
)

Expand Down
24 changes: 12 additions & 12 deletions docs/source/contribute/contribute_for_training.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ class YourModelEmbeddings(nn.Module):
self.embed_tokens = ParallelEmbedding(
config.vocab_size,
config.hidden_size,
dtype=config.torch_dtype,
dtype=config.dtype,
sequence_parallel_enabled=trn_config.sequence_parallel_enabled,
)
```
Expand All @@ -105,7 +105,7 @@ class YourModelMLP(nn.Module, CustomModule):
bias=False,
gather_output=False,
sequence_parallel_enabled=trn_config.sequence_parallel_enabled,
dtype=config.torch_dtype,
dtype=config.dtype,
)

self.down_proj = RowParallelLinear(
Expand All @@ -114,7 +114,7 @@ class YourModelMLP(nn.Module, CustomModule):
bias=False,
input_is_parallel=True,
sequence_parallel_enabled=trn_config.sequence_parallel_enabled,
dtype=config.torch_dtype,
dtype=config.dtype,
)

# Define transformation specs
Expand Down Expand Up @@ -151,23 +151,23 @@ class YourModelAttention(nn.Module, CustomModule):
bias=False,
gather_output=False,
sequence_parallel_enabled=trn_config.sequence_parallel_enabled,
dtype=config.torch_dtype,
dtype=config.dtype,
)
self.k_proj = ColumnParallelLinear(
config.hidden_size,
self.num_key_value_heads * self.head_dim,
bias=False,
gather_output=False,
sequence_parallel_enabled=trn_config.sequence_parallel_enabled,
dtype=config.torch_dtype,
dtype=config.dtype,
)
self.v_proj = ColumnParallelLinear(
config.hidden_size,
self.num_key_value_heads * self.head_dim,
bias=False,
gather_output=False,
sequence_parallel_enabled=trn_config.sequence_parallel_enabled,
dtype=config.torch_dtype,
dtype=config.dtype,
)

self.o_proj = RowParallelLinear(
Expand All @@ -176,7 +176,7 @@ class YourModelAttention(nn.Module, CustomModule):
bias=False,
input_is_parallel=True,
sequence_parallel_enabled=trn_config.sequence_parallel_enabled,
dtype=config.torch_dtype,
dtype=config.dtype,
)

# No transformation specs needed - regular parallel layers
Expand All @@ -201,7 +201,7 @@ class YourModelAttention(nn.Module, CustomModule):
bias=False,
gather_output=False,
sequence_parallel_enabled=trn_config.sequence_parallel_enabled,
dtype=config.torch_dtype,
dtype=config.dtype,
)

# Define transformation specs for fused QKV
Expand Down Expand Up @@ -246,7 +246,7 @@ class YourModelAttention(nn.Module, CustomModule):
sequence_parallel_enabled=trn_config.sequence_parallel_enabled,
kv_size_multiplier=self.kv_size_multiplier,
fuse_qkv=trn_config.fuse_qkv,
dtype=config.torch_dtype,
dtype=config.dtype,
)

# Define transformation specs for GQA QKV
Expand Down Expand Up @@ -336,7 +336,7 @@ class YourModelForCausalLM(NeuronModelMixin, YourPreTrainedModel):
config.vocab_size,
bias=False,
gather_output=False,
dtype=config.torch_dtype,
dtype=config.dtype,
)

self.post_init()
Expand Down Expand Up @@ -473,7 +473,7 @@ Update `tests/training/test_modeling_auto.py`:
@is_trainium_test
def test_auto_model_with_supported_architecture(from_pretrained):
trn_config = TrainingNeuronConfig()
kwargs = {"torch_dtype": torch.bfloat16}
kwargs = {"dtype": torch.bfloat16}
for model_name_or_path in [
"michaelbenayoun/llama-2-tiny-4kv-heads-4layers-random",
"michaelbenayoun/granite-tiny-4kv-heads-4layers-random",
Expand All @@ -487,7 +487,7 @@ def test_auto_model_with_supported_architecture(from_pretrained):
@is_trainium_test
def test_auto_model_for_causal_lm_with_supported_architecture(from_pretrained):
trn_config = TrainingNeuronConfig()
kwargs = {"torch_dtype": torch.bfloat16}
kwargs = {"dtype": torch.bfloat16}
for model_name_or_path in [
"michaelbenayoun/llama-2-tiny-4kv-heads-4layers-random",
"michaelbenayoun/granite-tiny-4kv-heads-4layers-random",
Expand Down
2 changes: 1 addition & 1 deletion docs/source/quickstart.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def main():
model = NeuronModelForCausalLM.from_pretrained(
model_id,
training_args.trn_config,
torch_dtype=torch.bfloat16,
dtype=torch.bfloat16,
attn_implementation="flash_attention_2", # Enable flash attention
)

Expand Down
2 changes: 1 addition & 1 deletion docs/source/training_tutorials/finetune_llama.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ dtype = torch.bfloat16 if training_args.bf16 else torch.float32
model = NeuronModelForCausalLM.from_pretrained(
model_id,
trn_config,
torch_dtype=dtype,
dtype=dtype,
# Use FlashAttention2 for better performance and to be able to use larger sequence lengths.
attn_implementation="flash_attention_2",
)
Expand Down
2 changes: 1 addition & 1 deletion docs/source/training_tutorials/finetune_qwen3.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ dtype = torch.bfloat16 if training_args.bf16 else torch.float32
model = NeuronModelForCausalLM.from_pretrained(
model_id,
trn_config,
torch_dtype=dtype,
dtype=dtype,
# Use FlashAttention2 for better performance and to be able to use larger sequence lengths.
attn_implementation="flash_attention_2",
)
Expand Down
2 changes: 1 addition & 1 deletion optimum/neuron/modeling_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -1193,7 +1193,7 @@ def forward(

outputs = self.model(*inputs)
if self.config.model_type == "t5" and isinstance(outputs, dict): # Flux text encoder 2
return [outputs["last_hidden_state"].to(self.config.torch_dtype)]
return [outputs["last_hidden_state"].to(self.config.dtype)]

if return_dict and not isinstance(outputs, dict):
outputs = ModelOutput(dict(zip(self.neuron_config.outputs, outputs)))
Expand Down
4 changes: 3 additions & 1 deletion optimum/neuron/modeling_traced.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,7 +612,7 @@ def neuron_padding_manager(self, inputs: dict[str, "torch.Tensor"]):

@staticmethod
def remove_padding(
outputs: list[torch.Tensor],
outputs: list[torch.Tensor] | dict,
dims: list[int],
indices: list[int],
padding_side: Literal["right", "left"] = "right",
Expand All @@ -633,6 +633,8 @@ def remove_padding(
if len(dims) != len(indices):
raise ValueError(f"The size of `dims`({len(dims)}) and indices`({len(indices)}) must be equal.")

if isinstance(outputs, dict):
outputs = list(outputs.values())
for dim, indice in zip(dims, indices):
if padding_side == "right":
outputs = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from typing import Any

import torch
from transformers import GenerationConfig
from transformers import GenerationConfig, PreTrainedModel
from transformers.generation import GenerationMixin, SampleDecoderOnlyOutput
from transformers.generation.logits_process import LogitsProcessorList
from transformers.generation.stopping_criteria import StoppingCriteriaList
Expand Down Expand Up @@ -270,14 +270,13 @@ def _update_model_kwargs_for_generation(
def _assisted_decoding(
self,
input_ids: torch.LongTensor,
candidate_generator: "CandidateGenerator", # noqa
stopping_criteria: StoppingCriteriaList,
generation_config: GenerationConfig,
assistant_model: "PreTrainedModel | None" = None,
**model_kwargs,
):
pad_token_id = generation_config.pad_token_id
eos_token_id = generation_config.eos_token_id
assistant_model = candidate_generator.assistant_model

if assistant_model.neuron_config.on_device_sampling:
raise ValueError("Assistant model must not use on-device sampling")
Expand Down
2 changes: 1 addition & 1 deletion optimum/neuron/models/inference/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def get_neuron_config(
batch_size=batch_size,
sequence_length=sequence_length,
tensor_parallel_size=tensor_parallel_size,
dtype=DTYPE_MAPPER.pt(config.torch_dtype),
dtype=DTYPE_MAPPER.pt(config.dtype),
)

@classmethod
Expand Down
48 changes: 22 additions & 26 deletions optimum/neuron/models/inference/t5/modeling_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from torch import nn
from transformers import T5Config
from transformers.activations import ACT2FN
from transformers.cache_utils import EncoderDecoderCache
from transformers.models.t5.modeling_t5 import (
T5Attention,
T5DenseActDense,
Expand Down Expand Up @@ -154,7 +155,7 @@ def forward(
mask=None,
key_value_states=None,
position_bias=None,
past_key_value=None,
past_key_values=None,
layer_head_mask=None,
query_length=None,
use_cache=False,
Expand All @@ -177,38 +178,38 @@ def forward(
batch_size, -1, self.num_attention_heads_per_partition, self.key_value_proj_dim
).transpose(1, 2)

if past_key_value is not None:
is_updated = past_key_value.is_updated.get(self.layer_idx)
# Check is encoder-decoder model is being used. Otherwise we'll get `DynamicCache`
is_updated = False
if isinstance(past_key_values, EncoderDecoderCache):
is_updated = past_key_values.is_updated.get(self.layer_idx)
if is_cross_attention:
# after the first generated id, we can subsequently re-use all key/value_states from cache
curr_past_key_value = past_key_value.cross_attention_cache
curr_past_key_values = past_key_values.cross_attention_cache
else:
curr_past_key_value = past_key_value.self_attention_cache
curr_past_key_values = past_key_values.self_attention_cache
else:
curr_past_key_values = past_key_values

current_states = key_value_states if is_cross_attention else hidden_states
if is_cross_attention and past_key_value is not None and is_updated:
if is_cross_attention and past_key_values is not None and is_updated:
# reuse k,v, cross_attentions
key_states = curr_past_key_value.key_cache[self.layer_idx]
value_states = curr_past_key_value.value_cache[self.layer_idx]
key_states = curr_past_key_values.layers[self.layer_idx].keys
value_states = curr_past_key_values.layers[self.layer_idx].values
else:
key_states = self.k(current_states)
value_states = self.v(current_states)
key_states = key_states.view(
batch_size, -1, self.num_attention_heads_per_partition, self.key_value_proj_dim
).transpose(1, 2)
value_states = value_states.view(
batch_size, -1, self.num_attention_heads_per_partition, self.key_value_proj_dim
).transpose(1, 2)

if past_key_value is not None:
key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)

if past_key_values is not None:
# save all key/value_states to cache to be re-used for fast auto-regressive generation
cache_position = cache_position if not is_cross_attention else None
key_states, value_states = curr_past_key_value.update(
key_states, value_states = curr_past_key_values.update(
key_states, value_states, self.layer_idx, {"cache_position": cache_position}
)
# set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
if is_cross_attention:
past_key_value.is_updated[self.layer_idx] = True
if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache):
past_key_values.is_updated[self.layer_idx] = True

# compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
scores = torch.matmul(query_states, key_states.transpose(3, 2))
Expand All @@ -235,14 +236,9 @@ def forward(
causal_mask = mask[:, :, :, : key_states.shape[-2]]
position_bias = position_bias + causal_mask

if self.pruned_heads:
mask = torch.ones(position_bias.shape[1])
mask[list(self.pruned_heads)] = 0
position_bias_masked = position_bias[:, mask.bool()]
else:
position_bias_masked = position_bias

position_bias_masked = position_bias
scores += position_bias_masked

# (batch_size, n_heads, seq_length, key_length)
attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores)
attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
Expand Down
20 changes: 10 additions & 10 deletions optimum/neuron/models/training/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def __init__(self, config, trn_config: TrainingNeuronConfig):
init_method=init_method,
sequence_parallel_enabled=self.trn_config.sequence_parallel_enabled,
sequence_dimension=0,
dtype=self.config.torch_dtype,
dtype=self.config.dtype,
)
self.down_proj = RowParallelLinear(
self.intermediate_size,
Expand All @@ -220,7 +220,7 @@ def __init__(self, config, trn_config: TrainingNeuronConfig):
init_method=init_method,
sequence_parallel_enabled=self.trn_config.sequence_parallel_enabled,
sequence_dimension=0,
dtype=self.config.torch_dtype,
dtype=self.config.dtype,
)

def forward(self, x):
Expand Down Expand Up @@ -333,7 +333,7 @@ def __init__(self, config: LlamaConfig, trn_config: TrainingNeuronConfig, layer_
sequence_parallel_enabled=trn_config.sequence_parallel_enabled,
kv_size_multiplier=self.kv_size_multiplier,
fuse_qkv=trn_config.fuse_qkv,
dtype=self.config.torch_dtype,
dtype=self.config.dtype,
)

gqa_qkv_specs = GQAQKVColumnParallelLinearSpec(
Expand Down Expand Up @@ -361,7 +361,7 @@ def __init__(self, config: LlamaConfig, trn_config: TrainingNeuronConfig, layer_
init_method=init_method,
sequence_parallel_enabled=trn_config.sequence_parallel_enabled,
sequence_dimension=0,
dtype=self.config.torch_dtype,
dtype=self.config.dtype,
)
self.specs.add_spec(
FusedLinearsSpec(
Expand All @@ -382,7 +382,7 @@ def __init__(self, config: LlamaConfig, trn_config: TrainingNeuronConfig, layer_
init_method=init_method,
sequence_parallel_enabled=trn_config.sequence_parallel_enabled,
sequence_dimension=0,
dtype=self.config.torch_dtype,
dtype=self.config.dtype,
)
self.k_proj = ColumnParallelLinear(
self.hidden_size,
Expand All @@ -392,7 +392,7 @@ def __init__(self, config: LlamaConfig, trn_config: TrainingNeuronConfig, layer_
init_method=init_method,
sequence_parallel_enabled=trn_config.sequence_parallel_enabled,
sequence_dimension=0,
dtype=self.config.torch_dtype,
dtype=self.config.dtype,
)
self.v_proj = ColumnParallelLinear(
self.hidden_size,
Expand All @@ -402,7 +402,7 @@ def __init__(self, config: LlamaConfig, trn_config: TrainingNeuronConfig, layer_
init_method=init_method,
sequence_parallel_enabled=trn_config.sequence_parallel_enabled,
sequence_dimension=0,
dtype=self.config.torch_dtype,
dtype=self.config.dtype,
)
self.o_proj = RowParallelLinear(
self.num_heads * self.head_dim,
Expand All @@ -412,7 +412,7 @@ def __init__(self, config: LlamaConfig, trn_config: TrainingNeuronConfig, layer_
init_method=init_method,
sequence_parallel_enabled=trn_config.sequence_parallel_enabled,
sequence_dimension=0,
dtype=self.config.torch_dtype,
dtype=self.config.dtype,
)
self.num_heads = neuronx_dist_utils.divide(config.num_attention_heads, tp_size)
self.num_key_value_heads = neuronx_dist_utils.divide(
Expand Down Expand Up @@ -606,7 +606,7 @@ def __init__(self, config: LlamaConfig, trn_config: TrainingNeuronConfig):
self.padding_idx,
init_method=init_method,
sequence_parallel_enabled=trn_config.sequence_parallel_enabled,
dtype=config.torch_dtype,
dtype=config.dtype,
)
self.layers = nn.ModuleList(
[LlamaDecoderLayer(config, trn_config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
Expand Down Expand Up @@ -715,7 +715,7 @@ def __init__(self, config, trn_config: TrainingNeuronConfig):
init_method=init_method,
sequence_parallel_enabled=trn_config.sequence_parallel_enabled,
sequence_dimension=0,
dtype=self.config.torch_dtype,
dtype=self.config.dtype,
)

self.vocab_size = config.vocab_size // get_tensor_model_parallel_size()
Expand Down
Loading
Loading