Skip to content

Commit 6fcec9d

Browse files
committed
fix
1 parent 3ae0d67 commit 6fcec9d

File tree

105 files changed

+769
-776
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

105 files changed

+769
-776
lines changed

examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -839,9 +839,9 @@ def initialize_new_tokens(self, inserting_toks: List[str]):
839839
idx = 0
840840
for tokenizer, text_encoder in zip(self.tokenizers, self.text_encoders):
841841
assert isinstance(inserting_toks, list), "inserting_toks should be a list of strings."
842-
assert all(
843-
isinstance(tok, str) for tok in inserting_toks
844-
), "All elements in inserting_toks should be strings."
842+
assert all(isinstance(tok, str) for tok in inserting_toks), (
843+
"All elements in inserting_toks should be strings."
844+
)
845845

846846
self.inserting_toks = inserting_toks
847847
special_tokens_dict = {"additional_special_tokens": self.inserting_toks}

examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -725,9 +725,9 @@ def initialize_new_tokens(self, inserting_toks: List[str]):
725725
idx = 0
726726
for tokenizer, text_encoder in zip(self.tokenizers, self.text_encoders):
727727
assert isinstance(inserting_toks, list), "inserting_toks should be a list of strings."
728-
assert all(
729-
isinstance(tok, str) for tok in inserting_toks
730-
), "All elements in inserting_toks should be strings."
728+
assert all(isinstance(tok, str) for tok in inserting_toks), (
729+
"All elements in inserting_toks should be strings."
730+
)
731731

732732
self.inserting_toks = inserting_toks
733733
special_tokens_dict = {"additional_special_tokens": self.inserting_toks}
@@ -747,9 +747,9 @@ def initialize_new_tokens(self, inserting_toks: List[str]):
747747
.to(dtype=self.dtype)
748748
* std_token_embedding
749749
)
750-
self.embeddings_settings[
751-
f"original_embeddings_{idx}"
752-
] = text_encoder.text_model.embeddings.token_embedding.weight.data.clone()
750+
self.embeddings_settings[f"original_embeddings_{idx}"] = (
751+
text_encoder.text_model.embeddings.token_embedding.weight.data.clone()
752+
)
753753
self.embeddings_settings[f"std_token_embedding_{idx}"] = std_token_embedding
754754

755755
inu = torch.ones((len(tokenizer),), dtype=torch.bool)

examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -890,9 +890,9 @@ def initialize_new_tokens(self, inserting_toks: List[str]):
890890
idx = 0
891891
for tokenizer, text_encoder in zip(self.tokenizers, self.text_encoders):
892892
assert isinstance(inserting_toks, list), "inserting_toks should be a list of strings."
893-
assert all(
894-
isinstance(tok, str) for tok in inserting_toks
895-
), "All elements in inserting_toks should be strings."
893+
assert all(isinstance(tok, str) for tok in inserting_toks), (
894+
"All elements in inserting_toks should be strings."
895+
)
896896

897897
self.inserting_toks = inserting_toks
898898
special_tokens_dict = {"additional_special_tokens": self.inserting_toks}
@@ -912,9 +912,9 @@ def initialize_new_tokens(self, inserting_toks: List[str]):
912912
.to(dtype=self.dtype)
913913
* std_token_embedding
914914
)
915-
self.embeddings_settings[
916-
f"original_embeddings_{idx}"
917-
] = text_encoder.text_model.embeddings.token_embedding.weight.data.clone()
915+
self.embeddings_settings[f"original_embeddings_{idx}"] = (
916+
text_encoder.text_model.embeddings.token_embedding.weight.data.clone()
917+
)
918918
self.embeddings_settings[f"std_token_embedding_{idx}"] = std_token_embedding
919919

920920
inu = torch.ones((len(tokenizer),), dtype=torch.bool)

