Skip to content

Commit 911c30e

Browse files
committed
initial commit
1 parent 5ecf0ed commit 911c30e

File tree

1 file changed

+4
-13
lines changed

1 file changed

+4
-13
lines changed

examples/dreambooth/train_dreambooth_lora_hidream.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,6 @@ def save_model_card(
8585
images=None,
8686
base_model: str = None,
8787
instance_prompt=None,
88-
system_prompt=None,
8988
validation_prompt=None,
9089
repo_folder=None,
9190
):
@@ -113,8 +112,6 @@ def save_model_card(
113112
114113
You should use `{instance_prompt}` to trigger the image generation.
115114
116-
The following `system_prompt` was also used used during training (ignore if `None`): {system_prompt}.
117-
118115
## Download model
119116
120117
[Download the *.safetensors LoRA]({repo_id}/tree/main) in the Files & versions tab.
@@ -324,12 +321,7 @@ def parse_args(input_args=None):
324321
default=256,
325322
help="Maximum sequence length to use with with the Gemma2 model",
326323
)
327-
parser.add_argument(
328-
"--system_prompt",
329-
type=str,
330-
default=None,
331-
help="System prompt to use during inference to give the Gemma2 model certain characteristics.",
332-
)
324+
333325
parser.add_argument(
334326
"--validation_prompt",
335327
type=str,
@@ -382,7 +374,7 @@ def parse_args(input_args=None):
382374
parser.add_argument(
383375
"--output_dir",
384376
type=str,
385-
default="lumina2-dreambooth-lora",
377+
default="hidream-dreambooth-lora",
386378
help="The output directory where the model predictions and checkpoints will be written.",
387379
)
388380
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
@@ -1755,7 +1747,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
17551747
variant=args.variant,
17561748
torch_dtype=weight_dtype,
17571749
)
1758-
pipeline_args = {"prompt": args.validation_prompt, "system_prompt": args.system_prompt}
1750+
pipeline_args = {"prompt": args.validation_prompt}
17591751
images = log_validation(
17601752
pipeline=pipeline,
17611753
args=args,
@@ -1799,7 +1791,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
17991791
if (args.validation_prompt and args.num_validation_images > 0) or (args.final_validation_prompt):
18001792
prompt_to_use = args.validation_prompt if args.validation_prompt else args.final_validation_prompt
18011793
args.num_validation_images = args.num_validation_images if args.num_validation_images else 1
1802-
pipeline_args = {"prompt": prompt_to_use, "system_prompt": args.system_prompt}
1794+
pipeline_args = {"prompt": prompt_to_use, "num_images_per_prompt": args.num_validation_images}
18031795
images = log_validation(
18041796
pipeline=pipeline,
18051797
args=args,
@@ -1816,7 +1808,6 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
18161808
images=images,
18171809
base_model=args.pretrained_model_name_or_path,
18181810
instance_prompt=args.instance_prompt,
1819-
system_prompt=args.system_prompt,
18201811
validation_prompt=validation_prpmpt,
18211812
repo_folder=args.output_dir,
18221813
)

0 commit comments

Comments
 (0)