diff --git a/src/transformers/models/idefics3/configuration_idefics3.py b/src/transformers/models/idefics3/configuration_idefics3.py index 159fa48a4f5f..31e976e3a58f 100644 --- a/src/transformers/models/idefics3/configuration_idefics3.py +++ b/src/transformers/models/idefics3/configuration_idefics3.py @@ -56,6 +56,10 @@ class Idefics3VisionConfig(PreTrainedConfig): The dropout ratio for the attention probabilities. initializer_range (`float`, *optional*, defaults to 0.02): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + use_export_friendly (`bool`, *optional*, defaults to `False`): + Whether to use export-friendly mode for vision model operations. When True, uses simplified + operations that are compatible with export frameworks (e.g., avoids data-dependent loops). + Only enable this when exporting the model. Example: @@ -89,6 +93,7 @@ def __init__( layer_norm_eps=1e-6, attention_dropout=0.0, initializer_range=0.02, + use_export_friendly=False, **kwargs, ): super().__init__(**kwargs) @@ -104,6 +109,7 @@ def __init__( self.layer_norm_eps = layer_norm_eps self.hidden_act = hidden_act self.initializer_range = initializer_range + self.use_export_friendly = use_export_friendly class Idefics3Config(PreTrainedConfig): @@ -132,6 +138,10 @@ class Idefics3Config(PreTrainedConfig): The scale factor for the image encoder. pad_token_id (`int`, *optional*, defaults to 128002): The id of the padding token. + use_export_friendly (`bool`, *optional*, defaults to `False`): + Whether to use export-friendly mode for model operations. When True, uses simplified + operations that are compatible with export frameworks (e.g., skips dynamic padding detection). + Only enable this when exporting the model or when you're certain the input won't have padding. Example: ```python @@ -156,18 +166,26 @@ def __init__( text_config=None, scale_factor=2, pad_token_id=128_002, + use_export_friendly=False, **kwargs, ): self.image_token_id = image_token_id self.use_cache = use_cache self.tie_word_embeddings = tie_word_embeddings + self.use_export_friendly = use_export_friendly if vision_config is None: - self.vision_config = Idefics3VisionConfig() + self.vision_config = Idefics3VisionConfig(use_export_friendly=use_export_friendly) logger.info("vision_config is None, using default vision config") elif isinstance(vision_config, dict): + # Propagate use_export_friendly to vision_config if not explicitly set + if "use_export_friendly" not in vision_config: + vision_config["use_export_friendly"] = use_export_friendly self.vision_config = Idefics3VisionConfig(**vision_config) elif isinstance(vision_config, Idefics3VisionConfig): + # Propagate use_export_friendly to vision_config if not explicitly set + if not hasattr(vision_config, "use_export_friendly"): + vision_config.use_export_friendly = use_export_friendly self.vision_config = vision_config if isinstance(text_config, dict): diff --git a/src/transformers/models/idefics3/modeling_idefics3.py b/src/transformers/models/idefics3/modeling_idefics3.py index 998e8c7da96a..0ab9cf936051 100644 --- a/src/transformers/models/idefics3/modeling_idefics3.py +++ b/src/transformers/models/idefics3/modeling_idefics3.py @@ -115,6 +115,7 @@ class Idefics3VisionEmbeddings(nn.Module): def __init__(self, config: Idefics3VisionConfig): super().__init__() + self.config = config self.embed_dim = config.hidden_size self.image_size = config.image_size self.patch_size = config.patch_size @@ -138,37 +139,58 @@ def forward(self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.B patch_embeds = self.patch_embedding(pixel_values) embeddings = patch_embeds.flatten(2).transpose(1, 2) - max_nb_patches_h, max_nb_patches_w = max_im_h // self.patch_size, max_im_w // self.patch_size - boundaries = torch.arange( - 1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side, device=pixel_values.device - ) - position_ids = torch.full( - size=(batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0, device=pixel_values.device - ) + if self.config.use_export_friendly: + # Export-friendly version: assumes image input batch_size = 1 + # This avoids the data-dependent loop over batch dimension + nb_patches_h = max_im_h // self.patch_size + nb_patches_w = max_im_w // self.patch_size + N = self.num_patches_per_side - for batch_idx, p_attn_mask in enumerate(patch_attention_mask): - nb_patches_h = p_attn_mask[:, 0].sum() - nb_patches_w = p_attn_mask[0].sum() + h_indices = torch.arange(nb_patches_h, device=pixel_values.device, dtype=torch.long) + w_indices = torch.arange(nb_patches_w, device=pixel_values.device, dtype=torch.long) - step_h = 1.0 / nb_patches_h - step_w = 1.0 / nb_patches_w + # This replaces bucketize(x, boundaries=[1/N, 2/N, ...], right=True) ≈ floor(x * N) + bucket_coords_h = (h_indices * N) // nb_patches_h + bucket_coords_w = (w_indices * N) // nb_patches_w - h_indices = torch.arange(nb_patches_h, device=position_ids.device, dtype=torch.float32) - w_indices = torch.arange(nb_patches_w, device=position_ids.device, dtype=torch.float32) - fractional_coords_h = h_indices * step_h - fractional_coords_w = w_indices * step_w + bucket_coords_h = torch.clamp(bucket_coords_h, max=N - 1) + bucket_coords_w = torch.clamp(bucket_coords_w, max=N - 1) - fractional_coords_h = torch.clamp(fractional_coords_h, max=(1.0 - 1e-6)) - fractional_coords_w = torch.clamp(fractional_coords_w, max=(1.0 - 1e-6)) + pos_ids = (bucket_coords_h[:, None] * N + bucket_coords_w[None, :]).reshape(-1) + position_ids = pos_ids.unsqueeze(0).expand(batch_size, -1) + else: + # Original version: data-dependent loop for variable resolution images + max_nb_patches_h, max_nb_patches_w = max_im_h // self.patch_size, max_im_w // self.patch_size + boundaries = torch.arange( + 1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side, device=pixel_values.device + ) + position_ids = torch.full( + size=(batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0, device=pixel_values.device + ) - fractional_coords_h = fractional_coords_h.to(pixel_values.dtype) - fractional_coords_w = fractional_coords_w.to(pixel_values.dtype) + for batch_idx, p_attn_mask in enumerate(patch_attention_mask): + nb_patches_h = p_attn_mask[:, 0].sum() + nb_patches_w = p_attn_mask[0].sum() - bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True) - bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True) + step_h = 1.0 / nb_patches_h + step_w = 1.0 / nb_patches_w - pos_ids = (bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w).flatten() - position_ids[batch_idx][p_attn_mask.view(-1)] = pos_ids + h_indices = torch.arange(nb_patches_h, device=position_ids.device, dtype=torch.float32) + w_indices = torch.arange(nb_patches_w, device=position_ids.device, dtype=torch.float32) + fractional_coords_h = h_indices * step_h + fractional_coords_w = w_indices * step_w + + fractional_coords_h = torch.clamp(fractional_coords_h, max=(1.0 - 1e-6)) + fractional_coords_w = torch.clamp(fractional_coords_w, max=(1.0 - 1e-6)) + + fractional_coords_h = fractional_coords_h.to(pixel_values.dtype) + fractional_coords_w = fractional_coords_w.to(pixel_values.dtype) + + bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True) + bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True) + + pos_ids = (bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w).flatten() + position_ids[batch_idx][p_attn_mask.view(-1)] = pos_ids embeddings = embeddings + self.position_embedding(position_ids) return embeddings @@ -632,9 +654,11 @@ def get_image_features( pixel_values = pixel_values.view(batch_size * num_images, *pixel_values.shape[2:]) # Remove padding images - padding images are full 0. - nb_values_per_image = pixel_values.shape[1:].numel() - real_images_inds = (pixel_values == 0.0).sum(dim=(-1, -2, -3)) != nb_values_per_image - pixel_values = pixel_values[real_images_inds].contiguous() + # Skip this in export-friendly mode for simpler operations + if not self.config.use_export_friendly: + nb_values_per_image = pixel_values.shape[1:].numel() + real_images_inds = (pixel_values == 0.0).sum(dim=(-1, -2, -3)) != nb_values_per_image + pixel_values = pixel_values[real_images_inds].contiguous() # Handle the vision attention mask if pixel_attention_mask is None: @@ -643,7 +667,7 @@ def get_image_features( dtype=torch.bool, device=pixel_values.device, ) - else: + elif not self.config.use_export_friendly: # Remove padding images from the mask pixel_attention_mask = pixel_attention_mask.view(batch_size * num_images, *pixel_attention_mask.shape[2:]) pixel_attention_mask = pixel_attention_mask[real_images_inds].contiguous()