diff --git a/README.md b/README.md index 5e7916c4c..e499d696a 100644 --- a/README.md +++ b/README.md @@ -115,7 +115,7 @@ def main(): model = NeuronModelForCausalLM.from_pretrained( model_id, training_args.trn_config, - torch_dtype=torch.bfloat16, + dtype=torch.bfloat16, attn_implementation="flash_attention_2", # Enable flash attention ) diff --git a/docs/source/contribute/contribute_for_training.mdx b/docs/source/contribute/contribute_for_training.mdx index 994dc9991..62475c1e0 100644 --- a/docs/source/contribute/contribute_for_training.mdx +++ b/docs/source/contribute/contribute_for_training.mdx @@ -80,7 +80,7 @@ class YourModelEmbeddings(nn.Module): self.embed_tokens = ParallelEmbedding( config.vocab_size, config.hidden_size, - dtype=config.torch_dtype, + dtype=config.dtype, sequence_parallel_enabled=trn_config.sequence_parallel_enabled, ) ``` @@ -105,7 +105,7 @@ class YourModelMLP(nn.Module, CustomModule): bias=False, gather_output=False, sequence_parallel_enabled=trn_config.sequence_parallel_enabled, - dtype=config.torch_dtype, + dtype=config.dtype, ) self.down_proj = RowParallelLinear( @@ -114,7 +114,7 @@ class YourModelMLP(nn.Module, CustomModule): bias=False, input_is_parallel=True, sequence_parallel_enabled=trn_config.sequence_parallel_enabled, - dtype=config.torch_dtype, + dtype=config.dtype, ) # Define transformation specs @@ -151,7 +151,7 @@ class YourModelAttention(nn.Module, CustomModule): bias=False, gather_output=False, sequence_parallel_enabled=trn_config.sequence_parallel_enabled, - dtype=config.torch_dtype, + dtype=config.dtype, ) self.k_proj = ColumnParallelLinear( config.hidden_size, @@ -159,7 +159,7 @@ class YourModelAttention(nn.Module, CustomModule): bias=False, gather_output=False, sequence_parallel_enabled=trn_config.sequence_parallel_enabled, - dtype=config.torch_dtype, + dtype=config.dtype, ) self.v_proj = ColumnParallelLinear( config.hidden_size, @@ -167,7 +167,7 @@ class YourModelAttention(nn.Module, CustomModule): bias=False, gather_output=False, sequence_parallel_enabled=trn_config.sequence_parallel_enabled, - dtype=config.torch_dtype, + dtype=config.dtype, ) self.o_proj = RowParallelLinear( @@ -176,7 +176,7 @@ class YourModelAttention(nn.Module, CustomModule): bias=False, input_is_parallel=True, sequence_parallel_enabled=trn_config.sequence_parallel_enabled, - dtype=config.torch_dtype, + dtype=config.dtype, ) # No transformation specs needed - regular parallel layers @@ -201,7 +201,7 @@ class YourModelAttention(nn.Module, CustomModule): bias=False, gather_output=False, sequence_parallel_enabled=trn_config.sequence_parallel_enabled, - dtype=config.torch_dtype, + dtype=config.dtype, ) # Define transformation specs for fused QKV @@ -246,7 +246,7 @@ class YourModelAttention(nn.Module, CustomModule): sequence_parallel_enabled=trn_config.sequence_parallel_enabled, kv_size_multiplier=self.kv_size_multiplier, fuse_qkv=trn_config.fuse_qkv, - dtype=config.torch_dtype, + dtype=config.dtype, ) # Define transformation specs for GQA QKV @@ -336,7 +336,7 @@ class YourModelForCausalLM(NeuronModelMixin, YourPreTrainedModel): config.vocab_size, bias=False, gather_output=False, - dtype=config.torch_dtype, + dtype=config.dtype, ) self.post_init() @@ -473,7 +473,7 @@ Update `tests/training/test_modeling_auto.py`: @is_trainium_test def test_auto_model_with_supported_architecture(from_pretrained): trn_config = TrainingNeuronConfig() - kwargs = {"torch_dtype": torch.bfloat16} + kwargs = {"dtype": torch.bfloat16} for model_name_or_path in [ "michaelbenayoun/llama-2-tiny-4kv-heads-4layers-random", "michaelbenayoun/granite-tiny-4kv-heads-4layers-random", @@ -487,7 +487,7 @@ def test_auto_model_with_supported_architecture(from_pretrained): @is_trainium_test def test_auto_model_for_causal_lm_with_supported_architecture(from_pretrained): trn_config = TrainingNeuronConfig() - kwargs = {"torch_dtype": torch.bfloat16} + kwargs = {"dtype": torch.bfloat16} for model_name_or_path in [ "michaelbenayoun/llama-2-tiny-4kv-heads-4layers-random", "michaelbenayoun/granite-tiny-4kv-heads-4layers-random", diff --git a/docs/source/quickstart.mdx b/docs/source/quickstart.mdx index 95ddfc833..e9739de66 100644 --- a/docs/source/quickstart.mdx +++ b/docs/source/quickstart.mdx @@ -79,7 +79,7 @@ def main(): model = NeuronModelForCausalLM.from_pretrained( model_id, training_args.trn_config, - torch_dtype=torch.bfloat16, + dtype=torch.bfloat16, attn_implementation="flash_attention_2", # Enable flash attention ) diff --git a/docs/source/training_tutorials/finetune_llama.mdx b/docs/source/training_tutorials/finetune_llama.mdx index a28f5c04c..4e7ae6697 100644 --- a/docs/source/training_tutorials/finetune_llama.mdx +++ b/docs/source/training_tutorials/finetune_llama.mdx @@ -138,7 +138,7 @@ dtype = torch.bfloat16 if training_args.bf16 else torch.float32 model = NeuronModelForCausalLM.from_pretrained( model_id, trn_config, - torch_dtype=dtype, + dtype=dtype, # Use FlashAttention2 for better performance and to be able to use larger sequence lengths. attn_implementation="flash_attention_2", ) diff --git a/docs/source/training_tutorials/finetune_qwen3.mdx b/docs/source/training_tutorials/finetune_qwen3.mdx index 0c9f4f379..1cf4a2d6e 100644 --- a/docs/source/training_tutorials/finetune_qwen3.mdx +++ b/docs/source/training_tutorials/finetune_qwen3.mdx @@ -137,7 +137,7 @@ dtype = torch.bfloat16 if training_args.bf16 else torch.float32 model = NeuronModelForCausalLM.from_pretrained( model_id, trn_config, - torch_dtype=dtype, + dtype=dtype, # Use FlashAttention2 for better performance and to be able to use larger sequence lengths. attn_implementation="flash_attention_2", ) diff --git a/optimum/neuron/modeling_diffusion.py b/optimum/neuron/modeling_diffusion.py index 4bf9cb212..45d579a36 100644 --- a/optimum/neuron/modeling_diffusion.py +++ b/optimum/neuron/modeling_diffusion.py @@ -1193,7 +1193,7 @@ def forward( outputs = self.model(*inputs) if self.config.model_type == "t5" and isinstance(outputs, dict): # Flux text encoder 2 - return [outputs["last_hidden_state"].to(self.config.torch_dtype)] + return [outputs["last_hidden_state"].to(self.config.dtype)] if return_dict and not isinstance(outputs, dict): outputs = ModelOutput(dict(zip(self.neuron_config.outputs, outputs))) diff --git a/optimum/neuron/modeling_traced.py b/optimum/neuron/modeling_traced.py index 99184f8f9..7f00214c4 100644 --- a/optimum/neuron/modeling_traced.py +++ b/optimum/neuron/modeling_traced.py @@ -612,7 +612,7 @@ def neuron_padding_manager(self, inputs: dict[str, "torch.Tensor"]): @staticmethod def remove_padding( - outputs: list[torch.Tensor], + outputs: list[torch.Tensor] | dict, dims: list[int], indices: list[int], padding_side: Literal["right", "left"] = "right", @@ -633,6 +633,8 @@ def remove_padding( if len(dims) != len(indices): raise ValueError(f"The size of `dims`({len(dims)}) and indices`({len(indices)}) must be equal.") + if isinstance(outputs, dict): + outputs = list(outputs.values()) for dim, indice in zip(dims, indices): if padding_side == "right": outputs = [ diff --git a/optimum/neuron/models/inference/backend/modules/generation/generation_utils.py b/optimum/neuron/models/inference/backend/modules/generation/generation_utils.py index c72bd5762..163690355 100644 --- a/optimum/neuron/models/inference/backend/modules/generation/generation_utils.py +++ b/optimum/neuron/models/inference/backend/modules/generation/generation_utils.py @@ -17,7 +17,7 @@ from typing import Any import torch -from transformers import GenerationConfig +from transformers import GenerationConfig, PreTrainedModel from transformers.generation import GenerationMixin, SampleDecoderOnlyOutput from transformers.generation.logits_process import LogitsProcessorList from transformers.generation.stopping_criteria import StoppingCriteriaList @@ -270,14 +270,13 @@ def _update_model_kwargs_for_generation( def _assisted_decoding( self, input_ids: torch.LongTensor, - candidate_generator: "CandidateGenerator", # noqa stopping_criteria: StoppingCriteriaList, generation_config: GenerationConfig, + assistant_model: "PreTrainedModel | None" = None, **model_kwargs, ): pad_token_id = generation_config.pad_token_id eos_token_id = generation_config.eos_token_id - assistant_model = candidate_generator.assistant_model if assistant_model.neuron_config.on_device_sampling: raise ValueError("Assistant model must not use on-device sampling") diff --git a/optimum/neuron/models/inference/modeling_utils.py b/optimum/neuron/models/inference/modeling_utils.py index 36c31afdb..3be29ae78 100644 --- a/optimum/neuron/models/inference/modeling_utils.py +++ b/optimum/neuron/models/inference/modeling_utils.py @@ -138,7 +138,7 @@ def get_neuron_config( batch_size=batch_size, sequence_length=sequence_length, tensor_parallel_size=tensor_parallel_size, - dtype=DTYPE_MAPPER.pt(config.torch_dtype), + dtype=DTYPE_MAPPER.pt(config.dtype), ) @classmethod diff --git a/optimum/neuron/models/inference/t5/modeling_t5.py b/optimum/neuron/models/inference/t5/modeling_t5.py index 15c3e8c67..49b69ff30 100644 --- a/optimum/neuron/models/inference/t5/modeling_t5.py +++ b/optimum/neuron/models/inference/t5/modeling_t5.py @@ -27,6 +27,7 @@ from torch import nn from transformers import T5Config from transformers.activations import ACT2FN +from transformers.cache_utils import EncoderDecoderCache from transformers.models.t5.modeling_t5 import ( T5Attention, T5DenseActDense, @@ -154,7 +155,7 @@ def forward( mask=None, key_value_states=None, position_bias=None, - past_key_value=None, + past_key_values=None, layer_head_mask=None, query_length=None, use_cache=False, @@ -177,38 +178,38 @@ def forward( batch_size, -1, self.num_attention_heads_per_partition, self.key_value_proj_dim ).transpose(1, 2) - if past_key_value is not None: - is_updated = past_key_value.is_updated.get(self.layer_idx) + # Check is encoder-decoder model is being used. Otherwise we'll get `DynamicCache` + is_updated = False + if isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) if is_cross_attention: # after the first generated id, we can subsequently re-use all key/value_states from cache - curr_past_key_value = past_key_value.cross_attention_cache + curr_past_key_values = past_key_values.cross_attention_cache else: - curr_past_key_value = past_key_value.self_attention_cache + curr_past_key_values = past_key_values.self_attention_cache + else: + curr_past_key_values = past_key_values current_states = key_value_states if is_cross_attention else hidden_states - if is_cross_attention and past_key_value is not None and is_updated: + if is_cross_attention and past_key_values is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_values.layers[self.layer_idx].keys + value_states = curr_past_key_values.layers[self.layer_idx].values else: key_states = self.k(current_states) value_states = self.v(current_states) - key_states = key_states.view( - batch_size, -1, self.num_attention_heads_per_partition, self.key_value_proj_dim - ).transpose(1, 2) - value_states = value_states.view( - batch_size, -1, self.num_attention_heads_per_partition, self.key_value_proj_dim - ).transpose(1, 2) - - if past_key_value is not None: + key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + + if past_key_values is not None: # save all key/value_states to cache to be re-used for fast auto-regressive generation cache_position = cache_position if not is_cross_attention else None - key_states, value_states = curr_past_key_value.update( + key_states, value_states = curr_past_key_values.update( key_states, value_states, self.layer_idx, {"cache_position": cache_position} ) # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls - if is_cross_attention: - past_key_value.is_updated[self.layer_idx] = True + if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache): + past_key_values.is_updated[self.layer_idx] = True # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 scores = torch.matmul(query_states, key_states.transpose(3, 2)) @@ -235,14 +236,9 @@ def forward( causal_mask = mask[:, :, :, : key_states.shape[-2]] position_bias = position_bias + causal_mask - if self.pruned_heads: - mask = torch.ones(position_bias.shape[1]) - mask[list(self.pruned_heads)] = 0 - position_bias_masked = position_bias[:, mask.bool()] - else: - position_bias_masked = position_bias - + position_bias_masked = position_bias scores += position_bias_masked + # (batch_size, n_heads, seq_length, key_length) attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores) attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) diff --git a/optimum/neuron/models/training/llama/modeling_llama.py b/optimum/neuron/models/training/llama/modeling_llama.py index f4cfe4c38..45c913626 100644 --- a/optimum/neuron/models/training/llama/modeling_llama.py +++ b/optimum/neuron/models/training/llama/modeling_llama.py @@ -210,7 +210,7 @@ def __init__(self, config, trn_config: TrainingNeuronConfig): init_method=init_method, sequence_parallel_enabled=self.trn_config.sequence_parallel_enabled, sequence_dimension=0, - dtype=self.config.torch_dtype, + dtype=self.config.dtype, ) self.down_proj = RowParallelLinear( self.intermediate_size, @@ -220,7 +220,7 @@ def __init__(self, config, trn_config: TrainingNeuronConfig): init_method=init_method, sequence_parallel_enabled=self.trn_config.sequence_parallel_enabled, sequence_dimension=0, - dtype=self.config.torch_dtype, + dtype=self.config.dtype, ) def forward(self, x): @@ -333,7 +333,7 @@ def __init__(self, config: LlamaConfig, trn_config: TrainingNeuronConfig, layer_ sequence_parallel_enabled=trn_config.sequence_parallel_enabled, kv_size_multiplier=self.kv_size_multiplier, fuse_qkv=trn_config.fuse_qkv, - dtype=self.config.torch_dtype, + dtype=self.config.dtype, ) gqa_qkv_specs = GQAQKVColumnParallelLinearSpec( @@ -361,7 +361,7 @@ def __init__(self, config: LlamaConfig, trn_config: TrainingNeuronConfig, layer_ init_method=init_method, sequence_parallel_enabled=trn_config.sequence_parallel_enabled, sequence_dimension=0, - dtype=self.config.torch_dtype, + dtype=self.config.dtype, ) self.specs.add_spec( FusedLinearsSpec( @@ -382,7 +382,7 @@ def __init__(self, config: LlamaConfig, trn_config: TrainingNeuronConfig, layer_ init_method=init_method, sequence_parallel_enabled=trn_config.sequence_parallel_enabled, sequence_dimension=0, - dtype=self.config.torch_dtype, + dtype=self.config.dtype, ) self.k_proj = ColumnParallelLinear( self.hidden_size, @@ -392,7 +392,7 @@ def __init__(self, config: LlamaConfig, trn_config: TrainingNeuronConfig, layer_ init_method=init_method, sequence_parallel_enabled=trn_config.sequence_parallel_enabled, sequence_dimension=0, - dtype=self.config.torch_dtype, + dtype=self.config.dtype, ) self.v_proj = ColumnParallelLinear( self.hidden_size, @@ -402,7 +402,7 @@ def __init__(self, config: LlamaConfig, trn_config: TrainingNeuronConfig, layer_ init_method=init_method, sequence_parallel_enabled=trn_config.sequence_parallel_enabled, sequence_dimension=0, - dtype=self.config.torch_dtype, + dtype=self.config.dtype, ) self.o_proj = RowParallelLinear( self.num_heads * self.head_dim, @@ -412,7 +412,7 @@ def __init__(self, config: LlamaConfig, trn_config: TrainingNeuronConfig, layer_ init_method=init_method, sequence_parallel_enabled=trn_config.sequence_parallel_enabled, sequence_dimension=0, - dtype=self.config.torch_dtype, + dtype=self.config.dtype, ) self.num_heads = neuronx_dist_utils.divide(config.num_attention_heads, tp_size) self.num_key_value_heads = neuronx_dist_utils.divide( @@ -606,7 +606,7 @@ def __init__(self, config: LlamaConfig, trn_config: TrainingNeuronConfig): self.padding_idx, init_method=init_method, sequence_parallel_enabled=trn_config.sequence_parallel_enabled, - dtype=config.torch_dtype, + dtype=config.dtype, ) self.layers = nn.ModuleList( [LlamaDecoderLayer(config, trn_config, layer_idx) for layer_idx in range(config.num_hidden_layers)] @@ -715,7 +715,7 @@ def __init__(self, config, trn_config: TrainingNeuronConfig): init_method=init_method, sequence_parallel_enabled=trn_config.sequence_parallel_enabled, sequence_dimension=0, - dtype=self.config.torch_dtype, + dtype=self.config.dtype, ) self.vocab_size = config.vocab_size // get_tensor_model_parallel_size() diff --git a/optimum/neuron/models/training/modeling_utils.py b/optimum/neuron/models/training/modeling_utils.py index dfd5fd23c..31efbab17 100644 --- a/optimum/neuron/models/training/modeling_utils.py +++ b/optimum/neuron/models/training/modeling_utils.py @@ -61,7 +61,6 @@ get_state_dict_dtype, load_state_dict, no_init_weights, - set_initialized_submodules, ) from transformers.pytorch_utils import id_tensor_storage from transformers.quantizers import AutoHfQuantizer @@ -213,7 +212,7 @@ def _check_and_adjust_attn_implementation( return attn_implementation def _flash_attn_2_can_dispatch(self, is_init_check: bool = False) -> bool: - torch_dtype = self.config.torch_dtype + dtype = self.config.dtype if not self._supports_flash_attn: raise ValueError( @@ -221,9 +220,9 @@ def _flash_attn_2_can_dispatch(self, is_init_check: bool = False) -> bool: "https://github.com/huggingface/optimum-neuron/issues" ) - if torch_dtype is None: + if dtype is None: logger.warning_once( - "You are attempting to use Flash Attention 2 without specifying a torch dtype. This might lead to unexpected behaviour" + "You are attempting to use Flash Attention 2 without specifying a dtype. This might lead to unexpected behaviour" ) # If no error raise by this point, we can return `True` @@ -447,7 +446,14 @@ def _load_pretrained_model( _loaded_keys = [k[len(prefix) + 1 :] for k in loaded_keys] else: _loaded_keys = loaded_keys - not_initialized_submodules = set_initialized_submodules(model, _loaded_keys) + + # Mark loaded parameters/buffers as initialized (transformers 4.56.0+ approach) + for key in model.state_dict(): + if key in _loaded_keys: + param_or_buffer = model.get_parameter_or_buffer(key) + if param_or_buffer is not None: + param_or_buffer._is_hf_initialized = True + # If we're about to tie the output embeds to the input embeds we don't need to init them if ( hasattr(model.config.get_text_config(decoder=True), "tie_word_embeddings") @@ -458,6 +464,26 @@ def _load_pretrained_model( # Still need to initialize if there is a bias term since biases are not tied. if not hasattr(output_embeddings, "bias") or output_embeddings.bias is None: output_embeddings._is_hf_initialized = True + + # Set the flag on modules recursively + def set_is_initialized_for_modules(module): + if ( + all(getattr(child, "_is_hf_initialized", False) for child in module.children()) + and all( + getattr(param, "_is_hf_initialized", False) for param in module.parameters(recurse=False) + ) + and all( + getattr(buffer, "_is_hf_initialized", False) + for buffer in module.buffers(recurse=False) + if buffer not in module._non_persistent_buffers_set + ) + ): + module._is_hf_initialized = True + + model.apply(set_is_initialized_for_modules) + not_initialized_submodules = { + name: mod for name, mod in model.named_modules() if not getattr(mod, "_is_hf_initialized", False) + } else: not_initialized_submodules = dict(model.named_modules()) @@ -704,7 +730,11 @@ def from_pretrained( from_pipeline = kwargs.pop("_from_pipeline", None) from_auto_class = kwargs.pop("_from_auto", False) _fast_init = kwargs.pop("_fast_init", True) + dtype = kwargs.pop("dtype", None) torch_dtype = kwargs.pop("torch_dtype", None) + # For BC on torch_dtype argument (deprecated in favor of dtype) + if torch_dtype is not None: + dtype = dtype if dtype is not None else torch_dtype low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", None) device_map = kwargs.pop("device_map", None) kwargs.pop("max_memory", None) @@ -1172,18 +1202,18 @@ def from_pretrained( # 1. If torch_dtype is not None, we use that dtype # 2. If torch_dtype is "auto", we auto-detect dtype from the loaded state_dict, by checking its first # weights entry that is of a floating type - we assume all floating dtype weights are of the same dtype - # We also may have config.torch_dtype available, but we won't rely on it till v5 + # We also may have config.dtype available, but we won't rely on it till v5 dtype_orig = None - if torch_dtype is not None: - if isinstance(torch_dtype, str): - if torch_dtype == "auto": - if hasattr(config, "torch_dtype") and config.torch_dtype is not None: - torch_dtype = config.torch_dtype - logger.info(f"Will use torch_dtype={torch_dtype} as defined in model's config object") + if dtype is not None: + if isinstance(dtype, str): + if dtype == "auto": + if hasattr(config, "dtype") and config.dtype is not None: + dtype = config.dtype + logger.info(f"Will use dtype={dtype} as defined in model's config object") else: if is_sharded and "dtype" in sharded_metadata: - torch_dtype = sharded_metadata["dtype"] + dtype = sharded_metadata["dtype"] elif not is_sharded: # ** Difference from original from_pretrained ** # Here we load the state dict only if we end up in this case, otherwise we defer the @@ -1193,52 +1223,52 @@ def from_pretrained( one_time_state_dict = load_state_dict( resolved_archive_file, weights_only=weights_only ) - torch_dtype = get_state_dict_dtype(one_time_state_dict) + dtype = get_state_dict_dtype(one_time_state_dict) del one_time_state_dict - xm.rendezvous(f"auto torch_dtype_{worker}") + xm.rendezvous(f"auto dtype_{worker}") else: one_state_dict = load_state_dict(resolved_archive_file[0], weights_only=weights_only) - torch_dtype = get_state_dict_dtype(one_state_dict) + dtype = get_state_dict_dtype(one_state_dict) del one_state_dict # free CPU memory logger.info( - "Since the `torch_dtype` attribute can't be found in model's config object, " - "will use torch_dtype={torch_dtype} as derived from model's weights" + "Since the `dtype` attribute can't be found in model's config object, " + "will use dtype={dtype} as derived from model's weights" ) - elif hasattr(torch, torch_dtype): - torch_dtype = getattr(torch, torch_dtype) + elif hasattr(torch, dtype): + dtype = getattr(torch, dtype) for sub_config_key in config.sub_configs.keys(): sub_config = getattr(config, sub_config_key) - sub_config.torch_dtype = torch_dtype - elif isinstance(torch_dtype, torch.dtype): + sub_config.dtype = dtype + elif isinstance(dtype, torch.dtype): for sub_config_key in config.sub_configs.keys(): sub_config = getattr(config, sub_config_key) - sub_config.torch_dtype = torch_dtype - elif isinstance(torch_dtype, dict): - for key, curr_dtype in torch_dtype.items(): + sub_config.dtype = dtype + elif isinstance(dtype, dict): + for key, curr_dtype in dtype.items(): if hasattr(config, key): value = getattr(config, key) - value.torch_dtype = curr_dtype + value.dtype = curr_dtype # main torch dtype for modules that aren't part of any sub-config - torch_dtype = torch_dtype.get("") - config.torch_dtype = torch_dtype - if isinstance(torch_dtype, str) and hasattr(torch, torch_dtype): - torch_dtype = getattr(torch, torch_dtype) - elif torch_dtype is None: - torch_dtype = torch.float32 + dtype = dtype.get("") + config.dtype = dtype + if isinstance(dtype, str) and hasattr(torch, dtype): + dtype = getattr(torch, dtype) + elif dtype is None: + dtype = torch.float32 else: raise ValueError( - f"`torch_dtype` can be one of: `torch.dtype`, `'auto'`, a string of a valid `torch.dtype` or a `dict` with valid `torch_dtype` " - f"for each sub-config in composite configs, but received {torch_dtype}" + f"`dtype` can be one of: `torch.dtype`, `'auto'`, a string of a valid `torch.dtype` or a `dict` with valid `dtype` " + f"for each sub-config in composite configs, but received {dtype}" ) - dtype_orig = cls._set_default_torch_dtype(torch_dtype) + dtype_orig = cls._set_default_dtype(dtype) else: # set fp32 as the default dtype for BC default_dtype = str(torch.get_default_dtype()).split(".")[-1] - config.torch_dtype = default_dtype + config.dtype = default_dtype for key in config.sub_configs.keys(): value = getattr(config, key) - value.torch_dtype = default_dtype + value.dtype = default_dtype # ** Difference from original from_pretrained ** # We do not handle `use_keep_in_fp32_modules` here since it is not relevant for us. @@ -1264,9 +1294,9 @@ def from_pretrained( config = copy.deepcopy(config) # We do not want to modify the config inplace in from_pretrained. # ** Difference from original from_pretrained ** - # We make sure that config.torch_dtype is of type torch.dtype. + # We make sure that config.dtype is of type torch.dtype. # We do not change the config inplace since we are working from a deepcopy. - config.torch_dtype = torch_dtype + config.dtype = dtype # ** Difference from original from_pretrained ** # We do not support the `tie_word_embeddings` feature in pipeline parallelism. @@ -1316,7 +1346,7 @@ def from_pretrained( sharded_metadata=sharded_metadata, _fast_init=_fast_init, device_map=device_map, - dtype=torch_dtype, + dtype=dtype, weights_only=weights_only, ) @@ -1419,7 +1449,7 @@ def save_pretrained( # save the string version of dtype to the config, e.g. convert torch.float32 => "float32" # we currently don't use this setting automatically, but may start to use with v5 dtype = get_parameter_dtype(model_to_save) - model_to_save.config.torch_dtype = str(dtype).split(".")[1] + model_to_save.config.dtype = str(dtype).split(".")[1] # Attach architecture to the config model_to_save.config.architectures = [model_to_save.__class__.__name__] diff --git a/optimum/neuron/models/training/qwen3/modeling_qwen3.py b/optimum/neuron/models/training/qwen3/modeling_qwen3.py index b5536d910..8b1eddb6d 100644 --- a/optimum/neuron/models/training/qwen3/modeling_qwen3.py +++ b/optimum/neuron/models/training/qwen3/modeling_qwen3.py @@ -197,7 +197,7 @@ def __init__(self, config: Qwen3Config, trn_config: TrainingNeuronConfig): self.padding_idx, init_method=init_method, sequence_parallel_enabled=trn_config.sequence_parallel_enabled, - dtype=config.torch_dtype, + dtype=config.dtype, ) self.layers = nn.ModuleList( [Qwen3DecoderLayer(config, trn_config, layer_idx) for layer_idx in range(config.num_hidden_layers)] diff --git a/optimum/neuron/trainers/sft_trainer.py b/optimum/neuron/trainers/sft_trainer.py index b69e909d2..c9a481bb4 100644 --- a/optimum/neuron/trainers/sft_trainer.py +++ b/optimum/neuron/trainers/sft_trainer.py @@ -138,7 +138,7 @@ def __init__( raise ValueError("You passed model_init_kwargs to the SFTConfig, but your model is already instantiated.") else: model_init_kwargs = args.model_init_kwargs - torch_dtype = model_init_kwargs.get("torch_dtype") + torch_dtype = model_init_kwargs.get("dtype") if torch_dtype is not None: # Convert to `torch.dtype` if an str is passed if isinstance(torch_dtype, str) and torch_dtype != "auto": @@ -147,7 +147,7 @@ def __init__( raise ValueError( f"Invalid `torch_dtype` passed to the SFTConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}." ) - model_init_kwargs["torch_dtype"] = torch_dtype + model_init_kwargs["dtype"] = torch_dtype if isinstance(model, str): logging.warning( diff --git a/optimum/neuron/trainers/training_args.py b/optimum/neuron/trainers/training_args.py index f0e6f164e..e692d38e3 100644 --- a/optimum/neuron/trainers/training_args.py +++ b/optimum/neuron/trainers/training_args.py @@ -18,6 +18,7 @@ import os from dataclasses import dataclass, field, fields from enum import Enum +from functools import cached_property from typing import Any import torch @@ -30,9 +31,6 @@ get_last_checkpoint, ) from transformers.training_args import OptimizerNames, _convert_str_dict, default_logdir, trainer_log_levels -from transformers.utils import ( - cached_property, -) from ...utils import logging from ..accelerate import NeuronAcceleratorState, NeuronPartialState @@ -759,8 +757,8 @@ def _dict_torch_dtype_to_str(self, d: dict[str, Any]) -> None: converts torch.dtype to a string of just the type. For example, `torch.float32` get converted into *"float32"* string, which can then be stored in the json format. """ - if d.get("torch_dtype", None) is not None and not isinstance(d["torch_dtype"], str): - d["torch_dtype"] = str(d["torch_dtype"]).split(".")[1] + if d.get("dtype", None) is not None and not isinstance(d["dtype"], str): + d["dtype"] = str(d["dtype"]).split(".")[1] for value in d.values(): if isinstance(value, dict): self._dict_torch_dtype_to_str(value) diff --git a/pyproject.toml b/pyproject.toml index 6f82fa577..9c1836bba 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,7 @@ classifiers = [ "Topic :: Scientific/Engineering :: Artificial Intelligence", ] dependencies = [ - "transformers ~= 4.55.4", + "transformers ~= 4.57.1", "accelerate == 1.8.1", "optimum ~= 1.24.0", "huggingface_hub >= 0.31.4", diff --git a/tests/training/test_custom_modeling.py b/tests/training/test_custom_modeling.py index 29746b9d1..bda6d3b89 100644 --- a/tests/training/test_custom_modeling.py +++ b/tests/training/test_custom_modeling.py @@ -614,8 +614,6 @@ def test_each_pp_rank_only_loads_relevant_parameters(set_cache_for_ci): ("flash_attention_2", "flash_attention_2"), ("eager", "eager"), (None, "eager"), - # Unsupported attention implementation - should default to eager - ("sdpa", "eager"), ], ) @distributed_test(world_size=8, tp_size=2, pp_size=1) diff --git a/tools/cache/auto_fill_diffusion_cache.py b/tools/cache/auto_fill_diffusion_cache.py index 7213dda88..72dcbac45 100644 --- a/tools/cache/auto_fill_diffusion_cache.py +++ b/tools/cache/auto_fill_diffusion_cache.py @@ -219,7 +219,7 @@ def compile_and_cache_model( task=model_config.get("task", None), auto_cast=model_config.get("auto_cast", None), auto_cast_type=model_config.get("auto_cast_type", None), - torch_dtype=model_config.get("torch_dtype", None), + torch_dtype=model_config.get("dtype", None) or model_config.get("torch_dtype", None), ) elif args.hf_model_id is None: raise ValueError("You must provide --hf_model_id to compile a model without a config file.") @@ -235,4 +235,5 @@ def compile_and_cache_model( task=args.task, auto_cast=args.auto_cast, auto_cast_type=args.auto_cast_type, + torch_dtype=args.torch_dtype, )