@@ -241,7 +241,45 @@ from diffusers import StableDiffusionPipeline
241241from diffusers.callbacks import PipelineCallback, MultiPipelineCallbacks
242242from diffusers.configuration_utils import register_to_config
243243import torch
244- from typing import Any, Dict, Optional
244+ from typing import Any, Dict, Tuple, Union
245+
246+
247+ class SDPromptSchedulingCallback (PipelineCallback ):
248+ @register_to_config
249+ def __init__ (
250+ self ,
251+ encoded_prompt : Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
252+ cutoff_step_ratio = None ,
253+ cutoff_step_index = None ,
254+ ):
255+ super ().__init__ (
256+ cutoff_step_ratio = cutoff_step_ratio, cutoff_step_index = cutoff_step_index
257+ )
258+
259+ tensor_inputs = [" prompt_embeds" ]
260+
261+ def callback_fn (
262+ self , pipeline , step_index , timestep , callback_kwargs
263+ ) -> Dict[str , Any]:
264+ cutoff_step_ratio = self .config.cutoff_step_ratio
265+ cutoff_step_index = self .config.cutoff_step_index
266+ if isinstance (self .config.encoded_prompt, tuple ):
267+ prompt_embeds, negative_prompt_embeds = self .config.encoded_prompt
268+ else :
269+ prompt_embeds = self .config.encoded_prompt
270+
271+ # Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio
272+ cutoff_step = (
273+ cutoff_step_index
274+ if cutoff_step_index is not None
275+ else int (pipeline.num_timesteps * cutoff_step_ratio)
276+ )
277+
278+ if step_index == cutoff_step:
279+ if pipeline.do_classifier_free_guidance:
280+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
281+ callback_kwargs[self .tensor_inputs[0 ]] = prompt_embeds
282+ return callback_kwargs
245283
246284
247285pipeline: StableDiffusionPipeline = StableDiffusionPipeline.from_pretrained(
@@ -253,28 +291,73 @@ pipeline: StableDiffusionPipeline = StableDiffusionPipeline.from_pretrained(
253291pipeline.safety_checker = None
254292pipeline.requires_safety_checker = False
255293
294+ callback = MultiPipelineCallbacks(
295+ [
296+ SDPromptSchedulingCallback(
297+ encoded_prompt = pipeline.encode_prompt(
298+ prompt = f " prompt { index} " ,
299+ negative_prompt = f " negative prompt { index} " ,
300+ device = pipeline._execution_device,
301+ num_images_per_prompt = 1 ,
302+ # pipeline.do_classifier_free_guidance can't be accessed until after pipeline is ran
303+ do_classifier_free_guidance = True ,
304+ ),
305+ cutoff_step_index = index,
306+ ) for index in range (1 , 20 )
307+ ]
308+ )
309+
310+ image = pipeline(
311+ prompt = " prompt"
312+ negative_prompt = " negative prompt" ,
313+ callback_on_step_end = callback,
314+ callback_on_step_end_tensor_inputs = [" prompt_embeds" ],
315+ ).images[0 ]
316+ torch.cuda.empty_cache()
317+ image.save(' image.png' )
318+ ```
256319
257- class SDPromptScheduleCallback (PipelineCallback ):
320+ ``` python
321+ from diffusers import StableDiffusionXLPipeline
322+ from diffusers.callbacks import PipelineCallback, MultiPipelineCallbacks
323+ from diffusers.configuration_utils import register_to_config
324+ import torch
325+ from typing import Any, Dict, Tuple, Union
326+
327+
328+ class SDXLPromptSchedulingCallback (PipelineCallback ):
258329 @register_to_config
259330 def __init__ (
260331 self ,
261- prompt : str ,
262- negative_prompt : Optional[ str ] = None ,
263- num_images_per_prompt : int = 1 ,
264- cutoff_step_ratio = 1.0 ,
332+ encoded_prompt : Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] ,
333+ add_text_embeds : Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] ,
334+ add_time_ids : Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] ,
335+ cutoff_step_ratio = None ,
265336 cutoff_step_index = None ,
266337 ):
267338 super ().__init__ (
268339 cutoff_step_ratio = cutoff_step_ratio, cutoff_step_index = cutoff_step_index
269340 )
270341
271- tensor_inputs = [" prompt_embeds" ]
342+ tensor_inputs = [" prompt_embeds" , " add_text_embeds " , " add_time_ids " ]
272343
273344 def callback_fn (
274345 self , pipeline , step_index , timestep , callback_kwargs
275346 ) -> Dict[str , Any]:
276347 cutoff_step_ratio = self .config.cutoff_step_ratio
277348 cutoff_step_index = self .config.cutoff_step_index
349+ if isinstance (self .config.encoded_prompt, tuple ):
350+ prompt_embeds, negative_prompt_embeds = self .config.encoded_prompt
351+ else :
352+ prompt_embeds = self .config.encoded_prompt
353+ if isinstance (self .config.add_text_embeds, tuple ):
354+ add_text_embeds, negative_add_text_embeds = self .config.add_text_embeds
355+ else :
356+ add_text_embeds = self .config.add_text_embeds
357+ if isinstance (self .config.add_time_ids, tuple ):
358+ add_time_ids, negative_add_time_ids = self .config.add_time_ids
359+ else :
360+ add_time_ids = self .config.add_time_ids
278361
279362 # Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio
280363 cutoff_step = (
@@ -284,34 +367,73 @@ class SDPromptScheduleCallback(PipelineCallback):
284367 )
285368
286369 if step_index == cutoff_step:
287- prompt_embeds, negative_prompt_embeds = pipeline.encode_prompt(
288- prompt = self .config.prompt,
289- negative_prompt = self .config.negative_prompt,
290- device = pipeline._execution_device,
291- num_images_per_prompt = self .config.num_images_per_prompt,
292- do_classifier_free_guidance = pipeline.do_classifier_free_guidance,
293- )
294370 if pipeline.do_classifier_free_guidance:
295371 prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
372+ add_text_embeds = torch.cat([negative_add_text_embeds, add_text_embeds])
373+ add_time_ids = torch.cat([negative_add_time_ids, add_time_ids])
296374 callback_kwargs[self .tensor_inputs[0 ]] = prompt_embeds
375+ callback_kwargs[self .tensor_inputs[1 ]] = add_text_embeds
376+ callback_kwargs[self .tensor_inputs[2 ]] = add_time_ids
297377 return callback_kwargs
298378
299- callback = MultiPipelineCallbacks(
300- [
301- SDPromptScheduleCallback(
302- prompt = " Official portrait of a smiling world war ii general, female, cheerful, happy, detailed face, 20th century, highly detailed, cinematic lighting, digital art painting by Greg Rutkowski" ,
303- negative_prompt = " Deformed, ugly, bad anatomy" ,
304- cutoff_step_ratio = 0.25 ,
379+
380+ pipeline: StableDiffusionXLPipeline = StableDiffusionXLPipeline.from_pretrained(
381+ " stabilityai/stable-diffusion-xl-base-1.0" ,
382+ torch_dtype = torch.float16,
383+ variant = " fp16" ,
384+ use_safetensors = True ,
385+ ).to(" cuda" )
386+
387+ callbacks = []
388+ for index in range (1 , 20 ):
389+ (
390+ prompt_embeds,
391+ negative_prompt_embeds,
392+ pooled_prompt_embeds,
393+ negative_pooled_prompt_embeds,
394+ ) = pipeline.encode_prompt(
395+ prompt = f " prompt { index} " ,
396+ negative_prompt = f " prompt { index} " ,
397+ device = pipeline._execution_device,
398+ num_images_per_prompt = 1 ,
399+ # pipeline.do_classifier_free_guidance can't be accessed until after pipeline is ran
400+ do_classifier_free_guidance = True ,
401+ )
402+ text_encoder_projection_dim = int (pooled_prompt_embeds.shape[- 1 ])
403+ add_time_ids = pipeline._get_add_time_ids(
404+ (1024 , 1024 ),
405+ (0 , 0 ),
406+ (1024 , 1024 ),
407+ dtype = prompt_embeds.dtype,
408+ text_encoder_projection_dim = text_encoder_projection_dim,
409+ )
410+ negative_add_time_ids = pipeline._get_add_time_ids(
411+ (1024 , 1024 ),
412+ (0 , 0 ),
413+ (1024 , 1024 ),
414+ dtype = prompt_embeds.dtype,
415+ text_encoder_projection_dim = text_encoder_projection_dim,
416+ )
417+ callbacks.append(
418+ SDXLPromptSchedulingCallback(
419+ encoded_prompt = (prompt_embeds, negative_prompt_embeds),
420+ add_text_embeds = (pooled_prompt_embeds, negative_pooled_prompt_embeds),
421+ add_time_ids = (add_time_ids, negative_add_time_ids),
422+ cutoff_step_index = index,
305423 )
306- ]
307- )
424+ )
425+
426+
427+ callback = MultiPipelineCallbacks(callbacks)
308428
309429image = pipeline(
310- prompt = " Official portrait of a smiling world war ii general, male, cheerful, happy, detailed face, 20th century, highly detailed, cinematic lighting, digital art painting by Greg Rutkowski " ,
311- negative_prompt = " Deformed, ugly, bad anatomy " ,
430+ prompt = " prompt " ,
431+ negative_prompt = " negative prompt " ,
312432 callback_on_step_end = callback,
313- callback_on_step_end_tensor_inputs = [" prompt_embeds" ],
433+ callback_on_step_end_tensor_inputs = [
434+ " prompt_embeds" ,
435+ " add_text_embeds" ,
436+ " add_time_ids" ,
437+ ],
314438).images[0 ]
315- torch.cuda.empty_cache()
316- image.save(' image.png' )
317439```
0 commit comments