Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion fast_llm/layers/vision/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ class PatchConvolutionConfig(BlockConfig):
)

@functools.cached_property
def input_channels(self):
def input_channels(self) -> int:
# Number of input channels. Currently hard-coded to 3 (RGB).
return 3

Expand All @@ -99,6 +99,7 @@ def layer_class(self) -> "type[PatchConvolution]":
@config_class(registry=True)
class VisionEncoderConfig(BlockConfig):
_abstract = False
# TODO: ====== Rename to patch_embeddings? ======
patch_convolution: PatchConvolutionConfig = Field(
desc="Configuration for the patch convolution layer.",
hint=FieldHint.architecture,
Expand Down Expand Up @@ -132,6 +133,11 @@ class VisionMultiModalModelConfig(LanguageModelConfig):
hint=FieldHint.architecture,
desc="Configuration for the vision encoder.",
)
image_token_index: int | None = Field(
default=None,
hint=FieldHint.optional,
desc="Index of the image token. Unused, but required for Hugging Face conversion.",
)

@property
def layer_class(self) -> "type[VisionMultiModalModel]":
Expand Down
4 changes: 2 additions & 2 deletions fast_llm/models/gpt/conversion/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
SplitWeightConverter,
WeightConverter,
)
from fast_llm.engine.checkpoint.huggingface import HuggingfaceStateDictCheckpointHandler
from fast_llm.engine.checkpoint.huggingface import HuggingFaceBaseModelConverter, HuggingfaceStateDictCheckpointHandler
from fast_llm.engine.multi_stage.config import FastLLMModelConfig
from fast_llm.functional.config import ActivationType
from fast_llm.layers.attention.config import AttentionConfig
Expand Down Expand Up @@ -498,7 +498,7 @@ def get_converters(
]


class LlamaBaseModelConverter:
class LlamaBaseModelConverter(HuggingFaceBaseModelConverter):
# TODO: Peft?
decoder_converter_class: typing.ClassVar[type[LlamaDecoderConverter]] = LlamaDecoderConverter
embeddings_converter_class: typing.ClassVar[type[LlamaEmbeddingsConverter]] = LlamaEmbeddingsConverter
Expand Down
4 changes: 2 additions & 2 deletions fast_llm/models/gpt/conversion/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def _check_config(cls, config: AttentionConfig) -> None:
assert not config.add_linear_biases


class MistrallMLPConverter(LlamaMLPConverter):
class MistralMLPConverter(LlamaMLPConverter):
@classmethod
def import_config(cls, config: dict) -> dict:
config["mlp_bias"] = False
Expand All @@ -56,7 +56,7 @@ def export_config(cls, config: MLPConfig) -> dict:

class MistralBlockConverter(LlamaBlockConverter):
mixer_converter_class: typing.ClassVar[type[MistralAttentionConverter]] = MistralAttentionConverter
mlp_converter_class: typing.ClassVar[type[MistrallMLPConverter]] = MistrallMLPConverter
mlp_converter_class: typing.ClassVar[type[MistralMLPConverter]] = MistralMLPConverter


class MistralDecoderConverter(LlamaDecoderConverter):
Expand Down
1 change: 1 addition & 0 deletions fast_llm/models/gpt/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def inner_forward(
output_attentions: bool | None = None,
output_hidden_states: bool | None = None,
return_dict: bool | None = None,
pixel_values: torch.Tensor | None = None,
) -> tuple | transformers.modeling_outputs.CausalLMOutputWithPast:
# TODO: Most of this is generalizable.
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
Expand Down
7 changes: 5 additions & 2 deletions fast_llm/models/multimodal/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
GPTTrainerConfig,
PretrainedGPTModelConfig,
)
from fast_llm.models.multimodal.conversion.config import LlavaCheckpointFormat, LlavaHybridSSMCheckpointFormat

if typing.TYPE_CHECKING:
from fast_llm.models.multimodal.model import MultiModalBaseModel, MultiModalModel
Expand Down Expand Up @@ -41,8 +42,10 @@ class MultiModalModelConfig(GPTModelConfig):
_abstract = False
model_name: typing.ClassVar[str] = "multimodal"
base_model: MultiModalBaseModelConfig = FieldUpdate()
# TODO: ====== Conversion ======
checkpoint_formats: typing.ClassVar[tuple[type[CheckpointFormat], ...]] = FastLLMModelConfig.checkpoint_formats
checkpoint_formats: typing.ClassVar[tuple[type[CheckpointFormat], ...]] = FastLLMModelConfig.checkpoint_formats + (
LlavaCheckpointFormat,
LlavaHybridSSMCheckpointFormat,
)

@classmethod
def get_model_class(cls) -> type["MultiModalModel"]:
Expand Down
Empty file.
17 changes: 17 additions & 0 deletions fast_llm/models/multimodal/conversion/auto.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import abc

from fast_llm.engine.checkpoint.external import AutoStateDictCheckpointHandler
from fast_llm.engine.checkpoint.huggingface import HuggingfaceStateDictCheckpointHandler
from fast_llm.models.multimodal.conversion.config import LlavaCheckpointFormat, LlavaHybridSSMCheckpointFormat
from fast_llm.models.multimodal.conversion.llava import LlavaHuggingfaceCheckpointHandler
from fast_llm.models.multimodal.conversion.llava_hybrid import LlavaHybridSSMHuggingfaceCheckpointHandler


class AutoMultimodalHuggingfaceCheckpointHandler(
AutoStateDictCheckpointHandler, HuggingfaceStateDictCheckpointHandler, abc.ABC
):

handler_map = {
LlavaCheckpointFormat.name: LlavaHuggingfaceCheckpointHandler,
LlavaHybridSSMCheckpointFormat.name: LlavaHybridSSMHuggingfaceCheckpointHandler,
}
25 changes: 25 additions & 0 deletions fast_llm/models/multimodal/conversion/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import typing

from fast_llm.engine.checkpoint.config import CheckpointFormat, CheckpointHandler


class MultimodalHuggingfaceCheckpointFormat(CheckpointFormat):
support_optimizer: typing.ClassVar[bool] = False

@classmethod
def get_handler_class(cls) -> type[CheckpointHandler]:
from fast_llm.models.multimodal.conversion.auto import AutoMultimodalHuggingfaceCheckpointHandler

return AutoMultimodalHuggingfaceCheckpointHandler.get_handler_class(cls.name)


class AutoMultimodalHuggingfaceCheckpointFormat(MultimodalHuggingfaceCheckpointFormat):
name: typing.ClassVar[str] = "auto"


class LlavaCheckpointFormat(MultimodalHuggingfaceCheckpointFormat):
name: typing.ClassVar[str] = "llava"


class LlavaHybridSSMCheckpointFormat(MultimodalHuggingfaceCheckpointFormat):
name: typing.ClassVar[str] = "llava_hybrid_ssm"
Loading