Skip to content

Commit f55873b

Browse files
lawrence-cjsayakpaulyiyixuxu
authored
Fix PixArt 256px inference (#6789)
* feat 256px diffusers inference bug * change the max_length of T5 to pipeline config file * fix bug in convert_pixart_alpha_to_diffusers.py * Update scripts/convert_pixart_alpha_to_diffusers.py Co-authored-by: Sayak Paul <[email protected]> * remove multi_scale_train parser * Update src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py Co-authored-by: YiYi Xu <[email protected]> * Update src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py Co-authored-by: YiYi Xu <[email protected]> * styling * change `model_token_max_length` to call argument. * Refactoring * add: max_sequence_length to the docstring. --------- Co-authored-by: Sayak Paul <[email protected]> Co-authored-by: YiYi Xu <[email protected]>
1 parent ccb93dc commit f55873b

File tree

3 files changed

+59
-11
lines changed

3 files changed

+59
-11
lines changed

scripts/convert_pixart_alpha_to_diffusers.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,11 @@
99

1010
ckpt_id = "PixArt-alpha/PixArt-alpha"
1111
# https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/scripts/inference.py#L125
12-
interpolation_scale = {512: 1, 1024: 2}
12+
interpolation_scale = {256: 0.5, 512: 1, 1024: 2}
1313

1414

1515
def main(args):
16-
all_state_dict = torch.load(args.orig_ckpt_path)
16+
all_state_dict = torch.load(args.orig_ckpt_path, map_location="cpu")
1717
state_dict = all_state_dict.pop("state_dict")
1818
converted_state_dict = {}
1919

@@ -22,7 +22,6 @@ def main(args):
2222
converted_state_dict["pos_embed.proj.bias"] = state_dict.pop("x_embedder.proj.bias")
2323

2424
# Caption projection.
25-
converted_state_dict["caption_projection.y_embedding"] = state_dict.pop("y_embedder.y_embedding")
2625
converted_state_dict["caption_projection.linear_1.weight"] = state_dict.pop("y_embedder.y_proj.fc1.weight")
2726
converted_state_dict["caption_projection.linear_1.bias"] = state_dict.pop("y_embedder.y_proj.fc1.bias")
2827
converted_state_dict["caption_projection.linear_2.weight"] = state_dict.pop("y_embedder.y_proj.fc2.weight")
@@ -155,6 +154,7 @@ def main(args):
155154

156155
assert transformer.pos_embed.pos_embed is not None
157156
state_dict.pop("pos_embed")
157+
state_dict.pop("y_embedder.y_embedding")
158158
assert len(state_dict) == 0, f"State dict is not empty, {state_dict.keys()}"
159159

160160
num_model_params = sum(p.numel() for p in transformer.parameters())
@@ -187,7 +187,7 @@ def main(args):
187187
"--image_size",
188188
default=1024,
189189
type=int,
190-
choices=[512, 1024],
190+
choices=[256, 512, 1024],
191191
required=False,
192192
help="Image size of pretrained model, either 512 or 1024.",
193193
)

