Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion fast_llm/engine/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,9 @@ def setup(self, distributed: Distributed, run: Run) -> None:
# Setup the model.
with torch.no_grad():
log_main_rank("Setting up model...")
self._multi_stage.setup(distributed)
self._multi_stage.setup(
distributed, mode=StageMode.inference if self._is_evaluation_only else StageMode.training
)
for name, reference_model in self._reference_models.items():
log_main_rank(f"Setting up `{name}` reference model...")
reference_model.fast_llm_model.setup(distributed, StageMode.inference)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from transformers.modeling_utils import PreTrainedModel
from transformers.models.mistral.modeling_mistral import MistralDecoderLayer, MistralMLP, MistralModel, MistralRMSNorm
from transformers.processing_utils import Unpack
from transformers.utils import LossKwargs, logging
from transformers.utils import TransformersKwargs, logging
from transformers.utils.generic import ModelOutput

from fast_llm_external_models.apriel_hybrid_ssm.configuration_apriel_hybrid_ssm import AprielHybridSSMConfig
Expand Down Expand Up @@ -1252,9 +1252,6 @@ def forward(
return output


class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...


class AprielHybridSSMPreTrainedModel(PreTrainedModel):
config_class = AprielHybridSSMConfig
base_model_prefix = "model"
Expand Down Expand Up @@ -1383,7 +1380,7 @@ def forward(
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs: Unpack[KwargsForCausalLM],
**kwargs: Unpack[TransformersKwargs],
) -> Union[tuple, CausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -706,7 +706,7 @@ def forward(
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs, # TODO: Kwargs for Diffusion? : Unpack[KwargsForCausalLM],
**kwargs, # TODO: Kwargs for Diffusion? : Unpack[TransformersKwargs],
) -> MaskedLMOutput:
r"""
# TODO: Update docstring for diffusion
Expand Down
7 changes: 2 additions & 5 deletions fast_llm_external_models/mtp_llama/modeling_mtp_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from transformers.processing_utils import Unpack
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
from transformers.utils import (
LossKwargs,
TransformersKwargs,
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
Expand Down Expand Up @@ -761,9 +761,6 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
return causal_mask


class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...


class MTPLlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
Expand Down Expand Up @@ -812,7 +809,7 @@ def forward(
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs: Unpack[KwargsForCausalLM],
**kwargs: Unpack[TransformersKwargs],
) -> CausalLMOutputWithPast:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Expand Down
8 changes: 8 additions & 0 deletions tests/models/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,9 @@ def test_save_and_load_in_parallel(run_distributed_script, run_test_script_base_
# TODO: Test beyond 2 gpu configs?
import tests.models.distributed_test_checkpoint

if torch.cuda.device_count() < 2:
pytest.skip(f"Not enough GPUs: {torch.cuda.device_count()} < 2")

script = [
"-m",
tests.models.distributed_test_checkpoint.__name__,
Expand All @@ -405,6 +408,7 @@ def reference_distributed_shard(get_convert_path) -> torch.Tensor | None:


@requires_cuda
# NOTE: Should it depend on test_model_distributed instead?
@pytest.mark.depends_on(on=["test_load_pretrained[{model_testing_config}]"])
@pytest.mark.model_testing_group(ModelTestingGroup.convert, ModelTestingGroup.distributed)
def test_load_parallel_checkpoint_in_single_gpu(
Expand All @@ -425,6 +429,10 @@ def test_load_parallel_checkpoint_in_single_gpu(
distributed_save_load_config = distributed_save_load_config.resolve(
base_path=run_test_script_base_path, model_testing_config=model_testing_config
)
if torch.cuda.device_count() < distributed_save_load_config.num_gpus:
pytest.skip(
f"Not enough GPUs to run dependency: {torch.cuda.device_count()} < {distributed_save_load_config.num_gpus}"
)
report_subtest(distributed_save_load_config.save_path, distributed_save_load_config.num_gpus)
load_and_compare_checkpoints(
DistributedCheckpointFormat,
Expand Down
7 changes: 4 additions & 3 deletions tests/utils/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,7 +637,8 @@ def _update_and_add_testing_config(
groups={
ModelTestingGroup.basic: ModelTestingGroupAction.normal,
ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal,
ModelTestingGroup.convert: ModelTestingGroupAction.normal,
# TODO: Fix and bring back to `testing_groups`
ModelTestingGroup.convert: ModelTestingGroupAction.broken,
ModelTestingGroup.generate: ModelTestingGroupAction.not_implemented,
ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented,
ModelTestingGroup.distributed: ModelTestingGroupAction.normal,
Expand All @@ -650,7 +651,7 @@ def _update_and_add_testing_config(
), # "pp","dp", "ce","16", "bf", "df", "stp"),
)


# TODO: remove obsolete model
_update_and_add_testing_config(
# Tests hybrid discrete Mamba 2.
"llama",
Expand Down Expand Up @@ -682,7 +683,7 @@ def _update_and_add_testing_config(
groups={
ModelTestingGroup.basic: ModelTestingGroupAction.normal,
ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal,
ModelTestingGroup.convert: ModelTestingGroupAction.normal,
ModelTestingGroup.convert: ModelTestingGroupAction.broken,
ModelTestingGroup.generate: ModelTestingGroupAction.not_implemented,
ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented,
# TODO: Implement
Expand Down