diff --git a/pyproject.toml b/pyproject.toml index 72dbc51d4..092cc5983 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -87,6 +87,7 @@ mflux-generate-kontext = "mflux.models.flux.cli.flux_generate_kontext:main" mflux-generate-qwen = "mflux.models.qwen.cli.qwen_image_generate:main" mflux-generate-qwen-edit = "mflux.models.qwen.cli.qwen_image_edit_generate:main" mflux-generate-fibo = "mflux.models.fibo.cli.fibo_generate:main" +mflux-generate-fibo-edit = "mflux.models.fibo.cli.fibo_edit_generate:main" mflux-generate-z-image = "mflux.models.z_image.cli.z_image_generate:main" mflux-generate-z-image-turbo = "mflux.models.z_image.cli.z_image_turbo_generate:main" mflux-refine-fibo = "mflux.models.fibo_vlm.cli.fibo_refine:main" diff --git a/src/mflux/cli/defaults/defaults.py b/src/mflux/cli/defaults/defaults.py index 8729f41aa..128bad9e6 100644 --- a/src/mflux/cli/defaults/defaults.py +++ b/src/mflux/cli/defaults/defaults.py @@ -19,6 +19,7 @@ "dev-krea", "qwen", "fibo", + "fibo-edit", "z-image", "z-image-turbo", "flux2-klein-4b", @@ -34,6 +35,7 @@ "qwen-image": 20, "qwen-image-edit": 20, "fibo": 20, + "fibo-edit": 20, "z-image": 50, "z-image-turbo": 9, "flux2-klein-4b": 4, diff --git a/src/mflux/models/common/config/model_config.py b/src/mflux/models/common/config/model_config.py index a8ba9829b..7038a4088 100644 --- a/src/mflux/models/common/config/model_config.py +++ b/src/mflux/models/common/config/model_config.py @@ -137,6 +137,11 @@ def qwen_image_edit() -> "ModelConfig": def fibo() -> "ModelConfig": return AVAILABLE_MODELS["fibo"] + @staticmethod + @lru_cache + def fibo_edit() -> "ModelConfig": + return AVAILABLE_MODELS["fibo-edit"] + @staticmethod @lru_cache def z_image_turbo() -> "ModelConfig": @@ -453,8 +458,20 @@ def from_name( supports_guidance=True, requires_sigma_shift=False, ), - "z-image": ModelConfig( + "fibo-edit": ModelConfig( priority=18, + aliases=["fibo-edit", "fiboedit"], + model_name="briaai/Fibo-Edit", + base_model=None, + controlnet_model=None, + custom_transformer_model=None, + num_train_steps=1000, + max_sequence_length=512, + supports_guidance=True, + requires_sigma_shift=False, + ), + "z-image": ModelConfig( + priority=19, aliases=["z-image", "zimage"], model_name="Tongyi-MAI/Z-Image", base_model=None, @@ -466,7 +483,7 @@ def from_name( requires_sigma_shift=True, ), "z-image-turbo": ModelConfig( - priority=19, + priority=20, aliases=["z-image-turbo", "zimage-turbo"], model_name="Tongyi-MAI/Z-Image-Turbo", base_model=None, @@ -478,7 +495,7 @@ def from_name( requires_sigma_shift=True, ), "seedvr2-3b": ModelConfig( - priority=20, + priority=21, aliases=["seedvr2-3b", "seedvr2"], model_name="numz/SeedVR2_comfyUI", base_model=None, diff --git a/src/mflux/models/fibo/README.md b/src/mflux/models/fibo/README.md index b117b6867..1f8e4fabe 100644 --- a/src/mflux/models/fibo/README.md +++ b/src/mflux/models/fibo/README.md @@ -13,7 +13,7 @@ Most text-to-image models excel at imagination—but not control. FIBO is traine - **Strong prompt adherence**: High alignment on PRISM-style evaluations - **Enterprise-grade**: 100% licensed data with governance, repeatability, and legal clarity -## The three modes: Generate, Refine, and Inspire +## The four modes: Generate, Edit, Refine, and Inspire ### Generate While the actual prompt input to FIBO is a structured JSON file, the generate command provides an interface to input pure text prompts. These are then expanded into structured JSON prompts using FIBO's Vision-Language Model (VLM) before being passed to the diffusion model for image generation. @@ -259,6 +259,35 @@ image.save("owl_white.png") It is worth noting that refine does not work the same way as other editing techniques like Flux Kontext or Qwen Image Edit. Instead of modifying an existing image, it modifies the underlying **structured prompt** to produce a new image. +### Edit +FIBO Edit supports direct image-conditioned editing using a structured JSON prompt that includes an `edit_instruction` field. + +```sh +mflux-generate-fibo-edit \ + --image-path owl_original.png \ + --prompt-file owl_brown.json \ + --edit-instruction "Make the owl white and add round glasses while keeping composition unchanged." \ + --width 1024 \ + --height 560 \ + --steps 20 \ + --guidance 4.0 \ + --seed 42 \ + --output owl_white_edit.png +``` + +Optional localized editing is supported with a mask: + +```sh +mflux-generate-fibo-edit \ + --image-path owl_original.png \ + --mask-path owl_mask.png \ + --prompt-file owl_brown.json \ + --edit-instruction "Replace only the owl with a white owl wearing glasses." \ + --steps 20 \ + --seed 42 \ + --output owl_masked_edit.png +``` + ### Inspire Provide an image instead of text. FIBO's vision-language model extracts a detailed, structured prompt, blends it with your creative intent, and produces related images—ideal for inspiration without overreliance on the original. diff --git a/src/mflux/models/fibo/cli/fibo_edit_generate.py b/src/mflux/models/fibo/cli/fibo_edit_generate.py new file mode 100644 index 000000000..6dfdd48ea --- /dev/null +++ b/src/mflux/models/fibo/cli/fibo_edit_generate.py @@ -0,0 +1,79 @@ +from pathlib import Path + +from mflux.callbacks.callback_manager import CallbackManager +from mflux.cli.defaults import defaults as ui_defaults +from mflux.cli.parser.parsers import CommandLineParser +from mflux.models.fibo.latent_creator.fibo_latent_creator import FiboLatentCreator +from mflux.models.fibo.variants.edit.fibo_edit import FIBOEdit +from mflux.models.fibo.variants.edit.util import FiboEditUtil +from mflux.models.fibo.variants.txt2img.util import FiboUtil +from mflux.utils.dimension_resolver import DimensionResolver +from mflux.utils.exceptions import PromptFileReadError, StopImageGenerationException +from mflux.utils.prompt_util import PromptUtil + + +def main(): + parser = CommandLineParser(description="Generate an edited image using Bria FIBO Edit.") + parser.add_general_arguments() + parser.add_model_arguments(require_model_arg=False) + parser.add_lora_arguments() + parser.add_image_generator_arguments(supports_metadata_config=True, supports_dimension_scale_factor=True) + parser.add_argument("--image-path", type=Path, required=True, help="Local path to source image for editing.") + parser.add_argument("--mask-path", type=Path, default=None, help="Optional mask image path for localized edits.") + parser.add_argument( + "--edit-instruction", + type=str, + default=None, + help="Optional edit instruction. Used when prompt JSON does not already include `edit_instruction`.", + ) + parser.add_output_arguments() + args = parser.parse_args() + + if args.guidance is None: + args.guidance = ui_defaults.GUIDANCE_SCALE + + json_prompt = FiboUtil.get_json_prompt(args, quantize=args.quantize) + json_prompt = FiboEditUtil.ensure_edit_instruction(json_prompt, edit_instruction=args.edit_instruction) + + fibo_edit = FIBOEdit( + quantize=args.quantize, + model_path=args.model_path, + lora_paths=args.lora_paths, + lora_scales=args.lora_scales, + ) + + memory_saver = CallbackManager.register_callbacks( + args=args, + model=fibo_edit, + latent_creator=FiboLatentCreator, + ) + + try: + width, height = DimensionResolver.resolve( + width=args.width, + height=args.height, + reference_image_path=args.image_path, + ) + for seed in args.seed: + image = fibo_edit.generate_image( + seed=seed, + prompt=json_prompt, + image_path=args.image_path, + mask_path=args.mask_path, + width=width, + height=height, + guidance=args.guidance, + num_inference_steps=args.steps, + scheduler="flow_match_euler_discrete", + negative_prompt=PromptUtil.read_negative_prompt(args), + ) + image.save(path=args.output.format(seed=seed), export_json_metadata=args.metadata) + except (StopImageGenerationException, PromptFileReadError, ValueError) as exc: + print(exc) + finally: + if memory_saver: + print(memory_saver.memory_stats()) + + +if __name__ == "__main__": + main() diff --git a/src/mflux/models/fibo/model/fibo_transformer/transformer.py b/src/mflux/models/fibo/model/fibo_transformer/transformer.py index ed7056485..d93fb1d0f 100644 --- a/src/mflux/models/fibo/model/fibo_transformer/transformer.py +++ b/src/mflux/models/fibo/model/fibo_transformer/transformer.py @@ -35,13 +35,21 @@ def __call__( hidden_states: mx.array, encoder_hidden_states: mx.array, text_encoder_layers: list[mx.array], + conditioning_seq_len: int = 0, + conditioning_image_ids: mx.array | None = None, ) -> mx.array: # 1. Create embeddings hidden_states = FiboTransformer._handle_classifier_free_guidance(hidden_states, encoder_hidden_states) hidden_states = self.x_embedder(hidden_states) encoder_hidden_states = self.context_embedder(encoder_hidden_states) time_embeddings = FiboTransformer._compute_time_embeddings(t, config, hidden_states.shape[0], hidden_states.dtype, self.time_embed) # fmt: off - image_rotary_emb = FiboTransformer._compute_rotary_embeddings(encoder_hidden_states, self.pos_embed, config, hidden_states.dtype) # fmt: off + image_rotary_emb = FiboTransformer._compute_rotary_embeddings( + encoder_hidden_states=encoder_hidden_states, + pos_embed=self.pos_embed, + config=config, + dtype=hidden_states.dtype, + conditioning_image_ids=conditioning_image_ids, + ) # 2. Compute attention mask attention_mask = FiboTransformer._compute_attention_mask( @@ -49,6 +57,7 @@ def __call__( batch_size=hidden_states.shape[0], encoder_hidden_states=encoder_hidden_states, max_tokens=encoder_hidden_states.shape[1], + conditioning_seq_len=conditioning_seq_len, ) # 3. Project the fibo-specific text encoder layers @@ -165,10 +174,15 @@ def _compute_rotary_embeddings( pos_embed: FiboEmbedND, config: Config, dtype: mx.Dtype, + conditioning_image_ids: mx.array | None = None, ) -> mx.array: max_tokens = encoder_hidden_states.shape[1] txt_ids = mx.zeros((max_tokens, 3), dtype=dtype) img_ids = FiboTransformer._prepare_latent_image_ids(height=config.height, width=config.width, dtype=dtype) + if conditioning_image_ids is not None: + if conditioning_image_ids.ndim == 3: + conditioning_image_ids = conditioning_image_ids[0] + img_ids = mx.concatenate([img_ids, conditioning_image_ids.astype(dtype)], axis=0) if txt_ids.ndim == 3 and txt_ids.shape[0] == 1: txt_ids = txt_ids[0] @@ -212,6 +226,7 @@ def _compute_attention_mask( config: Config, encoder_hidden_states: mx.array, max_tokens: int, + conditioning_seq_len: int = 0, ) -> mx.array: vae_scale_factor = 16 latent_height = config.height // vae_scale_factor @@ -219,7 +234,13 @@ def _compute_attention_mask( latent_seq_len = latent_height * latent_width prompt_attention_mask = mx.ones((batch_size, max_tokens), dtype=mx.float32) latent_attention_mask = mx.ones((batch_size, latent_seq_len), dtype=mx.float32) - attention_mask_2d = mx.concatenate([prompt_attention_mask, latent_attention_mask], axis=1) + if conditioning_seq_len > 0: + conditioning_attention_mask = mx.ones((batch_size, conditioning_seq_len), dtype=mx.float32) + attention_mask_2d = mx.concatenate( + [prompt_attention_mask, latent_attention_mask, conditioning_attention_mask], axis=1 + ) + else: + attention_mask_2d = mx.concatenate([prompt_attention_mask, latent_attention_mask], axis=1) attention_mask = FiboTransformer._prepare_attention_mask(attention_mask_2d) attention_mask = attention_mask.astype(encoder_hidden_states.dtype) return attention_mask diff --git a/src/mflux/models/fibo/variants/edit/__init__.py b/src/mflux/models/fibo/variants/edit/__init__.py new file mode 100644 index 000000000..8520e9c23 --- /dev/null +++ b/src/mflux/models/fibo/variants/edit/__init__.py @@ -0,0 +1,3 @@ +from .fibo_edit import FIBOEdit + +__all__ = ["FIBOEdit"] diff --git a/src/mflux/models/fibo/variants/edit/fibo_edit.py b/src/mflux/models/fibo/variants/edit/fibo_edit.py new file mode 100644 index 000000000..c3a2ab0f9 --- /dev/null +++ b/src/mflux/models/fibo/variants/edit/fibo_edit.py @@ -0,0 +1,154 @@ +from pathlib import Path + +import mlx.core as mx +from mlx import nn + +from mflux.models.common.config.config import Config +from mflux.models.common.config.model_config import ModelConfig +from mflux.models.common.vae.vae_util import VAEUtil +from mflux.models.common.weights.saving.model_saver import ModelSaver +from mflux.models.fibo.fibo_initializer import FIBOInitializer +from mflux.models.fibo.latent_creator.fibo_latent_creator import FiboLatentCreator +from mflux.models.fibo.model.fibo_text_encoder.prompt_encoder import PromptEncoder +from mflux.models.fibo.model.fibo_text_encoder.smol_lm3_3b_text_encoder import SmolLM3_3B_TextEncoder +from mflux.models.fibo.model.fibo_transformer import FiboTransformer +from mflux.models.fibo.model.fibo_vae.wan_2_2_vae import Wan2_2_VAE +from mflux.models.fibo.variants.edit.util import FiboEditUtil +from mflux.models.fibo.weights.fibo_weight_definition import FIBOWeightDefinition +from mflux.utils.exceptions import StopImageGenerationException +from mflux.utils.generated_image import GeneratedImage +from mflux.utils.image_util import ImageUtil + + +class FIBOEdit(nn.Module): + vae: Wan2_2_VAE + transformer: FiboTransformer + text_encoder: SmolLM3_3B_TextEncoder + + def __init__( + self, + quantize: int | None = None, + model_path: str | None = None, + lora_paths: list[str] | None = None, + lora_scales: list[float] | None = None, + model_config: ModelConfig = ModelConfig.fibo_edit(), + ): + super().__init__() + FIBOInitializer.init( + model=self, + quantize=quantize, + model_path=model_path, + lora_paths=lora_paths, + lora_scales=lora_scales, + model_config=model_config, + ) + + def generate_image( + self, + seed: int, + prompt: str, + image_path: Path | str, + mask_path: Path | str | None = None, + num_inference_steps: int = 20, + height: int = 1024, + width: int = 1024, + guidance: float = 4.0, + scheduler: str = "flow_match_euler_discrete", + negative_prompt: str | None = None, + ) -> GeneratedImage: + prompt = FiboEditUtil.ensure_edit_instruction(prompt) + + config = Config( + width=width, + height=height, + guidance=guidance, + scheduler=scheduler, + image_path=image_path, + model_config=self.model_config, + num_inference_steps=num_inference_steps, + ) + if hasattr(config.scheduler, "set_image_seq_len"): + config.scheduler.set_image_seq_len(config.image_seq_len) + + latents = FiboLatentCreator.create_noise(seed=seed, width=config.width, height=config.height) + json_prompt, encoder_hidden_states, text_encoder_layers = PromptEncoder.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + tokenizer=self.tokenizers["fibo"], + text_encoder=self.text_encoder, + ) + + edit_image = FiboEditUtil.load_edit_image( + image_path=image_path, + width=config.width, + height=config.height, + mask_path=mask_path, + ) + conditioning_latents = FiboEditUtil.encode_conditioning_image( + vae=self.vae, + image=edit_image, + height=config.height, + width=config.width, + tiling_config=self.tiling_config, + ) + conditioning_image_ids = FiboEditUtil.create_conditioning_image_ids( + height=config.height, + width=config.width, + dtype=encoder_hidden_states.dtype, + ) + + ctx = self.callbacks.start(seed=seed, prompt=json_prompt, config=config) + ctx.before_loop(latents) + + for t in config.time_steps: + try: + hidden_states = mx.concatenate([latents, conditioning_latents], axis=1) + noise = self.transformer( + t=t, + config=config, + hidden_states=hidden_states, + text_encoder_layers=text_encoder_layers, + encoder_hidden_states=encoder_hidden_states, + conditioning_seq_len=conditioning_latents.shape[1], + conditioning_image_ids=conditioning_image_ids, + ) + noise = FIBOEdit._apply_classifier_free_guidance(noise[:, : latents.shape[1]], config.guidance) + latents = config.scheduler.step(noise=noise, timestep=t, latents=latents) + ctx.in_loop(t, latents) + mx.eval(latents) + except KeyboardInterrupt: # noqa: PERF203 + ctx.interruption(t, latents) + raise StopImageGenerationException( + f"Stopping image generation at step {t + 1}/{config.num_inference_steps}" + ) + + ctx.after_loop(latents) + + latents = FiboLatentCreator.unpack_latents(latents, config.height, config.width) + decoded = VAEUtil.decode(vae=self.vae, latent=latents, tiling_config=self.tiling_config) + return ImageUtil.to_image( + decoded_latents=decoded, + config=config, + seed=seed, + prompt=json_prompt, + quantization=self.bits, + image_path=config.image_path, + masked_image_path=mask_path, + generation_time=config.time_steps.format_dict["elapsed"], + negative_prompt=negative_prompt, + ) + + @staticmethod + def _apply_classifier_free_guidance(noise: mx.array, guidance: float) -> mx.array: + half = noise.shape[0] // 2 + noise_uncond = noise[:half] + noise_text = noise[half:] + return noise_uncond + guidance * (noise_text - noise_uncond) + + def save_model(self, base_path: str) -> None: + ModelSaver.save_model( + model=self, + bits=self.bits, + base_path=base_path, + weight_definition=FIBOWeightDefinition, + ) diff --git a/src/mflux/models/fibo/variants/edit/util.py b/src/mflux/models/fibo/variants/edit/util.py new file mode 100644 index 000000000..46332e10d --- /dev/null +++ b/src/mflux/models/fibo/variants/edit/util.py @@ -0,0 +1,83 @@ +import json +from pathlib import Path + +import mlx.core as mx +from PIL import Image + +from mflux.models.common.vae.vae_util import VAEUtil +from mflux.models.fibo.latent_creator.fibo_latent_creator import FiboLatentCreator +from mflux.models.fibo.model.fibo_vae.wan_2_2_vae import Wan2_2_VAE +from mflux.utils.image_util import ImageUtil + + +class FiboEditUtil: + @staticmethod + def parse_json_prompt(prompt: str) -> dict: + try: + value = json.loads(prompt) + except json.JSONDecodeError as exc: + raise ValueError("FIBO edit prompt must be a valid JSON string.") from exc + + if not isinstance(value, dict): + raise ValueError("FIBO edit prompt JSON must be an object.") + return value + + @staticmethod + def ensure_edit_instruction(prompt: str, edit_instruction: str | None = None) -> str: + prompt_dict = FiboEditUtil.parse_json_prompt(prompt) + if "edit_instruction" in prompt_dict and prompt_dict["edit_instruction"]: + return json.dumps(prompt_dict) + + if edit_instruction is None or not edit_instruction.strip(): + raise ValueError("FIBO edit prompt JSON must include `edit_instruction` (or provide --edit-instruction).") + + prompt_dict["edit_instruction"] = edit_instruction.strip() + return json.dumps(prompt_dict) + + @staticmethod + def load_edit_image( + image_path: Path | str, + width: int, + height: int, + mask_path: Path | str | None = None, + ) -> Image.Image: + image = ImageUtil.load_image(image_path) + if mask_path is None: + return ImageUtil.scale_to_dimensions(image, width, height) + + mask_image = Image.open(mask_path).convert("L") + if mask_image.size != image.size: + raise ValueError("Mask and image must have the same size.") + + masked_image = FiboEditUtil._composite_mask_on_image(mask=mask_image, image=image) + return ImageUtil.scale_to_dimensions(masked_image, width, height) + + @staticmethod + def encode_conditioning_image( + vae: Wan2_2_VAE, + image: Image.Image, + height: int, + width: int, + tiling_config=None, + ) -> mx.array: + image_array = ImageUtil.to_array(image=image) + image_latents = VAEUtil.encode(vae=vae, image=image_array, tiling_config=tiling_config) + return FiboLatentCreator.pack_latents(latents=image_latents, height=height, width=width) + + @staticmethod + def create_conditioning_image_ids(height: int, width: int, dtype: mx.Dtype) -> mx.array: + latent_height = height // 16 + latent_width = width // 16 + row_indices = mx.arange(0, latent_height, dtype=dtype)[:, None] + row_indices = mx.broadcast_to(row_indices, (latent_height, latent_width)) + col_indices = mx.arange(0, latent_width, dtype=dtype)[None, :] + col_indices = mx.broadcast_to(col_indices, (latent_height, latent_width)) + ones_channel = mx.ones((latent_height, latent_width), dtype=dtype) + latent_image_ids = mx.stack([ones_channel, row_indices, col_indices], axis=-1) + latent_image_ids = mx.reshape(latent_image_ids, (1, latent_height * latent_width, 3)) + return latent_image_ids + + @staticmethod + def _composite_mask_on_image(mask: Image.Image, image: Image.Image) -> Image.Image: + gray_img = Image.new("RGB", image.size, (128, 128, 128)) + return Image.composite(gray_img, image.convert("RGB"), mask.convert("L")) diff --git a/tests/image_generation/test_fibo_edit_util.py b/tests/image_generation/test_fibo_edit_util.py new file mode 100644 index 000000000..278644e5a --- /dev/null +++ b/tests/image_generation/test_fibo_edit_util.py @@ -0,0 +1,36 @@ +import json + +import pytest +from PIL import Image + +from mflux.models.fibo.variants.edit.util import FiboEditUtil + + +def test_ensure_edit_instruction_uses_existing_value(): + prompt = json.dumps({"short_description": "owl", "edit_instruction": "make it white"}) + updated = FiboEditUtil.ensure_edit_instruction(prompt, edit_instruction="ignored") + updated_dict = json.loads(updated) + assert updated_dict["edit_instruction"] == "make it white" + + +def test_ensure_edit_instruction_injects_value_when_missing(): + prompt = json.dumps({"short_description": "owl"}) + updated = FiboEditUtil.ensure_edit_instruction(prompt, edit_instruction="add glasses") + updated_dict = json.loads(updated) + assert updated_dict["edit_instruction"] == "add glasses" + + +def test_ensure_edit_instruction_requires_value_when_missing(): + prompt = json.dumps({"short_description": "owl"}) + with pytest.raises(ValueError, match="edit_instruction"): + FiboEditUtil.ensure_edit_instruction(prompt, edit_instruction=None) + + +def test_load_edit_image_raises_for_mask_size_mismatch(tmp_path): + image_path = tmp_path / "image.png" + mask_path = tmp_path / "mask.png" + Image.new("RGB", (64, 64), (255, 255, 255)).save(image_path) + Image.new("L", (32, 32), 255).save(mask_path) + + with pytest.raises(ValueError, match="Mask and image must have the same size"): + FiboEditUtil.load_edit_image(image_path=image_path, width=64, height=64, mask_path=mask_path)