Skip to content
4 changes: 2 additions & 2 deletions fast_llm/engine/checkpoint/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,14 +120,14 @@ def _export_config(cls, config: FastLLMModelConfig) -> dict[str, typing.Any]:
cls.base_model_converter_class.export_config(config.base_model),
{
"model_type": cls.get_huggingface_model_type(),
"architecture": cls.architecture,
"architectures": [cls.architecture],
},
)

@classmethod
def _import_config(cls, config: dict[str, typing.Any]) -> FastLLMModelConfig:
Assert.eq(config["model_type"], cls.get_huggingface_model_type())
Assert.eq(config["architecture"], cls.architecture)
Assert.eq(config["architectures"], [cls.architecture])
return cls._model_class.from_dict({"base_model": cls.base_model_converter_class.import_config(config)})

def _create_weight_converters(self) -> list[WeightConverter]:
Expand Down
4 changes: 3 additions & 1 deletion fast_llm/engine/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,9 @@ def setup(self, distributed: Distributed, run: Run) -> None:
# Setup the model.
with torch.no_grad():
log_main_rank("Setting up model...")
self._multi_stage.setup(distributed)
self._multi_stage.setup(
distributed, mode=StageMode.inference if self._is_evaluation_only else StageMode.training
)
for name, reference_model in self._reference_models.items():
log_main_rank(f"Setting up `{name}` reference model...")
reference_model.fast_llm_model.setup(distributed, StageMode.inference)
Expand Down
32 changes: 28 additions & 4 deletions fast_llm/models/gpt/conversion/apriel.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,15 @@
from fast_llm.layers.attention.config import AttentionConfig
from fast_llm.layers.block.config import BlockSequenceConfig, FixedBlockSequenceConfig, PatternBlockSequenceConfig
from fast_llm.layers.decoder.config import DecoderBlockConfig
from fast_llm.layers.decoder.mlp.config import MLPConfig
from fast_llm.layers.ssm.config import DiscreteMamba2Config, Mamba2Config
from fast_llm.models.gpt.config import GPTModelConfig
from fast_llm.models.gpt.conversion.config import AprielHybridSSMCheckpointFormat
from fast_llm.models.gpt.conversion.llama import get_parameter_converter, get_weight_and_bias_converters
from fast_llm.models.gpt.conversion.llama import (
LlamaMLPConverter,
get_parameter_converter,
get_weight_and_bias_converters,
)
from fast_llm.models.gpt.conversion.mistral import (
MistralBaseModelConverter,
MistralBlockConverter,
Expand Down Expand Up @@ -224,12 +229,31 @@ def get_converters(
]


class AprielDiscreteMamba2BlockConverter(MistralBlockConverter):
class AprielMLPConverter(LlamaMLPConverter):
@classmethod
def import_config(cls, config: dict) -> dict:
config["mlp_bias"] = False
return super().import_config(config)

@classmethod
def export_config(cls, config: MLPConfig) -> dict:
out = super().export_config(config)
del out["mlp_bias"]
return out


class AprielBlockConverterBase(MistralBlockConverter):
mlp_converter_class: typing.ClassVar[type[AprielMLPConverter]] = AprielMLPConverter


class AprielDiscreteMamba2BlockConverter(AprielBlockConverterBase):
mixer_converter_class: typing.ClassVar[type[AprielDiscreteMamba2Converter]] = AprielDiscreteMamba2Converter
hf_mixer_name: typing.ClassVar[str] = "mixer"


class AprielMamba2BlockConverter(MistralBlockConverter):
class AprielMamba2BlockConverter(AprielBlockConverterBase):
mixer_converter_class: typing.ClassVar[type[AprielMamba2Converter]] = AprielMamba2Converter
hf_mixer_name: typing.ClassVar[str] = "mixer"


class AprielBlockConverter:
Expand All @@ -239,7 +263,7 @@ class AprielBlockConverter:
DiscreteMamba2Config: "m2d",
}
_converter_classes = {
AttentionConfig: MistralBlockConverter,
AttentionConfig: AprielBlockConverterBase,
Mamba2Config: AprielMamba2BlockConverter,
DiscreteMamba2Config: AprielDiscreteMamba2BlockConverter,
}
Expand Down
10 changes: 8 additions & 2 deletions fast_llm/models/gpt/conversion/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,20 @@
class MistralAttentionConverter(LlamaAttentionConverter):
@classmethod
def import_config(cls, config: dict) -> dict:
return safe_merge_dicts(super().import_config(config), {"window_size": config["sliding_window"]})
config["attention_bias"] = False
return safe_merge_dicts(
super().import_config(config),
{"window_size": config["sliding_window"]},
)