examples/community/pipeline_prompt2prompt.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -907,12 +907,12 @@ def create_controller(
907907

908908
# reweight
909909
if edit_type == "reweight":
910-
assert (
911-
equalizer_words is not None and equalizer_strengths is not None
912-
), "To use reweight edit, please specify equalizer_words and equalizer_strengths."
913-
assert len(equalizer_words) == len(
914-
equalizer_strengths
915-
), "equalizer_words and equalizer_strengths must be of same length."
910+
assert equalizer_words is not None and equalizer_strengths is not None, (
911+
"To use reweight edit, please specify equalizer_words and equalizer_strengths."
912+
)
913+
assert len(equalizer_words) == len(equalizer_strengths), (
914+
"equalizer_words and equalizer_strengths must be of same length."
915+
)
916916
equalizer = get_equalizer(prompts[1], equalizer_words, equalizer_strengths, tokenizer=tokenizer)
917917
return AttentionReweight(
918918
prompts,

examples/custom_diffusion/train_custom_diffusion.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -731,18 +731,18 @@ def main(args):
731731
if not class_images_dir.exists():
732732
class_images_dir.mkdir(parents=True, exist_ok=True)
733733
if args.real_prior:
734-
assert (
735-
class_images_dir / "images"
736-
).exists(), f'Please run: python retrieve.py --class_prompt "{concept["class_prompt"]}" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}'
737-
assert (
738-
len(list((class_images_dir / "images").iterdir())) == args.num_class_images
739-
), f'Please run: python retrieve.py --class_prompt "{concept["class_prompt"]}" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}'
740-
assert (
741-
class_images_dir / "caption.txt"
742-
).exists(), f'Please run: python retrieve.py --class_prompt "{concept["class_prompt"]}" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}'
743-
assert (
744-
class_images_dir / "images.txt"
745-
).exists(), f'Please run: python retrieve.py --class_prompt "{concept["class_prompt"]}" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}'
734+
assert (class_images_dir / "images").exists(), (
735+
f'Please run: python retrieve.py --class_prompt "{concept["class_prompt"]}" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}'
736+
)
737+
assert len(list((class_images_dir / "images").iterdir())) == args.num_class_images, (
738+
f'Please run: python retrieve.py --class_prompt "{concept["class_prompt"]}" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}'
739+
)
740+
assert (class_images_dir / "caption.txt").exists(), (
741+
f'Please run: python retrieve.py --class_prompt "{concept["class_prompt"]}" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}'
742+
)
743+
assert (class_images_dir / "images.txt").exists(), (
744+
f'Please run: python retrieve.py --class_prompt "{concept["class_prompt"]}" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}'
745+
)
746746
concept["class_prompt"] = os.path.join(class_images_dir, "caption.txt")
747747
concept["class_data_dir"] = os.path.join(class_images_dir, "images.txt")
748748
args.concepts_list[i] = concept

examples/flux-control/train_control_lora_flux.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,9 +91,9 @@ def log_validation(flux_transformer, args, accelerator, weight_dtype, step, is_f
9191
torch_dtype=weight_dtype,
9292
)
9393
pipeline.load_lora_weights(args.output_dir)
94-
assert (
95-
pipeline.transformer.config.in_channels == initial_channels * 2
96-
), f"{pipeline.transformer.config.in_channels=}"
94+
assert pipeline.transformer.config.in_channels == initial_channels * 2, (
95+
f"{pipeline.transformer.config.in_channels=}"
96+
)
9797

9898
pipeline.to(accelerator.device)
9999
pipeline.set_progress_bar_config(disable=True)

examples/model_search/pipeline_easy.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1081,9 +1081,9 @@ def auto_load_textual_inversion(
10811081
f"textual_inversion_path: {search_word} -> {textual_inversion_path.model_status.site_url}"
10821082
)
10831083

1084-
pretrained_model_name_or_paths[
1085-
pretrained_model_name_or_paths.index(search_word)
1086-
] = textual_inversion_path.model_path
1084+
pretrained_model_name_or_paths[pretrained_model_name_or_paths.index(search_word)] = (
1085+
textual_inversion_path.model_path
1086+
)
10871087

10881088
self.load_textual_inversion(
10891089
pretrained_model_name_or_paths, token=tokens, tokenizer=tokenizer, text_encoder=text_encoder, **kwargs

examples/research_projects/anytext/anytext.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -187,9 +187,9 @@ def get_clip_token_for_string(tokenizer, string):
187187
return_tensors="pt",
188188
)
189189
tokens = batch_encoding["input_ids"]
190-
assert (
191-
torch.count_nonzero(tokens - 49407) == 2
192-
), f"String '{string}' maps to more than a single token. Please use another string"
190+
assert torch.count_nonzero(tokens - 49407) == 2, (
191+
f"String '{string}' maps to more than a single token. Please use another string"
192+
)
193193
return tokens[0, 1]
194194

195195

examples/research_projects/anytext/ocr_recog/RecSVTR.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -312,9 +312,9 @@ def __init__(self, img_size=(32, 100), in_channels=3, embed_dim=768, sub_num=2):
312312

313313
def forward(self, x):
314314
B, C, H, W = x.shape
315-
assert (
316-
H == self.img_size[0] and W == self.img_size[1]
317-
), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
315+
assert H == self.img_size[0] and W == self.img_size[1], (
316+
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
317+
)
318318
x = self.proj(x).flatten(2).permute(0, 2, 1)
319319
return x
320320

examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -763,9 +763,9 @@ def main(args):
763763
# Parse instance and class inputs, and double check that lengths match
764764
instance_data_dir = args.instance_data_dir.split(",")
765765
instance_prompt = args.instance_prompt.split(",")
766-
assert all(
767-
x == len(instance_data_dir) for x in [len(instance_data_dir), len(instance_prompt)]
768-
), "Instance data dir and prompt inputs are not of the same length."
766+
assert all(x == len(instance_data_dir) for x in [len(instance_data_dir), len(instance_prompt)]), (
767+
"Instance data dir and prompt inputs are not of the same length."
768+
)
769769

770770
if args.with_prior_preservation:
771771
class_data_dir = args.class_data_dir.split(",")
@@ -788,9 +788,9 @@ def main(args):
788788
negative_validation_prompts.append(None)
789789
args.validation_negative_prompt = negative_validation_prompts
790790

791-
assert num_of_validation_prompts == len(
792-
negative_validation_prompts
793-
), "The length of negative prompts for validation is greater than the number of validation prompts."
791+
assert num_of_validation_prompts == len(negative_validation_prompts), (
792+
"The length of negative prompts for validation is greater than the number of validation prompts."
793+
)
794794
args.validation_inference_steps = [args.validation_inference_steps] * num_of_validation_prompts
795795
args.validation_guidance_scale = [args.validation_guidance_scale] * num_of_validation_prompts
796796

0 commit comments

Comments
 (0)