From 9e253a7bb7ca847efc25b3bb727ae8ee3991dd88 Mon Sep 17 00:00:00 2001 From: Gal Davidi Date: Sun, 26 Oct 2025 16:41:39 +0000 Subject: [PATCH 01/15] Bria FIBO pipeline --- docs/source/en/api/pipelines/bria_fibo.md | 37 + src/diffusers/__init__.py | 6 + src/diffusers/models/__init__.py | 2 + .../transformers/transformer_bria_fibo.py | 446 ++++++++++ src/diffusers/modular_pipelines/__init__.py | 2 + .../modular_pipelines/bria_fibo/__init__.py | 47 + .../bria_fibo/fibo_vlm_prompt_to_json.py | 377 ++++++++ .../bria_fibo/gemini_prompt_to_json.py | 804 +++++++++++++++++ src/diffusers/pipelines/__init__.py | 2 + src/diffusers/pipelines/bria_fibo/__init__.py | 48 + .../pipelines/bria_fibo/pipeline_bria_fibo.py | 826 ++++++++++++++++++ .../pipelines/bria_fibo/pipeline_output.py | 21 + .../test_models_transformer_bria_fibo.py | 132 +++ tests/pipelines/bria_fibo/__init__.py | 0 .../bria_fibo/test_pipeline_bria_fibo.py | 198 +++++ 15 files changed, 2948 insertions(+) create mode 100644 docs/source/en/api/pipelines/bria_fibo.md create mode 100644 src/diffusers/models/transformers/transformer_bria_fibo.py create mode 100644 src/diffusers/modular_pipelines/bria_fibo/__init__.py create mode 100644 src/diffusers/modular_pipelines/bria_fibo/fibo_vlm_prompt_to_json.py create mode 100644 src/diffusers/modular_pipelines/bria_fibo/gemini_prompt_to_json.py create mode 100644 src/diffusers/pipelines/bria_fibo/__init__.py create mode 100644 src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo.py create mode 100644 src/diffusers/pipelines/bria_fibo/pipeline_output.py create mode 100644 tests/models/transformers/test_models_transformer_bria_fibo.py create mode 100644 tests/pipelines/bria_fibo/__init__.py create mode 100644 tests/pipelines/bria_fibo/test_pipeline_bria_fibo.py diff --git a/docs/source/en/api/pipelines/bria_fibo.md b/docs/source/en/api/pipelines/bria_fibo.md new file mode 100644 index 000000000000..086da924bdf1 --- /dev/null +++ b/docs/source/en/api/pipelines/bria_fibo.md @@ -0,0 +1,37 @@ + + +# Bria Fibo + +Text-to-image models have mastered imagination - but not control. FIBO changes that. + +FIBO is trained on structured JSON captions up to 1,000+ words and designed to understand and control different visual parameters such as lighting, composition, color, and camera settings, enabling precise and reproducible outputs. + +With only 8 billion parameters, FIBO provides a new level of image quality, prompt adherence and proffesional control. + +## Usage + +_As the model is gated, before using it with diffusers you first need to go to the [Bria Fibo Hugging Face page](https://huggingface.co/briaai/FIBO), fill in the form and accept the gate. Once you are in, you need to login so that your system knows you’ve accepted the gate._ + +Use the command below to log in: + +```bash +hf auth login +``` + + +## BriaPipeline + +[[autodoc]] BriaPipeline + - all + - __call__ + diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 2a8b589846ad..e0e31f58f5d2 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -199,6 +199,7 @@ "AutoencoderTiny", "AutoModel", "BriaTransformer2DModel", + "BriaFiboTransformer2DModel", "CacheMixin", "ChromaTransformer2DModel", "CogVideoXTransformer3DModel", @@ -392,6 +393,8 @@ else: _import_structure["modular_pipelines"].extend( [ + "BriaFiboVLMPromptToJson", + "BriaFiboGeminiPromptToJson", "FluxAutoBlocks", "FluxKontextAutoBlocks", "FluxKontextModularPipeline", @@ -431,6 +434,7 @@ "BlipDiffusionControlNetPipeline", "BlipDiffusionPipeline", "BriaPipeline", + "BriaFiboPipeline", "ChromaImg2ImgPipeline", "ChromaPipeline", "CLIPImageProjection", @@ -902,6 +906,7 @@ AutoencoderTiny, AutoModel, BriaTransformer2DModel, + BriaFiboTransformer2DModel, CacheMixin, ChromaTransformer2DModel, CogVideoXTransformer3DModel, @@ -1104,6 +1109,7 @@ AudioLDMPipeline, AuraFlowPipeline, BriaPipeline, + BriaFiboPipeline, ChromaImg2ImgPipeline, ChromaPipeline, CLIPImageProjection, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 241ccf7b785a..d25ef37b7cdf 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -84,6 +84,7 @@ _import_structure["transformers.transformer_2d"] = ["Transformer2DModel"] _import_structure["transformers.transformer_allegro"] = ["AllegroTransformer3DModel"] _import_structure["transformers.transformer_bria"] = ["BriaTransformer2DModel"] + _import_structure["transformers.transformer_bria_fibo"] = ["BriaFiboTransformer2DModel"] _import_structure["transformers.transformer_chroma"] = ["ChromaTransformer2DModel"] _import_structure["transformers.transformer_cogview3plus"] = ["CogView3PlusTransformer2DModel"] _import_structure["transformers.transformer_cogview4"] = ["CogView4Transformer2DModel"] @@ -175,6 +176,7 @@ AllegroTransformer3DModel, AuraFlowTransformer2DModel, BriaTransformer2DModel, + BriaFiboTransformer2DModel, ChromaTransformer2DModel, CogVideoXTransformer3DModel, CogView3PlusTransformer2DModel, diff --git a/src/diffusers/models/transformers/transformer_bria_fibo.py b/src/diffusers/models/transformers/transformer_bria_fibo.py new file mode 100644 index 000000000000..9521b7f3dd72 --- /dev/null +++ b/src/diffusers/models/transformers/transformer_bria_fibo.py @@ -0,0 +1,446 @@ +from typing import Any, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import FromOriginalModelMixin, PeftAdapterMixin +from ...models.attention_processor import Attention +from ...models.embeddings import TimestepEmbedding, get_timestep_embedding +from ...models.modeling_outputs import Transformer2DModelOutput +from ...models.modeling_utils import ModelMixin +from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZeroSingle +from ...models.transformers.transformer_bria import BriaAttnProcessor +from ...models.transformers.transformer_flux import FluxTransformerBlock +from ...utils import ( + USE_PEFT_BACKEND, + logging, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import maybe_allow_in_graph + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def get_1d_rotary_pos_embed( + dim: int, + pos: Union[np.ndarray, int], + theta: float = 10000.0, + use_real=False, + linear_factor=1.0, + ntk_factor=1.0, + repeat_interleave_real=True, + freqs_dtype=torch.float32, # torch.float32, torch.float64 +): + """ + Precompute the frequency tensor for complex exponentials (cis) with given dimensions. This function calculates a + frequency tensor with complex exponentials using the given dimension 'dim' and the end index 'end'. The 'theta' + parameter scales the frequencies. The returned tensor contains complex values in complex64 data type. + + Args: + dim (`int`): Dimension of the frequency tensor. + pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar + theta (`float`, *optional*, defaults to 10000.0): + Scaling factor for frequency computation. Defaults to 10000.0. + use_real (`bool`, *optional*): + If True, return real part and imaginary part separately. Otherwise, return complex numbers. + linear_factor (`float`, *optional*, defaults to 1.0): + Scaling factor for the context extrapolation. Defaults to 1.0. + ntk_factor (`float`, *optional*, defaults to 1.0): + Scaling factor for the NTK-Aware RoPE. Defaults to 1.0. + repeat_interleave_real (`bool`, *optional*, defaults to `True`): + If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`. + Otherwise, they are concateanted with themselves. + freqs_dtype (`torch.float32` or `torch.float64`, *optional*, defaults to `torch.float32`): + the dtype of the frequency tensor. + Returns: + `torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2] + """ + assert dim % 2 == 0 + + if isinstance(pos, int): + pos = torch.arange(pos) + if isinstance(pos, np.ndarray): + pos = torch.from_numpy(pos) # type: ignore # [S] + + theta = theta * ntk_factor + freqs = ( + 1.0 + / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device)[: (dim // 2)] / dim)) + / linear_factor + ) # [D/2] + freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2] + if use_real and repeat_interleave_real: + # flux, hunyuan-dit, cogvideox + freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D] + freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D] + return freqs_cos, freqs_sin + elif use_real: + # stable audio, allegro + freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D] + freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D] + return freqs_cos, freqs_sin + else: + # lumina + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2] + return freqs_cis + + +class EmbedND(torch.nn.Module): + # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11 + def __init__(self, theta: int, axes_dim: List[int]): + super().__init__() + self.theta = theta + self.axes_dim = axes_dim + + def forward(self, ids: torch.Tensor) -> torch.Tensor: + n_axes = ids.shape[-1] + cos_out = [] + sin_out = [] + pos = ids.float() + is_mps = ids.device.type == "mps" + freqs_dtype = torch.float32 if is_mps else torch.float64 + for i in range(n_axes): + cos, sin = get_1d_rotary_pos_embed( + self.axes_dim[i], + pos[:, i], + theta=self.theta, + repeat_interleave_real=True, + use_real=True, + freqs_dtype=freqs_dtype, + ) + cos_out.append(cos) + sin_out.append(sin) + freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device) + freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device) + return freqs_cos, freqs_sin + + +@maybe_allow_in_graph +class BriaFiboSingleTransformerBlock(nn.Module): + def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, mlp_ratio: float = 4.0): + super().__init__() + self.mlp_hidden_dim = int(dim * mlp_ratio) + + self.norm = AdaLayerNormZeroSingle(dim) + self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim) + self.act_mlp = nn.GELU(approximate="tanh") + self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim) + + processor = BriaAttnProcessor() + + self.attn = Attention( + query_dim=dim, + cross_attention_dim=None, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=dim, + bias=True, + processor=processor, + qk_norm="rms_norm", + eps=1e-6, + pre_only=True, + ) + + def forward( + self, + hidden_states: torch.Tensor, + temb: torch.Tensor, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + ) -> torch.Tensor: + residual = hidden_states + norm_hidden_states, gate = self.norm(hidden_states, emb=temb) + mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states)) + joint_attention_kwargs = joint_attention_kwargs or {} + attn_output = self.attn( + hidden_states=norm_hidden_states, + image_rotary_emb=image_rotary_emb, + **joint_attention_kwargs, + ) + + hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2) + gate = gate.unsqueeze(1) + hidden_states = gate * self.proj_out(hidden_states) + hidden_states = residual + hidden_states + if hidden_states.dtype == torch.float16: + hidden_states = hidden_states.clip(-65504, 65504) + + return hidden_states + + +class TextProjection(nn.Module): + def __init__(self, in_features, hidden_size): + super().__init__() + self.linear = nn.Linear(in_features=in_features, out_features=hidden_size, bias=False) + + def forward(self, caption): + hidden_states = self.linear(caption) + return hidden_states + + +class Timesteps(nn.Module): + def __init__( + self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1, time_theta=10000 + ): + super().__init__() + self.num_channels = num_channels + self.flip_sin_to_cos = flip_sin_to_cos + self.downscale_freq_shift = downscale_freq_shift + self.scale = scale + self.time_theta = time_theta + + def forward(self, timesteps): + t_emb = get_timestep_embedding( + timesteps, + self.num_channels, + flip_sin_to_cos=self.flip_sin_to_cos, + downscale_freq_shift=self.downscale_freq_shift, + scale=self.scale, + max_period=self.time_theta, + ) + return t_emb + + +class TimestepProjEmbeddings(nn.Module): + def __init__(self, embedding_dim, time_theta): + super().__init__() + + self.time_proj = Timesteps( + num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, time_theta=time_theta + ) + self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + + def forward(self, timestep, dtype): + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=dtype)) # (N, D) + return timesteps_emb + + +class BriaFiboTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): + """ + Parameters: + patch_size (`int`): Patch size to turn the input data into small patches. + in_channels (`int`, *optional*, defaults to 16): The number of channels in the input. + num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use. + num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use. + attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head. + num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention. + joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. + pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`. + guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings. + ... + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + patch_size: int = 1, + in_channels: int = 64, + num_layers: int = 19, + num_single_layers: int = 38, + attention_head_dim: int = 128, + num_attention_heads: int = 24, + joint_attention_dim: int = 4096, + pooled_projection_dim: int = None, + guidance_embeds: bool = False, + axes_dims_rope: List[int] = [16, 56, 56], + rope_theta=10000, + time_theta=10000, + text_encoder_dim: int = 2048, + ): + super().__init__() + self.out_channels = in_channels + self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim + + self.pos_embed = EmbedND(theta=rope_theta, axes_dim=axes_dims_rope) + + self.time_embed = TimestepProjEmbeddings(embedding_dim=self.inner_dim, time_theta=time_theta) + + if guidance_embeds: + self.guidance_embed = TimestepProjEmbeddings(embedding_dim=self.inner_dim) + + self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.inner_dim) + self.x_embedder = torch.nn.Linear(self.config.in_channels, self.inner_dim) + + self.transformer_blocks = nn.ModuleList( + [ + FluxTransformerBlock( + dim=self.inner_dim, + num_attention_heads=self.config.num_attention_heads, + attention_head_dim=self.config.attention_head_dim, + ) + for i in range(self.config.num_layers) + ] + ) + + self.single_transformer_blocks = nn.ModuleList( + [ + BriaFiboSingleTransformerBlock( + dim=self.inner_dim, + num_attention_heads=self.config.num_attention_heads, + attention_head_dim=self.config.attention_head_dim, + ) + for i in range(self.config.num_single_layers) + ] + ) + + self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) + + self.gradient_checkpointing = False + + caption_projection = [ + TextProjection(in_features=text_encoder_dim, hidden_size=self.inner_dim // 2) + for i in range(self.config.num_layers + self.config.num_single_layers) + ] + self.caption_projection = nn.ModuleList(caption_projection) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + text_encoder_layers: list = None, + pooled_projections: torch.Tensor = None, + timestep: torch.LongTensor = None, + img_ids: torch.Tensor = None, + txt_ids: torch.Tensor = None, + guidance: torch.Tensor = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, + ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: + """ + + Args: + hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`): + Input `hidden_states`. + encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`): + Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. + pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected + from the embeddings of input conditions. + timestep ( `torch.LongTensor`): + Used to indicate denoising step. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain + tuple. + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + if joint_attention_kwargs is not None: + joint_attention_kwargs = joint_attention_kwargs.copy() + lora_scale = joint_attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." + ) + hidden_states = self.x_embedder(hidden_states) + + timestep = timestep.to(hidden_states.dtype) + if guidance is not None: + guidance = guidance.to(hidden_states.dtype) + else: + guidance = None + + temb = self.time_embed(timestep, dtype=hidden_states.dtype) + + if guidance: + temb += self.guidance_embed(guidance, dtype=hidden_states.dtype) + + encoder_hidden_states = self.context_embedder(encoder_hidden_states) + + if len(txt_ids.shape) == 3: + txt_ids = txt_ids[0] + + if len(img_ids.shape) == 3: + img_ids = img_ids[0] + + ids = torch.cat((txt_ids, img_ids), dim=0) + image_rotary_emb = self.pos_embed(ids) + + new_text_encoder_layers = [] + for i, text_encoder_layer in enumerate(text_encoder_layers): + text_encoder_layer = self.caption_projection[i](text_encoder_layer) + new_text_encoder_layers.append(text_encoder_layer) + text_encoder_layers = new_text_encoder_layers + + block_id = 0 + for index_block, block in enumerate(self.transformer_blocks): + current_text_encoder_layer = text_encoder_layers[block_id] + encoder_hidden_states = torch.cat( + [encoder_hidden_states[:, :, : self.inner_dim // 2], current_text_encoder_layer], dim=-1 + ) + block_id += 1 + if torch.is_grad_enabled() and self.gradient_checkpointing: + encoder_hidden_states, hidden_states = self._gradient_checkpointing_func( + block, + hidden_states, + encoder_hidden_states, + temb, + image_rotary_emb, + joint_attention_kwargs, + ) + + else: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + joint_attention_kwargs=joint_attention_kwargs, + ) + + for index_block, block in enumerate(self.single_transformer_blocks): + current_text_encoder_layer = text_encoder_layers[block_id] + encoder_hidden_states = torch.cat( + [encoder_hidden_states[:, :, : self.inner_dim // 2], current_text_encoder_layer], dim=-1 + ) + block_id += 1 + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func( + block, + hidden_states, + temb, + image_rotary_emb, + joint_attention_kwargs, + ) + + else: + hidden_states = block( + hidden_states=hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + joint_attention_kwargs=joint_attention_kwargs, + ) + + encoder_hidden_states = hidden_states[:, : encoder_hidden_states.shape[1], ...] + hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...] + + hidden_states = self.norm_out(hidden_states, temb) + output = self.proj_out(hidden_states) + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) diff --git a/src/diffusers/modular_pipelines/__init__.py b/src/diffusers/modular_pipelines/__init__.py index 86ed735134ff..91f9397629b7 100644 --- a/src/diffusers/modular_pipelines/__init__.py +++ b/src/diffusers/modular_pipelines/__init__.py @@ -45,6 +45,7 @@ "InsertableDict", ] _import_structure["stable_diffusion_xl"] = ["StableDiffusionXLAutoBlocks", "StableDiffusionXLModularPipeline"] + _import_structure["bria_fibo"] = ["BriaFiboVLMPromptToJson", "BriaFiboGeminiPromptToJson"] _import_structure["wan"] = ["WanAutoBlocks", "WanModularPipeline"] _import_structure["flux"] = [ "FluxAutoBlocks", @@ -69,6 +70,7 @@ except OptionalDependencyNotAvailable: from ..utils.dummy_pt_objects import * # noqa F403 else: + from .bria_fibo import BriaFiboGeminiPromptToJson, BriaFiboVLMPromptToJson from .components_manager import ComponentsManager from .flux import FluxAutoBlocks, FluxKontextAutoBlocks, FluxKontextModularPipeline, FluxModularPipeline from .modular_pipeline import ( diff --git a/src/diffusers/modular_pipelines/bria_fibo/__init__.py b/src/diffusers/modular_pipelines/bria_fibo/__init__.py new file mode 100644 index 000000000000..302d271e0c43 --- /dev/null +++ b/src/diffusers/modular_pipelines/bria_fibo/__init__.py @@ -0,0 +1,47 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["gemini_prompt_to_json"] = ["BriaFiboGeminiPromptToJson"] + _import_structure["fibo_vlm_prompt_to_json"] = ["BriaFiboVLMPromptToJson"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .gemini_prompt_to_json import BriaFiboGeminiPromptToJson + from .fibo_vlm_prompt_to_json import BriaFiboVLMPromptToJson +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/modular_pipelines/bria_fibo/fibo_vlm_prompt_to_json.py b/src/diffusers/modular_pipelines/bria_fibo/fibo_vlm_prompt_to_json.py new file mode 100644 index 000000000000..c63c7f85d190 --- /dev/null +++ b/src/diffusers/modular_pipelines/bria_fibo/fibo_vlm_prompt_to_json.py @@ -0,0 +1,377 @@ +import json +import math +import textwrap +from typing import Any, Dict, Iterable, List, Optional + +import torch +from boltons.iterutils import remap +from PIL import Image +from transformers import AutoModelForCausalLM, AutoProcessor, Qwen3VLForConditionalGeneration + +from .. import ComponentSpec, InputParam, ModularPipelineBlocks, OutputParam, PipelineState + + +def parse_aesthetic_score(record: dict) -> str: + ae = record["aesthetic_score"] + if ae < 5.5: + return "very low" + elif ae < 6: + return "low" + elif ae < 7: + return "medium" + elif ae < 7.6: + return "high" + else: + return "very high" + + +def parse_pickascore(record: dict) -> str: + ps = record["pickascore"] + if ps < 0.78: + return "very low" + elif ps < 0.82: + return "low" + elif ps < 0.87: + return "medium" + elif ps < 0.91: + return "high" + else: + return "very high" + + +def prepare_clean_caption(record: dict) -> str: + def keep(p, k, v): + is_none = v is None + is_empty_string = isinstance(v, str) and v == "" + is_empty_dict = isinstance(v, dict) and not v + is_empty_list = isinstance(v, list) and not v + is_nan = isinstance(v, float) and math.isnan(v) + if is_none or is_empty_string or is_empty_list or is_empty_dict or is_nan: + return False + return True + + try: + scores = {} + if "pickascore" in record: + scores["preference_score"] = parse_pickascore(record) + if "aesthetic_score" in record: + scores["aesthetic_score"] = parse_aesthetic_score(record) + + clean_caption_dict = remap(record, visit=keep) + + # Set aesthetics scores + if "aesthetics" not in clean_caption_dict: + if len(scores) > 0: + clean_caption_dict["aesthetics"] = scores + else: + clean_caption_dict["aesthetics"].update(scores) + + # Dumps clean structured caption as minimal json string (i.e. no newlines\whitespaces seps) + clean_caption_str = json.dumps(clean_caption_dict) + return clean_caption_str + except Exception as ex: + print("Error: ", ex) + raise ex + + +def _collect_images(messages: Iterable[Dict[str, Any]]) -> List[Image.Image]: + images: List[Image.Image] = [] + for message in messages: + content = message.get("content", []) + if not isinstance(content, list): + continue + for item in content: + if not isinstance(item, dict): + continue + if item.get("type") != "image": + continue + image_value = item.get("image") + if isinstance(image_value, Image.Image): + images.append(image_value) + else: + raise ValueError("Expected PIL.Image for image content in messages.") + return images + + +def _strip_stop_sequences(text: str, stop_sequences: Optional[List[str]]) -> str: + if not stop_sequences: + return text.strip() + cleaned = text + for stop in stop_sequences: + if not stop: + continue + index = cleaned.find(stop) + if index >= 0: + cleaned = cleaned[:index] + return cleaned.strip() + + +class TransformersEngine(torch.nn.Module): + """Inference wrapper using Hugging Face transformers.""" + + def __init__( + self, + model: str, + *, + processor_kwargs: Optional[Dict[str, Any]] = None, + model_kwargs: Optional[Dict[str, Any]] = None, + ) -> None: + super(TransformersEngine, self).__init__() + default_processor_kwargs: Dict[str, Any] = { + "min_pixels": 256 * 28 * 28, + "max_pixels": 1024 * 28 * 28, + } + processor_kwargs = {**default_processor_kwargs, **(processor_kwargs or {})} + model_kwargs = model_kwargs or {} + + self.processor = AutoProcessor.from_pretrained(model, **processor_kwargs) + + self.model = Qwen3VLForConditionalGeneration.from_pretrained( + model, + dtype=torch.bfloat16, + **model_kwargs, + ) + self.model.eval() + + tokenizer_obj = self.processor.tokenizer + if tokenizer_obj.pad_token_id is None: + tokenizer_obj.pad_token = tokenizer_obj.eos_token + self._pad_token_id = tokenizer_obj.pad_token_id + eos_token_id = tokenizer_obj.eos_token_id + if isinstance(eos_token_id, list) and eos_token_id: + self._eos_token_id = eos_token_id + elif eos_token_id is not None: + self._eos_token_id = [eos_token_id] + else: + raise ValueError("Tokenizer must define an EOS token for generation.") + + def dtype(self) -> torch.dtype: + return self.model.dtype + + def device(self) -> torch.device: + return self.model.device + + def _to_model_device(self, value: Any) -> Any: + if not isinstance(value, torch.Tensor): + return value + target_device = getattr(self.model, "device", None) + if target_device is None or target_device.type == "meta": + return value + if value.device == target_device: + return value + return value.to(target_device) + + def generate( + self, + messages: List[Dict[str, Any]], + top_p: float, + temperature: float, + max_tokens: int, + stop: Optional[List[str]] = None, + ) -> str: + tokenizer = self.processor.tokenizer + prompt_text = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + ) + processor_inputs: Dict[str, Any] = { + "text": [prompt_text], + "padding": True, + "return_tensors": "pt", + } + images = _collect_images(messages) + if images: + processor_inputs["images"] = images + inputs = self.processor(**processor_inputs) + inputs = {key: self._to_model_device(value) for key, value in inputs.items()} + + generation_kwargs: Dict[str, Any] = { + "max_new_tokens": max_tokens, + "temperature": temperature, + "top_p": top_p, + "do_sample": temperature > 0, + "eos_token_id": self._eos_token_id, + "pad_token_id": self._pad_token_id, + } + + with torch.inference_mode(): + generated_ids = self.model.generate(**inputs, **generation_kwargs) + + input_ids = inputs.get("input_ids") + if input_ids is None: + raise RuntimeError("Processor did not return input_ids; cannot compute new tokens.") + new_token_ids = generated_ids[:, input_ids.shape[-1] :] + decoded = tokenizer.batch_decode(new_token_ids, skip_special_tokens=True) + if not decoded: + return "" + text = decoded[0] + stripped_text = _strip_stop_sequences(text, stop) + json_prompt = json.loads(stripped_text) + return json_prompt + + +def generate_json_prompt( + vlm_processor: AutoModelForCausalLM, + top_p: float, + temperature: float, + max_tokens: int, + stop: List[str], + image: Optional[Image.Image] = None, + prompt: Optional[str] = None, + structured_prompt: Optional[str] = None, +): + if image is None and structured_prompt is None: + # only got prompt + task = "generate" + editing_instructions = None + elif image is None and structured_prompt is not None and prompt is not None: + # got structured prompt and prompt + task = "refine" + editing_instructions = prompt + elif image is not None and structured_prompt is None and prompt is not None: + # got image and prompt + task = "refine" + editing_instructions = prompt + elif image is not None and structured_prompt is None and prompt is None: + # only got image + task = "inspire" + editing_instructions = None + else: + raise ValueError("Invalid input") + + messages = build_messages( + task, + image=image, + prompt=prompt, + structured_prompt=structured_prompt, + editing_instructions=editing_instructions, + ) + + generated_prompt = vlm_processor.generate( + messages=messages, top_p=top_p, temperature=temperature, max_tokens=max_tokens, stop=stop + ) + cleaned_json_data = prepare_clean_caption(generated_prompt) + return cleaned_json_data + + +def build_messages( + task: str, + *, + image: Optional[Image.Image] = None, + refine_image: Optional[Image.Image] = None, + prompt: Optional[str] = None, + structured_prompt: Optional[str] = None, + editing_instructions: Optional[str] = None, +) -> List[Dict[str, Any]]: + user_content: List[Dict[str, Any]] = [] + + if task == "inspire": + user_content.append({"type": "image", "image": image}) + user_content.append({"type": "text", "text": ""}) + elif task == "generate": + text_value = (prompt or "").strip() + formatted = f"\n{text_value}" + user_content.append({"type": "text", "text": formatted}) + else: # refine + if refine_image is None: + base_prompt = (structured_prompt or "").strip() + edits = (editing_instructions or "").strip() + formatted = textwrap.dedent( + f""" Input: {base_prompt} Editing instructions: {edits}""" + ).strip() + user_content.append({"type": "text", "text": formatted}) + else: + user_content.append({"type": "image", "image": refine_image}) + edits = (editing_instructions or "").strip() + formatted = textwrap.dedent( + f""" Editing instructions: {edits}""" + ).strip() + user_content.append({"type": "text", "text": formatted}) + + messages: List[Dict[str, Any]] = [] + messages.append({"role": "user", "content": user_content}) + return messages + + +class BriaFiboVLMPromptToJson(ModularPipelineBlocks): + model_name = "BriaFibo" + + def __init__(self, model_id): + super().__init__() + self.engine = TransformersEngine(model_id) + self.engine.model.to("cuda") + + @property + def expected_components(self) -> List[ComponentSpec]: + return [] + + @property + def inputs(self) -> List[InputParam]: + prompt_input = InputParam( + "prompt", + type_hint=str, + required=False, + description="Prompt to use", + ) + image_input = InputParam( + name="image", type_hint=Image.Image, required=False, description="image for inspiration mode" + ) + json_prompt_input = InputParam( + name="json_prompt", type_hint=str, required=False, description="JSON prompt to use" + ) + sampling_top_p_input = InputParam( + name="sampling_top_p", type_hint=float, required=False, description="Sampling top p", default=0.9 + ) + sampling_temperature_input = InputParam( + name="sampling_temperature", + type_hint=float, + required=False, + description="Sampling temperature", + default=0.2, + ) + sampling_max_tokens_input = InputParam( + name="sampling_max_tokens", type_hint=int, required=False, description="Sampling max tokens", default=4096 + ) + return [ + prompt_input, + image_input, + json_prompt_input, + sampling_top_p_input, + sampling_temperature_input, + sampling_max_tokens_input, + ] + + @property + def intermediate_inputs(self) -> List[InputParam]: + return [] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + "json_prompt", + type_hint=str, + description="JSON prompt by the VLM", + ) + ] + + def __call__(self, components, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + prompt = block_state.prompt + image = block_state.image + json_prompt = block_state.json_prompt + block_state.json_prompt = generate_json_prompt( + vlm_processor=self.engine, + image=image, + prompt=prompt, + structured_prompt=json_prompt, + top_p=block_state.sampling_top_p, + temperature=block_state.sampling_temperature, + max_tokens=block_state.sampling_max_tokens, + stop=["<|im_end|>", "<|end_of_text|>"], + ) + self.set_block_state(state, block_state) + + return components, state diff --git a/src/diffusers/modular_pipelines/bria_fibo/gemini_prompt_to_json.py b/src/diffusers/modular_pipelines/bria_fibo/gemini_prompt_to_json.py new file mode 100644 index 000000000000..dc690444839d --- /dev/null +++ b/src/diffusers/modular_pipelines/bria_fibo/gemini_prompt_to_json.py @@ -0,0 +1,804 @@ +import io +import json +import math +import os +from functools import cache +from typing import List, Optional, Tuple + +from boltons.iterutils import remap +from google import genai +from PIL import Image +from pydantic import BaseModel, Field + +from ...modular_pipelines import InputParam, ModularPipelineBlocks, OutputParam, PipelineState + + +class ObjectDescription(BaseModel): + description: str = Field(..., description="Short description of the object.") + location: str = Field(..., description="E.g., 'center', 'top-left', 'bottom-right foreground'.") + relationship: str = Field( + ..., description="Describe the relationship between the object and the other objects in the image." + ) + relative_size: Optional[str] = Field(None, description="E.g., 'small', 'medium', 'large within frame'.") + shape_and_color: Optional[str] = Field(None, description="Describe the basic shape and dominant color.") + texture: Optional[str] = Field(None, description="E.g., 'smooth', 'rough', 'metallic', 'furry'.") + appearance_details: Optional[str] = Field(None, description="Any other notable visual details.") + # If cluster of object + number_of_objects: Optional[int] = Field(None, description="The number of objects in the cluster.") + # Human-specific fields + pose: Optional[str] = Field(None, description="Describe the body position.") + expression: Optional[str] = Field(None, description="Describe facial expression.") + clothing: Optional[str] = Field(None, description="Describe attire.") + action: Optional[str] = Field(None, description="Describe the action of the human.") + gender: Optional[str] = Field(None, description="Describe the gender of the human.") + skin_tone_and_texture: Optional[str] = Field(None, description="Describe the skin tone and texture.") + orientation: Optional[str] = Field(None, description="Describe the orientation of the human.") + + +class LightingDetails(BaseModel): + conditions: str = Field( + ..., description="E.g., 'bright daylight', 'dim indoor', 'studio lighting', 'golden hour'." + ) + direction: str = Field(..., description="E.g., 'front-lit', 'backlit', 'side-lit from left'.") + shadows: Optional[str] = Field(None, description="Describe the presence of shadows.") + + +class AestheticsDetails(BaseModel): + composition: str = Field(..., description="E.g., 'rule of thirds', 'symmetrical', 'centered', 'leading lines'.") + color_scheme: str = Field( + ..., description="E.g., 'monochromatic blue', 'warm complementary colors', 'high contrast'." + ) + mood_atmosphere: str = Field(..., description="E.g., 'serene', 'energetic', 'mysterious', 'joyful'.") + + +class PhotographicCharacteristicsDetails(BaseModel): + depth_of_field: str = Field(..., description="E.g., 'shallow', 'deep', 'bokeh background'.") + focus: str = Field(..., description="E.g., 'sharp focus on subject', 'soft focus', 'motion blur'.") + camera_angle: str = Field(..., description="E.g., 'eye-level', 'low angle', 'high angle', 'dutch angle'.") + lens_focal_length: str = Field(..., description="E.g., 'wide-angle', 'telephoto', 'macro', 'fisheye'.") + + +class TextRender(BaseModel): + text: str = Field(..., description="The text content.") + location: str = Field(..., description="E.g., 'center', 'top-left', 'bottom-right foreground'.") + size: str = Field(..., description="E.g., 'small', 'medium', 'large within frame'.") + color: str = Field(..., description="E.g., 'red', 'blue', 'green'.") + font: str = Field(..., description="E.g., 'realistic', 'cartoonish', 'minimalist'.") + appearance_details: Optional[str] = Field(None, description="Any other notable visual details.") + + +class ImageAnalysis(BaseModel): + short_description: str = Field(..., description="A concise summary of the image content, 200 words maximum.") + objects: List[ObjectDescription] = Field(..., description="List of prominent foreground/midground objects.") + background_setting: str = Field( + ..., + description="Describe the overall environment, setting, or background, including any notable background elements.", + ) + lighting: LightingDetails = Field(..., description="Details about the lighting.") + aesthetics: AestheticsDetails = Field(..., description="Details about the image aesthetics.") + photographic_characteristics: Optional[PhotographicCharacteristicsDetails] = Field( + None, description="Details about photographic characteristics." + ) + style_medium: Optional[str] = Field(None, description="Identify the artistic style or medium.") + text_render: Optional[List[TextRender]] = Field(None, description="List of text renders in the image.") + context: str = Field(..., description="Provide any additional context that helps understand the image better.") + artistic_style: Optional[str] = Field( + None, description="describe specific artistic characteristics, 3 words maximum." + ) + + +def get_gemini_output_schema() -> dict: + return { + "properties": { + "short_description": {"type": "STRING"}, + "objects": { + "items": { + "properties": { + "description": {"type": "STRING"}, + "location": {"type": "STRING"}, + "relationship": {"type": "STRING"}, + "relative_size": {"type": "STRING"}, + "shape_and_color": {"type": "STRING"}, + "texture": {"nullable": True, "type": "STRING"}, + "appearance_details": {"nullable": True, "type": "STRING"}, + "number_of_objects": {"nullable": True, "type": "INTEGER"}, + "pose": {"nullable": True, "type": "STRING"}, + "expression": {"nullable": True, "type": "STRING"}, + "clothing": {"nullable": True, "type": "STRING"}, + "action": {"nullable": True, "type": "STRING"}, + "gender": {"nullable": True, "type": "STRING"}, + "skin_tone_and_texture": {"nullable": True, "type": "STRING"}, + "orientation": {"nullable": True, "type": "STRING"}, + }, + "required": [ + "description", + "location", + "relationship", + "relative_size", + "shape_and_color", + "texture", + "appearance_details", + "number_of_objects", + "pose", + "expression", + "clothing", + "action", + "gender", + "skin_tone_and_texture", + "orientation", + ], + "type": "OBJECT", + }, + "type": "ARRAY", + }, + "background_setting": {"type": "STRING"}, + "lighting": { + "properties": { + "conditions": {"type": "STRING"}, + "direction": {"type": "STRING"}, + "shadows": {"nullable": True, "type": "STRING"}, + }, + "required": ["conditions", "direction", "shadows"], + "type": "OBJECT", + }, + "aesthetics": { + "properties": { + "composition": {"type": "STRING"}, + "color_scheme": {"type": "STRING"}, + "mood_atmosphere": {"type": "STRING"}, + }, + "required": ["composition", "color_scheme", "mood_atmosphere"], + "type": "OBJECT", + }, + "photographic_characteristics": { + "nullable": True, + "properties": { + "depth_of_field": {"type": "STRING"}, + "focus": {"type": "STRING"}, + "camera_angle": {"type": "STRING"}, + "lens_focal_length": {"type": "STRING"}, + }, + "required": [ + "depth_of_field", + "focus", + "camera_angle", + "lens_focal_length", + ], + "type": "OBJECT", + }, + "style_medium": {"type": "STRING"}, + "text_render": { + "items": { + "properties": { + "text": {"type": "STRING"}, + "location": {"type": "STRING"}, + "size": {"type": "STRING"}, + "color": {"type": "STRING"}, + "font": {"type": "STRING"}, + "appearance_details": {"nullable": True, "type": "STRING"}, + }, + "required": [ + "text", + "location", + "size", + "color", + "font", + "appearance_details", + ], + "type": "OBJECT", + }, + "type": "ARRAY", + }, + "context": {"type": "STRING"}, + "artistic_style": {"type": "STRING"}, + }, + "required": [ + "short_description", + "objects", + "background_setting", + "lighting", + "aesthetics", + "photographic_characteristics", + "style_medium", + "text_render", + "context", + "artistic_style", + ], + "type": "OBJECT", + } + + +json_schema_full = """1. `short_description`: (String) A concise summary of the imagined image content, 200 words maximum. +2. `objects`: (Array of Objects) List a maximum of 5 prominent objects. If the scene implies more than 5, creatively + choose the most important ones and describe the rest in the background. For each object, include: + * `description`: (String) A detailed description of the imagined object, 100 words maximum. + * `location`: (String) E.g., "center", "top-left", "bottom-right foreground". + * `relative_size`: (String) E.g., "small", "medium", "large within frame". (If a person is the main subject, this + should be "medium-to-large" or "large within frame"). + * `shape_and_color`: (String) Describe the basic shape and dominant color. + * `texture`: (String) E.g., "smooth", "rough", "metallic", "furry". + * `appearance_details`: (String) Any other notable visual details. + * `relationship`: (String) Describe the relationship between the object and the other objects in the image. + * `orientation`: (String) Describe the orientation or positioning of the object, e.g., "upright", "tilted 45 + degrees", "horizontal", "vertical", "facing left", "facing right", "upside down", "lying on its side". + * If the object is a human or a human-like object, include the following: + * `pose`: (String) Describe the body position. + * `expression`: (String) Describe facial expression and emotion. E.g., "winking", "joyful", "serious", + "surprised", "calm". + * `clothing`: (String) Describe attire. + * `action`: (String) Describe the action of the human. + * `gender`: (String) Describe the gender of the human. + * `skin_tone_and_texture`: (String) Describe the skin tone and texture. + * If the object is a cluster of objects, include the following: + * `number_of_objects`: (Integer) The number of objects in the cluster. +3. `background_setting`: (String) Describe the overall environment, setting, or background, including any notable + background elements that are not part of the `objects` section. +4. `lighting`: (Object) + * `conditions`: (String) E.g., "bright daylight", "dim indoor", "studio lighting", "golden hour". + * `direction`: (String) E.g., "front-lit", "backlit", "side-lit from left". + * `shadows`: (String) Describe the presence and quality of shadows, e.g., "long, soft shadows", "sharp, defined + shadows", "minimal shadows". +5. `aesthetics`: (Object) + * `composition`: (String) E.g., "rule of thirds", "symmetrical", "centered", "leading lines". If people are the + main subject, specify the shot type, e.g., "medium shot", "close-up", "portrait composition". + * `color_scheme`: (String) E.g., "monochromatic blue", "warm complementary colors", "high contrast". + * `mood_atmosphere`: (String) E.g., "serene", "energetic", "mysterious", "joyful". +6. `photographic_characteristics`: (Object) + * `depth_of_field`: (String) E.g., "shallow", "deep", "bokeh background". + * `focus`: (String) E.g., "sharp focus on subject", "soft focus", "motion blur". + * `camera_angle`: (String) E.g., "eye-level", "low angle", "high angle", "dutch angle". + * `lens_focal_length`: (String) E.g., "wide-angle", "telephoto", "macro", "fisheye". (If the main subject is a + person, prefer "standard lens (e.g., 35mm-50mm)" or "portrait lens (e.g., 50mm-85mm)" to ensure they are framed + more closely. Avoid "wide-angle" for people unless specified). +7. `style_medium`: (String) Identify the artistic style or medium based on the user's prompt or creative + interpretation (e.g., "photograph", "oil painting", "watercolor", "3D render", "digital illustration", "pencil + sketch"). +8. `artistic_style`: (String) If the style is not "photograph", describe its specific artistic characteristics, 3 + words maximum. (e.g., "impressionistic, vibrant, textured" for an oil painting). +9. `context`: (String) Provide a general description of the type of image this would be. For example: "This is a + concept for a high-fashion editorial photograph intended for a magazine spread," or "This describes a piece of + concept art for a fantasy video game." +10. `text_render`: (Array of Objects) By default, this array should be empty (`[]`). Only add text objects to this + array if the user's prompt explicitly specifies the exact text content to be rendered (e.g., user asks for "a + poster with the title 'Cosmic Dream'"). Do not invent titles, names, or slogans for concepts like book covers or + posters unless the user provides them. A rare exception is for universally recognized text that is integral to an + object (e.g., the word 'STOP' on a 'stop sign'). For all other cases, if the user does not provide text, this array + must be empty. + * `text`: (String) The exact text content provided by the user. NEVER use generic placeholders. + * `location`: (String) E.g., "center", "top-left", "bottom-right foreground". + * `size`: (String) E.g., "medium", "large", "large within frame". + * `color`: (String) E.g., "red", "blue", "green". + * `font`: (String) E.g., "realistic", "cartoonish", "minimalist", "serif typeface". + * `appearance_details`: (String) Any other notable visual details.""" + + +@cache +def get_instructions(mode: str) -> Tuple[str, str]: + system_prompts = {} + + system_prompts["Caption"] = """ +You are a meticulous and perceptive Visual Art Director working for a leading Generative AI company. Your expertise +lies in analyzing images and extracting detailed, structured information. Your primary task is to analyze provided +images and generate a comprehensive JSON object describing them. Adhere strictly to the following structure and +guidelines: The output MUST be ONLY a valid JSON object. Do not include any text before or after the JSON object (e.g., +no "Here is the JSON:", no explanations, no apologies). IMPORTANT: When describing human body parts, positions, or +actions, always describe them from the PERSON'S OWN PERSPECTIVE, not from the observer's viewpoint. For example, if a +person's left arm is raised (from their own perspective), describe it as "left arm" even if it appears on the right +side of the image from the viewer's perspective. The JSON object must contain the following keys precisely: +1. `short_description`: (String) A concise summary of the image content, 200 words maximum. +2. `objects`: (Array of Objects) List a maximum of 5 prominent objects if there are more than 5, list them in the + background. For each object, include: + * `description`: (String) a detailed description of the object, 100 words maximum. + * `location`: (String) E.g., "center", "top-left", "bottom-right foreground". + * `relative_size`: (String) E.g., "small", "medium", "large within frame". + * `shape_and_color`: (String) Describe the basic shape and dominant color. + * `texture`: (String) E.g., "smooth", "rough", "metallic", "furry". + * `appearance_details`: (String) Any other notable visual details. + * `relationship`: (String) Describe the relationship between the object and the other objects in the image. + * `orientation`: (String) Describe the orientation or positioning of the object, e.g., "upright", "tilted 45 + degrees", "horizontal", "vertical", "facing left", "facing right", "upside down", "lying on its side". + if the object is a human or a human-like object, include the following: + * `pose`: (String) Describe the body position. + * `expression`: (String) Describe facial expression and emotion. E.g., "winking", "joyful", "serious", + "surprised", "calm". + * `clothing`: (String) Describe attire. + * `action`: (String) Describe the action of the human. + * `gender`: (String) Describe the gender of the human. + * `skin_tone_and_texture`: (String) Describe the skin tone and texture. + if the object is a cluster of objects, include the following: + * `number_of_objects`: (Integer) The number of objects in the cluster. +3. `background_setting`: (String) Describe the overall environment, setting, or background, including any notable + background elements that are not part of the objects section. +4. `lighting`: (Object) + * `conditions`: (String) E.g., "bright daylight", "dim indoor", "studio lighting", "golden hour". + * `direction`: (String) E.g., "front-lit", "backlit", "side-lit from left". + * `shadows`: (String) Describe the presence of shadows. +5. `aesthetics`: (Object) + * `composition`: (String) E.g., "rule of thirds", "symmetrical", "centered", "leading lines". + * `color_scheme`: (String) E.g., "monochromatic blue", "warm complementary colors", "high contrast". + * `mood_atmosphere`: (String) E.g., "serene", "energetic", "mysterious", "joyful". +6. `photographic_characteristics`: (Object) + * `depth_of_field`: (String) E.g., "shallow", "deep", "bokeh background". + * `focus`: (String) E.g., "sharp focus on subject", "soft focus", "motion blur". + * `camera_angle`: (String) E.g., "eye-level", "low angle", "high angle", "dutch angle". + * `lens_focal_length`: (String) E.g., "wide-angle", "telephoto", "macro", "fisheye". +7. `style_medium`: (String) Identify the artistic style or medium (e.g., "photograph", "oil painting", "watercolor", + "3D render", "digital illustration", "pencil sketch") If the style is not "photograph", but artistic, please + describe the specific artistic characteristics under 'artistic_style', 50 words maximum. +8. `artistic_style`: (String) describe specific artistic characteristics, 3 words maximum. +9. `context`: (String) Provide any additional context that helps understand the image better. This should include a + general description of the type of image (e.g., Fashion Photography, Product Shot, Magazine Cover, Nature + Photography, Art Piece, etc.), as well as any other relevant contextual information that situates the image within a + broader category or intended use. For example: "This is a high-fashion editorial photograph intended for a magazine + spread" +10. `text_render`: (Array of Objects) List of a maximum of 5 most prominent text renders in the image. For each text + render, include: + * `text`: (String) The text content. + * `location`: (String) E.g., "center", "top-left", "bottom-right foreground". + * `size`: (String) E.g., "small", "medium", "large within frame". + * `color`: (String) E.g., "red", "blue", "green". + * `font`: (String) E.g., "realistic", "cartoonish", "minimalist". + * `appearance_details`: (String) Any other notable visual details. +Ensure the information within the JSON is accurate, detailed where specified, and avoids redundancy between fields. +""" + + system_prompts[ + "Generate" + ] = f"""You are a visionary and creative Visual Art Director at a leading Generative AI company. + +Your expertise lies in taking a user's textual concept and transforming it into a rich, detailed, and aesthetically +compelling visual scene. + +Your primary task is to receive a user's description of a desired image and generate a comprehensive JSON object that +describes this imagined scene in vivid detail. You must creatively infer and add details that are not explicitly +mentioned in the user's request, such as background elements, lighting conditions, composition, and mood, always aiming +for a high-quality, visually appealing result unless the user's prompt suggests otherwise. + +Adhere strictly to the following structure and guidelines: + +The output MUST be ONLY a valid JSON object. Do not include any text before or after the JSON object (e.g., no "Here is +the JSON:", no explanations, no apologies). + +IMPORTANT: When describing human body parts, positions, or actions, always describe them from the PERSON'S OWN +PERSPECTIVE, not from the observer's viewpoint. For example, if a person's left arm is raised (from their own +perspective), describe it as "left arm" even if it appears on the right side of the image from the viewer's +perspective. + +RULE for Human Subjects: When the user's prompt features a person or people as the main subject, you MUST default to a +composition that frames them prominently. Aim for compositions where their face and upper body are a primary focus +(e.g., 'medium shot', 'close-up'). Avoid defaulting to 'wide-angle' or 'full-body' shots where the face is small, +unless the user's prompt specifically implies a large scene (e.g., "a person standing on a mountain"). + +Unless the user's prompt explicitly requests a different style (e.g., 'painting', 'cartoon', 'illustration'), you MUST +default to `style_medium: "photograph"` and aim for the highest degree of photorealism. In such cases, `artistic_style` +should be "realistic" or a similar descriptor. + +The JSON object must contain the following keys precisely: + +{json_schema_full} + +Ensure the information within the JSON is detailed, creative, internally consistent, and avoids redundancy between +fields.""" + + system_prompts[ + "RefineA" + ] = f"""You are a Meticulous Visual Editor and Senior Art Director at a leading Generative AI company. + +Your expertise is in refining and modifying existing visual concepts based on precise feedback. + +Your primary task is to receive an existing JSON object that describes a visual scene, along with a textual instruction +for how to change it. You must then generate a new, updated JSON object that perfectly incorporates the requested +changes. + +Adhere strictly to the following structure and guidelines: + +1. **Input:** You will receive two pieces of information: an existing JSON object and a textual instruction. +2. **Output:** Your output MUST be ONLY a single, valid JSON object in the specified schema. Do not include any text + before or after the JSON object. +3. **Modification Logic:** + * Carefully parse the user's textual instruction to understand the desired changes. + * Modify ONLY the fields in the JSON that are directly or logically affected by the instruction. + * All other fields not relevant to the change must be copied exactly from the original JSON. Do not alter or omit + them. +4. **Holistic Consistency (IMPORTANT):** Changes in one field must be logically reflected in others. For example: + * If the instruction is to "change the background to a snowy forest," you must update the `background_setting` + field, and also update the `short_description` to mention the new setting. The `mood_atmosphere` might also need + to change to "serene" or "wintry." + * If the instruction is to "add the text 'WINTER SALE' at the top," you must add a new entry to the `text_render` + array. + * If the instruction is to "make the person smile," you must update the `expression` field for that object and + potentially update the overall `mood_atmosphere`. +5. **Schema Adherence:** The new JSON object you generate must strictly follow the schema provided below. + +The JSON object must contain the following keys precisely: + +{json_schema_full}""" + + system_prompts[ + "RefineB" + ] = f"""You are an advanced Multimodal Visual Specialist at a leading Generative AI company. + +Your unique expertise is in analyzing and editing visual concepts by processing an image, its corresponding JSON +metadata, and textual feedback simultaneously. + +Your primary task is to receive three inputs: an existing image, its descriptive JSON object, and a textual instruction +for a modification. You must use the image as the primary source of truth to understand the context of the requested +change and then generate a new, updated JSON object that accurately reflects that change. + +Adhere strictly to the following structure and guidelines: + +1. **Inputs:** You will receive an image, an existing JSON object, and a textual instruction. +2. **Visual Grounding (IMPORTANT):** The provided image is the ground truth. Use it to visually verify the contents of + the scene and to understand the context of the user's edit instruction. For example, if the instruction is "make + the car blue," visually locate the car in the image to inform your edits to the JSON. +3. **Output:** Your output MUST be ONLY a single, valid JSON object in the specified schema. Do not include any text + before or after the JSON object. +4. **Modification Logic:** + * Analyze the user's textual instruction in the context of what you see in the image. + * Modify ONLY the fields in the JSON that are directly or logically affected by the instruction. + * All other fields not relevant to the change must be copied exactly from the original JSON. +5. **Holistic Consistency:** Changes must be reflected logically across the JSON, consistent with a potential visual + change to the image. For instance, changing the lighting from 'daylight' to 'golden hour' should not only update + the `lighting` object but also the `mood_atmosphere`, `shadows`, and the `short_description`. +6. **Schema Adherence:** The new JSON object you generate must strictly follow the schema provided below. + +The JSON object must contain the following keys precisely: + +{json_schema_full}""" + + system_prompts[ + "InspireA" + ] = f"""You are a highly skilled Creative Director for Visual Adaptation at a leading Generative AI company. + +Your expertise lies in using an existing image as a visual reference to create entirely new scenes. You can deconstruct +a reference image to understand its subject, pose, and style, and then reimagine it in a new context based on textual +instructions. + +Your primary task is to receive a reference image and a textual instruction. You will analyze the reference to extract +key visual information and then generate a comprehensive JSON object describing a new scene that creatively +incorporates the user's instructions. + +Adhere strictly to the following structure and guidelines: + +1. **Inputs:** You will receive a reference image and a textual instruction. You will NOT receive a starting JSON. +2. **Core Logic (Analyze and Synthesize):** + * **Analyze:** First, deeply analyze the provided reference image. Identify its primary subject(s), their specific + poses, expressions, and appearance. Also note the overall composition, lighting style, and artistic medium. + * **Synthesize:** Next, interpret the textual instruction to understand what elements to keep from the reference + and what to change. You will then construct a brand new JSON object from scratch that describes the desired final + scene. For example, if the instruction is "the same dog and pose, but at the beach," you must describe the dog + from the reference image in the `objects` array but create a new `background_setting` for a beach, with + appropriate `lighting` and `mood_atmosphere`. +3. **Output:** Your output MUST be ONLY a single, valid JSON object that describes the **new, imagined scene**. Do not + describe the original reference image. +4. **Holistic Consistency:** Ensure the generated JSON is internally consistent. A change in the environment should be + reflected logically across multiple fields, such as `background_setting`, `lighting`, `shadows`, and the + `short_description`. +5. **Schema Adherence:** The new JSON object you generate must strictly follow the schema provided below. + +The JSON object must contain the following keys precisely: + +{json_schema_full}""" + + system_prompts["InspireB"] = system_prompts["Caption"] + + final_prompts = {} + + final_prompts["Generate"] = ( + "Generate a detailed JSON object, adhering to the expected schema, for an imagined scene based on the following request: {user_prompt}." + ) + + final_prompts["RefineA"] = """ + [EXISTING JSON]: {json_data} + + [EDIT INSTRUCTIONS]: {user_prompt} + + [TASK]: Generate the new, updated JSON object that incorporates the edit instructions. Follow all system rules + for modification, consistency, and formatting. + """ + + final_prompts["RefineB"] = """ + [EXISTING JSON]: {json_data} + + [EDIT INSTRUCTIONS]: {user_prompt} + + [TASK]: Analyze the provided image and its contextual JSON. Then, generate the new, updated JSON object that + incorporates the edit instructions. Follow all your system rules for visual analysis, modification, and + consistency. + """ + + final_prompts["InspireA"] = """ + [EDIT INSTRUCTIONS]: {user_prompt} + + [TASK]: Use the provided image as a visual reference only. Analyze its key elements (like the subject and pose) + and then generate a new, detailed JSON object for the scene described in the instructions above. Do not + describe the reference image itself; describe the new scene. Follow all of your system rules. + """ + + final_prompts["Caption"] = ( + "Analyze the provided image and generate the detailed JSON object as specified in your instructions." + ) + final_prompts["InspireB"] = final_prompts["Caption"] + + return system_prompts.get(mode, ""), final_prompts.get(mode, "") + + +def keep(p, k, v): + is_none = v is None + is_empty_string = isinstance(v, str) and v == "" + is_empty_dict = isinstance(v, dict) and not v + is_empty_list = isinstance(v, list) and not v + is_nan = isinstance(v, float) and math.isnan(v) + if is_none or is_empty_string or is_empty_list or is_empty_dict: + return False + if is_nan: + return False + return True + + +def validate_json(json_data: dict) -> dict: + ia = ImageAnalysis.model_validate_json(json_data, strict=True) + return ia.model_dump(exclude_none=True) + + +def validate_structured_prompt_str(structured_prompt_str: str) -> str: + ia = ImageAnalysis.model_validate_json(structured_prompt_str, strict=True) + c = ia.model_dump(exclude_none=True) + return json.dumps(c) + + +def prepare_clean_caption(json_dump: dict) -> str: + # filter empty values recursivly (i.e. None, "", {}, [], float("nan")) + clean_caption_dict = remap(json_dump, visit=keep) + + scores = {"preference_score": "very high", "aesthetic_score": "very high"} + # Set aesthetics scores + if "aesthetics" not in clean_caption_dict: + clean_caption_dict["aesthetics"] = scores + else: + clean_caption_dict["aesthetics"].update(scores) + + # Dumps clean structured caption as minimal json string (i.e. no newlines\whitespaces seps) + clean_caption_str = json.dumps(clean_caption_dict) + return clean_caption_str + + +# resize an input image to have a specific number of pixels (1,048,576 or 1024×1024) +# while maintaining a certain aspect ratio and granularity (output width and height must be multiples of this number). +def resize_image_by_num_pixels( + image: Image.Image, pixel_number: int = 1048576, granularity_val: int = 64, target_ratio: float = 0.0 +) -> Image.Image: + if target_ratio != 0.0: + ratio = target_ratio + else: + ratio = image.size[0] / image.size[1] + width = int((pixel_number * ratio) ** 0.5) + width = width - (width % granularity_val) + height = int(pixel_number / width) + height = height - (height % granularity_val) + return image.resize((width, height)) + + +def infer_with_gemini( + client: genai.Client, + final_prompt: str, + system_prompt: str, + top_p: float, + temperature: float, + max_tokens: int, + seed: int = 42, + image: Optional[Image.Image] = None, + model: str = "gemini-2.5-flash", +) -> str: + """ + Calls Gemini API with the given prompt and returns the raw JSON response. + + Args: + final_prompt: The text prompt to send to Gemini + system_prompt: The system instruction for Gemini + existing_image_path: Optional path to an image file to include + model: The Gemini model to use + + Returns: + Raw JSON response text from Gemini + """ + parts = [{"text": final_prompt}] + if image: + # Save image into bytes + image = image.convert("RGB") # the model can't produce rgba so sending them as input has no effect + less_then = 262144 + if image.size[0] * image.size[1] > less_then: + image = resize_image_by_num_pixels( + image, pixel_number=less_then, granularity_val=1, target_ratio=0.0 + ) # 512x512 + buffer = io.BytesIO() + image.save(buffer, format="JPEG") + image_bytes = buffer.getvalue() + + img_part = { + "inlineData": { + "data": image_bytes, + "mimeType": "image/jpeg", + } + } + parts.append(img_part) + + contents = [{"role": "user", "parts": parts}] + + generationConfig = { + "temperature": temperature, + "topP": top_p, + "maxOutputTokens": max_tokens, + "response_mime_type": "application/json", + "response_schema": get_gemini_output_schema(), + "system_instruction": system_prompt, # len 5900 + "thinkingConfig": {"thinkingBudget": 0}, + "seed": seed, + } + + response = client.models.generate_content( + model=model, + contents=contents, + config=generationConfig, + ) + + if response.candidates[0].finish_reason == "MAX_TOKENS": + raise Exception("Max tokens") + + return response.candidates[0].content.parts[0].text + + +def get_default_negative_prompt(existing_json: dict) -> str: + negative_prompt = "" + style_medium = existing_json.get("style_medium", "").lower() + if style_medium in ["photograph", "photography", "photo"]: + negative_prompt = """{'style_medium': 'digital illustration', 'artistic_style': 'non-realistic'}""" + return negative_prompt + + +def json_promptify( + client: genai.Client, + model_id: str, + top_p: float, + temperature: float, + max_tokens: int, + user_prompt: Optional[str] = None, + existing_json: Optional[str] = None, + image: Optional[Image.Image] = None, + seed: int = 42, +) -> str: + if existing_json: + # make sure aesthetic scores are not in the existing json (will be added later) + existing_json = json.loads(existing_json) + if "aesthetics" in existing_json: + existing_json["aesthetics"].pop("aesthetic_score", None) + existing_json["aesthetics"].pop("preference_score", None) + existing_json = json.dumps(existing_json) + + if not user_prompt: + raise ValueError("user_prompt is required if existing_json is provided") + + if image: + mode = "RefineB" + system_prompt, final_prompt = get_instructions(mode) + final_prompt = final_prompt.format(user_prompt=user_prompt, json_data=existing_json) + + else: + mode = "RefineA" + system_prompt, final_prompt = get_instructions(mode) + final_prompt = final_prompt.format(user_prompt=user_prompt, json_data=existing_json) + elif image and user_prompt: + mode = "InspireA" + system_prompt, final_prompt = get_instructions(mode) + final_prompt = final_prompt.format(user_prompt=user_prompt) + elif image and not user_prompt: + mode = "Caption" + system_prompt, final_prompt = get_instructions(mode) + else: + mode = "Generate" + system_prompt, final_prompt = get_instructions(mode) + final_prompt = final_prompt.format(user_prompt=user_prompt) + + json_data = infer_with_gemini( + client=client, + model=model_id, + final_prompt=final_prompt, + system_prompt=system_prompt, + seed=seed, + image=image, + top_p=top_p, + temperature=temperature, + max_tokens=max_tokens, + ) + json_data = validate_json(json_data) + clean_caption = prepare_clean_caption(json_data) + + return clean_caption + + +class BriaFiboGeminiPromptToJson(ModularPipelineBlocks): + model_name = "BriaFibo" + + def __init__(self, model_id="gemini-2.5-flash"): + super().__init__() + api_key = os.getenv("GOOGLE_API_KEY") + if api_key is None: + raise ValueError("Must provide an API key for Gemini through the `GOOGLE_API_KEY` env variable.") + self.model_id = model_id + + @property + def expected_components(self): + return [] + + @property + def inputs(self) -> List[InputParam]: + task_input = InputParam("task", type_hint=str, required=False, description="VLM Task to execute") + prompt_input = InputParam( + "prompt", + type_hint=str, + required=False, + description="Prompt to use", + ) + image_input = InputParam( + name="image", type_hint=Image.Image, required=False, description="image for inspiration mode" + ) + json_prompt_input = InputParam( + name="json_prompt", type_hint=str, required=False, description="JSON prompt to use" + ) + sampling_top_p_input = InputParam( + name="sampling_top_p", type_hint=float, required=False, description="Sampling top p", default=1.0 + ) + sampling_temperature_input = InputParam( + name="sampling_temperature", + type_hint=float, + required=False, + description="Sampling temperature", + default=0.2, + ) + sampling_max_tokens_input = InputParam( + name="sampling_max_tokens", type_hint=int, required=False, description="Sampling max tokens", default=3000 + ) + return [ + task_input, + prompt_input, + image_input, + json_prompt_input, + sampling_top_p_input, + sampling_temperature_input, + sampling_max_tokens_input, + ] + + @property + def intermediate_inputs(self) -> List[InputParam]: + return [] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + "json_prompt", + type_hint=str, + description="JSON prompt by the VLM", + ) + ] + + def __call__(self, components, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + prompt = block_state.prompt + image = block_state.image + json_prompt = block_state.json_prompt + client = genai.Client() + json_prompt = json_promptify( + client=client, + model_id=self.model_id, + top_p=block_state.sampling_top_p, + temperature=block_state.sampling_temperature, + max_tokens=block_state.sampling_max_tokens, + user_prompt=prompt, + existing_json=json_prompt, + image=image, + ) + block_state.json_prompt = json_prompt + self.set_block_state(state, block_state) + + return components, state diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 65e8b2469ebf..1b96214196d5 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -128,6 +128,7 @@ "AnimateDiffVideoToVideoControlNetPipeline", ] _import_structure["bria"] = ["BriaPipeline"] + _import_structure["bria_fibo"] = ["BriaFiboPipeline"] _import_structure["flux"] = [ "FluxControlPipeline", "FluxControlInpaintPipeline", @@ -562,6 +563,7 @@ from .aura_flow import AuraFlowPipeline from .blip_diffusion import BlipDiffusionPipeline from .bria import BriaPipeline + from .bria_fibo import BriaFiboPipelin from .chroma import ChromaImg2ImgPipeline, ChromaPipeline from .cogvideo import ( CogVideoXFunControlPipeline, diff --git a/src/diffusers/pipelines/bria_fibo/__init__.py b/src/diffusers/pipelines/bria_fibo/__init__.py new file mode 100644 index 000000000000..206a463b394b --- /dev/null +++ b/src/diffusers/pipelines/bria_fibo/__init__.py @@ -0,0 +1,48 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_bria_fibo"] = ["BriaFiboPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_bria_fibo import BriaFiboPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo.py b/src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo.py new file mode 100644 index 000000000000..690b54607d2c --- /dev/null +++ b/src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo.py @@ -0,0 +1,826 @@ +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import torch +from transformers import AutoTokenizer +from transformers.models.smollm3.modeling_smollm3 import SmolLM3ForCausalLM + +from ...image_processor import VaeImageProcessor +from ...loaders import FluxLoraLoaderMixin +from ...models.autoencoders.autoencoder_kl_wan import AutoencoderKLWan +from ...models.transformers.transformer_bria_fibo import BriaFiboTransformer2DModel +from ...pipelines.bria_fibo.pipeline_output import BriaFiboPipelineOutput +from ...pipelines.flux.pipeline_flux import calculate_shift, retrieve_timesteps +from ...pipelines.pipeline_utils import DiffusionPipeline +from ...schedulers import FlowMatchEulerDiscreteScheduler, KarrasDiffusionSchedulers +from ...utils import ( + USE_PEFT_BACKEND, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + ... +""" + + +class BriaFiboPipeline(DiffusionPipeline): + r""" + Args: + transformer ([`GaiaTransformer2DModel`]): + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`SmolLM3ForCausalLM`]): + tokenizer (`AutoTokenizer`): + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + transformer: BriaFiboTransformer2DModel, + scheduler: Union[FlowMatchEulerDiscreteScheduler, KarrasDiffusionSchedulers], + vae: AutoencoderKLWan, + text_encoder: SmolLM3ForCausalLM, + tokenizer: AutoTokenizer, + ): + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + ) + + self.vae_scale_factor = 16 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.default_sample_size = 64 + + def get_prompt_embeds( + self, + prompt: Union[str, List[str]], + num_images_per_prompt: int = 1, + max_sequence_length: int = 2048, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + if not prompt: + raise ValueError("`prompt` must be a non-empty string or list of strings.") + + batch_size = len(prompt) + bot_token_id = 128000 + + text_encoder_device = device if device is not None else torch.device("cpu") + if not isinstance(text_encoder_device, torch.device): + text_encoder_device = torch.device(text_encoder_device) + + if all(p == "" for p in prompt): + input_ids = torch.full((batch_size, 1), bot_token_id, dtype=torch.long, device=text_encoder_device) + attention_mask = torch.ones_like(input_ids) + else: + tokenized = self.tokenizer( + prompt, + padding="longest", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + input_ids = tokenized.input_ids.to(text_encoder_device) + attention_mask = tokenized.attention_mask.to(text_encoder_device) + + if any(p == "" for p in prompt): + empty_rows = torch.tensor([p == "" for p in prompt], dtype=torch.bool, device=text_encoder_device) + input_ids[empty_rows] = bot_token_id + attention_mask[empty_rows] = 1 + + encoder_outputs = self.text_encoder( + input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + ) + hidden_states = encoder_outputs.hidden_states + + prompt_embeds = torch.cat([hidden_states[-1], hidden_states[-2]], dim=-1) + prompt_embeds = prompt_embeds.to(device=device, dtype=dtype) + + prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) + hidden_states = tuple( + layer.repeat_interleave(num_images_per_prompt, dim=0).to(device=device) for layer in hidden_states + ) + attention_mask = attention_mask.repeat_interleave(num_images_per_prompt, dim=0).to(device=device) + + return prompt_embeds, hidden_states, attention_mask + + @staticmethod + def pad_embedding(prompt_embeds, max_tokens, attention_mask=None): + # Pad embeddings to `max_tokens` while preserving the mask of real tokens. + batch_size, seq_len, dim = prompt_embeds.shape + + if attention_mask is None: + attention_mask = torch.ones((batch_size, seq_len), dtype=prompt_embeds.dtype, device=prompt_embeds.device) + else: + attention_mask = attention_mask.to(device=prompt_embeds.device, dtype=prompt_embeds.dtype) + + if max_tokens < seq_len: + raise ValueError("`max_tokens` must be greater or equal to the current sequence length.") + + if max_tokens > seq_len: + pad_length = max_tokens - seq_len + padding = torch.zeros( + (batch_size, pad_length, dim), dtype=prompt_embeds.dtype, device=prompt_embeds.device + ) + prompt_embeds = torch.cat([prompt_embeds, padding], dim=1) + + mask_padding = torch.zeros( + (batch_size, pad_length), dtype=prompt_embeds.dtype, device=prompt_embeds.device + ) + attention_mask = torch.cat([attention_mask, mask_padding], dim=1) + + return prompt_embeds, attention_mask + + def encode_prompt( + self, + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[Union[str, List[str]]] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + max_sequence_length: int = 3000, + lora_scale: Optional[float] = None, + ): + r""" + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + prompt_attention_mask = None + negative_prompt_attention_mask = None + if prompt_embeds is None: + prompt_embeds, prompt_layers, prompt_attention_mask = self.get_prompt_embeds( + prompt=prompt, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + prompt_embeds = prompt_embeds.to(dtype=self.transformer.dtype) + prompt_layers = [tensor.to(dtype=self.transformer.dtype) for tensor in prompt_layers] + + if do_classifier_free_guidance: + if isinstance(negative_prompt, list) and negative_prompt[0] is None: + negative_prompt = "" + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds, negative_prompt_layers, negative_prompt_attention_mask = self.get_prompt_embeds( + prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.transformer.dtype) + negative_prompt_layers = [tensor.to(dtype=self.transformer.dtype) for tensor in negative_prompt_layers] + + if self.text_encoder is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + # Pad to longest + if prompt_attention_mask is not None: + prompt_attention_mask = prompt_attention_mask.to(device=prompt_embeds.device, dtype=prompt_embeds.dtype) + + if negative_prompt_embeds is not None: + if negative_prompt_attention_mask is not None: + negative_prompt_attention_mask = negative_prompt_attention_mask.to( + device=negative_prompt_embeds.device, dtype=negative_prompt_embeds.dtype + ) + max_tokens = max(negative_prompt_embeds.shape[1], prompt_embeds.shape[1]) + + prompt_embeds, prompt_attention_mask = self.pad_embedding( + prompt_embeds, max_tokens, attention_mask=prompt_attention_mask + ) + prompt_layers = [self.pad_embedding(layer, max_tokens)[0] for layer in prompt_layers] + + negative_prompt_embeds, negative_prompt_attention_mask = self.pad_embedding( + negative_prompt_embeds, max_tokens, attention_mask=negative_prompt_attention_mask + ) + negative_prompt_layers = [self.pad_embedding(layer, max_tokens)[0] for layer in negative_prompt_layers] + else: + max_tokens = prompt_embeds.shape[1] + prompt_embeds, prompt_attention_mask = self.pad_embedding( + prompt_embeds, max_tokens, attention_mask=prompt_attention_mask + ) + negative_prompt_layers = None + + dtype = self.text_encoder.dtype + text_ids = torch.zeros(prompt_embeds.shape[0], max_tokens, 3).to(device=device, dtype=dtype) + + return ( + prompt_embeds, + negative_prompt_embeds, + text_ids, + prompt_attention_mask, + negative_prompt_attention_mask, + prompt_layers, + negative_prompt_layers, + ) + + @property + def guidance_scale(self): + return self._guidance_scale + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @staticmethod + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + height = height // vae_scale_factor + width = width // vae_scale_factor + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), height, width) + + return latents + + @staticmethod + def _prepare_latent_image_ids(batch_size, height, width, device, dtype): + latent_image_ids = torch.zeros(height, width, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape + + latent_image_ids = latent_image_ids.reshape( + latent_image_id_height * latent_image_id_width, latent_image_id_channels + ) + + return latent_image_ids.to(device=device, dtype=dtype) + + @staticmethod + def _unpack_latents_no_patch(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + height = height // vae_scale_factor + width = width // vae_scale_factor + + latents = latents.view(batch_size, height, width, channels) + latents = latents.permute(0, 3, 1, 2) + + return latents + + @staticmethod + def _pack_latents_no_patch(latents, batch_size, num_channels_latents, height, width): + latents = latents.permute(0, 2, 3, 1) + latents = latents.reshape(batch_size, height * width, num_channels_latents) + return latents + + @staticmethod + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + do_patching=False, + ): + height = int(height) // self.vae_scale_factor + width = int(width) // self.vae_scale_factor + + shape = (batch_size, num_channels_latents, height, width) + + if latents is not None: + latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) + return latents.to(device=device, dtype=dtype), latent_image_ids + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + if do_patching: + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + else: + latents = self._pack_latents_no_patch(latents, batch_size, num_channels_latents, height, width) + latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) + + return latents, latent_image_ids + + @staticmethod + def init_inference_scheduler(height, width, device, image_seq_len, num_inference_steps=1000, noise_scheduler=None): + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + + assert height % 16 == 0 and width % 16 == 0 + + mu = calculate_shift( + image_seq_len, + noise_scheduler.config.base_image_seq_len, + noise_scheduler.config.max_image_seq_len, + noise_scheduler.config.base_shift, + noise_scheduler.config.max_shift, + ) + + # Init sigmas and timesteps according to shift size + # This changes the scheduler in-place according to the dynamic scheduling + timesteps, num_inference_steps = retrieve_timesteps( + noise_scheduler, + num_inference_steps=num_inference_steps, + device=device, + timesteps=None, + sigmas=sigmas, + mu=mu, + ) + + return noise_scheduler, timesteps, num_inference_steps, mu + + @staticmethod + def create_attention_matrix(attention_mask): + attention_matrix = torch.einsum("bi,bj->bij", attention_mask, attention_mask) + + # convert to 0 - keep, -inf ignore + attention_matrix = torch.where( + attention_matrix == 1, 0.0, -torch.inf + ) # Apply -inf to ignored tokens for nulling softmax score + return attention_matrix + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 30, + timesteps: List[int] = None, + guidance_scale: float = 5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 3000, + do_patching=False, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 5.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`. + Examples: + Returns: + [`~pipelines.flux.BriaFiboPipelineOutput`] or `tuple`: [`~pipelines.flux.BriaFiboPipelineOutput`] if + `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the + generated images. + """ + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt=prompt, + height=height, + width=width, + prompt_embeds=prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + lora_scale = ( + self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None + ) + + ( + prompt_embeds, + negative_prompt_embeds, + text_ids, + prompt_attention_mask, + negative_prompt_attention_mask, + prompt_layers, + negative_prompt_layers, + ) = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + device=device, + max_sequence_length=max_sequence_length, + num_images_per_prompt=num_images_per_prompt, + lora_scale=lora_scale, + ) + prompt_batch_size = prompt_embeds.shape[0] + + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + prompt_layers = [ + torch.cat([negative_prompt_layers[i], prompt_layers[i]], dim=0) for i in range(len(prompt_layers)) + ] + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) + + total_num_layers_transformer = len(self.transformer.transformer_blocks) + len( + self.transformer.single_transformer_blocks + ) + if len(prompt_layers) >= total_num_layers_transformer: + # remove first layers + prompt_layers = prompt_layers[len(prompt_layers) - total_num_layers_transformer :] + else: + # duplicate last layer + prompt_layers = prompt_layers + [prompt_layers[-1]] * (total_num_layers_transformer - len(prompt_layers)) + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + if do_patching: + num_channels_latents = int(num_channels_latents / 4) + + latents, latent_image_ids = self.prepare_latents( + prompt_batch_size, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + do_patching, + ) + + latent_attention_mask = torch.ones( + [latents.shape[0], latents.shape[1]], dtype=latents.dtype, device=latents.device + ) + if self.do_classifier_free_guidance: + latent_attention_mask = latent_attention_mask.repeat(2, 1) + + attention_mask = torch.cat([prompt_attention_mask, latent_attention_mask], dim=1) + attention_mask = self.create_attention_matrix(attention_mask) # batch, seq => batch, seq, seq + attention_mask = attention_mask.unsqueeze(dim=1).to(dtype=self.transformer.dtype) # for head broadcasting + + if self._joint_attention_kwargs is None: + self._joint_attention_kwargs = {} + self._joint_attention_kwargs["attention_mask"] = attention_mask + + # Adapt scheduler to dynamic shifting (resolution dependent) + + if do_patching: + seq_len = (height // (self.vae_scale_factor * 2)) * (width // (self.vae_scale_factor * 2)) + else: + seq_len = (height // self.vae_scale_factor) * (width // self.vae_scale_factor) + + self.noise_scheduler, timesteps, num_inference_steps, mu = self.init_inference_scheduler( + height=height, + width=width, + device=device, + num_inference_steps=num_inference_steps, + noise_scheduler=self.scheduler, + image_seq_len=seq_len, + ) + + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # Support old different diffusers versions + if len(latent_image_ids.shape) == 3: + latent_image_ids = latent_image_ids[0] + + if len(text_ids.shape) == 3: + text_ids = text_ids[0] + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + + if type(self.scheduler) != FlowMatchEulerDiscreteScheduler: + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]).to( + device=latent_model_input.device, dtype=latent_model_input.dtype + ) + + # This is predicts "v" from flow-matching or eps from diffusion + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + text_encoder_layers=prompt_layers, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + txt_ids=text_ids, + img_ids=latent_image_ids, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + image = latents + + else: + if do_patching: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + else: + latents = self._unpack_latents_no_patch(latents, height, width, self.vae_scale_factor) + + latents = latents.to(dtype=self.vae.dtype) + latents = latents.unsqueeze(dim=2) + latents = list(torch.unbind(latents, dim=0)) + latents_device = latents[0].device + latents_dtype = latents[0].dtype + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents_device, latents_dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents_device, latents_dtype + ) + latents_scaled = [latent / latents_std + latents_mean for latent in latents] + latents_scaled = torch.cat(latents_scaled, dim=0) + image = [] + for scaled_latent in latents_scaled: + curr_image = self.vae.decode(scaled_latent.unsqueeze(0), return_dict=False)[0] + curr_image = self.image_processor.postprocess(curr_image.squeeze(dim=2), output_type=output_type) + image.append(curr_image) + if len(image) == 1: + image = image[0] + else: + image = np.stack(image, axis=0) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return BriaFiboPipelineOutput(images=image) + + def check_inputs( + self, + prompt, + height, + width, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if max_sequence_length is not None and max_sequence_length > 3000: + raise ValueError(f"`max_sequence_length` cannot be greater than 3000 but is {max_sequence_length}") + + def to(self, *args, **kwargs): + DiffusionPipeline.to(self, *args, **kwargs) + # We use as float32 since wan22 in their repo use it like this + self.vae.to(dtype=torch.float32) + return self diff --git a/src/diffusers/pipelines/bria_fibo/pipeline_output.py b/src/diffusers/pipelines/bria_fibo/pipeline_output.py new file mode 100644 index 000000000000..f459185a2c7c --- /dev/null +++ b/src/diffusers/pipelines/bria_fibo/pipeline_output.py @@ -0,0 +1,21 @@ +from dataclasses import dataclass +from typing import List, Union + +import numpy as np +import PIL.Image + +from ...utils import BaseOutput + + +@dataclass +class BriaFiboPipelineOutput(BaseOutput): + """ + Output class for BriaFibo pipelines. + + Args: + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + """ + + images: Union[List[PIL.Image.Image], np.ndarray] diff --git a/tests/models/transformers/test_models_transformer_bria_fibo.py b/tests/models/transformers/test_models_transformer_bria_fibo.py new file mode 100644 index 000000000000..3df75ec15afa --- /dev/null +++ b/tests/models/transformers/test_models_transformer_bria_fibo.py @@ -0,0 +1,132 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from diffusers import BriaFiboTransformer2DModel +from ...testing_utils import enable_full_determinism, torch_device +from ..test_modeling_common import LoraHotSwappingForModelTesterMixin, ModelTesterMixin, TorchCompileTesterMixin + + +enable_full_determinism() + + + + +class BriaFiboTransformerTests(ModelTesterMixin, unittest.TestCase): + model_class = BriaFiboTransformer2DModel + main_input_name = "hidden_states" + # We override the items here because the transformer under consideration is small. + model_split_percents = [0.8, 0.7, 0.7] + + # Skip setting testing with default: AttnProcessor + uses_custom_attn_processor = True + + @property + def dummy_input(self): + batch_size = 1 + num_latent_channels = 48 + num_image_channels = 3 + height = width = 16 + sequence_length = 32 + embedding_dim = 64 + + hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device) + encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device) + text_ids = torch.randn((sequence_length, num_image_channels)).to(torch_device) + image_ids = torch.randn((height * width, num_image_channels)).to(torch_device) + timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size) + + return { + "hidden_states": hidden_states, + "encoder_hidden_states": encoder_hidden_states, + "img_ids": image_ids, + "txt_ids": text_ids, + "timestep": timestep, + "text_encoder_layers": [encoder_hidden_states[:,:,:32], encoder_hidden_states[:,:,:32]], + } + + @property + def input_shape(self): + return (16, 16) + + @property + def output_shape(self): + return (256, 48) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "patch_size": 1, + "in_channels": 48, + "num_layers": 1, + "num_single_layers": 1, + "attention_head_dim": 8, + "num_attention_heads": 2, + "joint_attention_dim": 64, + "text_encoder_dim": 32, + "pooled_projection_dim": None, + "axes_dims_rope": [0, 4, 4], + } + + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def test_deprecated_inputs_img_txt_ids_3d(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + output_1 = model(**inputs_dict)[0] + + # update inputs_dict with txt_ids and img_ids as 3d tensors (deprecated) + text_ids_3d = inputs_dict["txt_ids"].unsqueeze(0) + image_ids_3d = inputs_dict["img_ids"].unsqueeze(0) + + assert text_ids_3d.ndim == 3, "text_ids_3d should be a 3d tensor" + assert image_ids_3d.ndim == 3, "img_ids_3d should be a 3d tensor" + + inputs_dict["txt_ids"] = text_ids_3d + inputs_dict["img_ids"] = image_ids_3d + + with torch.no_grad(): + output_2 = model(**inputs_dict)[0] + + self.assertEqual(output_1.shape, output_2.shape) + self.assertTrue( + torch.allclose(output_1, output_2, atol=1e-5), + msg="output with deprecated inputs (img_ids and txt_ids as 3d torch tensors) are not equal as them as 2d inputs", + ) + + def test_gradient_checkpointing_is_applied(self): + expected_set = {"BriaFiboTransformer2DModel"} + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) + + +class BriaFiboTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase): + model_class = BriaFiboTransformer2DModel + + def prepare_init_args_and_inputs_for_common(self): + return BriaFiboTransformerTests().prepare_init_args_and_inputs_for_common() + + +class BriaFiboTransformerLoRAHotSwapTests(LoraHotSwappingForModelTesterMixin, unittest.TestCase): + model_class = BriaFiboTransformer2DModel + + def prepare_init_args_and_inputs_for_common(self): + return BriaFiboTransformerTests().prepare_init_args_and_inputs_for_common() diff --git a/tests/pipelines/bria_fibo/__init__.py b/tests/pipelines/bria_fibo/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/pipelines/bria_fibo/test_pipeline_bria_fibo.py b/tests/pipelines/bria_fibo/test_pipeline_bria_fibo.py new file mode 100644 index 000000000000..219dea0ba7a2 --- /dev/null +++ b/tests/pipelines/bria_fibo/test_pipeline_bria_fibo.py @@ -0,0 +1,198 @@ +# Copyright 2024 Bria AI and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import tempfile +import unittest + +import numpy as np +import torch +from transformers import AutoTokenizer +from transformers.models.smollm3.modeling_smollm3 import SmolLM3Config, SmolLM3ForCausalLM + +from diffusers import ( + AutoencoderKLWan, + BriaFiboPipeline, + FlowMatchEulerDiscreteScheduler, +) +from diffusers.models.transformers.transformer_bria_fibo import BriaFiboTransformer2DModel +from tests.pipelines.test_pipelines_common import PipelineTesterMixin, to_np + +from ...testing_utils import ( + backend_empty_cache, + enable_full_determinism, + require_torch_accelerator, + slow, + torch_device, +) + + +enable_full_determinism() + + +class BriaFiboPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = BriaFiboPipeline + params = frozenset(["prompt", "height", "width", "guidance_scale"]) + batch_params = frozenset(["prompt"]) + test_xformers_attention = False + test_layerwise_casting = False + test_group_offloading = False + + def get_dummy_components(self): + torch.manual_seed(0) + transformer = BriaFiboTransformer2DModel( + patch_size=1, + in_channels=16, + num_layers=1, + num_single_layers=1, + attention_head_dim=8, + num_attention_heads=2, + joint_attention_dim=64, + text_encoder_dim=32, + pooled_projection_dim=None, + axes_dims_rope=[0, 4, 4], + ) + + torch.manual_seed(0) + vae = AutoencoderKLWan( + base_dim=160, + decoder_base_dim=256, + num_res_blocks=2, + out_channels=12, + patch_size=2, + scale_factor_spatial=16, + scale_factor_temporal=4, + temperal_downsample=[False, True, True], + z_dim=16, + ) + + scheduler = FlowMatchEulerDiscreteScheduler() + + torch.manual_seed(0) + text_encoder = SmolLM3ForCausalLM(SmolLM3Config(hidden_size=32)) + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") + + components = { + "scheduler": scheduler, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "transformer": transformer, + "vae": vae, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device="cpu").manual_seed(seed) + + inputs = { + "prompt": "{'text': 'A painting of a squirrel eating a burger'}", + "negative_prompt": "bad, ugly", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 5.0, + "height": 32, + "width": 32, + "output_type": "np", + } + return inputs + + def test_encode_prompt_works_in_isolation(self): + pass + + def test_bria_fibo_different_prompts(self): + pipe = self.pipeline_class(**self.get_dummy_components()) + pipe = pipe.to(torch_device) + inputs = self.get_dummy_inputs(torch_device) + output_same_prompt = pipe(**inputs).images[0] + + inputs = self.get_dummy_inputs(torch_device) + inputs["prompt"] = "a different prompt" + output_different_prompts = pipe(**inputs).images[0] + + max_diff = np.abs(output_same_prompt - output_different_prompts).max() + assert max_diff > 1e-6 + + def test_image_output_shape(self): + pipe = self.pipeline_class(**self.get_dummy_components()) + pipe = pipe.to(torch_device) + inputs = self.get_dummy_inputs(torch_device) + + height_width_pairs = [(32, 32)] + for height, width in height_width_pairs: + expected_height = height + expected_width = width + + inputs.update({"height": height, "width": width}) + image = pipe(**inputs).images[0] + output_height, output_width, _ = image.shape + assert (output_height, output_width) == (expected_height, expected_width) + + @unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU") + @require_torch_accelerator + def test_save_load_float16(self, expected_max_diff=1e-2): + components = self.get_dummy_components() + for name, module in components.items(): + if hasattr(module, "half"): + components[name] = module.to(torch_device).half() + + pipe = self.pipeline_class(**components) + for component in pipe.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(torch_device) + output = pipe(**inputs)[0] + + with tempfile.TemporaryDirectory() as tmpdir: + pipe.save_pretrained(tmpdir) + pipe_loaded = self.pipeline_class.from_pretrained(tmpdir, torch_dtype=torch.float16) + for component in pipe_loaded.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + pipe_loaded.to(torch_device) + pipe_loaded.set_progress_bar_config(disable=None) + + for name, component in pipe_loaded.components.items(): + if name == "vae": + continue + if hasattr(component, "dtype"): + self.assertTrue( + component.dtype == torch.float16, + f"`{name}.dtype` switched from `float16` to {component.dtype} after loading.", + ) + + inputs = self.get_dummy_inputs(torch_device) + output_loaded = pipe_loaded(**inputs)[0] + max_diff = np.abs(to_np(output) - to_np(output_loaded)).max() + self.assertLess( + max_diff, expected_max_diff, "The output of the fp16 pipeline changed after saving and loading." + ) + + def test_to_dtype(self): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.set_progress_bar_config(disable=None) + + model_dtypes = [component.dtype for component in components.values() if hasattr(component, "dtype")] + self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes)) + + def test_save_load_dduf(self): + pass + + From 371e5f511ef1edb1a57076c8a9caf250dcfb4e46 Mon Sep 17 00:00:00 2001 From: Gal Davidi Date: Sun, 26 Oct 2025 16:46:42 +0000 Subject: [PATCH 02/15] style fixs --- src/diffusers/__init__.py | 10 +++++----- src/diffusers/models/__init__.py | 2 +- src/diffusers/modular_pipelines/bria_fibo/__init__.py | 4 ++-- .../bria_fibo/fibo_vlm_prompt_to_json.py | 8 ++------ .../transformers/test_models_transformer_bria_fibo.py | 5 ++--- tests/pipelines/bria_fibo/test_pipeline_bria_fibo.py | 5 ----- 6 files changed, 12 insertions(+), 22 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index e0e31f58f5d2..efce36c0d4b8 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -198,8 +198,8 @@ "AutoencoderOobleck", "AutoencoderTiny", "AutoModel", - "BriaTransformer2DModel", "BriaFiboTransformer2DModel", + "BriaTransformer2DModel", "CacheMixin", "ChromaTransformer2DModel", "CogVideoXTransformer3DModel", @@ -393,8 +393,8 @@ else: _import_structure["modular_pipelines"].extend( [ - "BriaFiboVLMPromptToJson", "BriaFiboGeminiPromptToJson", + "BriaFiboVLMPromptToJson", "FluxAutoBlocks", "FluxKontextAutoBlocks", "FluxKontextModularPipeline", @@ -433,8 +433,8 @@ "AuraFlowPipeline", "BlipDiffusionControlNetPipeline", "BlipDiffusionPipeline", - "BriaPipeline", "BriaFiboPipeline", + "BriaPipeline", "ChromaImg2ImgPipeline", "ChromaPipeline", "CLIPImageProjection", @@ -905,8 +905,8 @@ AutoencoderOobleck, AutoencoderTiny, AutoModel, - BriaTransformer2DModel, BriaFiboTransformer2DModel, + BriaTransformer2DModel, CacheMixin, ChromaTransformer2DModel, CogVideoXTransformer3DModel, @@ -1108,8 +1108,8 @@ AudioLDM2UNet2DConditionModel, AudioLDMPipeline, AuraFlowPipeline, - BriaPipeline, BriaFiboPipeline, + BriaPipeline, ChromaImg2ImgPipeline, ChromaPipeline, CLIPImageProjection, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index d25ef37b7cdf..e3b297464143 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -175,8 +175,8 @@ from .transformers import ( AllegroTransformer3DModel, AuraFlowTransformer2DModel, - BriaTransformer2DModel, BriaFiboTransformer2DModel, + BriaTransformer2DModel, ChromaTransformer2DModel, CogVideoXTransformer3DModel, CogView3PlusTransformer2DModel, diff --git a/src/diffusers/modular_pipelines/bria_fibo/__init__.py b/src/diffusers/modular_pipelines/bria_fibo/__init__.py index 302d271e0c43..770cb9391a38 100644 --- a/src/diffusers/modular_pipelines/bria_fibo/__init__.py +++ b/src/diffusers/modular_pipelines/bria_fibo/__init__.py @@ -21,8 +21,8 @@ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: - _import_structure["gemini_prompt_to_json"] = ["BriaFiboGeminiPromptToJson"] _import_structure["fibo_vlm_prompt_to_json"] = ["BriaFiboVLMPromptToJson"] + _import_structure["gemini_prompt_to_json"] = ["BriaFiboGeminiPromptToJson"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: @@ -31,8 +31,8 @@ except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 else: - from .gemini_prompt_to_json import BriaFiboGeminiPromptToJson from .fibo_vlm_prompt_to_json import BriaFiboVLMPromptToJson + from .gemini_prompt_to_json import BriaFiboGeminiPromptToJson else: import sys diff --git a/src/diffusers/modular_pipelines/bria_fibo/fibo_vlm_prompt_to_json.py b/src/diffusers/modular_pipelines/bria_fibo/fibo_vlm_prompt_to_json.py index c63c7f85d190..689e4ae59ae3 100644 --- a/src/diffusers/modular_pipelines/bria_fibo/fibo_vlm_prompt_to_json.py +++ b/src/diffusers/modular_pipelines/bria_fibo/fibo_vlm_prompt_to_json.py @@ -277,16 +277,12 @@ def build_messages( if refine_image is None: base_prompt = (structured_prompt or "").strip() edits = (editing_instructions or "").strip() - formatted = textwrap.dedent( - f""" Input: {base_prompt} Editing instructions: {edits}""" - ).strip() + formatted = textwrap.dedent(f""" Input: {base_prompt} Editing instructions: {edits}""").strip() user_content.append({"type": "text", "text": formatted}) else: user_content.append({"type": "image", "image": refine_image}) edits = (editing_instructions or "").strip() - formatted = textwrap.dedent( - f""" Editing instructions: {edits}""" - ).strip() + formatted = textwrap.dedent(f""" Editing instructions: {edits}""").strip() user_content.append({"type": "text", "text": formatted}) messages: List[Dict[str, Any]] = [] diff --git a/tests/models/transformers/test_models_transformer_bria_fibo.py b/tests/models/transformers/test_models_transformer_bria_fibo.py index 3df75ec15afa..ad87b5710aeb 100644 --- a/tests/models/transformers/test_models_transformer_bria_fibo.py +++ b/tests/models/transformers/test_models_transformer_bria_fibo.py @@ -18,6 +18,7 @@ import torch from diffusers import BriaFiboTransformer2DModel + from ...testing_utils import enable_full_determinism, torch_device from ..test_modeling_common import LoraHotSwappingForModelTesterMixin, ModelTesterMixin, TorchCompileTesterMixin @@ -25,8 +26,6 @@ enable_full_determinism() - - class BriaFiboTransformerTests(ModelTesterMixin, unittest.TestCase): model_class = BriaFiboTransformer2DModel main_input_name = "hidden_states" @@ -57,7 +56,7 @@ def dummy_input(self): "img_ids": image_ids, "txt_ids": text_ids, "timestep": timestep, - "text_encoder_layers": [encoder_hidden_states[:,:,:32], encoder_hidden_states[:,:,:32]], + "text_encoder_layers": [encoder_hidden_states[:, :, :32], encoder_hidden_states[:, :, :32]], } @property diff --git a/tests/pipelines/bria_fibo/test_pipeline_bria_fibo.py b/tests/pipelines/bria_fibo/test_pipeline_bria_fibo.py index 219dea0ba7a2..15cdb82fe128 100644 --- a/tests/pipelines/bria_fibo/test_pipeline_bria_fibo.py +++ b/tests/pipelines/bria_fibo/test_pipeline_bria_fibo.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import gc import tempfile import unittest @@ -30,10 +29,8 @@ from tests.pipelines.test_pipelines_common import PipelineTesterMixin, to_np from ...testing_utils import ( - backend_empty_cache, enable_full_determinism, require_torch_accelerator, - slow, torch_device, ) @@ -194,5 +191,3 @@ def test_to_dtype(self): def test_save_load_dduf(self): pass - - From a617433aceb67d2bab269a7cfcb04f75d5443612 Mon Sep 17 00:00:00 2001 From: galbria Date: Mon, 27 Oct 2025 13:04:57 +0000 Subject: [PATCH 03/15] fix CR --- docs/source/en/_toctree.yml | 4 + .../transformers/transformer_bria_fibo.py | 353 ++++++-- .../modular_pipelines/bria_fibo/__init__.py | 47 - .../bria_fibo/fibo_vlm_prompt_to_json.py | 373 -------- .../bria_fibo/gemini_prompt_to_json.py | 804 ------------------ .../pipelines/bria_fibo/pipeline_bria_fibo.py | 144 ++-- .../test_models_transformer_bria_fibo.py | 44 +- .../bria_fibo/test_pipeline_bria_fibo.py | 19 +- 8 files changed, 377 insertions(+), 1411 deletions(-) delete mode 100644 src/diffusers/modular_pipelines/bria_fibo/__init__.py delete mode 100644 src/diffusers/modular_pipelines/bria_fibo/fibo_vlm_prompt_to_json.py delete mode 100644 src/diffusers/modular_pipelines/bria_fibo/gemini_prompt_to_json.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 540e99a2c609..a5f0efe02f60 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -323,6 +323,8 @@ title: AllegroTransformer3DModel - local: api/models/aura_flow_transformer2d title: AuraFlowTransformer2DModel + - local: api/models/transformer_bria_fibo + title: BriaFiboTransformer2DModel - local: api/models/bria_transformer title: BriaTransformer2DModel - local: api/models/chroma_transformer @@ -469,6 +471,8 @@ title: BLIP-Diffusion - local: api/pipelines/bria_3_2 title: Bria 3.2 + - local: api/pipelines/bria_fibo + title: Bria Fibo - local: api/pipelines/chroma title: Chroma - local: api/pipelines/cogview3 diff --git a/src/diffusers/models/transformers/transformer_bria_fibo.py b/src/diffusers/models/transformers/transformer_bria_fibo.py index 9521b7f3dd72..e1bfde955527 100644 --- a/src/diffusers/models/transformers/transformer_bria_fibo.py +++ b/src/diffusers/models/transformers/transformer_bria_fibo.py @@ -1,18 +1,26 @@ +# Copyright (c) Bria.ai. All rights reserved. +# +# This file is licensed under the Creative Commons Attribution-NonCommercial 4.0 International Public License (CC-BY-NC-4.0). +# You may obtain a copy of the license at https://creativecommons.org/licenses/by-nc/4.0/ +# +# You are free to share and adapt this material for non-commercial purposes provided you give appropriate credit, +# indicate if changes were made, and do not use the material for commercial purposes. +# +# See the license for further details. +import inspect from typing import Any, Dict, List, Optional, Tuple, Union -import numpy as np import torch import torch.nn as nn +import torch.nn.functional as F from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin from ...models.attention_processor import Attention -from ...models.embeddings import TimestepEmbedding, get_timestep_embedding +from ...models.embeddings import TimestepEmbedding, apply_rotary_emb, get_1d_rotary_pos_embed, get_timestep_embedding from ...models.modeling_outputs import Transformer2DModelOutput from ...models.modeling_utils import ModelMixin -from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZeroSingle from ...models.transformers.transformer_bria import BriaAttnProcessor -from ...models.transformers.transformer_flux import FluxTransformerBlock from ...utils import ( USE_PEFT_BACKEND, logging, @@ -20,76 +28,193 @@ unscale_lora_layers, ) from ...utils.torch_utils import maybe_allow_in_graph +from ..attention import AttentionModuleMixin, FeedForward +from ..attention_dispatch import dispatch_attention_fn +from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle logger = logging.get_logger(__name__) # pylint: disable=invalid-name -def get_1d_rotary_pos_embed( - dim: int, - pos: Union[np.ndarray, int], - theta: float = 10000.0, - use_real=False, - linear_factor=1.0, - ntk_factor=1.0, - repeat_interleave_real=True, - freqs_dtype=torch.float32, # torch.float32, torch.float64 -): - """ - Precompute the frequency tensor for complex exponentials (cis) with given dimensions. This function calculates a - frequency tensor with complex exponentials using the given dimension 'dim' and the end index 'end'. The 'theta' - parameter scales the frequencies. The returned tensor contains complex values in complex64 data type. - - Args: - dim (`int`): Dimension of the frequency tensor. - pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar - theta (`float`, *optional*, defaults to 10000.0): - Scaling factor for frequency computation. Defaults to 10000.0. - use_real (`bool`, *optional*): - If True, return real part and imaginary part separately. Otherwise, return complex numbers. - linear_factor (`float`, *optional*, defaults to 1.0): - Scaling factor for the context extrapolation. Defaults to 1.0. - ntk_factor (`float`, *optional*, defaults to 1.0): - Scaling factor for the NTK-Aware RoPE. Defaults to 1.0. - repeat_interleave_real (`bool`, *optional*, defaults to `True`): - If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`. - Otherwise, they are concateanted with themselves. - freqs_dtype (`torch.float32` or `torch.float64`, *optional*, defaults to `torch.float32`): - the dtype of the frequency tensor. - Returns: - `torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2] - """ - assert dim % 2 == 0 - - if isinstance(pos, int): - pos = torch.arange(pos) - if isinstance(pos, np.ndarray): - pos = torch.from_numpy(pos) # type: ignore # [S] - - theta = theta * ntk_factor - freqs = ( - 1.0 - / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device)[: (dim // 2)] / dim)) - / linear_factor - ) # [D/2] - freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2] - if use_real and repeat_interleave_real: - # flux, hunyuan-dit, cogvideox - freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D] - freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D] - return freqs_cos, freqs_sin - elif use_real: - # stable audio, allegro - freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D] - freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D] - return freqs_cos, freqs_sin - else: - # lumina - freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2] - return freqs_cis +def _get_projections(attn: "BriaFiboAttention", hidden_states, encoder_hidden_states=None): + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + encoder_query = encoder_key = encoder_value = None + if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None: + encoder_query = attn.add_q_proj(encoder_hidden_states) + encoder_key = attn.add_k_proj(encoder_hidden_states) + encoder_value = attn.add_v_proj(encoder_hidden_states) + + return query, key, value, encoder_query, encoder_key, encoder_value + + +def _get_fused_projections(attn: "BriaFiboAttention", hidden_states, encoder_hidden_states=None): + query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1) + + encoder_query = encoder_key = encoder_value = (None,) + if encoder_hidden_states is not None and hasattr(attn, "to_added_qkv"): + encoder_query, encoder_key, encoder_value = attn.to_added_qkv(encoder_hidden_states).chunk(3, dim=-1) + + return query, key, value, encoder_query, encoder_key, encoder_value + + +def _get_qkv_projections(attn: "BriaFiboAttention", hidden_states, encoder_hidden_states=None): + if attn.fused_projections: + return _get_fused_projections(attn, hidden_states, encoder_hidden_states) + return _get_projections(attn, hidden_states, encoder_hidden_states) + + +class BriaFiboAttnProcessor: + # Copied from diffusers.models.transformers.transformer_flux.FluxAttnProcessor + _attention_backend = None + _parallel_config = None + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.") + + def __call__( + self, + attn: "BriaFiboAttention", + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections( + attn, hidden_states, encoder_hidden_states + ) + + query = query.unflatten(-1, (attn.heads, -1)) + key = key.unflatten(-1, (attn.heads, -1)) + value = value.unflatten(-1, (attn.heads, -1)) + query = attn.norm_q(query) + key = attn.norm_k(key) -class EmbedND(torch.nn.Module): + if attn.added_kv_proj_dim is not None: + encoder_query = encoder_query.unflatten(-1, (attn.heads, -1)) + encoder_key = encoder_key.unflatten(-1, (attn.heads, -1)) + encoder_value = encoder_value.unflatten(-1, (attn.heads, -1)) + + encoder_query = attn.norm_added_q(encoder_query) + encoder_key = attn.norm_added_k(encoder_key) + + query = torch.cat([encoder_query, query], dim=1) + key = torch.cat([encoder_key, key], dim=1) + value = torch.cat([encoder_value, value], dim=1) + + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) + key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) + + hidden_states = dispatch_attention_fn( + query, + key, + value, + attn_mask=attention_mask, + backend=self._attention_backend, + parallel_config=self._parallel_config, + ) + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.to(query.dtype) + + if encoder_hidden_states is not None: + encoder_hidden_states, hidden_states = hidden_states.split_with_sizes( + [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1 + ) + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + return hidden_states, encoder_hidden_states + else: + return hidden_states + + +class BriaFiboAttention(torch.nn.Module, AttentionModuleMixin): + # Copied from diffusers.models.transformers.transformer_flux.FluxAttention + _default_processor_cls = BriaFiboAttnProcessor + _available_processors = [ + BriaFiboAttnProcessor, + ] + + def __init__( + self, + query_dim: int, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = False, + added_kv_proj_dim: Optional[int] = None, + added_proj_bias: Optional[bool] = True, + out_bias: bool = True, + eps: float = 1e-5, + out_dim: int = None, + context_pre_only: Optional[bool] = None, + pre_only: bool = False, + elementwise_affine: bool = True, + processor=None, + ): + super().__init__() + + self.head_dim = dim_head + self.inner_dim = out_dim if out_dim is not None else dim_head * heads + self.query_dim = query_dim + self.use_bias = bias + self.dropout = dropout + self.out_dim = out_dim if out_dim is not None else query_dim + self.context_pre_only = context_pre_only + self.pre_only = pre_only + self.heads = out_dim // dim_head if out_dim is not None else heads + self.added_kv_proj_dim = added_kv_proj_dim + self.added_proj_bias = added_proj_bias + + self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) + self.to_k = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) + self.to_v = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) + + if not self.pre_only: + self.to_out = torch.nn.ModuleList([]) + self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)) + self.to_out.append(torch.nn.Dropout(dropout)) + + if added_kv_proj_dim is not None: + self.norm_added_q = torch.nn.RMSNorm(dim_head, eps=eps) + self.norm_added_k = torch.nn.RMSNorm(dim_head, eps=eps) + self.add_q_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + self.to_add_out = torch.nn.Linear(self.inner_dim, query_dim, bias=out_bias) + + if processor is None: + processor = self._default_processor_cls() + self.set_processor(processor) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) + quiet_attn_parameters = {"ip_adapter_masks", "ip_hidden_states"} + unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters] + if len(unused_kwargs) > 0: + logger.warning( + f"joint_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored." + ) + kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters} + return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs) + + +class FIBOEmbedND(torch.nn.Module): # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11 def __init__(self, theta: int, axes_dim: List[int]): super().__init__() @@ -182,7 +307,93 @@ def forward(self, caption): return hidden_states -class Timesteps(nn.Module): +@maybe_allow_in_graph +class BriaFiboTransformerBlock(nn.Module): + # Copied from diffusers.models.transformers.transformer_flux.FluxTransformerBlock + def __init__( + self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6 + ): + super().__init__() + + self.norm1 = AdaLayerNormZero(dim) + self.norm1_context = AdaLayerNormZero(dim) + + self.attn = BriaFiboAttention( + query_dim=dim, + added_kv_proj_dim=dim, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=dim, + context_pre_only=False, + bias=True, + processor=BriaFiboAttnProcessor(), + eps=eps, + ) + + self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") + + self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) + + norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( + encoder_hidden_states, emb=temb + ) + joint_attention_kwargs = joint_attention_kwargs or {} + + # Attention. + attention_outputs = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + image_rotary_emb=image_rotary_emb, + **joint_attention_kwargs, + ) + + if len(attention_outputs) == 2: + attn_output, context_attn_output = attention_outputs + elif len(attention_outputs) == 3: + attn_output, context_attn_output, ip_attn_output = attention_outputs + + # Process attention outputs for the `hidden_states`. + attn_output = gate_msa.unsqueeze(1) * attn_output + hidden_states = hidden_states + attn_output + + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + + ff_output = self.ff(norm_hidden_states) + ff_output = gate_mlp.unsqueeze(1) * ff_output + + hidden_states = hidden_states + ff_output + if len(attention_outputs) == 3: + hidden_states = hidden_states + ip_attn_output + + # Process attention outputs for the `encoder_hidden_states`. + context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output + encoder_hidden_states = encoder_hidden_states + context_attn_output + + norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) + norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] + + context_ff_output = self.ff_context(norm_encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output + if encoder_hidden_states.dtype == torch.float16: + encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) + + return encoder_hidden_states, hidden_states + + +class FIBOTimesteps(nn.Module): def __init__( self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1, time_theta=10000 ): @@ -209,7 +420,7 @@ class TimestepProjEmbeddings(nn.Module): def __init__(self, embedding_dim, time_theta): super().__init__() - self.time_proj = Timesteps( + self.time_proj = FIBOTimesteps( num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, time_theta=time_theta ) self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) @@ -258,7 +469,7 @@ def __init__( self.out_channels = in_channels self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim - self.pos_embed = EmbedND(theta=rope_theta, axes_dim=axes_dims_rope) + self.pos_embed = FIBOEmbedND(theta=rope_theta, axes_dim=axes_dims_rope) self.time_embed = TimestepProjEmbeddings(embedding_dim=self.inner_dim, time_theta=time_theta) @@ -270,7 +481,7 @@ def __init__( self.transformer_blocks = nn.ModuleList( [ - FluxTransformerBlock( + BriaFiboTransformerBlock( dim=self.inner_dim, num_attention_heads=self.config.num_attention_heads, attention_head_dim=self.config.attention_head_dim, diff --git a/src/diffusers/modular_pipelines/bria_fibo/__init__.py b/src/diffusers/modular_pipelines/bria_fibo/__init__.py deleted file mode 100644 index 770cb9391a38..000000000000 --- a/src/diffusers/modular_pipelines/bria_fibo/__init__.py +++ /dev/null @@ -1,47 +0,0 @@ -from typing import TYPE_CHECKING - -from ...utils import ( - DIFFUSERS_SLOW_IMPORT, - OptionalDependencyNotAvailable, - _LazyModule, - get_objects_from_module, - is_torch_available, - is_transformers_available, -) - - -_dummy_objects = {} -_import_structure = {} - -try: - if not (is_transformers_available() and is_torch_available()): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - from ...utils import dummy_torch_and_transformers_objects # noqa F403 - - _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) -else: - _import_structure["fibo_vlm_prompt_to_json"] = ["BriaFiboVLMPromptToJson"] - _import_structure["gemini_prompt_to_json"] = ["BriaFiboGeminiPromptToJson"] - -if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: - try: - if not (is_transformers_available() and is_torch_available()): - raise OptionalDependencyNotAvailable() - except OptionalDependencyNotAvailable: - from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 - else: - from .fibo_vlm_prompt_to_json import BriaFiboVLMPromptToJson - from .gemini_prompt_to_json import BriaFiboGeminiPromptToJson -else: - import sys - - sys.modules[__name__] = _LazyModule( - __name__, - globals()["__file__"], - _import_structure, - module_spec=__spec__, - ) - - for name, value in _dummy_objects.items(): - setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/modular_pipelines/bria_fibo/fibo_vlm_prompt_to_json.py b/src/diffusers/modular_pipelines/bria_fibo/fibo_vlm_prompt_to_json.py deleted file mode 100644 index 689e4ae59ae3..000000000000 --- a/src/diffusers/modular_pipelines/bria_fibo/fibo_vlm_prompt_to_json.py +++ /dev/null @@ -1,373 +0,0 @@ -import json -import math -import textwrap -from typing import Any, Dict, Iterable, List, Optional - -import torch -from boltons.iterutils import remap -from PIL import Image -from transformers import AutoModelForCausalLM, AutoProcessor, Qwen3VLForConditionalGeneration - -from .. import ComponentSpec, InputParam, ModularPipelineBlocks, OutputParam, PipelineState - - -def parse_aesthetic_score(record: dict) -> str: - ae = record["aesthetic_score"] - if ae < 5.5: - return "very low" - elif ae < 6: - return "low" - elif ae < 7: - return "medium" - elif ae < 7.6: - return "high" - else: - return "very high" - - -def parse_pickascore(record: dict) -> str: - ps = record["pickascore"] - if ps < 0.78: - return "very low" - elif ps < 0.82: - return "low" - elif ps < 0.87: - return "medium" - elif ps < 0.91: - return "high" - else: - return "very high" - - -def prepare_clean_caption(record: dict) -> str: - def keep(p, k, v): - is_none = v is None - is_empty_string = isinstance(v, str) and v == "" - is_empty_dict = isinstance(v, dict) and not v - is_empty_list = isinstance(v, list) and not v - is_nan = isinstance(v, float) and math.isnan(v) - if is_none or is_empty_string or is_empty_list or is_empty_dict or is_nan: - return False - return True - - try: - scores = {} - if "pickascore" in record: - scores["preference_score"] = parse_pickascore(record) - if "aesthetic_score" in record: - scores["aesthetic_score"] = parse_aesthetic_score(record) - - clean_caption_dict = remap(record, visit=keep) - - # Set aesthetics scores - if "aesthetics" not in clean_caption_dict: - if len(scores) > 0: - clean_caption_dict["aesthetics"] = scores - else: - clean_caption_dict["aesthetics"].update(scores) - - # Dumps clean structured caption as minimal json string (i.e. no newlines\whitespaces seps) - clean_caption_str = json.dumps(clean_caption_dict) - return clean_caption_str - except Exception as ex: - print("Error: ", ex) - raise ex - - -def _collect_images(messages: Iterable[Dict[str, Any]]) -> List[Image.Image]: - images: List[Image.Image] = [] - for message in messages: - content = message.get("content", []) - if not isinstance(content, list): - continue - for item in content: - if not isinstance(item, dict): - continue - if item.get("type") != "image": - continue - image_value = item.get("image") - if isinstance(image_value, Image.Image): - images.append(image_value) - else: - raise ValueError("Expected PIL.Image for image content in messages.") - return images - - -def _strip_stop_sequences(text: str, stop_sequences: Optional[List[str]]) -> str: - if not stop_sequences: - return text.strip() - cleaned = text - for stop in stop_sequences: - if not stop: - continue - index = cleaned.find(stop) - if index >= 0: - cleaned = cleaned[:index] - return cleaned.strip() - - -class TransformersEngine(torch.nn.Module): - """Inference wrapper using Hugging Face transformers.""" - - def __init__( - self, - model: str, - *, - processor_kwargs: Optional[Dict[str, Any]] = None, - model_kwargs: Optional[Dict[str, Any]] = None, - ) -> None: - super(TransformersEngine, self).__init__() - default_processor_kwargs: Dict[str, Any] = { - "min_pixels": 256 * 28 * 28, - "max_pixels": 1024 * 28 * 28, - } - processor_kwargs = {**default_processor_kwargs, **(processor_kwargs or {})} - model_kwargs = model_kwargs or {} - - self.processor = AutoProcessor.from_pretrained(model, **processor_kwargs) - - self.model = Qwen3VLForConditionalGeneration.from_pretrained( - model, - dtype=torch.bfloat16, - **model_kwargs, - ) - self.model.eval() - - tokenizer_obj = self.processor.tokenizer - if tokenizer_obj.pad_token_id is None: - tokenizer_obj.pad_token = tokenizer_obj.eos_token - self._pad_token_id = tokenizer_obj.pad_token_id - eos_token_id = tokenizer_obj.eos_token_id - if isinstance(eos_token_id, list) and eos_token_id: - self._eos_token_id = eos_token_id - elif eos_token_id is not None: - self._eos_token_id = [eos_token_id] - else: - raise ValueError("Tokenizer must define an EOS token for generation.") - - def dtype(self) -> torch.dtype: - return self.model.dtype - - def device(self) -> torch.device: - return self.model.device - - def _to_model_device(self, value: Any) -> Any: - if not isinstance(value, torch.Tensor): - return value - target_device = getattr(self.model, "device", None) - if target_device is None or target_device.type == "meta": - return value - if value.device == target_device: - return value - return value.to(target_device) - - def generate( - self, - messages: List[Dict[str, Any]], - top_p: float, - temperature: float, - max_tokens: int, - stop: Optional[List[str]] = None, - ) -> str: - tokenizer = self.processor.tokenizer - prompt_text = tokenizer.apply_chat_template( - messages, - tokenize=False, - add_generation_prompt=True, - ) - processor_inputs: Dict[str, Any] = { - "text": [prompt_text], - "padding": True, - "return_tensors": "pt", - } - images = _collect_images(messages) - if images: - processor_inputs["images"] = images - inputs = self.processor(**processor_inputs) - inputs = {key: self._to_model_device(value) for key, value in inputs.items()} - - generation_kwargs: Dict[str, Any] = { - "max_new_tokens": max_tokens, - "temperature": temperature, - "top_p": top_p, - "do_sample": temperature > 0, - "eos_token_id": self._eos_token_id, - "pad_token_id": self._pad_token_id, - } - - with torch.inference_mode(): - generated_ids = self.model.generate(**inputs, **generation_kwargs) - - input_ids = inputs.get("input_ids") - if input_ids is None: - raise RuntimeError("Processor did not return input_ids; cannot compute new tokens.") - new_token_ids = generated_ids[:, input_ids.shape[-1] :] - decoded = tokenizer.batch_decode(new_token_ids, skip_special_tokens=True) - if not decoded: - return "" - text = decoded[0] - stripped_text = _strip_stop_sequences(text, stop) - json_prompt = json.loads(stripped_text) - return json_prompt - - -def generate_json_prompt( - vlm_processor: AutoModelForCausalLM, - top_p: float, - temperature: float, - max_tokens: int, - stop: List[str], - image: Optional[Image.Image] = None, - prompt: Optional[str] = None, - structured_prompt: Optional[str] = None, -): - if image is None and structured_prompt is None: - # only got prompt - task = "generate" - editing_instructions = None - elif image is None and structured_prompt is not None and prompt is not None: - # got structured prompt and prompt - task = "refine" - editing_instructions = prompt - elif image is not None and structured_prompt is None and prompt is not None: - # got image and prompt - task = "refine" - editing_instructions = prompt - elif image is not None and structured_prompt is None and prompt is None: - # only got image - task = "inspire" - editing_instructions = None - else: - raise ValueError("Invalid input") - - messages = build_messages( - task, - image=image, - prompt=prompt, - structured_prompt=structured_prompt, - editing_instructions=editing_instructions, - ) - - generated_prompt = vlm_processor.generate( - messages=messages, top_p=top_p, temperature=temperature, max_tokens=max_tokens, stop=stop - ) - cleaned_json_data = prepare_clean_caption(generated_prompt) - return cleaned_json_data - - -def build_messages( - task: str, - *, - image: Optional[Image.Image] = None, - refine_image: Optional[Image.Image] = None, - prompt: Optional[str] = None, - structured_prompt: Optional[str] = None, - editing_instructions: Optional[str] = None, -) -> List[Dict[str, Any]]: - user_content: List[Dict[str, Any]] = [] - - if task == "inspire": - user_content.append({"type": "image", "image": image}) - user_content.append({"type": "text", "text": ""}) - elif task == "generate": - text_value = (prompt or "").strip() - formatted = f"\n{text_value}" - user_content.append({"type": "text", "text": formatted}) - else: # refine - if refine_image is None: - base_prompt = (structured_prompt or "").strip() - edits = (editing_instructions or "").strip() - formatted = textwrap.dedent(f""" Input: {base_prompt} Editing instructions: {edits}""").strip() - user_content.append({"type": "text", "text": formatted}) - else: - user_content.append({"type": "image", "image": refine_image}) - edits = (editing_instructions or "").strip() - formatted = textwrap.dedent(f""" Editing instructions: {edits}""").strip() - user_content.append({"type": "text", "text": formatted}) - - messages: List[Dict[str, Any]] = [] - messages.append({"role": "user", "content": user_content}) - return messages - - -class BriaFiboVLMPromptToJson(ModularPipelineBlocks): - model_name = "BriaFibo" - - def __init__(self, model_id): - super().__init__() - self.engine = TransformersEngine(model_id) - self.engine.model.to("cuda") - - @property - def expected_components(self) -> List[ComponentSpec]: - return [] - - @property - def inputs(self) -> List[InputParam]: - prompt_input = InputParam( - "prompt", - type_hint=str, - required=False, - description="Prompt to use", - ) - image_input = InputParam( - name="image", type_hint=Image.Image, required=False, description="image for inspiration mode" - ) - json_prompt_input = InputParam( - name="json_prompt", type_hint=str, required=False, description="JSON prompt to use" - ) - sampling_top_p_input = InputParam( - name="sampling_top_p", type_hint=float, required=False, description="Sampling top p", default=0.9 - ) - sampling_temperature_input = InputParam( - name="sampling_temperature", - type_hint=float, - required=False, - description="Sampling temperature", - default=0.2, - ) - sampling_max_tokens_input = InputParam( - name="sampling_max_tokens", type_hint=int, required=False, description="Sampling max tokens", default=4096 - ) - return [ - prompt_input, - image_input, - json_prompt_input, - sampling_top_p_input, - sampling_temperature_input, - sampling_max_tokens_input, - ] - - @property - def intermediate_inputs(self) -> List[InputParam]: - return [] - - @property - def intermediate_outputs(self) -> List[OutputParam]: - return [ - OutputParam( - "json_prompt", - type_hint=str, - description="JSON prompt by the VLM", - ) - ] - - def __call__(self, components, state: PipelineState) -> PipelineState: - block_state = self.get_block_state(state) - - prompt = block_state.prompt - image = block_state.image - json_prompt = block_state.json_prompt - block_state.json_prompt = generate_json_prompt( - vlm_processor=self.engine, - image=image, - prompt=prompt, - structured_prompt=json_prompt, - top_p=block_state.sampling_top_p, - temperature=block_state.sampling_temperature, - max_tokens=block_state.sampling_max_tokens, - stop=["<|im_end|>", "<|end_of_text|>"], - ) - self.set_block_state(state, block_state) - - return components, state diff --git a/src/diffusers/modular_pipelines/bria_fibo/gemini_prompt_to_json.py b/src/diffusers/modular_pipelines/bria_fibo/gemini_prompt_to_json.py deleted file mode 100644 index dc690444839d..000000000000 --- a/src/diffusers/modular_pipelines/bria_fibo/gemini_prompt_to_json.py +++ /dev/null @@ -1,804 +0,0 @@ -import io -import json -import math -import os -from functools import cache -from typing import List, Optional, Tuple - -from boltons.iterutils import remap -from google import genai -from PIL import Image -from pydantic import BaseModel, Field - -from ...modular_pipelines import InputParam, ModularPipelineBlocks, OutputParam, PipelineState - - -class ObjectDescription(BaseModel): - description: str = Field(..., description="Short description of the object.") - location: str = Field(..., description="E.g., 'center', 'top-left', 'bottom-right foreground'.") - relationship: str = Field( - ..., description="Describe the relationship between the object and the other objects in the image." - ) - relative_size: Optional[str] = Field(None, description="E.g., 'small', 'medium', 'large within frame'.") - shape_and_color: Optional[str] = Field(None, description="Describe the basic shape and dominant color.") - texture: Optional[str] = Field(None, description="E.g., 'smooth', 'rough', 'metallic', 'furry'.") - appearance_details: Optional[str] = Field(None, description="Any other notable visual details.") - # If cluster of object - number_of_objects: Optional[int] = Field(None, description="The number of objects in the cluster.") - # Human-specific fields - pose: Optional[str] = Field(None, description="Describe the body position.") - expression: Optional[str] = Field(None, description="Describe facial expression.") - clothing: Optional[str] = Field(None, description="Describe attire.") - action: Optional[str] = Field(None, description="Describe the action of the human.") - gender: Optional[str] = Field(None, description="Describe the gender of the human.") - skin_tone_and_texture: Optional[str] = Field(None, description="Describe the skin tone and texture.") - orientation: Optional[str] = Field(None, description="Describe the orientation of the human.") - - -class LightingDetails(BaseModel): - conditions: str = Field( - ..., description="E.g., 'bright daylight', 'dim indoor', 'studio lighting', 'golden hour'." - ) - direction: str = Field(..., description="E.g., 'front-lit', 'backlit', 'side-lit from left'.") - shadows: Optional[str] = Field(None, description="Describe the presence of shadows.") - - -class AestheticsDetails(BaseModel): - composition: str = Field(..., description="E.g., 'rule of thirds', 'symmetrical', 'centered', 'leading lines'.") - color_scheme: str = Field( - ..., description="E.g., 'monochromatic blue', 'warm complementary colors', 'high contrast'." - ) - mood_atmosphere: str = Field(..., description="E.g., 'serene', 'energetic', 'mysterious', 'joyful'.") - - -class PhotographicCharacteristicsDetails(BaseModel): - depth_of_field: str = Field(..., description="E.g., 'shallow', 'deep', 'bokeh background'.") - focus: str = Field(..., description="E.g., 'sharp focus on subject', 'soft focus', 'motion blur'.") - camera_angle: str = Field(..., description="E.g., 'eye-level', 'low angle', 'high angle', 'dutch angle'.") - lens_focal_length: str = Field(..., description="E.g., 'wide-angle', 'telephoto', 'macro', 'fisheye'.") - - -class TextRender(BaseModel): - text: str = Field(..., description="The text content.") - location: str = Field(..., description="E.g., 'center', 'top-left', 'bottom-right foreground'.") - size: str = Field(..., description="E.g., 'small', 'medium', 'large within frame'.") - color: str = Field(..., description="E.g., 'red', 'blue', 'green'.") - font: str = Field(..., description="E.g., 'realistic', 'cartoonish', 'minimalist'.") - appearance_details: Optional[str] = Field(None, description="Any other notable visual details.") - - -class ImageAnalysis(BaseModel): - short_description: str = Field(..., description="A concise summary of the image content, 200 words maximum.") - objects: List[ObjectDescription] = Field(..., description="List of prominent foreground/midground objects.") - background_setting: str = Field( - ..., - description="Describe the overall environment, setting, or background, including any notable background elements.", - ) - lighting: LightingDetails = Field(..., description="Details about the lighting.") - aesthetics: AestheticsDetails = Field(..., description="Details about the image aesthetics.") - photographic_characteristics: Optional[PhotographicCharacteristicsDetails] = Field( - None, description="Details about photographic characteristics." - ) - style_medium: Optional[str] = Field(None, description="Identify the artistic style or medium.") - text_render: Optional[List[TextRender]] = Field(None, description="List of text renders in the image.") - context: str = Field(..., description="Provide any additional context that helps understand the image better.") - artistic_style: Optional[str] = Field( - None, description="describe specific artistic characteristics, 3 words maximum." - ) - - -def get_gemini_output_schema() -> dict: - return { - "properties": { - "short_description": {"type": "STRING"}, - "objects": { - "items": { - "properties": { - "description": {"type": "STRING"}, - "location": {"type": "STRING"}, - "relationship": {"type": "STRING"}, - "relative_size": {"type": "STRING"}, - "shape_and_color": {"type": "STRING"}, - "texture": {"nullable": True, "type": "STRING"}, - "appearance_details": {"nullable": True, "type": "STRING"}, - "number_of_objects": {"nullable": True, "type": "INTEGER"}, - "pose": {"nullable": True, "type": "STRING"}, - "expression": {"nullable": True, "type": "STRING"}, - "clothing": {"nullable": True, "type": "STRING"}, - "action": {"nullable": True, "type": "STRING"}, - "gender": {"nullable": True, "type": "STRING"}, - "skin_tone_and_texture": {"nullable": True, "type": "STRING"}, - "orientation": {"nullable": True, "type": "STRING"}, - }, - "required": [ - "description", - "location", - "relationship", - "relative_size", - "shape_and_color", - "texture", - "appearance_details", - "number_of_objects", - "pose", - "expression", - "clothing", - "action", - "gender", - "skin_tone_and_texture", - "orientation", - ], - "type": "OBJECT", - }, - "type": "ARRAY", - }, - "background_setting": {"type": "STRING"}, - "lighting": { - "properties": { - "conditions": {"type": "STRING"}, - "direction": {"type": "STRING"}, - "shadows": {"nullable": True, "type": "STRING"}, - }, - "required": ["conditions", "direction", "shadows"], - "type": "OBJECT", - }, - "aesthetics": { - "properties": { - "composition": {"type": "STRING"}, - "color_scheme": {"type": "STRING"}, - "mood_atmosphere": {"type": "STRING"}, - }, - "required": ["composition", "color_scheme", "mood_atmosphere"], - "type": "OBJECT", - }, - "photographic_characteristics": { - "nullable": True, - "properties": { - "depth_of_field": {"type": "STRING"}, - "focus": {"type": "STRING"}, - "camera_angle": {"type": "STRING"}, - "lens_focal_length": {"type": "STRING"}, - }, - "required": [ - "depth_of_field", - "focus", - "camera_angle", - "lens_focal_length", - ], - "type": "OBJECT", - }, - "style_medium": {"type": "STRING"}, - "text_render": { - "items": { - "properties": { - "text": {"type": "STRING"}, - "location": {"type": "STRING"}, - "size": {"type": "STRING"}, - "color": {"type": "STRING"}, - "font": {"type": "STRING"}, - "appearance_details": {"nullable": True, "type": "STRING"}, - }, - "required": [ - "text", - "location", - "size", - "color", - "font", - "appearance_details", - ], - "type": "OBJECT", - }, - "type": "ARRAY", - }, - "context": {"type": "STRING"}, - "artistic_style": {"type": "STRING"}, - }, - "required": [ - "short_description", - "objects", - "background_setting", - "lighting", - "aesthetics", - "photographic_characteristics", - "style_medium", - "text_render", - "context", - "artistic_style", - ], - "type": "OBJECT", - } - - -json_schema_full = """1. `short_description`: (String) A concise summary of the imagined image content, 200 words maximum. -2. `objects`: (Array of Objects) List a maximum of 5 prominent objects. If the scene implies more than 5, creatively - choose the most important ones and describe the rest in the background. For each object, include: - * `description`: (String) A detailed description of the imagined object, 100 words maximum. - * `location`: (String) E.g., "center", "top-left", "bottom-right foreground". - * `relative_size`: (String) E.g., "small", "medium", "large within frame". (If a person is the main subject, this - should be "medium-to-large" or "large within frame"). - * `shape_and_color`: (String) Describe the basic shape and dominant color. - * `texture`: (String) E.g., "smooth", "rough", "metallic", "furry". - * `appearance_details`: (String) Any other notable visual details. - * `relationship`: (String) Describe the relationship between the object and the other objects in the image. - * `orientation`: (String) Describe the orientation or positioning of the object, e.g., "upright", "tilted 45 - degrees", "horizontal", "vertical", "facing left", "facing right", "upside down", "lying on its side". - * If the object is a human or a human-like object, include the following: - * `pose`: (String) Describe the body position. - * `expression`: (String) Describe facial expression and emotion. E.g., "winking", "joyful", "serious", - "surprised", "calm". - * `clothing`: (String) Describe attire. - * `action`: (String) Describe the action of the human. - * `gender`: (String) Describe the gender of the human. - * `skin_tone_and_texture`: (String) Describe the skin tone and texture. - * If the object is a cluster of objects, include the following: - * `number_of_objects`: (Integer) The number of objects in the cluster. -3. `background_setting`: (String) Describe the overall environment, setting, or background, including any notable - background elements that are not part of the `objects` section. -4. `lighting`: (Object) - * `conditions`: (String) E.g., "bright daylight", "dim indoor", "studio lighting", "golden hour". - * `direction`: (String) E.g., "front-lit", "backlit", "side-lit from left". - * `shadows`: (String) Describe the presence and quality of shadows, e.g., "long, soft shadows", "sharp, defined - shadows", "minimal shadows". -5. `aesthetics`: (Object) - * `composition`: (String) E.g., "rule of thirds", "symmetrical", "centered", "leading lines". If people are the - main subject, specify the shot type, e.g., "medium shot", "close-up", "portrait composition". - * `color_scheme`: (String) E.g., "monochromatic blue", "warm complementary colors", "high contrast". - * `mood_atmosphere`: (String) E.g., "serene", "energetic", "mysterious", "joyful". -6. `photographic_characteristics`: (Object) - * `depth_of_field`: (String) E.g., "shallow", "deep", "bokeh background". - * `focus`: (String) E.g., "sharp focus on subject", "soft focus", "motion blur". - * `camera_angle`: (String) E.g., "eye-level", "low angle", "high angle", "dutch angle". - * `lens_focal_length`: (String) E.g., "wide-angle", "telephoto", "macro", "fisheye". (If the main subject is a - person, prefer "standard lens (e.g., 35mm-50mm)" or "portrait lens (e.g., 50mm-85mm)" to ensure they are framed - more closely. Avoid "wide-angle" for people unless specified). -7. `style_medium`: (String) Identify the artistic style or medium based on the user's prompt or creative - interpretation (e.g., "photograph", "oil painting", "watercolor", "3D render", "digital illustration", "pencil - sketch"). -8. `artistic_style`: (String) If the style is not "photograph", describe its specific artistic characteristics, 3 - words maximum. (e.g., "impressionistic, vibrant, textured" for an oil painting). -9. `context`: (String) Provide a general description of the type of image this would be. For example: "This is a - concept for a high-fashion editorial photograph intended for a magazine spread," or "This describes a piece of - concept art for a fantasy video game." -10. `text_render`: (Array of Objects) By default, this array should be empty (`[]`). Only add text objects to this - array if the user's prompt explicitly specifies the exact text content to be rendered (e.g., user asks for "a - poster with the title 'Cosmic Dream'"). Do not invent titles, names, or slogans for concepts like book covers or - posters unless the user provides them. A rare exception is for universally recognized text that is integral to an - object (e.g., the word 'STOP' on a 'stop sign'). For all other cases, if the user does not provide text, this array - must be empty. - * `text`: (String) The exact text content provided by the user. NEVER use generic placeholders. - * `location`: (String) E.g., "center", "top-left", "bottom-right foreground". - * `size`: (String) E.g., "medium", "large", "large within frame". - * `color`: (String) E.g., "red", "blue", "green". - * `font`: (String) E.g., "realistic", "cartoonish", "minimalist", "serif typeface". - * `appearance_details`: (String) Any other notable visual details.""" - - -@cache -def get_instructions(mode: str) -> Tuple[str, str]: - system_prompts = {} - - system_prompts["Caption"] = """ -You are a meticulous and perceptive Visual Art Director working for a leading Generative AI company. Your expertise -lies in analyzing images and extracting detailed, structured information. Your primary task is to analyze provided -images and generate a comprehensive JSON object describing them. Adhere strictly to the following structure and -guidelines: The output MUST be ONLY a valid JSON object. Do not include any text before or after the JSON object (e.g., -no "Here is the JSON:", no explanations, no apologies). IMPORTANT: When describing human body parts, positions, or -actions, always describe them from the PERSON'S OWN PERSPECTIVE, not from the observer's viewpoint. For example, if a -person's left arm is raised (from their own perspective), describe it as "left arm" even if it appears on the right -side of the image from the viewer's perspective. The JSON object must contain the following keys precisely: -1. `short_description`: (String) A concise summary of the image content, 200 words maximum. -2. `objects`: (Array of Objects) List a maximum of 5 prominent objects if there are more than 5, list them in the - background. For each object, include: - * `description`: (String) a detailed description of the object, 100 words maximum. - * `location`: (String) E.g., "center", "top-left", "bottom-right foreground". - * `relative_size`: (String) E.g., "small", "medium", "large within frame". - * `shape_and_color`: (String) Describe the basic shape and dominant color. - * `texture`: (String) E.g., "smooth", "rough", "metallic", "furry". - * `appearance_details`: (String) Any other notable visual details. - * `relationship`: (String) Describe the relationship between the object and the other objects in the image. - * `orientation`: (String) Describe the orientation or positioning of the object, e.g., "upright", "tilted 45 - degrees", "horizontal", "vertical", "facing left", "facing right", "upside down", "lying on its side". - if the object is a human or a human-like object, include the following: - * `pose`: (String) Describe the body position. - * `expression`: (String) Describe facial expression and emotion. E.g., "winking", "joyful", "serious", - "surprised", "calm". - * `clothing`: (String) Describe attire. - * `action`: (String) Describe the action of the human. - * `gender`: (String) Describe the gender of the human. - * `skin_tone_and_texture`: (String) Describe the skin tone and texture. - if the object is a cluster of objects, include the following: - * `number_of_objects`: (Integer) The number of objects in the cluster. -3. `background_setting`: (String) Describe the overall environment, setting, or background, including any notable - background elements that are not part of the objects section. -4. `lighting`: (Object) - * `conditions`: (String) E.g., "bright daylight", "dim indoor", "studio lighting", "golden hour". - * `direction`: (String) E.g., "front-lit", "backlit", "side-lit from left". - * `shadows`: (String) Describe the presence of shadows. -5. `aesthetics`: (Object) - * `composition`: (String) E.g., "rule of thirds", "symmetrical", "centered", "leading lines". - * `color_scheme`: (String) E.g., "monochromatic blue", "warm complementary colors", "high contrast". - * `mood_atmosphere`: (String) E.g., "serene", "energetic", "mysterious", "joyful". -6. `photographic_characteristics`: (Object) - * `depth_of_field`: (String) E.g., "shallow", "deep", "bokeh background". - * `focus`: (String) E.g., "sharp focus on subject", "soft focus", "motion blur". - * `camera_angle`: (String) E.g., "eye-level", "low angle", "high angle", "dutch angle". - * `lens_focal_length`: (String) E.g., "wide-angle", "telephoto", "macro", "fisheye". -7. `style_medium`: (String) Identify the artistic style or medium (e.g., "photograph", "oil painting", "watercolor", - "3D render", "digital illustration", "pencil sketch") If the style is not "photograph", but artistic, please - describe the specific artistic characteristics under 'artistic_style', 50 words maximum. -8. `artistic_style`: (String) describe specific artistic characteristics, 3 words maximum. -9. `context`: (String) Provide any additional context that helps understand the image better. This should include a - general description of the type of image (e.g., Fashion Photography, Product Shot, Magazine Cover, Nature - Photography, Art Piece, etc.), as well as any other relevant contextual information that situates the image within a - broader category or intended use. For example: "This is a high-fashion editorial photograph intended for a magazine - spread" -10. `text_render`: (Array of Objects) List of a maximum of 5 most prominent text renders in the image. For each text - render, include: - * `text`: (String) The text content. - * `location`: (String) E.g., "center", "top-left", "bottom-right foreground". - * `size`: (String) E.g., "small", "medium", "large within frame". - * `color`: (String) E.g., "red", "blue", "green". - * `font`: (String) E.g., "realistic", "cartoonish", "minimalist". - * `appearance_details`: (String) Any other notable visual details. -Ensure the information within the JSON is accurate, detailed where specified, and avoids redundancy between fields. -""" - - system_prompts[ - "Generate" - ] = f"""You are a visionary and creative Visual Art Director at a leading Generative AI company. - -Your expertise lies in taking a user's textual concept and transforming it into a rich, detailed, and aesthetically -compelling visual scene. - -Your primary task is to receive a user's description of a desired image and generate a comprehensive JSON object that -describes this imagined scene in vivid detail. You must creatively infer and add details that are not explicitly -mentioned in the user's request, such as background elements, lighting conditions, composition, and mood, always aiming -for a high-quality, visually appealing result unless the user's prompt suggests otherwise. - -Adhere strictly to the following structure and guidelines: - -The output MUST be ONLY a valid JSON object. Do not include any text before or after the JSON object (e.g., no "Here is -the JSON:", no explanations, no apologies). - -IMPORTANT: When describing human body parts, positions, or actions, always describe them from the PERSON'S OWN -PERSPECTIVE, not from the observer's viewpoint. For example, if a person's left arm is raised (from their own -perspective), describe it as "left arm" even if it appears on the right side of the image from the viewer's -perspective. - -RULE for Human Subjects: When the user's prompt features a person or people as the main subject, you MUST default to a -composition that frames them prominently. Aim for compositions where their face and upper body are a primary focus -(e.g., 'medium shot', 'close-up'). Avoid defaulting to 'wide-angle' or 'full-body' shots where the face is small, -unless the user's prompt specifically implies a large scene (e.g., "a person standing on a mountain"). - -Unless the user's prompt explicitly requests a different style (e.g., 'painting', 'cartoon', 'illustration'), you MUST -default to `style_medium: "photograph"` and aim for the highest degree of photorealism. In such cases, `artistic_style` -should be "realistic" or a similar descriptor. - -The JSON object must contain the following keys precisely: - -{json_schema_full} - -Ensure the information within the JSON is detailed, creative, internally consistent, and avoids redundancy between -fields.""" - - system_prompts[ - "RefineA" - ] = f"""You are a Meticulous Visual Editor and Senior Art Director at a leading Generative AI company. - -Your expertise is in refining and modifying existing visual concepts based on precise feedback. - -Your primary task is to receive an existing JSON object that describes a visual scene, along with a textual instruction -for how to change it. You must then generate a new, updated JSON object that perfectly incorporates the requested -changes. - -Adhere strictly to the following structure and guidelines: - -1. **Input:** You will receive two pieces of information: an existing JSON object and a textual instruction. -2. **Output:** Your output MUST be ONLY a single, valid JSON object in the specified schema. Do not include any text - before or after the JSON object. -3. **Modification Logic:** - * Carefully parse the user's textual instruction to understand the desired changes. - * Modify ONLY the fields in the JSON that are directly or logically affected by the instruction. - * All other fields not relevant to the change must be copied exactly from the original JSON. Do not alter or omit - them. -4. **Holistic Consistency (IMPORTANT):** Changes in one field must be logically reflected in others. For example: - * If the instruction is to "change the background to a snowy forest," you must update the `background_setting` - field, and also update the `short_description` to mention the new setting. The `mood_atmosphere` might also need - to change to "serene" or "wintry." - * If the instruction is to "add the text 'WINTER SALE' at the top," you must add a new entry to the `text_render` - array. - * If the instruction is to "make the person smile," you must update the `expression` field for that object and - potentially update the overall `mood_atmosphere`. -5. **Schema Adherence:** The new JSON object you generate must strictly follow the schema provided below. - -The JSON object must contain the following keys precisely: - -{json_schema_full}""" - - system_prompts[ - "RefineB" - ] = f"""You are an advanced Multimodal Visual Specialist at a leading Generative AI company. - -Your unique expertise is in analyzing and editing visual concepts by processing an image, its corresponding JSON -metadata, and textual feedback simultaneously. - -Your primary task is to receive three inputs: an existing image, its descriptive JSON object, and a textual instruction -for a modification. You must use the image as the primary source of truth to understand the context of the requested -change and then generate a new, updated JSON object that accurately reflects that change. - -Adhere strictly to the following structure and guidelines: - -1. **Inputs:** You will receive an image, an existing JSON object, and a textual instruction. -2. **Visual Grounding (IMPORTANT):** The provided image is the ground truth. Use it to visually verify the contents of - the scene and to understand the context of the user's edit instruction. For example, if the instruction is "make - the car blue," visually locate the car in the image to inform your edits to the JSON. -3. **Output:** Your output MUST be ONLY a single, valid JSON object in the specified schema. Do not include any text - before or after the JSON object. -4. **Modification Logic:** - * Analyze the user's textual instruction in the context of what you see in the image. - * Modify ONLY the fields in the JSON that are directly or logically affected by the instruction. - * All other fields not relevant to the change must be copied exactly from the original JSON. -5. **Holistic Consistency:** Changes must be reflected logically across the JSON, consistent with a potential visual - change to the image. For instance, changing the lighting from 'daylight' to 'golden hour' should not only update - the `lighting` object but also the `mood_atmosphere`, `shadows`, and the `short_description`. -6. **Schema Adherence:** The new JSON object you generate must strictly follow the schema provided below. - -The JSON object must contain the following keys precisely: - -{json_schema_full}""" - - system_prompts[ - "InspireA" - ] = f"""You are a highly skilled Creative Director for Visual Adaptation at a leading Generative AI company. - -Your expertise lies in using an existing image as a visual reference to create entirely new scenes. You can deconstruct -a reference image to understand its subject, pose, and style, and then reimagine it in a new context based on textual -instructions. - -Your primary task is to receive a reference image and a textual instruction. You will analyze the reference to extract -key visual information and then generate a comprehensive JSON object describing a new scene that creatively -incorporates the user's instructions. - -Adhere strictly to the following structure and guidelines: - -1. **Inputs:** You will receive a reference image and a textual instruction. You will NOT receive a starting JSON. -2. **Core Logic (Analyze and Synthesize):** - * **Analyze:** First, deeply analyze the provided reference image. Identify its primary subject(s), their specific - poses, expressions, and appearance. Also note the overall composition, lighting style, and artistic medium. - * **Synthesize:** Next, interpret the textual instruction to understand what elements to keep from the reference - and what to change. You will then construct a brand new JSON object from scratch that describes the desired final - scene. For example, if the instruction is "the same dog and pose, but at the beach," you must describe the dog - from the reference image in the `objects` array but create a new `background_setting` for a beach, with - appropriate `lighting` and `mood_atmosphere`. -3. **Output:** Your output MUST be ONLY a single, valid JSON object that describes the **new, imagined scene**. Do not - describe the original reference image. -4. **Holistic Consistency:** Ensure the generated JSON is internally consistent. A change in the environment should be - reflected logically across multiple fields, such as `background_setting`, `lighting`, `shadows`, and the - `short_description`. -5. **Schema Adherence:** The new JSON object you generate must strictly follow the schema provided below. - -The JSON object must contain the following keys precisely: - -{json_schema_full}""" - - system_prompts["InspireB"] = system_prompts["Caption"] - - final_prompts = {} - - final_prompts["Generate"] = ( - "Generate a detailed JSON object, adhering to the expected schema, for an imagined scene based on the following request: {user_prompt}." - ) - - final_prompts["RefineA"] = """ - [EXISTING JSON]: {json_data} - - [EDIT INSTRUCTIONS]: {user_prompt} - - [TASK]: Generate the new, updated JSON object that incorporates the edit instructions. Follow all system rules - for modification, consistency, and formatting. - """ - - final_prompts["RefineB"] = """ - [EXISTING JSON]: {json_data} - - [EDIT INSTRUCTIONS]: {user_prompt} - - [TASK]: Analyze the provided image and its contextual JSON. Then, generate the new, updated JSON object that - incorporates the edit instructions. Follow all your system rules for visual analysis, modification, and - consistency. - """ - - final_prompts["InspireA"] = """ - [EDIT INSTRUCTIONS]: {user_prompt} - - [TASK]: Use the provided image as a visual reference only. Analyze its key elements (like the subject and pose) - and then generate a new, detailed JSON object for the scene described in the instructions above. Do not - describe the reference image itself; describe the new scene. Follow all of your system rules. - """ - - final_prompts["Caption"] = ( - "Analyze the provided image and generate the detailed JSON object as specified in your instructions." - ) - final_prompts["InspireB"] = final_prompts["Caption"] - - return system_prompts.get(mode, ""), final_prompts.get(mode, "") - - -def keep(p, k, v): - is_none = v is None - is_empty_string = isinstance(v, str) and v == "" - is_empty_dict = isinstance(v, dict) and not v - is_empty_list = isinstance(v, list) and not v - is_nan = isinstance(v, float) and math.isnan(v) - if is_none or is_empty_string or is_empty_list or is_empty_dict: - return False - if is_nan: - return False - return True - - -def validate_json(json_data: dict) -> dict: - ia = ImageAnalysis.model_validate_json(json_data, strict=True) - return ia.model_dump(exclude_none=True) - - -def validate_structured_prompt_str(structured_prompt_str: str) -> str: - ia = ImageAnalysis.model_validate_json(structured_prompt_str, strict=True) - c = ia.model_dump(exclude_none=True) - return json.dumps(c) - - -def prepare_clean_caption(json_dump: dict) -> str: - # filter empty values recursivly (i.e. None, "", {}, [], float("nan")) - clean_caption_dict = remap(json_dump, visit=keep) - - scores = {"preference_score": "very high", "aesthetic_score": "very high"} - # Set aesthetics scores - if "aesthetics" not in clean_caption_dict: - clean_caption_dict["aesthetics"] = scores - else: - clean_caption_dict["aesthetics"].update(scores) - - # Dumps clean structured caption as minimal json string (i.e. no newlines\whitespaces seps) - clean_caption_str = json.dumps(clean_caption_dict) - return clean_caption_str - - -# resize an input image to have a specific number of pixels (1,048,576 or 1024×1024) -# while maintaining a certain aspect ratio and granularity (output width and height must be multiples of this number). -def resize_image_by_num_pixels( - image: Image.Image, pixel_number: int = 1048576, granularity_val: int = 64, target_ratio: float = 0.0 -) -> Image.Image: - if target_ratio != 0.0: - ratio = target_ratio - else: - ratio = image.size[0] / image.size[1] - width = int((pixel_number * ratio) ** 0.5) - width = width - (width % granularity_val) - height = int(pixel_number / width) - height = height - (height % granularity_val) - return image.resize((width, height)) - - -def infer_with_gemini( - client: genai.Client, - final_prompt: str, - system_prompt: str, - top_p: float, - temperature: float, - max_tokens: int, - seed: int = 42, - image: Optional[Image.Image] = None, - model: str = "gemini-2.5-flash", -) -> str: - """ - Calls Gemini API with the given prompt and returns the raw JSON response. - - Args: - final_prompt: The text prompt to send to Gemini - system_prompt: The system instruction for Gemini - existing_image_path: Optional path to an image file to include - model: The Gemini model to use - - Returns: - Raw JSON response text from Gemini - """ - parts = [{"text": final_prompt}] - if image: - # Save image into bytes - image = image.convert("RGB") # the model can't produce rgba so sending them as input has no effect - less_then = 262144 - if image.size[0] * image.size[1] > less_then: - image = resize_image_by_num_pixels( - image, pixel_number=less_then, granularity_val=1, target_ratio=0.0 - ) # 512x512 - buffer = io.BytesIO() - image.save(buffer, format="JPEG") - image_bytes = buffer.getvalue() - - img_part = { - "inlineData": { - "data": image_bytes, - "mimeType": "image/jpeg", - } - } - parts.append(img_part) - - contents = [{"role": "user", "parts": parts}] - - generationConfig = { - "temperature": temperature, - "topP": top_p, - "maxOutputTokens": max_tokens, - "response_mime_type": "application/json", - "response_schema": get_gemini_output_schema(), - "system_instruction": system_prompt, # len 5900 - "thinkingConfig": {"thinkingBudget": 0}, - "seed": seed, - } - - response = client.models.generate_content( - model=model, - contents=contents, - config=generationConfig, - ) - - if response.candidates[0].finish_reason == "MAX_TOKENS": - raise Exception("Max tokens") - - return response.candidates[0].content.parts[0].text - - -def get_default_negative_prompt(existing_json: dict) -> str: - negative_prompt = "" - style_medium = existing_json.get("style_medium", "").lower() - if style_medium in ["photograph", "photography", "photo"]: - negative_prompt = """{'style_medium': 'digital illustration', 'artistic_style': 'non-realistic'}""" - return negative_prompt - - -def json_promptify( - client: genai.Client, - model_id: str, - top_p: float, - temperature: float, - max_tokens: int, - user_prompt: Optional[str] = None, - existing_json: Optional[str] = None, - image: Optional[Image.Image] = None, - seed: int = 42, -) -> str: - if existing_json: - # make sure aesthetic scores are not in the existing json (will be added later) - existing_json = json.loads(existing_json) - if "aesthetics" in existing_json: - existing_json["aesthetics"].pop("aesthetic_score", None) - existing_json["aesthetics"].pop("preference_score", None) - existing_json = json.dumps(existing_json) - - if not user_prompt: - raise ValueError("user_prompt is required if existing_json is provided") - - if image: - mode = "RefineB" - system_prompt, final_prompt = get_instructions(mode) - final_prompt = final_prompt.format(user_prompt=user_prompt, json_data=existing_json) - - else: - mode = "RefineA" - system_prompt, final_prompt = get_instructions(mode) - final_prompt = final_prompt.format(user_prompt=user_prompt, json_data=existing_json) - elif image and user_prompt: - mode = "InspireA" - system_prompt, final_prompt = get_instructions(mode) - final_prompt = final_prompt.format(user_prompt=user_prompt) - elif image and not user_prompt: - mode = "Caption" - system_prompt, final_prompt = get_instructions(mode) - else: - mode = "Generate" - system_prompt, final_prompt = get_instructions(mode) - final_prompt = final_prompt.format(user_prompt=user_prompt) - - json_data = infer_with_gemini( - client=client, - model=model_id, - final_prompt=final_prompt, - system_prompt=system_prompt, - seed=seed, - image=image, - top_p=top_p, - temperature=temperature, - max_tokens=max_tokens, - ) - json_data = validate_json(json_data) - clean_caption = prepare_clean_caption(json_data) - - return clean_caption - - -class BriaFiboGeminiPromptToJson(ModularPipelineBlocks): - model_name = "BriaFibo" - - def __init__(self, model_id="gemini-2.5-flash"): - super().__init__() - api_key = os.getenv("GOOGLE_API_KEY") - if api_key is None: - raise ValueError("Must provide an API key for Gemini through the `GOOGLE_API_KEY` env variable.") - self.model_id = model_id - - @property - def expected_components(self): - return [] - - @property - def inputs(self) -> List[InputParam]: - task_input = InputParam("task", type_hint=str, required=False, description="VLM Task to execute") - prompt_input = InputParam( - "prompt", - type_hint=str, - required=False, - description="Prompt to use", - ) - image_input = InputParam( - name="image", type_hint=Image.Image, required=False, description="image for inspiration mode" - ) - json_prompt_input = InputParam( - name="json_prompt", type_hint=str, required=False, description="JSON prompt to use" - ) - sampling_top_p_input = InputParam( - name="sampling_top_p", type_hint=float, required=False, description="Sampling top p", default=1.0 - ) - sampling_temperature_input = InputParam( - name="sampling_temperature", - type_hint=float, - required=False, - description="Sampling temperature", - default=0.2, - ) - sampling_max_tokens_input = InputParam( - name="sampling_max_tokens", type_hint=int, required=False, description="Sampling max tokens", default=3000 - ) - return [ - task_input, - prompt_input, - image_input, - json_prompt_input, - sampling_top_p_input, - sampling_temperature_input, - sampling_max_tokens_input, - ] - - @property - def intermediate_inputs(self) -> List[InputParam]: - return [] - - @property - def intermediate_outputs(self) -> List[OutputParam]: - return [ - OutputParam( - "json_prompt", - type_hint=str, - description="JSON prompt by the VLM", - ) - ] - - def __call__(self, components, state: PipelineState) -> PipelineState: - block_state = self.get_block_state(state) - prompt = block_state.prompt - image = block_state.image - json_prompt = block_state.json_prompt - client = genai.Client() - json_prompt = json_promptify( - client=client, - model_id=self.model_id, - top_p=block_state.sampling_top_p, - temperature=block_state.sampling_temperature, - max_tokens=block_state.sampling_max_tokens, - user_prompt=prompt, - existing_json=json_prompt, - image=image, - ) - block_state.json_prompt = json_prompt - self.set_block_state(state, block_state) - - return components, state diff --git a/src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo.py b/src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo.py index 690b54607d2c..ef869971551a 100644 --- a/src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo.py +++ b/src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo.py @@ -1,3 +1,13 @@ +# Copyright (c) Bria.ai. All rights reserved. +# +# This file is licensed under the Creative Commons Attribution-NonCommercial 4.0 International Public License (CC-BY-NC-4.0). +# You may obtain a copy of the license at https://creativecommons.org/licenses/by-nc/4.0/ +# +# You are free to share and adapt this material for non-commercial purposes provided you give appropriate credit, +# indicate if changes were made, and do not use the material for commercial purposes. +# +# See the license for further details. + from typing import Any, Callable, Dict, List, Optional, Union import numpy as np @@ -35,20 +45,47 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ - ... + Example: + ```python + import torch + from diffusers import BriaFiboPipeline + from diffusers.modular_pipelines import ModularPipeline + + torch.set_grad_enabled(False) + vlm_pipe = ModularPipeline.from_pretrained("briaai/FIBO-VLM-prompt-to-JSON", trust_remote_code=True) + + pipe = BriaFiboPipeline.from_pretrained( + "briaai/FIBO", + trust_remote_code=True, + torch_dtype=torch.bfloat16, + ) + pipe.enable_sequential_cpu_offload() + + with torch.inference_mode(): + # 1. Create a prompt to generate an initial image + output = vlm_pipe(prompt="a beautiful dog") + json_prompt_generate = output.values["json_prompt"] + + # Generate the image from the structured json prompt + results_generate = pipe(prompt=json_prompt_generate, num_inference_steps=50, guidance_scale=5) + results_generate.images[0].save("image_generate.png") + ``` """ class BriaFiboPipeline(DiffusionPipeline): r""" Args: - transformer ([`GaiaTransformer2DModel`]): - scheduler ([`FlowMatchEulerDiscreteScheduler`]): - A scheduler to be used in combination with `transformer` to denoise the encoded image latents. - vae ([`AutoencoderKL`]): - Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. - text_encoder ([`SmolLM3ForCausalLM`]): + transformer (`BriaFiboTransformer2DModel`): + The transformer model for 2D diffusion modeling. + scheduler (`FlowMatchEulerDiscreteScheduler` or `KarrasDiffusionSchedulers`): + Scheduler to be used with `transformer` to denoise the encoded latents. + vae (`AutoencoderKLWan`): + Variational Auto-Encoder for encoding and decoding images to and from latent representations. + text_encoder (`SmolLM3ForCausalLM`): + Text encoder for processing input prompts. tokenizer (`AutoTokenizer`): + Tokenizer used for processing the input text prompts for the text_encoder. """ model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->transformer->vae" @@ -166,7 +203,7 @@ def encode_prompt( prompt: Union[str, List[str]], device: Optional[torch.device] = None, num_images_per_prompt: int = 1, - do_classifier_free_guidance: bool = True, + guidance_scale: float = 5, negative_prompt: Optional[Union[str, List[str]]] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, @@ -181,8 +218,8 @@ def encode_prompt( torch device num_images_per_prompt (`int`): number of images that should be generated per prompt - do_classifier_free_guidance (`bool`): - whether to use classifier free guidance or not + guidance_scale (`float`): + Guidance scale for classifier free guidance. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is @@ -224,7 +261,7 @@ def encode_prompt( prompt_embeds = prompt_embeds.to(dtype=self.transformer.dtype) prompt_layers = [tensor.to(dtype=self.transformer.dtype) for tensor in prompt_layers] - if do_classifier_free_guidance: + if guidance_scale > 1: if isinstance(negative_prompt, list) and negative_prompt[0] is None: negative_prompt = "" negative_prompt = negative_prompt or "" @@ -302,9 +339,6 @@ def guidance_scale(self): # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. - @property - def do_classifier_free_guidance(self): - return self._guidance_scale > 1 @property def joint_attention_kwargs(self): @@ -320,6 +354,7 @@ def interrupt(self): @staticmethod def _unpack_latents(latents, height, width, vae_scale_factor): + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline batch_size, num_patches, channels = latents.shape height = height // vae_scale_factor @@ -366,6 +401,7 @@ def _pack_latents_no_patch(latents, batch_size, num_channels_latents, height, wi @staticmethod def _pack_latents(latents, batch_size, num_channels_latents, height, width): + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) latents = latents.permute(0, 2, 4, 1, 3, 5) latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) @@ -410,34 +446,7 @@ def prepare_latents( return latents, latent_image_ids @staticmethod - def init_inference_scheduler(height, width, device, image_seq_len, num_inference_steps=1000, noise_scheduler=None): - sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) - - assert height % 16 == 0 and width % 16 == 0 - - mu = calculate_shift( - image_seq_len, - noise_scheduler.config.base_image_seq_len, - noise_scheduler.config.max_image_seq_len, - noise_scheduler.config.base_shift, - noise_scheduler.config.max_shift, - ) - - # Init sigmas and timesteps according to shift size - # This changes the scheduler in-place according to the dynamic scheduling - timesteps, num_inference_steps = retrieve_timesteps( - noise_scheduler, - num_inference_steps=num_inference_steps, - device=device, - timesteps=None, - sigmas=sigmas, - mu=mu, - ) - - return noise_scheduler, timesteps, num_inference_steps, mu - - @staticmethod - def create_attention_matrix(attention_mask): + def _prepare_attention_mask(attention_mask): attention_matrix = torch.einsum("bi,bj->bij", attention_mask, attention_mask) # convert to 0 - keep, -inf ignore @@ -583,7 +592,7 @@ def __call__( ) = self.encode_prompt( prompt=prompt, negative_prompt=negative_prompt, - do_classifier_free_guidance=self.do_classifier_free_guidance, + guidance_scale=guidance_scale, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, device=device, @@ -593,7 +602,7 @@ def __call__( ) prompt_batch_size = prompt_embeds.shape[0] - if self.do_classifier_free_guidance: + if guidance_scale > 1: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) prompt_layers = [ torch.cat([negative_prompt_layers[i], prompt_layers[i]], dim=0) for i in range(len(prompt_layers)) @@ -611,6 +620,7 @@ def __call__( prompt_layers = prompt_layers + [prompt_layers[-1]] * (total_num_layers_transformer - len(prompt_layers)) # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels if do_patching: num_channels_latents = int(num_channels_latents / 4) @@ -630,11 +640,11 @@ def __call__( latent_attention_mask = torch.ones( [latents.shape[0], latents.shape[1]], dtype=latents.dtype, device=latents.device ) - if self.do_classifier_free_guidance: + if guidance_scale > 1: latent_attention_mask = latent_attention_mask.repeat(2, 1) attention_mask = torch.cat([prompt_attention_mask, latent_attention_mask], dim=1) - attention_mask = self.create_attention_matrix(attention_mask) # batch, seq => batch, seq, seq + attention_mask = self._prepare_attention_mask(attention_mask) # batch, seq => batch, seq, seq attention_mask = attention_mask.unsqueeze(dim=1).to(dtype=self.transformer.dtype) # for head broadcasting if self._joint_attention_kwargs is None: @@ -648,13 +658,25 @@ def __call__( else: seq_len = (height // self.vae_scale_factor) * (width // self.vae_scale_factor) - self.noise_scheduler, timesteps, num_inference_steps, mu = self.init_inference_scheduler( - height=height, - width=width, - device=device, + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + + mu = calculate_shift( + seq_len, + self.scheduler.config.base_image_seq_len, + self.scheduler.config.max_image_seq_len, + self.scheduler.config.base_shift, + self.scheduler.config.max_shift, + ) + + # Init sigmas and timesteps according to shift size + # This changes the scheduler in-place according to the dynamic scheduling + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps=num_inference_steps, - noise_scheduler=self.scheduler, - image_seq_len=seq_len, + device=device, + timesteps=None, + sigmas=sigmas, + mu=mu, ) num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) @@ -674,10 +696,7 @@ def __call__( continue # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents - - if type(self.scheduler) != FlowMatchEulerDiscreteScheduler: - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + latent_model_input = torch.cat([latents] * 2) if guidance_scale > 1 else latents # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latent_model_input.shape[0]).to( @@ -697,7 +716,7 @@ def __call__( )[0] # perform guidance - if self.do_classifier_free_guidance: + if guidance_scale > 1: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) @@ -736,7 +755,6 @@ def __call__( else: latents = self._unpack_latents_no_patch(latents, height, width, self.vae_scale_factor) - latents = latents.to(dtype=self.vae.dtype) latents = latents.unsqueeze(dim=2) latents = list(torch.unbind(latents, dim=0)) latents_device = latents[0].device @@ -780,8 +798,8 @@ def check_inputs( callback_on_step_end_tensor_inputs=None, max_sequence_length=None, ): - if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs @@ -818,9 +836,3 @@ def check_inputs( if max_sequence_length is not None and max_sequence_length > 3000: raise ValueError(f"`max_sequence_length` cannot be greater than 3000 but is {max_sequence_length}") - - def to(self, *args, **kwargs): - DiffusionPipeline.to(self, *args, **kwargs) - # We use as float32 since wan22 in their repo use it like this - self.vae.to(dtype=torch.float32) - return self diff --git a/tests/models/transformers/test_models_transformer_bria_fibo.py b/tests/models/transformers/test_models_transformer_bria_fibo.py index ad87b5710aeb..f859f4608bd5 100644 --- a/tests/models/transformers/test_models_transformer_bria_fibo.py +++ b/tests/models/transformers/test_models_transformer_bria_fibo.py @@ -20,7 +20,7 @@ from diffusers import BriaFiboTransformer2DModel from ...testing_utils import enable_full_determinism, torch_device -from ..test_modeling_common import LoraHotSwappingForModelTesterMixin, ModelTesterMixin, TorchCompileTesterMixin +from ..test_modeling_common import ModelTesterMixin enable_full_determinism() @@ -84,48 +84,6 @@ def prepare_init_args_and_inputs_for_common(self): inputs_dict = self.dummy_input return init_dict, inputs_dict - def test_deprecated_inputs_img_txt_ids_3d(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - model = self.model_class(**init_dict) - model.to(torch_device) - model.eval() - - with torch.no_grad(): - output_1 = model(**inputs_dict)[0] - - # update inputs_dict with txt_ids and img_ids as 3d tensors (deprecated) - text_ids_3d = inputs_dict["txt_ids"].unsqueeze(0) - image_ids_3d = inputs_dict["img_ids"].unsqueeze(0) - - assert text_ids_3d.ndim == 3, "text_ids_3d should be a 3d tensor" - assert image_ids_3d.ndim == 3, "img_ids_3d should be a 3d tensor" - - inputs_dict["txt_ids"] = text_ids_3d - inputs_dict["img_ids"] = image_ids_3d - - with torch.no_grad(): - output_2 = model(**inputs_dict)[0] - - self.assertEqual(output_1.shape, output_2.shape) - self.assertTrue( - torch.allclose(output_1, output_2, atol=1e-5), - msg="output with deprecated inputs (img_ids and txt_ids as 3d torch tensors) are not equal as them as 2d inputs", - ) - def test_gradient_checkpointing_is_applied(self): expected_set = {"BriaFiboTransformer2DModel"} super().test_gradient_checkpointing_is_applied(expected_set=expected_set) - - -class BriaFiboTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase): - model_class = BriaFiboTransformer2DModel - - def prepare_init_args_and_inputs_for_common(self): - return BriaFiboTransformerTests().prepare_init_args_and_inputs_for_common() - - -class BriaFiboTransformerLoRAHotSwapTests(LoraHotSwappingForModelTesterMixin, unittest.TestCase): - model_class = BriaFiboTransformer2DModel - - def prepare_init_args_and_inputs_for_common(self): - return BriaFiboTransformerTests().prepare_init_args_and_inputs_for_common() diff --git a/tests/pipelines/bria_fibo/test_pipeline_bria_fibo.py b/tests/pipelines/bria_fibo/test_pipeline_bria_fibo.py index 15cdb82fe128..969634f59708 100644 --- a/tests/pipelines/bria_fibo/test_pipeline_bria_fibo.py +++ b/tests/pipelines/bria_fibo/test_pipeline_bria_fibo.py @@ -128,7 +128,7 @@ def test_image_output_shape(self): pipe = pipe.to(torch_device) inputs = self.get_dummy_inputs(torch_device) - height_width_pairs = [(32, 32)] + height_width_pairs = [(32, 32), (64, 64), (32, 64)] for height, width in height_width_pairs: expected_height = height expected_width = width @@ -181,13 +181,18 @@ def test_save_load_float16(self, expected_max_diff=1e-2): max_diff, expected_max_diff, "The output of the fp16 pipeline changed after saving and loading." ) - def test_to_dtype(self): - components = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe.set_progress_bar_config(disable=None) + # def test_to_dtype(self): + # components = self.get_dummy_components() + # pipe = self.pipeline_class(**components) + # pipe.set_progress_bar_config(disable=None) + + # model_dtypes = [component.dtype for component in components.values() if hasattr(component, "dtype")] + # self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes)) - model_dtypes = [component.dtype for component in components.values() if hasattr(component, "dtype")] - self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes)) + # pipe.to(dtype=torch.float16) + # model_dtypes = [component.dtype for component in components.values() if hasattr(component, "dtype")] + # self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes)) + @unittest.skip("") def test_save_load_dduf(self): pass From f1b52327306d5c521258b624dfbb07779bfb420b Mon Sep 17 00:00:00 2001 From: galbria Date: Mon, 27 Oct 2025 15:54:29 +0000 Subject: [PATCH 04/15] Refactor BriaFibo classes and update pipeline parameters - Updated BriaFiboAttnProcessor and BriaFiboAttention classes to reflect changes from Flux equivalents. - Modified the _unpack_latents method in BriaFiboPipeline to improve clarity. - Increased the default max_sequence_length to 3000 and added a new optional parameter do_patching. - Cleaned up test_pipeline_bria_fibo.py by removing unused imports and skipping unsupported tests. --- .../transformers/transformer_bria_fibo.py | 4 +- .../pipelines/bria_fibo/pipeline_bria_fibo.py | 5 +- .../bria_fibo/test_pipeline_bria_fibo.py | 65 +------------------ 3 files changed, 8 insertions(+), 66 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_bria_fibo.py b/src/diffusers/models/transformers/transformer_bria_fibo.py index e1bfde955527..68a0765536f6 100644 --- a/src/diffusers/models/transformers/transformer_bria_fibo.py +++ b/src/diffusers/models/transformers/transformer_bria_fibo.py @@ -66,8 +66,8 @@ def _get_qkv_projections(attn: "BriaFiboAttention", hidden_states, encoder_hidde return _get_projections(attn, hidden_states, encoder_hidden_states) +# Copied from diffusers.models.transformers.transformer_flux.FluxAttnProcessor FluxAttnProcessor->BriaFiboAttnProcessor, FluxAttention-> BriaFiboAttention class BriaFiboAttnProcessor: - # Copied from diffusers.models.transformers.transformer_flux.FluxAttnProcessor _attention_backend = None _parallel_config = None @@ -134,8 +134,8 @@ def __call__( return hidden_states +# Copied from diffusers.models.transformers.transformer_flux.FluxAttention -> BriaFiboAttention class BriaFiboAttention(torch.nn.Module, AttentionModuleMixin): - # Copied from diffusers.models.transformers.transformer_flux.FluxAttention _default_processor_cls = BriaFiboAttnProcessor _available_processors = [ BriaFiboAttnProcessor, diff --git a/src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo.py b/src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo.py index ef869971551a..858193996388 100644 --- a/src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo.py +++ b/src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo.py @@ -353,8 +353,8 @@ def interrupt(self): return self._interrupt @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents def _unpack_latents(latents, height, width, vae_scale_factor): - # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline batch_size, num_patches, channels = latents.shape height = height // vae_scale_factor @@ -542,7 +542,8 @@ def __call__( The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. - max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`. + max_sequence_length (`int` defaults to 3000): Maximum sequence length to use with the `prompt`. + do_patching (`bool`, *optional*, defaults to `False`): Whether to use patching. Examples: Returns: [`~pipelines.flux.BriaFiboPipelineOutput`] or `tuple`: [`~pipelines.flux.BriaFiboPipelineOutput`] if diff --git a/tests/pipelines/bria_fibo/test_pipeline_bria_fibo.py b/tests/pipelines/bria_fibo/test_pipeline_bria_fibo.py index 969634f59708..76b41114f859 100644 --- a/tests/pipelines/bria_fibo/test_pipeline_bria_fibo.py +++ b/tests/pipelines/bria_fibo/test_pipeline_bria_fibo.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import tempfile import unittest import numpy as np @@ -26,11 +25,10 @@ FlowMatchEulerDiscreteScheduler, ) from diffusers.models.transformers.transformer_bria_fibo import BriaFiboTransformer2DModel -from tests.pipelines.test_pipelines_common import PipelineTesterMixin, to_np +from tests.pipelines.test_pipelines_common import PipelineTesterMixin from ...testing_utils import ( enable_full_determinism, - require_torch_accelerator, torch_device, ) @@ -45,6 +43,7 @@ class BriaFiboPipelineFastTests(PipelineTesterMixin, unittest.TestCase): test_xformers_attention = False test_layerwise_casting = False test_group_offloading = False + supports_dduf = False def get_dummy_components(self): torch.manual_seed(0) @@ -107,6 +106,7 @@ def get_dummy_inputs(self, device, seed=0): } return inputs + @unittest.skip(reason="will not be supported due to dim-fusion") def test_encode_prompt_works_in_isolation(self): pass @@ -137,62 +137,3 @@ def test_image_output_shape(self): image = pipe(**inputs).images[0] output_height, output_width, _ = image.shape assert (output_height, output_width) == (expected_height, expected_width) - - @unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU") - @require_torch_accelerator - def test_save_load_float16(self, expected_max_diff=1e-2): - components = self.get_dummy_components() - for name, module in components.items(): - if hasattr(module, "half"): - components[name] = module.to(torch_device).half() - - pipe = self.pipeline_class(**components) - for component in pipe.components.values(): - if hasattr(component, "set_default_attn_processor"): - component.set_default_attn_processor() - pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - - inputs = self.get_dummy_inputs(torch_device) - output = pipe(**inputs)[0] - - with tempfile.TemporaryDirectory() as tmpdir: - pipe.save_pretrained(tmpdir) - pipe_loaded = self.pipeline_class.from_pretrained(tmpdir, torch_dtype=torch.float16) - for component in pipe_loaded.components.values(): - if hasattr(component, "set_default_attn_processor"): - component.set_default_attn_processor() - pipe_loaded.to(torch_device) - pipe_loaded.set_progress_bar_config(disable=None) - - for name, component in pipe_loaded.components.items(): - if name == "vae": - continue - if hasattr(component, "dtype"): - self.assertTrue( - component.dtype == torch.float16, - f"`{name}.dtype` switched from `float16` to {component.dtype} after loading.", - ) - - inputs = self.get_dummy_inputs(torch_device) - output_loaded = pipe_loaded(**inputs)[0] - max_diff = np.abs(to_np(output) - to_np(output_loaded)).max() - self.assertLess( - max_diff, expected_max_diff, "The output of the fp16 pipeline changed after saving and loading." - ) - - # def test_to_dtype(self): - # components = self.get_dummy_components() - # pipe = self.pipeline_class(**components) - # pipe.set_progress_bar_config(disable=None) - - # model_dtypes = [component.dtype for component in components.values() if hasattr(component, "dtype")] - # self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes)) - - # pipe.to(dtype=torch.float16) - # model_dtypes = [component.dtype for component in components.values() if hasattr(component, "dtype")] - # self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes)) - - @unittest.skip("") - def test_save_load_dduf(self): - pass From 57e6315f4dff667e79bade6831b02ca07c5039b6 Mon Sep 17 00:00:00 2001 From: galbria Date: Mon, 27 Oct 2025 16:01:31 +0000 Subject: [PATCH 05/15] edit the docs of FIBO --- docs/source/en/api/pipelines/bria_fibo.md | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/docs/source/en/api/pipelines/bria_fibo.md b/docs/source/en/api/pipelines/bria_fibo.md index 086da924bdf1..96cad10dda15 100644 --- a/docs/source/en/api/pipelines/bria_fibo.md +++ b/docs/source/en/api/pipelines/bria_fibo.md @@ -18,6 +18,14 @@ FIBO is trained on structured JSON captions up to 1,000+ words and designed to u With only 8 billion parameters, FIBO provides a new level of image quality, prompt adherence and proffesional control. +FIBO is trained exclusively on a structured prompt and will not work with freeform text prompts. +you can use the [FIBO-VLM-prompt-to-JSON](https://huggingface.co/briaai/FIBO-VLM-prompt-to-JSON) model or the [FIBO-gemini-prompt-to-JSON](https://huggingface.co/briaai/FIBO-gemini-prompt-to-JSON) to convert your freeform text prompt to a structured JSON prompt. + +its not recommended to use freeform text prompts directly with FIBO, as it will not produce the best results. + +you can learn more about FIBO in [Bria Fibo Hugging Face page](https://huggingface.co/briaai/FIBO). + + ## Usage _As the model is gated, before using it with diffusers you first need to go to the [Bria Fibo Hugging Face page](https://huggingface.co/briaai/FIBO), fill in the form and accept the gate. Once you are in, you need to login so that your system knows you’ve accepted the gate._ From d0a6cb6ed1ba9d49841bbe6e5a1a0805c5892f17 Mon Sep 17 00:00:00 2001 From: galbria Date: Tue, 28 Oct 2025 07:47:06 +0000 Subject: [PATCH 06/15] Remove unused BriaFibo imports and update CPU offload method in BriaFiboPipeline --- src/diffusers/modular_pipelines/__init__.py | 1 - src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo.py | 3 +-- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/src/diffusers/modular_pipelines/__init__.py b/src/diffusers/modular_pipelines/__init__.py index 91f9397629b7..2d15db001d44 100644 --- a/src/diffusers/modular_pipelines/__init__.py +++ b/src/diffusers/modular_pipelines/__init__.py @@ -45,7 +45,6 @@ "InsertableDict", ] _import_structure["stable_diffusion_xl"] = ["StableDiffusionXLAutoBlocks", "StableDiffusionXLModularPipeline"] - _import_structure["bria_fibo"] = ["BriaFiboVLMPromptToJson", "BriaFiboGeminiPromptToJson"] _import_structure["wan"] = ["WanAutoBlocks", "WanModularPipeline"] _import_structure["flux"] = [ "FluxAutoBlocks", diff --git a/src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo.py b/src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo.py index 858193996388..6946b3aef213 100644 --- a/src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo.py +++ b/src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo.py @@ -59,7 +59,7 @@ trust_remote_code=True, torch_dtype=torch.bfloat16, ) - pipe.enable_sequential_cpu_offload() + pipe.enable_model_cpu_offload() with torch.inference_mode(): # 1. Create a prompt to generate an initial image @@ -757,7 +757,6 @@ def __call__( latents = self._unpack_latents_no_patch(latents, height, width, self.vae_scale_factor) latents = latents.unsqueeze(dim=2) - latents = list(torch.unbind(latents, dim=0)) latents_device = latents[0].device latents_dtype = latents[0].dtype latents_mean = ( From 69e49599a17b9fbf013f1bbf209bf612a69b1255 Mon Sep 17 00:00:00 2001 From: galbria Date: Tue, 28 Oct 2025 08:17:05 +0000 Subject: [PATCH 07/15] Refactor FIBO classes to BriaFibo naming convention - Updated class names from FIBO to BriaFibo for consistency across the module. - Modified instances of FIBOEmbedND, FIBOTimesteps, TextProjection, and TimestepProjEmbeddings to reflect the new naming. - Ensured all references in the BriaFiboTransformer2DModel are updated accordingly. --- .../transformers/transformer_bria_fibo.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_bria_fibo.py b/src/diffusers/models/transformers/transformer_bria_fibo.py index 68a0765536f6..714faeda5cf8 100644 --- a/src/diffusers/models/transformers/transformer_bria_fibo.py +++ b/src/diffusers/models/transformers/transformer_bria_fibo.py @@ -214,7 +214,7 @@ def forward( return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs) -class FIBOEmbedND(torch.nn.Module): +class BriaFiboEmbedND(torch.nn.Module): # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11 def __init__(self, theta: int, axes_dim: List[int]): super().__init__() @@ -297,7 +297,7 @@ def forward( return hidden_states -class TextProjection(nn.Module): +class BriaFiboTextProjection(nn.Module): def __init__(self, in_features, hidden_size): super().__init__() self.linear = nn.Linear(in_features=in_features, out_features=hidden_size, bias=False) @@ -393,7 +393,7 @@ def forward( return encoder_hidden_states, hidden_states -class FIBOTimesteps(nn.Module): +class BriaFiboTimesteps(nn.Module): def __init__( self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1, time_theta=10000 ): @@ -416,11 +416,11 @@ def forward(self, timesteps): return t_emb -class TimestepProjEmbeddings(nn.Module): +class BriaFiboTimestepProjEmbeddings(nn.Module): def __init__(self, embedding_dim, time_theta): super().__init__() - self.time_proj = FIBOTimesteps( + self.time_proj = BriaFiboTimesteps( num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, time_theta=time_theta ) self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) @@ -469,12 +469,12 @@ def __init__( self.out_channels = in_channels self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim - self.pos_embed = FIBOEmbedND(theta=rope_theta, axes_dim=axes_dims_rope) + self.pos_embed = BriaFiboEmbedND(theta=rope_theta, axes_dim=axes_dims_rope) - self.time_embed = TimestepProjEmbeddings(embedding_dim=self.inner_dim, time_theta=time_theta) + self.time_embed = BriaFiboTimestepProjEmbeddings(embedding_dim=self.inner_dim, time_theta=time_theta) if guidance_embeds: - self.guidance_embed = TimestepProjEmbeddings(embedding_dim=self.inner_dim) + self.guidance_embed = BriaFiboTimestepProjEmbeddings(embedding_dim=self.inner_dim) self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.inner_dim) self.x_embedder = torch.nn.Linear(self.config.in_channels, self.inner_dim) @@ -507,7 +507,7 @@ def __init__( self.gradient_checkpointing = False caption_projection = [ - TextProjection(in_features=text_encoder_dim, hidden_size=self.inner_dim // 2) + BriaFiboTextProjection(in_features=text_encoder_dim, hidden_size=self.inner_dim // 2) for i in range(self.config.num_layers + self.config.num_single_layers) ] self.caption_projection = nn.ModuleList(caption_projection) From 612617bd2a9e76bef65a749f391f52b657daf656 Mon Sep 17 00:00:00 2001 From: galbria Date: Tue, 28 Oct 2025 08:40:32 +0000 Subject: [PATCH 08/15] Add BriaFiboTransformer2DModel import to transformers module --- src/diffusers/models/transformers/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index 6c66131dea62..2fe1159eec4c 100755 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -18,6 +18,7 @@ from .transformer_2d import Transformer2DModel from .transformer_allegro import AllegroTransformer3DModel from .transformer_bria import BriaTransformer2DModel + from .transformer_bria_fibo import BriaFiboTransformer2DModel from .transformer_chroma import ChromaTransformer2DModel from .transformer_cogview3plus import CogView3PlusTransformer2DModel from .transformer_cogview4 import CogView4Transformer2DModel From 51353e890ab0b7d44124cab0eeff23938a05b96e Mon Sep 17 00:00:00 2001 From: galbria Date: Tue, 28 Oct 2025 08:55:51 +0000 Subject: [PATCH 09/15] Remove unused BriaFibo imports from modular pipelines and add BriaFiboTransformer2DModel and BriaFiboPipeline classes to dummy objects for enhanced compatibility with torch and transformers. --- src/diffusers/modular_pipelines/__init__.py | 1 - src/diffusers/utils/dummy_pt_objects.py | 15 +++++++++++++++ .../utils/dummy_torch_and_transformers_objects.py | 15 +++++++++++++++ 3 files changed, 30 insertions(+), 1 deletion(-) diff --git a/src/diffusers/modular_pipelines/__init__.py b/src/diffusers/modular_pipelines/__init__.py index 2d15db001d44..86ed735134ff 100644 --- a/src/diffusers/modular_pipelines/__init__.py +++ b/src/diffusers/modular_pipelines/__init__.py @@ -69,7 +69,6 @@ except OptionalDependencyNotAvailable: from ..utils.dummy_pt_objects import * # noqa F403 else: - from .bria_fibo import BriaFiboGeminiPromptToJson, BriaFiboVLMPromptToJson from .components_manager import ComponentsManager from .flux import FluxAutoBlocks, FluxKontextAutoBlocks, FluxKontextModularPipeline, FluxModularPipeline from .modular_pipeline import ( diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 6ef6b8b0e949..3c426d503996 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -588,6 +588,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class BriaFiboTransformer2DModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class BriaTransformer2DModel(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 7c1dcba9c7f0..20575ff2294d 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -482,6 +482,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class BriaFiboPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class BriaPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] From 4c6f59f944beae59b107519a9c927ebaebb352bc Mon Sep 17 00:00:00 2001 From: galbria Date: Tue, 28 Oct 2025 09:08:29 +0000 Subject: [PATCH 10/15] Update BriaFibo classes with copied documentation and fix import typo in pipeline module - Added documentation comments indicating the source of copied code in BriaFiboTransformerBlock and _pack_latents methods. - Corrected the import statement for BriaFiboPipeline in the pipelines module. --- src/diffusers/models/transformers/transformer_bria_fibo.py | 2 +- src/diffusers/pipelines/__init__.py | 2 +- src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_bria_fibo.py b/src/diffusers/models/transformers/transformer_bria_fibo.py index 714faeda5cf8..343a7136a3b4 100644 --- a/src/diffusers/models/transformers/transformer_bria_fibo.py +++ b/src/diffusers/models/transformers/transformer_bria_fibo.py @@ -308,8 +308,8 @@ def forward(self, caption): @maybe_allow_in_graph +# Copied from diffusers.models.transformers.transformer_flux.FluxTransformerBlock class BriaFiboTransformerBlock(nn.Module): - # Copied from diffusers.models.transformers.transformer_flux.FluxTransformerBlock def __init__( self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6 ): diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 1b96214196d5..db357669b6f3 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -563,7 +563,7 @@ from .aura_flow import AuraFlowPipeline from .blip_diffusion import BlipDiffusionPipeline from .bria import BriaPipeline - from .bria_fibo import BriaFiboPipelin + from .bria_fibo import BriaFiboPipeline from .chroma import ChromaImg2ImgPipeline, ChromaPipeline from .cogvideo import ( CogVideoXFunControlPipeline, diff --git a/src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo.py b/src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo.py index 6946b3aef213..97e98c5f8de3 100644 --- a/src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo.py +++ b/src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo.py @@ -400,8 +400,8 @@ def _pack_latents_no_patch(latents, batch_size, num_channels_latents, height, wi return latents @staticmethod + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline def _pack_latents(latents, batch_size, num_channels_latents, height, width): - # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) latents = latents.permute(0, 2, 4, 1, 3, 5) latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) From 455ae70b11de8f8ffa941e3ed537b0cf0c99590d Mon Sep 17 00:00:00 2001 From: galbria Date: Tue, 28 Oct 2025 09:47:16 +0000 Subject: [PATCH 11/15] Remove unused BriaFibo imports from __init__.py to streamline modular pipelines. --- src/diffusers/__init__.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index efce36c0d4b8..94104667b541 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -393,8 +393,6 @@ else: _import_structure["modular_pipelines"].extend( [ - "BriaFiboGeminiPromptToJson", - "BriaFiboVLMPromptToJson", "FluxAutoBlocks", "FluxKontextAutoBlocks", "FluxKontextModularPipeline", From 94abe1c3b975ebdffa7273912e7750722b1247c3 Mon Sep 17 00:00:00 2001 From: galbria Date: Tue, 28 Oct 2025 10:01:07 +0000 Subject: [PATCH 12/15] Refactor documentation comments in BriaFibo classes to indicate inspiration from existing implementations - Updated comments in BriaFiboAttnProcessor, BriaFiboAttention, and BriaFiboPipeline to reflect that the code is inspired by other modules rather than copied. - Enhanced clarity on the origins of the methods to maintain proper attribution. --- src/diffusers/models/transformers/transformer_bria_fibo.py | 4 ++-- src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_bria_fibo.py b/src/diffusers/models/transformers/transformer_bria_fibo.py index 343a7136a3b4..a4d33a233246 100644 --- a/src/diffusers/models/transformers/transformer_bria_fibo.py +++ b/src/diffusers/models/transformers/transformer_bria_fibo.py @@ -66,7 +66,7 @@ def _get_qkv_projections(attn: "BriaFiboAttention", hidden_states, encoder_hidde return _get_projections(attn, hidden_states, encoder_hidden_states) -# Copied from diffusers.models.transformers.transformer_flux.FluxAttnProcessor FluxAttnProcessor->BriaFiboAttnProcessor, FluxAttention-> BriaFiboAttention +# Inspired by from diffusers.models.transformers.transformer_flux.FluxAttnProcessor FluxAttnProcessor->BriaFiboAttnProcessor, FluxAttention-> BriaFiboAttention class BriaFiboAttnProcessor: _attention_backend = None _parallel_config = None @@ -134,7 +134,7 @@ def __call__( return hidden_states -# Copied from diffusers.models.transformers.transformer_flux.FluxAttention -> BriaFiboAttention +# Inspired by from diffusers.models.transformers.transformer_flux.FluxAttention -> BriaFiboAttention class BriaFiboAttention(torch.nn.Module, AttentionModuleMixin): _default_processor_cls = BriaFiboAttnProcessor _available_processors = [ diff --git a/src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo.py b/src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo.py index 97e98c5f8de3..ee7f9e20c0e3 100644 --- a/src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo.py +++ b/src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo.py @@ -353,7 +353,7 @@ def interrupt(self): return self._interrupt @staticmethod - # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents + # Inspired by from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents FluxPipeline-> BriaFiboPipeline _unpack_latents def _unpack_latents(latents, height, width, vae_scale_factor): batch_size, num_patches, channels = latents.shape @@ -400,7 +400,7 @@ def _pack_latents_no_patch(latents, batch_size, num_channels_latents, height, wi return latents @staticmethod - # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline + # Inspired by from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline def _pack_latents(latents, batch_size, num_channels_latents, height, width): latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) latents = latents.permute(0, 2, 4, 1, 3, 5) From dc77b316c5c7f4119d2b0dc0c3a60c048143137b Mon Sep 17 00:00:00 2001 From: galbria Date: Tue, 28 Oct 2025 10:11:46 +0000 Subject: [PATCH 13/15] change Inspired by to Based on --- src/diffusers/models/transformers/transformer_bria_fibo.py | 6 +++--- src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_bria_fibo.py b/src/diffusers/models/transformers/transformer_bria_fibo.py index a4d33a233246..0920d48822d1 100644 --- a/src/diffusers/models/transformers/transformer_bria_fibo.py +++ b/src/diffusers/models/transformers/transformer_bria_fibo.py @@ -66,7 +66,7 @@ def _get_qkv_projections(attn: "BriaFiboAttention", hidden_states, encoder_hidde return _get_projections(attn, hidden_states, encoder_hidden_states) -# Inspired by from diffusers.models.transformers.transformer_flux.FluxAttnProcessor FluxAttnProcessor->BriaFiboAttnProcessor, FluxAttention-> BriaFiboAttention +# Based on from diffusers.models.transformers.transformer_flux.FluxAttnProcessor class BriaFiboAttnProcessor: _attention_backend = None _parallel_config = None @@ -134,7 +134,7 @@ def __call__( return hidden_states -# Inspired by from diffusers.models.transformers.transformer_flux.FluxAttention -> BriaFiboAttention +# Based on from diffusers.models.transformers.transformer_flux.FluxAttention class BriaFiboAttention(torch.nn.Module, AttentionModuleMixin): _default_processor_cls = BriaFiboAttnProcessor _available_processors = [ @@ -308,7 +308,7 @@ def forward(self, caption): @maybe_allow_in_graph -# Copied from diffusers.models.transformers.transformer_flux.FluxTransformerBlock +# Based on from diffusers.models.transformers.transformer_flux.FluxTransformerBlock class BriaFiboTransformerBlock(nn.Module): def __init__( self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6 diff --git a/src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo.py b/src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo.py index ee7f9e20c0e3..f35eb50e3469 100644 --- a/src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo.py +++ b/src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo.py @@ -353,7 +353,7 @@ def interrupt(self): return self._interrupt @staticmethod - # Inspired by from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents FluxPipeline-> BriaFiboPipeline _unpack_latents + # Based on from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents def _unpack_latents(latents, height, width, vae_scale_factor): batch_size, num_patches, channels = latents.shape @@ -400,7 +400,7 @@ def _pack_latents_no_patch(latents, batch_size, num_channels_latents, height, wi return latents @staticmethod - # Inspired by from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline + # Based on from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline def _pack_latents(latents, batch_size, num_channels_latents, height, width): latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) latents = latents.permute(0, 2, 4, 1, 3, 5) From a3837719a30eac874858ed65db4d6a4a0df267ad Mon Sep 17 00:00:00 2001 From: galbria Date: Tue, 28 Oct 2025 10:17:21 +0000 Subject: [PATCH 14/15] add reference link and fix trailing whitespace --- src/diffusers/models/transformers/transformer_bria_fibo.py | 4 ++-- src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_bria_fibo.py b/src/diffusers/models/transformers/transformer_bria_fibo.py index 0920d48822d1..f47ea7a776c4 100644 --- a/src/diffusers/models/transformers/transformer_bria_fibo.py +++ b/src/diffusers/models/transformers/transformer_bria_fibo.py @@ -66,7 +66,7 @@ def _get_qkv_projections(attn: "BriaFiboAttention", hidden_states, encoder_hidde return _get_projections(attn, hidden_states, encoder_hidden_states) -# Based on from diffusers.models.transformers.transformer_flux.FluxAttnProcessor +# Based on https://github.com/huggingface/diffusers/blob/55d49d4379007740af20629bb61aba9546c6b053/src/diffusers/models/transformers/transformer_flux.py class BriaFiboAttnProcessor: _attention_backend = None _parallel_config = None @@ -134,7 +134,7 @@ def __call__( return hidden_states -# Based on from diffusers.models.transformers.transformer_flux.FluxAttention +# Based on https://github.com/huggingface/diffusers/blob/55d49d4379007740af20629bb61aba9546c6b053/src/diffusers/models/transformers/transformer_flux.py class BriaFiboAttention(torch.nn.Module, AttentionModuleMixin): _default_processor_cls = BriaFiboAttnProcessor _available_processors = [ diff --git a/src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo.py b/src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo.py index f35eb50e3469..f7cfc29b259f 100644 --- a/src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo.py +++ b/src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo.py @@ -353,7 +353,7 @@ def interrupt(self): return self._interrupt @staticmethod - # Based on from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents + # Based on https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/flux/pipeline_flux.py def _unpack_latents(latents, height, width, vae_scale_factor): batch_size, num_patches, channels = latents.shape @@ -400,7 +400,7 @@ def _pack_latents_no_patch(latents, batch_size, num_channels_latents, height, wi return latents @staticmethod - # Based on from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline + # Based on https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py def _pack_latents(latents, batch_size, num_channels_latents, height, width): latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) latents = latents.permute(0, 2, 4, 1, 3, 5) From 7f3dd1dc4d110ca1406f55cd430f38442f7b0c57 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 28 Oct 2025 15:54:01 +0530 Subject: [PATCH 15/15] Add BriaFiboTransformer2DModel documentation and update comments in BriaFibo classes - Introduced a new documentation file for BriaFiboTransformer2DModel. - Updated comments in BriaFiboAttnProcessor, BriaFiboAttention, and BriaFiboPipeline to clarify the origins of the code, indicating copied sources for better attribution. --- .../en/api/models/transformer_bria_fibo.md | 19 +++++++++++++++++++ .../transformers/transformer_bria_fibo.py | 6 ++---- .../pipelines/bria_fibo/pipeline_bria_fibo.py | 6 +++--- 3 files changed, 24 insertions(+), 7 deletions(-) create mode 100644 docs/source/en/api/models/transformer_bria_fibo.md diff --git a/docs/source/en/api/models/transformer_bria_fibo.md b/docs/source/en/api/models/transformer_bria_fibo.md new file mode 100644 index 000000000000..5691746ccd78 --- /dev/null +++ b/docs/source/en/api/models/transformer_bria_fibo.md @@ -0,0 +1,19 @@ + + +# BriaFiboTransformer2DModel + +A modified flux Transformer model from [Bria](https://huggingface.co/briaai/FIBO) + +## BriaFiboTransformer2DModel + +[[autodoc]] BriaFiboTransformer2DModel diff --git a/src/diffusers/models/transformers/transformer_bria_fibo.py b/src/diffusers/models/transformers/transformer_bria_fibo.py index f47ea7a776c4..09f79619320d 100644 --- a/src/diffusers/models/transformers/transformer_bria_fibo.py +++ b/src/diffusers/models/transformers/transformer_bria_fibo.py @@ -66,7 +66,7 @@ def _get_qkv_projections(attn: "BriaFiboAttention", hidden_states, encoder_hidde return _get_projections(attn, hidden_states, encoder_hidden_states) -# Based on https://github.com/huggingface/diffusers/blob/55d49d4379007740af20629bb61aba9546c6b053/src/diffusers/models/transformers/transformer_flux.py +# Copied from diffusers.models.transformers.transformer_flux.FluxAttnProcessor with FluxAttnProcessor->BriaFiboAttnProcessor, FluxAttention->BriaFiboAttention class BriaFiboAttnProcessor: _attention_backend = None _parallel_config = None @@ -137,9 +137,7 @@ def __call__( # Based on https://github.com/huggingface/diffusers/blob/55d49d4379007740af20629bb61aba9546c6b053/src/diffusers/models/transformers/transformer_flux.py class BriaFiboAttention(torch.nn.Module, AttentionModuleMixin): _default_processor_cls = BriaFiboAttnProcessor - _available_processors = [ - BriaFiboAttnProcessor, - ] + _available_processors = [BriaFiboAttnProcessor] def __init__( self, diff --git a/src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo.py b/src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo.py index f7cfc29b259f..85d29029e667 100644 --- a/src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo.py +++ b/src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo.py @@ -353,7 +353,7 @@ def interrupt(self): return self._interrupt @staticmethod - # Based on https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/flux/pipeline_flux.py + # Based on diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents def _unpack_latents(latents, height, width, vae_scale_factor): batch_size, num_patches, channels = latents.shape @@ -364,10 +364,10 @@ def _unpack_latents(latents, height, width, vae_scale_factor): latents = latents.permute(0, 3, 1, 4, 2, 5) latents = latents.reshape(batch_size, channels // (2 * 2), height, width) - return latents @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids def _prepare_latent_image_ids(batch_size, height, width, device, dtype): latent_image_ids = torch.zeros(height, width, 3) latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] @@ -400,7 +400,7 @@ def _pack_latents_no_patch(latents, batch_size, num_channels_latents, height, wi return latents @staticmethod - # Based on https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents def _pack_latents(latents, batch_size, num_channels_latents, height, width): latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) latents = latents.permute(0, 2, 4, 1, 3, 5)