Skip to content

Commit f188e80

Browse files
committed
fixes
1 parent f46330b commit f188e80

File tree

1 file changed

+1
-7
lines changed

1 file changed

+1
-7
lines changed

examples/control-lora/train_control_flux.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -78,18 +78,12 @@ def log_validation(flux_transformer, args, accelerator, weight_dtype, step, is_f
7878
torch_dtype=weight_dtype,
7979
)
8080
else:
81-
transformer = FluxTransformer2DModel.from_pretrained(
82-
args.pretrained_model_name_or_path, subfolder="transformer", torch_dtype=weight_dtype
83-
)
84-
initial_channels = transformer.config.in_channels
81+
transformer = FluxTransformer2DModel.from_pretrained(args.output_dir, torch_dtype=weight_dtype)
8582
pipeline = FluxControlPipeline.from_pretrained(
8683
args.pretrained_model_name_or_path,
8784
transformer=transformer,
8885
torch_dtype=weight_dtype,
8986
)
90-
assert (
91-
pipeline.transformer.config.in_channels == initial_channels * 2
92-
), f"{pipeline.transformer.config.in_channels=}"
9387

9488
pipeline.to(accelerator.device)
9589
pipeline.set_progress_bar_config(disable=True)

0 commit comments

Comments
 (0)