66from cog import BasePredictor , Input , Path , Secret
77import torch
88from PIL import Image , ImageChops
9- from src .pipeline_qwen_omini_image_edit import QwenOminiImageEditPipeline
9+ from src .pipeline_flux_omini_kontext import FluxOminiKontextPipeline
1010import random
1111import json
1212
2222 },
2323 "spatial_character_insertion" : {
2424 "lora_path" : "saquiboye/omini-kontext" ,
25- "weight_name" : "qwen/character_spatial_1000 .safetensors" ,
25+ "weight_name" : "spatial-character-test .safetensors" ,
2626 },
2727 "character_insertion" : {
2828 "lora_path" : "saquiboye/omini-kontext" ,
29- "weight_name" : "qwen/character_1000 .safetensors" ,
29+ "weight_name" : "character_3000 .safetensors" ,
3030 },
31- # "product_insertion": {
32- # "lora_path": "saquiboye/omini-kontext",
33- # "weight_name": "product_2000.safetensors",
34- # }
31+ "product_insertion" : {
32+ "lora_path" : "saquiboye/omini-kontext" ,
33+ "weight_name" : "product_2000.safetensors" ,
34+ }
3535}
3636
3737class Predictor (BasePredictor ):
3838 def setup (self ) -> None :
3939 """Load the model into memory to make running multiple predictions efficient"""
4040
4141 ensure_hf_login ()
42- self .pipe = QwenOminiImageEditPipeline .from_pretrained (
43- "Qwen/Qwen-Image-Edit " , torch_dtype = torch .bfloat16
42+ self .pipe = FluxOminiKontextPipeline .from_pretrained (
43+ "black-forest-labs/FLUX.1-Kontext-dev " , torch_dtype = torch .bfloat16
4444 ).to ("cuda" )
4545
4646 def predict (
@@ -133,11 +133,13 @@ def predict(
133133 reference_image = reference_image .resize ((width , height ), Image .LANCZOS )
134134
135135 try :
136- print ("has_reference: " , has_reference )
137- print ("reference_image: " , reference_image )
138- print ("delta: " , delta )
139- # if has_reference:
140- # optimised_reference, new_reference_delta = optimise_image_condition(reference_image, delta)
136+ if has_reference :
137+ optimised_reference , new_reference_delta = optimise_image_condition (reference_image , delta )
138+ print ("optimised_reference: " , optimised_reference )
139+ print ("new_reference_delta: " , new_reference_delta )
140+ o = "tmp/optimised_reference.png"
141+ optimised_reference .save (o )
142+ print ("saved optimised_reference to: " , Path (o ))
141143 result_img = self .pipe (
142144 prompt = prompt ,
143145 image = image ,
@@ -147,8 +149,8 @@ def predict(
147149 height = height ,
148150 width = width ,
149151 generator = generator ,
150- # _auto_resize=False,
151- # max_area=width*height,
152+ _auto_resize = False ,
153+ max_area = width * height ,
152154 guidance_scale = guidance_scale
153155 ).images [0 ]
154156
0 commit comments