Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion src/transformers/models/idefics3/configuration_idefics3.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ class Idefics3VisionConfig(PreTrainedConfig):
The dropout ratio for the attention probabilities.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
use_export_friendly (`bool`, *optional*, defaults to `False`):
Whether to use export-friendly mode for vision model operations. When True, uses simplified
operations that are compatible with export frameworks (e.g., avoids data-dependent loops).
Only enable this when exporting the model.

Example:

Expand Down Expand Up @@ -89,6 +93,7 @@ def __init__(
layer_norm_eps=1e-6,
attention_dropout=0.0,
initializer_range=0.02,
use_export_friendly=False,
**kwargs,
):
super().__init__(**kwargs)
Expand All @@ -104,6 +109,7 @@ def __init__(
self.layer_norm_eps = layer_norm_eps
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.use_export_friendly = use_export_friendly


class Idefics3Config(PreTrainedConfig):
Expand Down Expand Up @@ -132,6 +138,10 @@ class Idefics3Config(PreTrainedConfig):
The scale factor for the image encoder.
pad_token_id (`int`, *optional*, defaults to 128002):
The id of the padding token.
use_export_friendly (`bool`, *optional*, defaults to `False`):
Whether to use export-friendly mode for model operations. When True, uses simplified
operations that are compatible with export frameworks (e.g., skips dynamic padding detection).
Only enable this when exporting the model or when you're certain the input won't have padding.

Example:
```python
Expand All @@ -156,18 +166,26 @@ def __init__(
text_config=None,
scale_factor=2,
pad_token_id=128_002,
use_export_friendly=False,
**kwargs,
):
self.image_token_id = image_token_id
self.use_cache = use_cache
self.tie_word_embeddings = tie_word_embeddings
self.use_export_friendly = use_export_friendly

if vision_config is None:
self.vision_config = Idefics3VisionConfig()
self.vision_config = Idefics3VisionConfig(use_export_friendly=use_export_friendly)
logger.info("vision_config is None, using default vision config")
elif isinstance(vision_config, dict):
# Propagate use_export_friendly to vision_config if not explicitly set
if "use_export_friendly" not in vision_config:
vision_config["use_export_friendly"] = use_export_friendly
self.vision_config = Idefics3VisionConfig(**vision_config)
elif isinstance(vision_config, Idefics3VisionConfig):
# Propagate use_export_friendly to vision_config if not explicitly set
if not hasattr(vision_config, "use_export_friendly"):
vision_config.use_export_friendly = use_export_friendly
self.vision_config = vision_config

if isinstance(text_config, dict):
Expand Down
80 changes: 52 additions & 28 deletions src/transformers/models/idefics3/modeling_idefics3.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ class Idefics3VisionEmbeddings(nn.Module):

def __init__(self, config: Idefics3VisionConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.image_size = config.image_size
self.patch_size = config.patch_size
Expand All @@ -138,37 +139,58 @@ def forward(self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.B
patch_embeds = self.patch_embedding(pixel_values)
embeddings = patch_embeds.flatten(2).transpose(1, 2)

max_nb_patches_h, max_nb_patches_w = max_im_h // self.patch_size, max_im_w // self.patch_size
boundaries = torch.arange(
1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side, device=pixel_values.device
)
position_ids = torch.full(
size=(batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0, device=pixel_values.device
)
if self.config.use_export_friendly:
# Export-friendly version: assumes image input batch_size = 1
# This avoids the data-dependent loop over batch dimension
nb_patches_h = max_im_h // self.patch_size
nb_patches_w = max_im_w // self.patch_size
N = self.num_patches_per_side

for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
nb_patches_h = p_attn_mask[:, 0].sum()
nb_patches_w = p_attn_mask[0].sum()
h_indices = torch.arange(nb_patches_h, device=pixel_values.device, dtype=torch.long)
w_indices = torch.arange(nb_patches_w, device=pixel_values.device, dtype=torch.long)

step_h = 1.0 / nb_patches_h
step_w = 1.0 / nb_patches_w
# This replaces bucketize(x, boundaries=[1/N, 2/N, ...], right=True) ≈ floor(x * N)
bucket_coords_h = (h_indices * N) // nb_patches_h
bucket_coords_w = (w_indices * N) // nb_patches_w

h_indices = torch.arange(nb_patches_h, device=position_ids.device, dtype=torch.float32)
w_indices = torch.arange(nb_patches_w, device=position_ids.device, dtype=torch.float32)
fractional_coords_h = h_indices * step_h
fractional_coords_w = w_indices * step_w
bucket_coords_h = torch.clamp(bucket_coords_h, max=N - 1)
bucket_coords_w = torch.clamp(bucket_coords_w, max=N - 1)

fractional_coords_h = torch.clamp(fractional_coords_h, max=(1.0 - 1e-6))
fractional_coords_w = torch.clamp(fractional_coords_w, max=(1.0 - 1e-6))
pos_ids = (bucket_coords_h[:, None] * N + bucket_coords_w[None, :]).reshape(-1)
position_ids = pos_ids.unsqueeze(0).expand(batch_size, -1)
else:
# Original version: data-dependent loop for variable resolution images
max_nb_patches_h, max_nb_patches_w = max_im_h // self.patch_size, max_im_w // self.patch_size
boundaries = torch.arange(
1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side, device=pixel_values.device
)
position_ids = torch.full(
size=(batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0, device=pixel_values.device
)

fractional_coords_h = fractional_coords_h.to(pixel_values.dtype)
fractional_coords_w = fractional_coords_w.to(pixel_values.dtype)
for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
nb_patches_h = p_attn_mask[:, 0].sum()
nb_patches_w = p_attn_mask[0].sum()

bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)
step_h = 1.0 / nb_patches_h
step_w = 1.0 / nb_patches_w

pos_ids = (bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w).flatten()
position_ids[batch_idx][p_attn_mask.view(-1)] = pos_ids
h_indices = torch.arange(nb_patches_h, device=position_ids.device, dtype=torch.float32)
w_indices = torch.arange(nb_patches_w, device=position_ids.device, dtype=torch.float32)
fractional_coords_h = h_indices * step_h
fractional_coords_w = w_indices * step_w

fractional_coords_h = torch.clamp(fractional_coords_h, max=(1.0 - 1e-6))
fractional_coords_w = torch.clamp(fractional_coords_w, max=(1.0 - 1e-6))

fractional_coords_h = fractional_coords_h.to(pixel_values.dtype)
fractional_coords_w = fractional_coords_w.to(pixel_values.dtype)

bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)

pos_ids = (bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w).flatten()
position_ids[batch_idx][p_attn_mask.view(-1)] = pos_ids

embeddings = embeddings + self.position_embedding(position_ids)
return embeddings
Expand Down Expand Up @@ -632,9 +654,11 @@ def get_image_features(
pixel_values = pixel_values.view(batch_size * num_images, *pixel_values.shape[2:])

# Remove padding images - padding images are full 0.
nb_values_per_image = pixel_values.shape[1:].numel()
real_images_inds = (pixel_values == 0.0).sum(dim=(-1, -2, -3)) != nb_values_per_image
pixel_values = pixel_values[real_images_inds].contiguous()
# Skip this in export-friendly mode for simpler operations
if not self.config.use_export_friendly:
nb_values_per_image = pixel_values.shape[1:].numel()
real_images_inds = (pixel_values == 0.0).sum(dim=(-1, -2, -3)) != nb_values_per_image
pixel_values = pixel_values[real_images_inds].contiguous()

# Handle the vision attention mask
if pixel_attention_mask is None:
Expand All @@ -643,7 +667,7 @@ def get_image_features(
dtype=torch.bool,
device=pixel_values.device,
)
else:
elif not self.config.use_export_friendly:
# Remove padding images from the mask
pixel_attention_mask = pixel_attention_mask.view(batch_size * num_images, *pixel_attention_mask.shape[2:])
pixel_attention_mask = pixel_attention_mask[real_images_inds].contiguous()
Expand Down