Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1861,6 +1861,10 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
_supports_attention_backend = False
_can_record_outputs = None

# Attributes used mainly in multimodal LLMs, though all models contain a valid field for these
# Possible values are: text, image, video, audio and time
input_modalities: Union[str, list[str]] = "text" # most models are text

@property
@torch._dynamo.allow_in_graph
def can_record_outputs(self) -> dict[str, OutputRecorder]:
Expand Down Expand Up @@ -2224,6 +2228,20 @@ def base_model(self) -> nn.Module:
"""
return getattr(self, self.base_model_prefix, self)

@classmethod
def output_modalities(cls) -> Optional[Union[str, list[str]]]:
"""
Returns a list of output modalities that a model can generate. For non-generative models
returns a `None`. Multimodal models that can output several modalities or non-text modalities
should overwrite this method.

Returns:
`Union[str, list[str]]`: Output modalities supported for models that can call `.generate()`.
"""
if cls.can_generate():
return "text"
return None
Comment on lines +2231 to +2243
Copy link
Member

@gante gante Oct 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think one of two things should happen:

  1. [my preference] output_modalities can be used on all models, i.e. it is not limited to generative models. I suspect this is a useful piece of info to have in model-agnostic code, enabling better error handling and other functionality. (unimplemented cases could throw an exception for now?)
  2. If we truly only want to use this in models that inherit GenerationMixin, then this function should be moved to GenerationMixin. Otherwise, we're tangling the classes (= bad practice).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, if we define it for all models what would it mean for models that output encodings (e.g. CLIP)? I could not think of use-case for that tbh


@classmethod
def can_generate(cls) -> bool:
"""
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/aimv2/modeling_aimv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,7 @@ class Aimv2PreTrainedModel(PreTrainedModel):

config: Aimv2Config
base_model_prefix = "aimv2"
input_modalities = "image"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion: use a tuple always, even when length = 1

  • single possible type = simpler usage
  • immutable
  • tuples can be used as dictionary keys. In the future, this might be useful to do modality-specific operations (e.g. SOME_MAPPING_OF_FUNCTIONS[model.input_modalities](**kwargs))

