|
25 | 25 | T5TokenizerFast, |
26 | 26 | ) |
27 | 27 |
|
| 28 | +from ...callbacks import MultiPipelineCallbacks, PipelineCallback |
28 | 29 | from ...image_processor import PipelineImageInput, VaeImageProcessor |
29 | 30 | from ...loaders import FromSingleFileMixin, SD3IPAdapterMixin, SD3LoraLoaderMixin |
30 | 31 | from ...models.autoencoders import AutoencoderKL |
@@ -184,7 +185,7 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle |
184 | 185 |
|
185 | 186 | model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->image_encoder->transformer->vae" |
186 | 187 | _optional_components = ["image_encoder", "feature_extractor"] |
187 | | - _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"] |
| 188 | + _callback_tensor_inputs = ["latents", "prompt_embeds", "pooled_prompt_embeds"] |
188 | 189 |
|
189 | 190 | def __init__( |
190 | 191 | self, |
@@ -923,6 +924,9 @@ def __call__( |
923 | 924 | height = height or self.default_sample_size * self.vae_scale_factor |
924 | 925 | width = width or self.default_sample_size * self.vae_scale_factor |
925 | 926 |
|
| 927 | + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): |
| 928 | + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs |
| 929 | + |
926 | 930 | # 1. Check inputs. Raise error if not correct |
927 | 931 | self.check_inputs( |
928 | 932 | prompt, |
@@ -1109,10 +1113,7 @@ def __call__( |
1109 | 1113 |
|
1110 | 1114 | latents = callback_outputs.pop("latents", latents) |
1111 | 1115 | prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) |
1112 | | - negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) |
1113 | | - negative_pooled_prompt_embeds = callback_outputs.pop( |
1114 | | - "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds |
1115 | | - ) |
| 1116 | + pooled_prompt_embeds = callback_outputs.pop("pooled_prompt_embeds", pooled_prompt_embeds) |
1116 | 1117 |
|
1117 | 1118 | # call the callback, if provided |
1118 | 1119 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): |
|
0 commit comments