@classmethod
def export_config(cls, config: AttentionConfig) -> dict:
return safe_merge_dicts(
out = safe_merge_dicts(
super().export_config(config),
{"sliding_window": config.window_size},
)
del out["attention_bias"]
return out

@classmethod
def _check_config(cls, config: AttentionConfig) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from transformers.modeling_utils import PreTrainedModel
from transformers.models.mistral.modeling_mistral import MistralDecoderLayer, MistralMLP, MistralModel, MistralRMSNorm
from transformers.processing_utils import Unpack
from transformers.utils import LossKwargs, logging
from transformers.utils import TransformersKwargs, logging
from transformers.utils.generic import ModelOutput

from fast_llm_external_models.apriel_hybrid_ssm.configuration_apriel_hybrid_ssm import AprielHybridSSMConfig
Expand Down Expand Up @@ -1252,9 +1252,6 @@ def forward(
return output


class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...


class AprielHybridSSMPreTrainedModel(PreTrainedModel):
config_class = AprielHybridSSMConfig
base_model_prefix = "model"
Expand Down Expand Up @@ -1383,7 +1380,7 @@ def forward(
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs: Unpack[KwargsForCausalLM],
**kwargs: Unpack[TransformersKwargs],
) -> Union[tuple, CausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -706,7 +706,7 @@ def forward(
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs, # TODO: Kwargs for Diffusion? : Unpack[KwargsForCausalLM],
**kwargs, # TODO: Kwargs for Diffusion? : Unpack[TransformersKwargs],
) -> MaskedLMOutput:
r"""
# TODO: Update docstring for diffusion
Expand Down
7 changes: 2 additions & 5 deletions fast_llm_external_models/mtp_llama/modeling_mtp_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from transformers.processing_utils import Unpack
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
from transformers.utils import (
LossKwargs,
TransformersKwargs,
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
Expand Down Expand Up @@ -761,9 +761,6 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
return causal_mask


class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...


class MTPLlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
Expand Down Expand Up @@ -812,7 +809,7 @@ def forward(
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs: Unpack[KwargsForCausalLM],
**kwargs: Unpack[TransformersKwargs],
) -> CausalLMOutputWithPast:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Expand Down
8 changes: 8 additions & 0 deletions tests/models/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,9 @@ def test_save_and_load_in_parallel(run_distributed_script, run_test_script_base_
# TODO: Test beyond 2 gpu configs?
import tests.models.distributed_test_checkpoint

if torch.cuda.device_count() < 2:
pytest.skip(f"Not enough GPUs: {torch.cuda.device_count()} < 2")

script = [
"-m",
tests.models.distributed_test_checkpoint.__name__,
Expand All @@ -405,6 +408,7 @@ def reference_distributed_shard(get_convert_path) -> torch.Tensor | None:


@requires_cuda
# NOTE: Should it depend on test_model_distributed instead?
@pytest.mark.depends_on(on=["test_load_pretrained[{model_testing_config}]"])
@pytest.mark.model_testing_group(ModelTestingGroup.convert, ModelTestingGroup.distributed)
def test_load_parallel_checkpoint_in_single_gpu(
Expand All @@ -425,6 +429,10 @@ def test_load_parallel_checkpoint_in_single_gpu(
distributed_save_load_config = distributed_save_load_config.resolve(
base_path=run_test_script_base_path, model_testing_config=model_testing_config
)
if torch.cuda.device_count() < distributed_save_load_config.num_gpus:
pytest.skip(
f"Not enough GPUs to run dependency: {torch.cuda.device_count()} < {distributed_save_load_config.num_gpus}"
)
report_subtest(distributed_save_load_config.save_path, distributed_save_load_config.num_gpus)
load_and_compare_checkpoints(
DistributedCheckpointFormat,
Expand Down
5 changes: 3 additions & 2 deletions tests/utils/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,6 +636,7 @@ def _update_and_add_testing_config(
groups={
ModelTestingGroup.basic: ModelTestingGroupAction.normal,
ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal,
# TODO: Fix and bring back to `testing_groups`
ModelTestingGroup.convert: ModelTestingGroupAction.normal,
ModelTestingGroup.generate: ModelTestingGroupAction.not_implemented,
ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented,
Expand All @@ -649,7 +650,7 @@ def _update_and_add_testing_config(
), # "pp","dp", "ce","16", "bf", "df", "stp"),
)


# TODO: remove obsolete model
_update_and_add_testing_config(
# Tests hybrid discrete Mamba 2.
"llama",
Expand Down Expand Up @@ -681,7 +682,7 @@ def _update_and_add_testing_config(
groups={
ModelTestingGroup.basic: ModelTestingGroupAction.normal,
ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal,
ModelTestingGroup.convert: ModelTestingGroupAction.normal,
ModelTestingGroup.convert: ModelTestingGroupAction.broken,
ModelTestingGroup.generate: ModelTestingGroupAction.not_implemented,
ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented,
# TODO: Implement
Expand Down