@@ -241,7 +241,45 @@ from diffusers import StableDiffusionPipeline
241241from  diffusers.callbacks import  PipelineCallback, MultiPipelineCallbacks
242242from  diffusers.configuration_utils import  register_to_config
243243import  torch
244- from  typing import  Any, Dict, Optional
244+ from  typing import  Any, Dict, Tuple, Union
245+ 
246+ 
247+ class  SDPromptSchedulingCallback (PipelineCallback ):
248+     @register_to_config 
249+     def  __init__ (
250+         self 
251+         encoded_prompt : Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
252+         cutoff_step_ratio = None ,
253+         cutoff_step_index = None ,
254+     ):
255+         super ().__init__ (
256+             cutoff_step_ratio = cutoff_step_ratio, cutoff_step_index = cutoff_step_index
257+         )
258+ 
259+     tensor_inputs =  [" prompt_embeds" 
260+ 
261+     def  callback_fn (
262+         self pipeline , step_index , timestep , callback_kwargs 
263+     ) -> Dict[str , Any]:
264+         cutoff_step_ratio =  self .config.cutoff_step_ratio
265+         cutoff_step_index =  self .config.cutoff_step_index
266+         if  isinstance (self .config.encoded_prompt, tuple ):
267+             prompt_embeds, negative_prompt_embeds =  self .config.encoded_prompt
268+         else :
269+             prompt_embeds =  self .config.encoded_prompt
270+ 
271+         #  Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio
272+         cutoff_step =  (
273+             cutoff_step_index
274+             if  cutoff_step_index is  not  None 
275+             else  int (pipeline.num_timesteps *  cutoff_step_ratio)
276+         )
277+ 
278+         if  step_index ==  cutoff_step:
279+             if  pipeline.do_classifier_free_guidance:
280+                 prompt_embeds =  torch.cat([negative_prompt_embeds, prompt_embeds])
281+             callback_kwargs[self .tensor_inputs[0 ]] =  prompt_embeds
282+         return  callback_kwargs
245283
246284
247285pipeline: StableDiffusionPipeline =  StableDiffusionPipeline.from_pretrained(
@@ -253,28 +291,73 @@ pipeline: StableDiffusionPipeline = StableDiffusionPipeline.from_pretrained(
253291pipeline.safety_checker =  None 
254292pipeline.requires_safety_checker =  False 
255293
294+ callback =  MultiPipelineCallbacks(
295+     [
296+         SDPromptSchedulingCallback(
297+             encoded_prompt = pipeline.encode_prompt(
298+                 prompt = f " prompt  { index} " ,
299+                 negative_prompt = f " negative prompt  { index} " ,
300+                 device = pipeline._execution_device,
301+                 num_images_per_prompt = 1 ,
302+                 #  pipeline.do_classifier_free_guidance can't be accessed until after pipeline is ran
303+                 do_classifier_free_guidance = True ,
304+             ),
305+             cutoff_step_index = index,
306+         ) for  index in  range (1 , 20 )
307+     ]
308+ )
309+ 
310+ image =  pipeline(
311+     prompt = " prompt" 
312+     negative_prompt = " negative prompt" 
313+     callback_on_step_end = callback,
314+     callback_on_step_end_tensor_inputs = [" prompt_embeds" 
315+ ).images[0 ]
316+ torch.cuda.empty_cache()
317+ image.save(' image.png' 
318+ ``` 
256319
257- class  SDPromptScheduleCallback (PipelineCallback ):
320+ ``` python 
321+ from  diffusers import  StableDiffusionXLPipeline
322+ from  diffusers.callbacks import  PipelineCallback, MultiPipelineCallbacks
323+ from  diffusers.configuration_utils import  register_to_config
324+ import  torch
325+ from  typing import  Any, Dict, Tuple, Union
326+ 
327+ 
328+ class  SDXLPromptSchedulingCallback (PipelineCallback ):
258329    @register_to_config 
259330    def  __init__ (
260331        self 
261-         prompt :  str ,
262-         negative_prompt : Optional[ str ]  =   None ,
263-         num_images_per_prompt :  int   =   1 ,
264-         cutoff_step_ratio = 1.0 ,
332+         encoded_prompt : Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] ,
333+         add_text_embeds : Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] ,
334+         add_time_ids : Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] ,
335+         cutoff_step_ratio = None ,
265336        cutoff_step_index = None ,
266337    ):
267338        super ().__init__ (
268339            cutoff_step_ratio = cutoff_step_ratio, cutoff_step_index = cutoff_step_index
269340        )
270341
271-     tensor_inputs =  [" prompt_embeds" 
342+     tensor_inputs =  [" prompt_embeds" ,  " add_text_embeds " ,  " add_time_ids " 
272343
273344    def  callback_fn (
274345        self pipeline , step_index , timestep , callback_kwargs 
275346    ) -> Dict[str , Any]:
276347        cutoff_step_ratio =  self .config.cutoff_step_ratio
277348        cutoff_step_index =  self .config.cutoff_step_index
349+         if  isinstance (self .config.encoded_prompt, tuple ):
350+             prompt_embeds, negative_prompt_embeds =  self .config.encoded_prompt
351+         else :
352+             prompt_embeds =  self .config.encoded_prompt
353+         if  isinstance (self .config.add_text_embeds, tuple ):
354+             add_text_embeds, negative_add_text_embeds =  self .config.add_text_embeds
355+         else :
356+             add_text_embeds =  self .config.add_text_embeds
357+         if  isinstance (self .config.add_time_ids, tuple ):
358+             add_time_ids, negative_add_time_ids =  self .config.add_time_ids
359+         else :
360+             add_time_ids =  self .config.add_time_ids
278361
279362        #  Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio
280363        cutoff_step =  (
@@ -284,34 +367,73 @@ class SDPromptScheduleCallback(PipelineCallback):
284367        )
285368
286369        if  step_index ==  cutoff_step:
287-             prompt_embeds, negative_prompt_embeds =  pipeline.encode_prompt(
288-                 prompt = self .config.prompt,
289-                 negative_prompt = self .config.negative_prompt,
290-                 device = pipeline._execution_device,
291-                 num_images_per_prompt = self .config.num_images_per_prompt,
292-                 do_classifier_free_guidance = pipeline.do_classifier_free_guidance,
293-             )
294370            if  pipeline.do_classifier_free_guidance:
295371                prompt_embeds =  torch.cat([negative_prompt_embeds, prompt_embeds])
372+                 add_text_embeds =  torch.cat([negative_add_text_embeds, add_text_embeds])
373+                 add_time_ids =  torch.cat([negative_add_time_ids, add_time_ids])
296374            callback_kwargs[self .tensor_inputs[0 ]] =  prompt_embeds
375+             callback_kwargs[self .tensor_inputs[1 ]] =  add_text_embeds
376+             callback_kwargs[self .tensor_inputs[2 ]] =  add_time_ids
297377        return  callback_kwargs
298378
299- callback =  MultiPipelineCallbacks(
300-     [
301-         SDPromptScheduleCallback(
302-             prompt = " Official portrait of a smiling world war ii general, female, cheerful, happy, detailed face, 20th century, highly detailed, cinematic lighting, digital art painting by Greg Rutkowski" 
303-             negative_prompt = " Deformed, ugly, bad anatomy" 
304-             cutoff_step_ratio = 0.25 ,
379+ 
380+ pipeline: StableDiffusionXLPipeline =  StableDiffusionXLPipeline.from_pretrained(
381+     " stabilityai/stable-diffusion-xl-base-1.0" 
382+     torch_dtype = torch.float16,
383+     variant = " fp16" 
384+     use_safetensors = True ,
385+ ).to(" cuda" 
386+ 
387+ callbacks =  []
388+ for  index in  range (1 , 20 ):
389+     (
390+         prompt_embeds,
391+         negative_prompt_embeds,
392+         pooled_prompt_embeds,
393+         negative_pooled_prompt_embeds,
394+     ) =  pipeline.encode_prompt(
395+         prompt = f " prompt  { index} " ,
396+         negative_prompt = f " prompt  { index} " ,
397+         device = pipeline._execution_device,
398+         num_images_per_prompt = 1 ,
399+         #  pipeline.do_classifier_free_guidance can't be accessed until after pipeline is ran
400+         do_classifier_free_guidance = True ,
401+     )
402+     text_encoder_projection_dim =  int (pooled_prompt_embeds.shape[- 1 ])
403+     add_time_ids =  pipeline._get_add_time_ids(
404+         (1024 , 1024 ),
405+         (0 , 0 ),
406+         (1024 , 1024 ),
407+         dtype = prompt_embeds.dtype,
408+         text_encoder_projection_dim = text_encoder_projection_dim,
409+     )
410+     negative_add_time_ids =  pipeline._get_add_time_ids(
411+         (1024 , 1024 ),
412+         (0 , 0 ),
413+         (1024 , 1024 ),
414+         dtype = prompt_embeds.dtype,
415+         text_encoder_projection_dim = text_encoder_projection_dim,
416+     )
417+     callbacks.append(
418+         SDXLPromptSchedulingCallback(
419+             encoded_prompt = (prompt_embeds, negative_prompt_embeds),
420+             add_text_embeds = (pooled_prompt_embeds, negative_pooled_prompt_embeds),
421+             add_time_ids = (add_time_ids, negative_add_time_ids),
422+             cutoff_step_index = index,
305423        )
306-     ]
307- )
424+     )
425+ 
426+ 
427+ callback =  MultiPipelineCallbacks(callbacks)
308428
309429image =  pipeline(
310-     prompt = " Official portrait of a smiling world war ii general, male, cheerful, happy, detailed face, 20th century, highly detailed, cinematic lighting, digital art painting by Greg Rutkowski " 
311-     negative_prompt = " Deformed, ugly, bad anatomy " 
430+     prompt = " prompt " 
431+     negative_prompt = " negative prompt " 
312432    callback_on_step_end = callback,
313-     callback_on_step_end_tensor_inputs = [" prompt_embeds" 
433+     callback_on_step_end_tensor_inputs = [
434+         " prompt_embeds" 
435+         " add_text_embeds" 
436+         " add_time_ids" 
437+     ],
314438).images[0 ]
315- torch.cuda.empty_cache()
316- image.save(' image.png' 
317439``` 
0 commit comments