Skip to content

Commit 169d45c

Browse files
committed
update
1 parent 5c669f8 commit 169d45c

File tree

106 files changed

+782
-769
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

106 files changed

+782
-769
lines changed

examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -839,9 +839,9 @@ def initialize_new_tokens(self, inserting_toks: List[str]):
839839
idx = 0
840840
for tokenizer, text_encoder in zip(self.tokenizers, self.text_encoders):
841841
assert isinstance(inserting_toks, list), "inserting_toks should be a list of strings."
842-
assert all(isinstance(tok, str) for tok in inserting_toks), (
843-
"All elements in inserting_toks should be strings."
844-
)
842+
assert all(
843+
isinstance(tok, str) for tok in inserting_toks
844+
), "All elements in inserting_toks should be strings."
845845

846846
self.inserting_toks = inserting_toks
847847
special_tokens_dict = {"additional_special_tokens": self.inserting_toks}

examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -725,9 +725,9 @@ def initialize_new_tokens(self, inserting_toks: List[str]):
725725
idx = 0
726726
for tokenizer, text_encoder in zip(self.tokenizers, self.text_encoders):
727727
assert isinstance(inserting_toks, list), "inserting_toks should be a list of strings."
728-
assert all(isinstance(tok, str) for tok in inserting_toks), (
729-
"All elements in inserting_toks should be strings."
730-
)
728+
assert all(
729+
isinstance(tok, str) for tok in inserting_toks
730+
), "All elements in inserting_toks should be strings."
731731

732732
self.inserting_toks = inserting_toks
733733
special_tokens_dict = {"additional_special_tokens": self.inserting_toks}
@@ -747,9 +747,9 @@ def initialize_new_tokens(self, inserting_toks: List[str]):
747747
.to(dtype=self.dtype)
748748
* std_token_embedding
749749
)
750-
self.embeddings_settings[f"original_embeddings_{idx}"] = (
751-
text_encoder.text_model.embeddings.token_embedding.weight.data.clone()
752-
)
750+
self.embeddings_settings[
751+
f"original_embeddings_{idx}"
752+
] = text_encoder.text_model.embeddings.token_embedding.weight.data.clone()
753753
self.embeddings_settings[f"std_token_embedding_{idx}"] = std_token_embedding
754754

755755
inu = torch.ones((len(tokenizer),), dtype=torch.bool)

examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -890,9 +890,9 @@ def initialize_new_tokens(self, inserting_toks: List[str]):
890890
idx = 0
891891
for tokenizer, text_encoder in zip(self.tokenizers, self.text_encoders):
892892
assert isinstance(inserting_toks, list), "inserting_toks should be a list of strings."
893-
assert all(isinstance(tok, str) for tok in inserting_toks), (
894-
"All elements in inserting_toks should be strings."
895-
)
893+
assert all(
894+
isinstance(tok, str) for tok in inserting_toks
895+
), "All elements in inserting_toks should be strings."
896896

897897
self.inserting_toks = inserting_toks
898898
special_tokens_dict = {"additional_special_tokens": self.inserting_toks}
@@ -912,9 +912,9 @@ def initialize_new_tokens(self, inserting_toks: List[str]):
912912
.to(dtype=self.dtype)
913913
* std_token_embedding
914914
)
915-
self.embeddings_settings[f"original_embeddings_{idx}"] = (
916-
text_encoder.text_model.embeddings.token_embedding.weight.data.clone()
917-
)
915+
self.embeddings_settings[
916+
f"original_embeddings_{idx}"
917+
] = text_encoder.text_model.embeddings.token_embedding.weight.data.clone()
918918
self.embeddings_settings[f"std_token_embedding_{idx}"] = std_token_embedding
919919

920920
inu = torch.ones((len(tokenizer),), dtype=torch.bool)

