diff --git a/fast_llm/engine/checkpoint/huggingface.py b/fast_llm/engine/checkpoint/huggingface.py index 96fb53321..270171755 100644 --- a/fast_llm/engine/checkpoint/huggingface.py +++ b/fast_llm/engine/checkpoint/huggingface.py @@ -120,14 +120,14 @@ def _export_config(cls, config: FastLLMModelConfig) -> dict[str, typing.Any]: cls.base_model_converter_class.export_config(config.base_model), { "model_type": cls.get_huggingface_model_type(), - "architecture": cls.architecture, + "architectures": [cls.architecture], }, ) @classmethod def _import_config(cls, config: dict[str, typing.Any]) -> FastLLMModelConfig: Assert.eq(config["model_type"], cls.get_huggingface_model_type()) - Assert.eq(config["architecture"], cls.architecture) + Assert.eq(config["architectures"], [cls.architecture]) return cls._model_class.from_dict({"base_model": cls.base_model_converter_class.import_config(config)}) def _create_weight_converters(self) -> list[WeightConverter]: diff --git a/fast_llm/engine/training/trainer.py b/fast_llm/engine/training/trainer.py index aa4f2d570..9737546ad 100644 --- a/fast_llm/engine/training/trainer.py +++ b/fast_llm/engine/training/trainer.py @@ -203,7 +203,9 @@ def setup(self, distributed: Distributed, run: Run) -> None: # Setup the model. with torch.no_grad(): log_main_rank("Setting up model...") - self._multi_stage.setup(distributed) + self._multi_stage.setup( + distributed, mode=StageMode.inference if self._is_evaluation_only else StageMode.training + ) for name, reference_model in self._reference_models.items(): log_main_rank(f"Setting up `{name}` reference model...") reference_model.fast_llm_model.setup(distributed, StageMode.inference) diff --git a/fast_llm/models/gpt/conversion/apriel.py b/fast_llm/models/gpt/conversion/apriel.py index 4b9849630..7550df044 100644 --- a/fast_llm/models/gpt/conversion/apriel.py +++ b/fast_llm/models/gpt/conversion/apriel.py @@ -8,10 +8,15 @@ from fast_llm.layers.attention.config import AttentionConfig from fast_llm.layers.block.config import BlockSequenceConfig, FixedBlockSequenceConfig, PatternBlockSequenceConfig from fast_llm.layers.decoder.config import DecoderBlockConfig +from fast_llm.layers.decoder.mlp.config import MLPConfig from fast_llm.layers.ssm.config import DiscreteMamba2Config, Mamba2Config from fast_llm.models.gpt.config import GPTModelConfig from fast_llm.models.gpt.conversion.config import AprielHybridSSMCheckpointFormat -from fast_llm.models.gpt.conversion.llama import get_parameter_converter, get_weight_and_bias_converters +from fast_llm.models.gpt.conversion.llama import ( + LlamaMLPConverter, + get_parameter_converter, + get_weight_and_bias_converters, +) from fast_llm.models.gpt.conversion.mistral import ( MistralBaseModelConverter, MistralBlockConverter, @@ -224,12 +229,31 @@ def get_converters( ] -class AprielDiscreteMamba2BlockConverter(MistralBlockConverter): +class AprielMLPConverter(LlamaMLPConverter): + @classmethod + def import_config(cls, config: dict) -> dict: + config["mlp_bias"] = False + return super().import_config(config) + + @classmethod + def export_config(cls, config: MLPConfig) -> dict: + out = super().export_config(config) + del out["mlp_bias"] + return out + + +class AprielBlockConverterBase(MistralBlockConverter): + mlp_converter_class: typing.ClassVar[type[AprielMLPConverter]] = AprielMLPConverter + + +class AprielDiscreteMamba2BlockConverter(AprielBlockConverterBase): mixer_converter_class: typing.ClassVar[type[AprielDiscreteMamba2Converter]] = AprielDiscreteMamba2Converter + hf_mixer_name: typing.ClassVar[str] = "mixer" -class AprielMamba2BlockConverter(MistralBlockConverter): +class AprielMamba2BlockConverter(AprielBlockConverterBase): mixer_converter_class: typing.ClassVar[type[AprielMamba2Converter]] = AprielMamba2Converter + hf_mixer_name: typing.ClassVar[str] = "mixer" class AprielBlockConverter: @@ -239,7 +263,7 @@ class AprielBlockConverter: DiscreteMamba2Config: "m2d", } _converter_classes = { - AttentionConfig: MistralBlockConverter, + AttentionConfig: AprielBlockConverterBase, Mamba2Config: AprielMamba2BlockConverter, DiscreteMamba2Config: AprielDiscreteMamba2BlockConverter, } diff --git a/fast_llm/models/gpt/conversion/mistral.py b/fast_llm/models/gpt/conversion/mistral.py index bfc7d5569..b5db3fa06 100644 --- a/fast_llm/models/gpt/conversion/mistral.py +++ b/fast_llm/models/gpt/conversion/mistral.py @@ -17,14 +17,20 @@ class MistralAttentionConverter(LlamaAttentionConverter): @classmethod def import_config(cls, config: dict) -> dict: - return safe_merge_dicts(super().import_config(config), {"window_size": config["sliding_window"]}) + config["attention_bias"] = False + return safe_merge_dicts( + super().import_config(config), + {"window_size": config["sliding_window"]}, + ) @classmethod def export_config(cls, config: AttentionConfig) -> dict: - return safe_merge_dicts( + out = safe_merge_dicts( super().export_config(config), {"sliding_window": config.window_size}, ) + del out["attention_bias"] + return out @classmethod def _check_config(cls, config: AttentionConfig) -> None: diff --git a/fast_llm_external_models/apriel_hybrid_ssm/modeling_apriel_hybrid_ssm.py b/fast_llm_external_models/apriel_hybrid_ssm/modeling_apriel_hybrid_ssm.py index 40c4cfa87..a80c031aa 100644 --- a/fast_llm_external_models/apriel_hybrid_ssm/modeling_apriel_hybrid_ssm.py +++ b/fast_llm_external_models/apriel_hybrid_ssm/modeling_apriel_hybrid_ssm.py @@ -18,7 +18,7 @@ from transformers.modeling_utils import PreTrainedModel from transformers.models.mistral.modeling_mistral import MistralDecoderLayer, MistralMLP, MistralModel, MistralRMSNorm from transformers.processing_utils import Unpack -from transformers.utils import LossKwargs, logging +from transformers.utils import TransformersKwargs, logging from transformers.utils.generic import ModelOutput from fast_llm_external_models.apriel_hybrid_ssm.configuration_apriel_hybrid_ssm import AprielHybridSSMConfig @@ -1252,9 +1252,6 @@ def forward( return output -class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... - - class AprielHybridSSMPreTrainedModel(PreTrainedModel): config_class = AprielHybridSSMConfig base_model_prefix = "model" @@ -1383,7 +1380,7 @@ def forward( output_hidden_states: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, - **kwargs: Unpack[KwargsForCausalLM], + **kwargs: Unpack[TransformersKwargs], ) -> Union[tuple, CausalLMOutputWithPast]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): diff --git a/fast_llm_external_models/diffusion_llama/modeling_diffusion_llama.py b/fast_llm_external_models/diffusion_llama/modeling_diffusion_llama.py index c8723af5d..a67a302ef 100644 --- a/fast_llm_external_models/diffusion_llama/modeling_diffusion_llama.py +++ b/fast_llm_external_models/diffusion_llama/modeling_diffusion_llama.py @@ -706,7 +706,7 @@ def forward( output_hidden_states: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, - **kwargs, # TODO: Kwargs for Diffusion? : Unpack[KwargsForCausalLM], + **kwargs, # TODO: Kwargs for Diffusion? : Unpack[TransformersKwargs], ) -> MaskedLMOutput: r""" # TODO: Update docstring for diffusion diff --git a/fast_llm_external_models/mtp_llama/modeling_mtp_llama.py b/fast_llm_external_models/mtp_llama/modeling_mtp_llama.py index 5ad99ff96..d0e1988f1 100644 --- a/fast_llm_external_models/mtp_llama/modeling_mtp_llama.py +++ b/fast_llm_external_models/mtp_llama/modeling_mtp_llama.py @@ -15,7 +15,7 @@ from transformers.processing_utils import Unpack from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS from transformers.utils import ( - LossKwargs, + TransformersKwargs, add_start_docstrings, add_start_docstrings_to_model_forward, logging, @@ -761,9 +761,6 @@ def _prepare_4d_causal_attention_mask_with_cache_position( return causal_mask -class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... - - class MTPLlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} @@ -812,7 +809,7 @@ def forward( output_hidden_states: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, - **kwargs: Unpack[KwargsForCausalLM], + **kwargs: Unpack[TransformersKwargs], ) -> CausalLMOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): diff --git a/tests/models/test_checkpoint.py b/tests/models/test_checkpoint.py index 3c3bfb833..c5a5e1c5b 100644 --- a/tests/models/test_checkpoint.py +++ b/tests/models/test_checkpoint.py @@ -379,6 +379,9 @@ def test_save_and_load_in_parallel(run_distributed_script, run_test_script_base_ # TODO: Test beyond 2 gpu configs? import tests.models.distributed_test_checkpoint + if torch.cuda.device_count() < 2: + pytest.skip(f"Not enough GPUs: {torch.cuda.device_count()} < 2") + script = [ "-m", tests.models.distributed_test_checkpoint.__name__, @@ -405,6 +408,7 @@ def reference_distributed_shard(get_convert_path) -> torch.Tensor | None: @requires_cuda +# NOTE: Should it depend on test_model_distributed instead? @pytest.mark.depends_on(on=["test_load_pretrained[{model_testing_config}]"]) @pytest.mark.model_testing_group(ModelTestingGroup.convert, ModelTestingGroup.distributed) def test_load_parallel_checkpoint_in_single_gpu( @@ -425,6 +429,10 @@ def test_load_parallel_checkpoint_in_single_gpu( distributed_save_load_config = distributed_save_load_config.resolve( base_path=run_test_script_base_path, model_testing_config=model_testing_config ) + if torch.cuda.device_count() < distributed_save_load_config.num_gpus: + pytest.skip( + f"Not enough GPUs to run dependency: {torch.cuda.device_count()} < {distributed_save_load_config.num_gpus}" + ) report_subtest(distributed_save_load_config.save_path, distributed_save_load_config.num_gpus) load_and_compare_checkpoints( DistributedCheckpointFormat, diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 956aaea5a..5cfb7c952 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -636,6 +636,7 @@ def _update_and_add_testing_config( groups={ ModelTestingGroup.basic: ModelTestingGroupAction.normal, ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, + # TODO: Fix and bring back to `testing_groups` ModelTestingGroup.convert: ModelTestingGroupAction.normal, ModelTestingGroup.generate: ModelTestingGroupAction.not_implemented, ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, @@ -649,7 +650,7 @@ def _update_and_add_testing_config( ), # "pp","dp", "ce","16", "bf", "df", "stp"), ) - +# TODO: remove obsolete model _update_and_add_testing_config( # Tests hybrid discrete Mamba 2. "llama", @@ -681,7 +682,7 @@ def _update_and_add_testing_config( groups={ ModelTestingGroup.basic: ModelTestingGroupAction.normal, ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, - ModelTestingGroup.convert: ModelTestingGroupAction.normal, + ModelTestingGroup.convert: ModelTestingGroupAction.broken, ModelTestingGroup.generate: ModelTestingGroupAction.not_implemented, ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, # TODO: Implement