Skip to content

Commit 80924ff

Browse files
committed
Refactor Predictor class to use FluxOminiKontextPipeline; update LoRA model weights and re-enable product insertion. Add debug prints for optimised reference image handling.
1 parent 5de4adb commit 80924ff

File tree

1 file changed

+18
-16
lines changed

1 file changed

+18
-16
lines changed

predict.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from cog import BasePredictor, Input, Path, Secret
77
import torch
88
from PIL import Image, ImageChops
9-
from src.pipeline_qwen_omini_image_edit import QwenOminiImageEditPipeline
9+
from src.pipeline_flux_omini_kontext import FluxOminiKontextPipeline
1010
import random
1111
import json
1212

@@ -22,25 +22,25 @@
2222
},
2323
"spatial_character_insertion": {
2424
"lora_path": "saquiboye/omini-kontext",
25-
"weight_name": "qwen/character_spatial_1000.safetensors",
25+
"weight_name": "spatial-character-test.safetensors",
2626
},
2727
"character_insertion": {
2828
"lora_path": "saquiboye/omini-kontext",
29-
"weight_name": "qwen/character_1000.safetensors",
29+
"weight_name": "character_3000.safetensors",
3030
},
31-
# "product_insertion": {
32-
# "lora_path": "saquiboye/omini-kontext",
33-
# "weight_name": "product_2000.safetensors",
34-
# }
31+
"product_insertion": {
32+
"lora_path": "saquiboye/omini-kontext",
33+
"weight_name": "product_2000.safetensors",
34+
}
3535
}
3636

3737
class Predictor(BasePredictor):
3838
def setup(self) -> None:
3939
"""Load the model into memory to make running multiple predictions efficient"""
4040

4141
ensure_hf_login()
42-
self.pipe = QwenOminiImageEditPipeline.from_pretrained(
43-
"Qwen/Qwen-Image-Edit", torch_dtype=torch.bfloat16
42+
self.pipe = FluxOminiKontextPipeline.from_pretrained(
43+
"black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16
4444
).to("cuda")
4545

4646
def predict(
@@ -133,11 +133,13 @@ def predict(
133133
reference_image = reference_image.resize((width, height), Image.LANCZOS)
134134

135135
try:
136-
print("has_reference: ", has_reference)
137-
print("reference_image: ", reference_image)
138-
print("delta: ", delta)
139-
# if has_reference:
140-
# optimised_reference, new_reference_delta = optimise_image_condition(reference_image, delta)
136+
if has_reference:
137+
optimised_reference, new_reference_delta = optimise_image_condition(reference_image, delta)
138+
print("optimised_reference: ", optimised_reference)
139+
print("new_reference_delta: ", new_reference_delta)
140+
o = "tmp/optimised_reference.png"
141+
optimised_reference.save(o)
142+
print("saved optimised_reference to: ", Path(o))
141143
result_img = self.pipe(
142144
prompt=prompt,
143145
image=image,
@@ -147,8 +149,8 @@ def predict(
147149
height=height,
148150
width=width,
149151
generator=generator,
150-
# _auto_resize=False,
151-
# max_area=width*height,
152+
_auto_resize=False,
153+
max_area=width*height,
152154
guidance_scale=guidance_scale
153155
).images[0]
154156

0 commit comments

Comments
 (0)