examples/community/pipeline_prompt2prompt.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -907,12 +907,12 @@ def create_controller(
907907

908908
# reweight
909909
if edit_type == "reweight":
910-
assert equalizer_words is not None and equalizer_strengths is not None, (
911-
"To use reweight edit, please specify equalizer_words and equalizer_strengths."
912-
)
913-
assert len(equalizer_words) == len(equalizer_strengths), (
914-
"equalizer_words and equalizer_strengths must be of same length."
915-
)
910+
assert (
911+
equalizer_words is not None and equalizer_strengths is not None
912+
), "To use reweight edit, please specify equalizer_words and equalizer_strengths."
913+
assert len(equalizer_words) == len(
914+
equalizer_strengths
915+
), "equalizer_words and equalizer_strengths must be of same length."
916916
equalizer = get_equalizer(prompts[1], equalizer_words, equalizer_strengths, tokenizer=tokenizer)
917917
return AttentionReweight(
918918
prompts,

examples/custom_diffusion/train_custom_diffusion.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -731,18 +731,18 @@ def main(args):
731731
if not class_images_dir.exists():
732732
class_images_dir.mkdir(parents=True, exist_ok=True)
733733
if args.real_prior:
734-
assert (class_images_dir / "images").exists(), (
735-
f'Please run: python retrieve.py --class_prompt "{concept["class_prompt"]}" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}'
736-
)
737-
assert len(list((class_images_dir / "images").iterdir())) == args.num_class_images, (
738-
f'Please run: python retrieve.py --class_prompt "{concept["class_prompt"]}" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}'
739-
)
740-
assert (class_images_dir / "caption.txt").exists(), (
741-
f'Please run: python retrieve.py --class_prompt "{concept["class_prompt"]}" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}'
742-
)
743-
assert (class_images_dir / "images.txt").exists(), (
744-
f'Please run: python retrieve.py --class_prompt "{concept["class_prompt"]}" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}'
745-
)
734+
assert (
735+
class_images_dir / "images"
736+
).exists(), f'Please run: python retrieve.py --class_prompt "{concept["class_prompt"]}" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}'
737+
assert (
738+
len(list((class_images_dir / "images").iterdir())) == args.num_class_images
739+
), f'Please run: python retrieve.py --class_prompt "{concept["class_prompt"]}" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}'
740+
assert (
741+
class_images_dir / "caption.txt"
742+
).exists(), f'Please run: python retrieve.py --class_prompt "{concept["class_prompt"]}" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}'
743+
assert (
744+
class_images_dir / "images.txt"
745+
).exists(), f'Please run: python retrieve.py --class_prompt "{concept["class_prompt"]}" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}'
746746
concept["class_prompt"] = os.path.join(class_images_dir, "caption.txt")
747747
concept["class_data_dir"] = os.path.join(class_images_dir, "images.txt")
748748
args.concepts_list[i] = concept

examples/flux-control/train_control_lora_flux.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,9 +91,9 @@ def log_validation(flux_transformer, args, accelerator, weight_dtype, step, is_f
9191
torch_dtype=weight_dtype,
9292
)
9393
pipeline.load_lora_weights(args.output_dir)
94-
assert pipeline.transformer.config.in_channels == initial_channels * 2, (
95-
f"{pipeline.transformer.config.in_channels=}"
96-
)
94+
assert (
95+
pipeline.transformer.config.in_channels == initial_channels * 2
96+
), f"{pipeline.transformer.config.in_channels=}"
9797

9898
pipeline.to(accelerator.device)
9999
pipeline.set_progress_bar_config(disable=True)

examples/model_search/pipeline_easy.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1081,9 +1081,9 @@ def auto_load_textual_inversion(
10811081
f"textual_inversion_path: {search_word} -> {textual_inversion_path.model_status.site_url}"
10821082
)
10831083

1084-
pretrained_model_name_or_paths[pretrained_model_name_or_paths.index(search_word)] = (
1085-
textual_inversion_path.model_path
1086-
)
1084+
pretrained_model_name_or_paths[
1085+
pretrained_model_name_or_paths.index(search_word)
1086+
] = textual_inversion_path.model_path
10871087

10881088
self.load_textual_inversion(
10891089
pretrained_model_name_or_paths, token=tokens, tokenizer=tokenizer, text_encoder=text_encoder, **kwargs

examples/research_projects/anytext/anytext.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -187,9 +187,9 @@ def get_clip_token_for_string(tokenizer, string):
187187
return_tensors="pt",
188188
)
189189
tokens = batch_encoding["input_ids"]
190-
assert torch.count_nonzero(tokens - 49407) == 2, (
191-
f"String '{string}' maps to more than a single token. Please use another string"
192-
)
190+
assert (
191+
torch.count_nonzero(tokens - 49407) == 2
192+
), f"String '{string}' maps to more than a single token. Please use another string"
193193
return tokens[0, 1]
194194

195195

examples/research_projects/anytext/ocr_recog/RecSVTR.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -312,9 +312,9 @@ def __init__(self, img_size=(32, 100), in_channels=3, embed_dim=768, sub_num=2):
312312

313313
def forward(self, x):
314314
B, C, H, W = x.shape
315-
assert H == self.img_size[0] and W == self.img_size[1], (
316-
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
317-
)
315+
assert (
316+
H == self.img_size[0] and W == self.img_size[1]
317+
), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
318318
x = self.proj(x).flatten(2).permute(0, 2, 1)
319319
return x
320320

examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -763,9 +763,9 @@ def main(args):
763763
# Parse instance and class inputs, and double check that lengths match
764764
instance_data_dir = args.instance_data_dir.split(",")
765765
instance_prompt = args.instance_prompt.split(",")
766-
assert all(x == len(instance_data_dir) for x in [len(instance_data_dir), len(instance_prompt)]), (
767-
"Instance data dir and prompt inputs are not of the same length."
768-
)
766+
assert all(
767+
x == len(instance_data_dir) for x in [len(instance_data_dir), len(instance_prompt)]
768+
), "Instance data dir and prompt inputs are not of the same length."
769769

770770
if args.with_prior_preservation:
771771
class_data_dir = args.class_data_dir.split(",")
@@ -788,9 +788,9 @@ def main(args):
788788
negative_validation_prompts.append(None)
789789
args.validation_negative_prompt = negative_validation_prompts
790790

791-
assert num_of_validation_prompts == len(negative_validation_prompts), (
792-
"The length of negative prompts for validation is greater than the number of validation prompts."
793-
)
791+
assert num_of_validation_prompts == len(
792+
negative_validation_prompts
793+
), "The length of negative prompts for validation is greater than the number of validation prompts."
794794
args.validation_inference_steps = [args.validation_inference_steps] * num_of_validation_prompts
795795
args.validation_guidance_scale = [args.validation_guidance_scale] * num_of_validation_prompts
796796

0 commit comments

Comments
 (0)