src/diffusers/models/transformers/transformer_2d.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ def __init__(
9797
norm_eps: float = 1e-5,
9898
attention_type: str = "default",
9999
caption_channels: int = None,
100+
interpolation_scale: float = None,
100101
):
101102
super().__init__()
102103
self.use_linear_projection = use_linear_projection
@@ -168,8 +169,9 @@ def __init__(
168169
self.width = sample_size
169170

170171
self.patch_size = patch_size
171-
interpolation_scale = self.config.sample_size // 64 # => 64 (= 512 pixart) has interpolation scale 1
172-
interpolation_scale = max(interpolation_scale, 1)
172+
interpolation_scale = (
173+
interpolation_scale if interpolation_scale is not None else max(self.config.sample_size // 64, 1)
174+
)
173175
self.pos_embed = PatchEmbed(
174176
height=sample_size,
175177
width=sample_size,

src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py

Lines changed: 51 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,42 @@
133133
"4.0": [1024.0, 256.0],
134134
}
135135

136+
ASPECT_RATIO_256_BIN = {
137+
"0.25": [128.0, 512.0],
138+
"0.28": [128.0, 464.0],
139+
"0.32": [144.0, 448.0],
140+
"0.33": [144.0, 432.0],
141+
"0.35": [144.0, 416.0],
142+
"0.4": [160.0, 400.0],
143+
"0.42": [160.0, 384.0],
144+
"0.48": [176.0, 368.0],
145+
"0.5": [176.0, 352.0],
146+
"0.52": [176.0, 336.0],
147+
"0.57": [192.0, 336.0],
148+
"0.6": [192.0, 320.0],
149+
"0.68": [208.0, 304.0],
150+
"0.72": [208.0, 288.0],
151+
"0.78": [224.0, 288.0],
152+
"0.82": [224.0, 272.0],
153+
"0.88": [240.0, 272.0],
154+
"0.94": [240.0, 256.0],
155+
"1.0": [256.0, 256.0],
156+
"1.07": [256.0, 240.0],
157+
"1.13": [272.0, 240.0],
158+
"1.21": [272.0, 224.0],
159+
"1.29": [288.0, 224.0],
160+
"1.38": [288.0, 208.0],
161+
"1.46": [304.0, 208.0],
162+
"1.67": [320.0, 192.0],
163+
"1.75": [336.0, 192.0],
164+
"2.0": [352.0, 176.0],
165+
"2.09": [368.0, 176.0],
166+
"2.4": [384.0, 160.0],
167+
"2.5": [400.0, 160.0],
168+
"3.0": [432.0, 144.0],
169+
"4.0": [512.0, 128.0],
170+
}
171+
136172

137173
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
138174
def retrieve_timesteps(
@@ -260,6 +296,7 @@ def encode_prompt(
260296
prompt_attention_mask: Optional[torch.FloatTensor] = None,
261297
negative_prompt_attention_mask: Optional[torch.FloatTensor] = None,
262298
clean_caption: bool = False,
299+
max_sequence_length: int = 120,
263300
**kwargs,
264301
):
265302
r"""
@@ -284,8 +321,9 @@ def encode_prompt(
284321
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
285322
Pre-generated negative text embeddings. For PixArt-Alpha, it's should be the embeddings of the ""
286323
string.
287-
clean_caption (bool, defaults to `False`):
324+
clean_caption (`bool`, defaults to `False`):
288325
If `True`, the function will preprocess and clean the provided caption before encoding.
326+
max_sequence_length (`int`, defaults to 120): Maximum sequence length to use for the prompt.
289327
"""
290328

291329
if "mask_feature" in kwargs:
@@ -303,7 +341,7 @@ def encode_prompt(
303341
batch_size = prompt_embeds.shape[0]
304342

305343
# See Section 3.1. of the paper.
306-
max_length = 120
344+
max_length = max_sequence_length
307345

308346
if prompt_embeds is None:
309347
prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
@@ -688,6 +726,7 @@ def __call__(
688726
callback_steps: int = 1,
689727
clean_caption: bool = True,
690728
use_resolution_binning: bool = True,
729+
max_sequence_length: int = 120,
691730
**kwargs,
692731
) -> Union[ImagePipelineOutput, Tuple]:
693732
"""
@@ -757,6 +796,7 @@ def __call__(
757796
If set to `True`, the requested height and width are first mapped to the closest resolutions using
758797
`ASPECT_RATIO_1024_BIN`. After the produced latents are decoded into images, they are resized back to
759798
the requested resolution. Useful for generating non-square images.
799+
max_sequence_length (`int` defaults to 120): Maximum sequence length to use with the `prompt`.
760800
761801
Examples:
762802
@@ -772,9 +812,14 @@ def __call__(
772812
height = height or self.transformer.config.sample_size * self.vae_scale_factor
773813
width = width or self.transformer.config.sample_size * self.vae_scale_factor
774814
if use_resolution_binning:
775-
aspect_ratio_bin = (
776-
ASPECT_RATIO_1024_BIN if self.transformer.config.sample_size == 128 else ASPECT_RATIO_512_BIN
777-
)
815+
if self.transformer.config.sample_size == 128:
816+
aspect_ratio_bin = ASPECT_RATIO_1024_BIN
817+
elif self.transformer.config.sample_size == 64:
818+
aspect_ratio_bin = ASPECT_RATIO_512_BIN
819+
elif self.transformer.config.sample_size == 32:
820+
aspect_ratio_bin = ASPECT_RATIO_256_BIN
821+
else:
822+
raise ValueError("Invalid sample size")
778823
orig_height, orig_width = height, width
779824
height, width = self.classify_height_width_bin(height, width, ratios=aspect_ratio_bin)
780825

@@ -822,6 +867,7 @@ def __call__(
822867
prompt_attention_mask=prompt_attention_mask,
823868
negative_prompt_attention_mask=negative_prompt_attention_mask,
824869
clean_caption=clean_caption,
870+
max_sequence_length=max_sequence_length,
825871
)
826872
if do_classifier_free_guidance:
827873
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)

0 commit comments

Comments
 (0)