Skip to content

Commit c576c58

Browse files
committed
support z-image with lora
1 parent 827eabf commit c576c58

File tree

1 file changed

+29
-1
lines changed

1 file changed

+29
-1
lines changed

lightx2v/models/runners/z_image/z_image_runner.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from loguru import logger
88

99
from lightx2v.models.input_encoders.hf.z_image.qwen3_model import Qwen3Model_TextEncoder
10+
from lightx2v.models.networks.lora_adapter import LoraAdapter
1011
from lightx2v.models.networks.z_image.model import ZImageTransformerModel
1112
from lightx2v.models.runners.default_runner import DefaultRunner
1213
from lightx2v.models.schedulers.z_image.scheduler import ZImageScheduler
@@ -30,6 +31,24 @@ def calculate_dimensions(target_area, ratio):
3031
return width, height, None
3132

3233

34+
def build_z_image_model_with_lora(z_image_module, config, model_kwargs, lora_configs):
35+
lora_dynamic_apply = config.get("lora_dynamic_apply", False)
36+
37+
if lora_dynamic_apply:
38+
lora_path = lora_configs[0]["path"]
39+
lora_strength = lora_configs[0]["strength"]
40+
model_kwargs["lora_path"] = lora_path
41+
model_kwargs["lora_strength"] = lora_strength
42+
model = z_image_module(**model_kwargs)
43+
else:
44+
assert not config.get("dit_quantized", False), "Online LoRA only for quantized models; merging LoRA is unsupported."
45+
assert not config.get("lazy_load", False), "Lazy load mode does not support LoRA merging."
46+
model = z_image_module(**model_kwargs)
47+
lora_adapter = LoraAdapter(model)
48+
lora_adapter.apply_lora(lora_configs)
49+
return model
50+
51+
3352
@RUNNER_REGISTER("z_image")
3453
class ZImageRunner(DefaultRunner):
3554
model_cpu_offload_seq = "text_encoder->transformer->vae"
@@ -45,7 +64,16 @@ def load_model(self):
4564
self.vae = self.load_vae()
4665

4766
def load_transformer(self):
48-
model = ZImageTransformerModel(os.path.join(self.config["model_path"], "transformer"), self.config, self.init_device)
67+
z_image_model_kwargs = {
68+
"model_path": os.path.join(self.config["model_path"], "transformer"),
69+
"config": self.config,
70+
"device": self.init_device,
71+
}
72+
lora_configs = self.config.get("lora_configs")
73+
if not lora_configs:
74+
model = ZImageTransformerModel(**z_image_model_kwargs)
75+
else:
76+
model = build_z_image_model_with_lora(ZImageTransformerModel, self.config, z_image_model_kwargs, lora_configs)
4977
return model
5078

5179
def load_text_encoder(self):

0 commit comments

Comments
 (0)