supports_gradient_checkpointing = True
_no_split_modules = [
"Aimv2EncoderLayer",
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/aimv2/modular_aimv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,7 @@ class Aimv2PreTrainedModel(PreTrainedModel):

config: Aimv2Config
base_model_prefix = "aimv2"
input_modalities = "image"
supports_gradient_checkpointing = True
_no_split_modules = [
"Aimv2EncoderLayer",
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/models/align/modeling_align.py
Original file line number Diff line number Diff line change
Expand Up @@ -839,6 +839,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
class AlignPreTrainedModel(PreTrainedModel):
config: AlignConfig
base_model_prefix = "align"
input_modalities = ["image", "text"]
supports_gradient_checkpointing = True

def _init_weights(self, module: nn.Module):
Expand Down Expand Up @@ -868,6 +869,7 @@ def _init_weights(self, module: nn.Module):
)
class AlignTextModel(AlignPreTrainedModel):
config: AlignTextConfig
input_modalities = "text"
_no_split_modules = ["AlignTextEmbeddings"]

def __init__(self, config: AlignTextConfig, add_pooling_layer: bool = True):
Expand Down Expand Up @@ -988,6 +990,7 @@ def forward(
class AlignVisionModel(AlignPreTrainedModel):
config: AlignVisionConfig
main_input_name = "pixel_values"
input_modalities = "image"
supports_gradient_checkpointing = False

def __init__(self, config: AlignVisionConfig):
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/models/altclip/modeling_altclip.py
Original file line number Diff line number Diff line change
Expand Up @@ -810,6 +810,7 @@ def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=Fals
class AltCLIPPreTrainedModel(PreTrainedModel):
config: AltCLIPConfig
base_model_prefix = "altclip"
input_modalities = ["image", "text"]
supports_gradient_checkpointing = True
_no_split_module = []

Expand Down Expand Up @@ -914,6 +915,7 @@ def forward(
class AltCLIPVisionModel(AltCLIPPreTrainedModel):
config: AltCLIPVisionConfig
main_input_name = "pixel_values"
input_modalities = "image"

def __init__(self, config: AltCLIPVisionConfig):
super().__init__(config)
Expand Down Expand Up @@ -1080,6 +1082,7 @@ def forward(

class AltCLIPTextModel(AltCLIPPreTrainedModel):
config: AltCLIPTextConfig
input_modalities = "text"

def __init__(self, config):
super().__init__(config)
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/aria/modeling_aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,7 @@ def forward(
class AriaTextPreTrainedModel(PreTrainedModel):
config: AriaTextConfig
base_model_prefix = "model"
input_modalities = ["image", "text"]
_no_split_modules = ["AriaTextDecoderLayer", "AriaGroupedExpertsGemm"]
supports_gradient_checkpointing = True
_skip_keys_device_placement = "past_key_values"
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/aria/modular_aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -1207,6 +1207,7 @@ def __init__(self, config: AriaTextConfig, layer_idx: int):
class AriaTextPreTrainedModel(PreTrainedModel):
config: AriaTextConfig
base_model_prefix = "model"
input_modalities = ["image", "text"]
_no_split_modules = ["AriaTextDecoderLayer", "AriaGroupedExpertsGemm"]
supports_gradient_checkpointing = True
_skip_keys_device_placement = "past_key_values"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,7 @@ def forward(self, hidden_states: torch.Tensor) -> BaseModelOutput:
class ASTPreTrainedModel(PreTrainedModel):
config: ASTConfig
base_model_prefix = "audio_spectrogram_transformer"
input_modalities = "audio"
main_input_name = "input_values"
supports_gradient_checkpointing = True
_supports_sdpa = True
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/autoformer/modeling_autoformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -825,6 +825,7 @@ def forward(
class AutoformerPreTrainedModel(PreTrainedModel):
config: AutoformerConfig
base_model_prefix = "model"
input_modalities = "time"
main_input_name = "past_values"
supports_gradient_checkpointing = True

Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/aya_vision/modeling_aya_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def pixel_shuffle(self, image_features): # B, S, D
class AyaVisionPreTrainedModel(PreTrainedModel):
config: AyaVisionConfig
base_model_prefix = ""
input_modalities = ["image", "text"]
supports_gradient_checkpointing = True
_skip_keys_device_placement = "past_key_values"

Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/bark/modeling_bark.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,7 @@ def forward(
@auto_docstring
class BarkPreTrainedModel(PreTrainedModel):
config: BarkConfig
output_modalities = "audio"
supports_gradient_checkpointing = False
_supports_flash_attn = True

Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/beit/modeling_beit.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,6 +704,7 @@ def forward(
class BeitPreTrainedModel(PreTrainedModel):
config: BeitConfig
base_model_prefix = "beit"
input_modalities = "image"
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
_no_split_modules = ["BeitLayer"]
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/bit/modeling_bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,6 +624,7 @@ def forward(
class BitPreTrainedModel(PreTrainedModel):
config: BitConfig
base_model_prefix = "bit"
input_modalities = "image"
main_input_name = "pixel_values"
_no_split_modules = ["BitEmbeddings"]

Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/blip/modeling_blip.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,7 @@ def forward(
class BlipPreTrainedModel(PreTrainedModel):
config: BlipConfig
base_model_prefix = "blip"
input_modalities = ["image", "text"]
supports_gradient_checkpointing = True
_no_split_modules = ["BlipEncoderLayer", "BlipTextEmbeddings"]
_skip_keys_device_placement = ["past_key_values"]
Expand Down Expand Up @@ -482,6 +483,7 @@ def forward(

class BlipVisionModel(BlipPreTrainedModel):
main_input_name = "pixel_values"
input_modalities = "image"
config: BlipVisionConfig
_can_record_outputs = {
"hidden_states": BlipEncoderLayer,
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/models/blip_2/modeling_blip_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,7 @@ def forward(
class Blip2PreTrainedModel(PreTrainedModel):
config: Blip2Config
base_model_prefix = "blip"
input_modalities = ["image", "text"]
supports_gradient_checkpointing = True
_supports_attention_backend = True
_supports_flash_attn = True
Expand Down Expand Up @@ -473,6 +474,7 @@ def forward(
# Copied from transformers.models.blip.modeling_blip.BlipVisionModel with Blip->Blip2, BLIP->BLIP_2
class Blip2VisionModel(Blip2PreTrainedModel):
main_input_name = "pixel_values"
input_modalities = "image"
config: Blip2VisionConfig
_can_record_outputs = {
"hidden_states": Blip2EncoderLayer,
Expand Down Expand Up @@ -1536,6 +1538,7 @@ def forward(
@auto_docstring
class Blip2VisionModelWithProjection(Blip2PreTrainedModel):
main_input_name = "pixel_values"
input_modalities = "image"
_keep_in_fp32_modules = ["query_tokens", "qformer"]
_supports_flash_attn = False # because self.qformer does not support FA2

Expand Down Expand Up @@ -2007,6 +2010,7 @@ def generate(
)
class Blip2ForImageTextRetrieval(Blip2PreTrainedModel):
main_input_name = "pixel_values"
input_modalities = "image"
_keep_in_fp32_modules = ["query_tokens", "qformer"]
_supports_flash_attn = False # because self.qformer does not support FA2

Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/blt/modeling_blt.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,7 @@ def forward(
class BltPreTrainedModel(PreTrainedModel):
config: BltConfig
base_model_prefix = ""
input_modalities = ["image", "text"]
supports_gradient_checkpointing = True
_no_split_modules = ["BltTransformerLayer"]
_can_compile_fullgraph = False # static cache cannot have different shapes for each layer
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/models/bridgetower/modeling_bridgetower.py
Original file line number Diff line number Diff line change
Expand Up @@ -995,6 +995,7 @@ def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_l
class BridgeTowerPreTrainedModel(PreTrainedModel):
config: BridgeTowerConfig
base_model_prefix = "bridgetower"
input_modalities = ["image", "text"]
supports_gradient_checkpointing = False
_no_split_modules = ["BridgeTowerSelfAttention", "BridgeTowerResidualAttention"]
_skip_keys_device_placement = "past_key_values"
Expand Down Expand Up @@ -1028,6 +1029,7 @@ def _init_weights(self, module: nn.Module):

class BridgeTowerVisionModel(BridgeTowerPreTrainedModel):
config: BridgeTowerVisionConfig
input_modalities = "image"

def __init__(self, config):
super().__init__(config)
Expand Down Expand Up @@ -1057,6 +1059,7 @@ def forward(self, image, image_mask=None, interpolate_pos_encoding=False):
)
class BridgeTowerTextModel(BridgeTowerPreTrainedModel):
config: BridgeTowerTextConfig
input_modalities = "text"

def __init__(self, config, add_pooling_layer=True):
r"""
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/chameleon/modeling_chameleon.py
Original file line number Diff line number Diff line change
Expand Up @@ -799,6 +799,7 @@ def convert_img2bpe(self, img_batch: torch.Tensor) -> torch.Tensor:
class ChameleonPreTrainedModel(PreTrainedModel):
config: ChameleonConfig
base_model_prefix = "model"
input_modalities = ["image", "text"]
supports_gradient_checkpointing = True
_no_split_modules = ["ChameleonDecoderLayer", "ChameleonSwinDecoderLayer"]
_skip_keys_device_placement = ["past_key_values", "causal_mask"]
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/models/chinese_clip/modeling_chinese_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,6 +578,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
class ChineseCLIPPreTrainedModel(PreTrainedModel):
config: ChineseCLIPConfig
base_model_prefix = "chinese_clip"
input_modalities = ["image", "text"]
supports_gradient_checkpointing = True

def _init_weights(self, module):
Expand Down Expand Up @@ -814,6 +815,7 @@ class ChineseCLIPTextModel(ChineseCLIPPreTrainedModel):
"""

config: ChineseCLIPTextConfig
input_modalities = "text"
_no_split_modules = ["ChineseCLIPTextEmbeddings"]

def __init__(self, config, add_pooling_layer=True):
Expand Down Expand Up @@ -929,6 +931,7 @@ def forward(
class ChineseCLIPVisionModel(ChineseCLIPPreTrainedModel):
config: ChineseCLIPVisionConfig
main_input_name = "pixel_values"
input_modalities = "image"
_no_split_modules = ["ChineseCLIPVisionEmbeddings", "ChineseCLIPVisionAttention"]

def __init__(self, config: ChineseCLIPVisionConfig):
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/models/clap/modeling_clap.py
Original file line number Diff line number Diff line change
Expand Up @@ -1343,6 +1343,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
class ClapPreTrainedModel(PreTrainedModel):
config: ClapConfig
base_model_prefix = "clap"
input_modalities = ["audio", "text"]
supports_gradient_checkpointing = False

def _init_weights(self, module: nn.Module):
Expand Down Expand Up @@ -1372,6 +1373,7 @@ def _init_weights(self, module: nn.Module):
class ClapAudioModel(ClapPreTrainedModel):
config: ClapAudioConfig
main_input_name = "input_features"
input_modalities = "audio"

def __init__(self, config: ClapAudioConfig):
super().__init__(config)
Expand Down Expand Up @@ -1444,6 +1446,7 @@ def forward(
)
class ClapTextModel(ClapPreTrainedModel):
config: ClapTextConfig
input_modalities = "text"

def __init__(self, config, add_pooling_layer=True):
r"""
Expand Down Expand Up @@ -1748,6 +1751,7 @@ def forward(
@auto_docstring
class ClapTextModelWithProjection(ClapPreTrainedModel):
config: ClapTextConfig
input_modalities = "text"

def __init__(self, config: ClapTextConfig):
super().__init__(config)
Expand Down Expand Up @@ -1814,6 +1818,7 @@ def forward(
class ClapAudioModelWithProjection(ClapPreTrainedModel):
config: ClapAudioConfig
main_input_name = "input_features"
input_modalities = "audio"

def __init__(self, config: ClapAudioConfig):
super().__init__(config)
Expand Down
6 changes: 6 additions & 0 deletions src/transformers/models/clip/modeling_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,7 @@ def forward(
class CLIPPreTrainedModel(PreTrainedModel):
config: CLIPConfig
base_model_prefix = "clip"
input_modalities = ["image", "text"]
supports_gradient_checkpointing = True
_supports_sdpa = True
_supports_flash_attn = True
Expand Down Expand Up @@ -661,6 +662,7 @@ def forward(
)
class CLIPTextModel(CLIPPreTrainedModel):
config: CLIPTextConfig
input_modalities = "text"

_no_split_modules = ["CLIPTextEmbeddings", "CLIPEncoderLayer"]
_supports_flash_attn = False # mask creation only accounts for sdpa/eager
Expand Down Expand Up @@ -768,6 +770,7 @@ def forward(
class CLIPVisionModel(CLIPPreTrainedModel):
config: CLIPVisionConfig
main_input_name = "pixel_values"
input_modalities = "image"
_no_split_modules = ["CLIPEncoderLayer"]

def __init__(self, config: CLIPVisionConfig):
Expand Down Expand Up @@ -1028,6 +1031,7 @@ def forward(
@auto_docstring
class CLIPTextModelWithProjection(CLIPPreTrainedModel):
config: CLIPTextConfig
input_modalities = "text"

_supports_flash_attn = False
_no_split_modules = ["CLIPTextEmbeddings", "CLIPEncoderLayer"]
Expand Down Expand Up @@ -1098,6 +1102,7 @@ def forward(
class CLIPVisionModelWithProjection(CLIPPreTrainedModel):
config: CLIPVisionConfig
main_input_name = "pixel_values"
input_modalities = "image"

def __init__(self, config: CLIPVisionConfig):
super().__init__(config)
Expand Down Expand Up @@ -1168,6 +1173,7 @@ def forward(
)
class CLIPForImageClassification(CLIPPreTrainedModel):
main_input_name = "pixel_values"
input_modalities = "image"

def __init__(self, config: CLIPConfig) -> None:
super().__init__(config)
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/models/clipseg/modeling_clipseg.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,7 @@ def forward(
class CLIPSegPreTrainedModel(PreTrainedModel):
config: CLIPSegConfig
base_model_prefix = "clip"
input_modalities = ["image", "text"]
supports_gradient_checkpointing = True

def _init_weights(self, module):
Expand Down Expand Up @@ -647,6 +648,7 @@ def forward(

class CLIPSegTextModel(CLIPSegPreTrainedModel):
config: CLIPSegTextConfig
input_modalities = "text"

_no_split_modules = ["CLIPSegTextEmbeddings", "CLIPSegEncoderLayer"]

Expand Down Expand Up @@ -752,6 +754,7 @@ def forward(
class CLIPSegVisionModel(CLIPSegPreTrainedModel):
config: CLIPSegVisionConfig
main_input_name = "pixel_values"
input_modalities = "image"

def __init__(self, config: CLIPSegVisionConfig):
super().__init__(config)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ class Cohere2VisionCausalLMOutputWithPast(ModelOutput):
class Cohere2VisionPreTrainedModel(PreTrainedModel):
config: Cohere2VisionConfig
base_model_prefix = ""
input_modalities = ["image", "text"]
supports_gradient_checkpointing = True
_skip_keys_device_placement = "past_key_values"

Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/colpali/modeling_colpali.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
class ColPaliPreTrainedModel(PreTrainedModel):
config: ColPaliConfig
base_model_prefix = "model"
input_modalities = ["image", "text"]
_no_split_modules = []
_supports_sdpa = True
_supports_flash_attn = True
Expand Down
Loading