1+ 
2+ 
3+ 
4+ # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.calculate_shift 
5+ def  calculate_shift (
6+     image_seq_len ,
7+     base_seq_len : int  =  256 ,
8+     max_seq_len : int  =  4096 ,
9+     base_shift : float  =  0.5 ,
10+     max_shift : float  =  1.15 ,
11+ ):
12+     m  =  (max_shift  -  base_shift ) /  (max_seq_len  -  base_seq_len )
13+     b  =  base_shift  -  m  *  base_seq_len 
14+     mu  =  image_seq_len  *  m  +  b 
15+     return  mu 
16+ 
17+ 
18+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps 
19+ def  retrieve_timesteps (
20+     scheduler ,
21+     num_inference_steps : Optional [int ] =  None ,
22+     device : Optional [Union [str , torch .device ]] =  None ,
23+     timesteps : Optional [List [int ]] =  None ,
24+     sigmas : Optional [List [float ]] =  None ,
25+     ** kwargs ,
26+ ):
27+     r""" 
28+     Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles 
29+     custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. 
30+ 
31+     Args: 
32+         scheduler (`SchedulerMixin`): 
33+             The scheduler to get timesteps from. 
34+         num_inference_steps (`int`): 
35+             The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` 
36+             must be `None`. 
37+         device (`str` or `torch.device`, *optional*): 
38+             The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. 
39+         timesteps (`List[int]`, *optional*): 
40+             Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, 
41+             `num_inference_steps` and `sigmas` must be `None`. 
42+         sigmas (`List[float]`, *optional*): 
43+             Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, 
44+             `num_inference_steps` and `timesteps` must be `None`. 
45+ 
46+     Returns: 
47+         `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the 
48+         second element is the number of inference steps. 
49+     """ 
50+     if  timesteps  is  not None  and  sigmas  is  not None :
51+         raise  ValueError ("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values" )
52+     if  timesteps  is  not None :
53+         accepts_timesteps  =  "timesteps"  in  set (inspect .signature (scheduler .set_timesteps ).parameters .keys ())
54+         if  not  accepts_timesteps :
55+             raise  ValueError (
56+                 f"The current scheduler class { scheduler .__class__ }  
57+                 f" timestep schedules. Please check whether you are using the correct scheduler." 
58+             )
59+         scheduler .set_timesteps (timesteps = timesteps , device = device , ** kwargs )
60+         timesteps  =  scheduler .timesteps 
61+         num_inference_steps  =  len (timesteps )
62+     elif  sigmas  is  not None :
63+         accept_sigmas  =  "sigmas"  in  set (inspect .signature (scheduler .set_timesteps ).parameters .keys ())
64+         if  not  accept_sigmas :
65+             raise  ValueError (
66+                 f"The current scheduler class { scheduler .__class__ }  
67+                 f" sigmas schedules. Please check whether you are using the correct scheduler." 
68+             )
69+         scheduler .set_timesteps (sigmas = sigmas , device = device , ** kwargs )
70+         timesteps  =  scheduler .timesteps 
71+         num_inference_steps  =  len (timesteps )
72+     else :
73+         scheduler .set_timesteps (num_inference_steps , device = device , ** kwargs )
74+         timesteps  =  scheduler .timesteps 
75+     return  timesteps , num_inference_steps 
76+ 
77+ 
78+ 
79+ def  pack_latents (latents , batch_size , num_channels_latents , height , width ):
80+     latents  =  latents .view (batch_size , num_channels_latents , height  //  2 , 2 , width  //  2 , 2 )
81+     latents  =  latents .permute (0 , 2 , 4 , 1 , 3 , 5 )
82+     latents  =  latents .reshape (batch_size , (height  //  2 ) *  (width  //  2 ), num_channels_latents  *  4 )
83+ 
84+     return  latents 
85+ 
86+ def  unpack_latents (latents , height , width , vae_scale_factor ):
87+     batch_size , num_patches , channels  =  latents .shape 
88+ 
89+     # VAE applies 8x compression on images but we must also account for packing which requires 
90+     # latent height and width to be divisible by 2. 
91+     height  =  2  *  (int (height ) //  (vae_scale_factor  *  2 ))
92+     width  =  2  *  (int (width ) //  (vae_scale_factor  *  2 ))
93+ 
94+     latents  =  latents .view (batch_size , height  //  2 , width  //  2 , channels  //  4 , 2 , 2 )
95+     latents  =  latents .permute (0 , 3 , 1 , 4 , 2 , 5 )
96+ 
97+     latents  =  latents .reshape (batch_size , channels  //  (2  *  2 ), 1 , height , width )
98+ 
99+     return  latents 
100+ 
101+ class  QwenImagePrepareLatentsStep (PipelineBlock ):
102+ 
103+     model_name  =  "qwenimage" 
104+ 
105+     @property  
106+     def  description (self ) ->  str :
107+         return  "Prepare latents step that prepares the latents for the text-to-image generation process" 
108+ 
109+     @property  
110+     def  inputs (self ) ->  List [InputParam ]:
111+         return  [
112+             InputParam (name = "height" ),
113+             InputParam (name = "width" ),
114+             InputParam (name = "latents" ),
115+             InputParam (name = "num_images_per_prompt" , default = 1 ),
116+         ]
117+     
118+     @property  
119+     def  intermediate_inputs (self ) ->  List [InputParam ]:
120+         return  [
121+             InputParam (
122+                 name = "batch_size" ,
123+                 required = True ,
124+                 type_hint = int ,
125+                 description = "Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." ,
126+             ),
127+             InputParam (name = "generator" ),
128+             InputParam (name = "dtype" , type_hint = torch .dtype , description = "The dtype of the model inputs" ),
129+         ]
130+     
131+     @property  
132+     def  intermediate_outputs (self ) ->  List [OutputParam ]:
133+         return  [
134+             OutputParam (name = "latents" , type_hint = torch .Tensor , description = "The initial latents to use for the denoising process" ),
135+         ]
136+ 
137+     
138+     def  check_inputs (self , height , width , components ):
139+ 
140+         if  height  is  not None  and  height  %  (components .vae_scale_factor  *  2 ) !=  0 :
141+             raise  ValueError (f"Height must be divisible by { components .vae_scale_factor  *  2 } { height }  )
142+ 
143+         if  width  is  not None  and  width  %  (components .vae_scale_factor  *  2 ) !=  0 :
144+             raise  ValueError (f"Width must be divisible by { components .vae_scale_factor  *  2 } { width }  )
145+     
146+     # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline.prepare_latents with self->components 
147+     def  prepare_latents (
148+         components ,
149+         batch_size ,
150+         num_channels_latents ,
151+         height ,
152+         width ,
153+         dtype ,
154+         device ,
155+         generator ,
156+     ):
157+         # VAE applies 8x compression on images but we must also account for packing which requires 
158+         # latent height and width to be divisible by 2. 
159+         height  =  2  *  (int (height ) //  (components .vae_scale_factor  *  2 ))
160+         width  =  2  *  (int (width ) //  (components .vae_scale_factor  *  2 ))
161+ 
162+         shape  =  (batch_size , 1 , num_channels_latents , height , width )
163+ 
164+         if  isinstance (generator , list ) and  len (generator ) !=  batch_size :
165+             raise  ValueError (
166+                 f"You have passed a list of generators of length { len (generator )}  
167+                 f" size of { batch_size }  
168+             )
169+ 
170+         latents  =  randn_tensor (shape , generator = generator , device = device , dtype = dtype )
171+         latents  =  pack_latents (latents , batch_size , num_channels_latents , height , width )
172+ 
173+         return  latents 
174+ 
175+ 
176+     def  __call__ (self , components : QwenImageModularPipeline , state : PipelineState ) ->  PipelineState :
177+ 
178+         block_state  =  self .get_block_state (state )
179+ 
180+         device  =  components ._execution_device 
181+         dtype  =  block_state .dtype  
182+ 
183+         height  =  block_state .height  or  components .default_height 
184+         width  =  block_state .width  or  components .default_width 
185+         final_batch_size  =  block_state .batch_size  *  block_state .num_images_per_prompt 
186+ 
187+         latents  =  self .prepare_latents (
188+             components = components ,
189+             batch_size = final_batch_size ,
190+             num_channels_latents = components .num_channels_latents ,
191+             height = height ,
192+             width = width ,
193+             dtype = dtype ,
194+             device = device ,
195+             generator = block_state .generator )
196+         
197+         self .set_block_state (state , block_state )
198+ 
199+         return  components , state 
200+ 
201+ 
202+ 
203+ class  QwenImageSetTimestepsStep (PipelineBlock ):
204+ 
205+     model_name  =  "qwenimage" 
206+ 
207+     @property  
208+     def  description (self ) ->  str :
209+         return  "Step that sets the the scheduler's timesteps for inference" 
210+     
211+     @property  
212+     def  expected_components (self ) ->  List [ComponentSpec ]:
213+         return  [
214+             ComponentSpec (name = "scheduler" , FlowMatchEulerDiscreteScheduler ),
215+         ]
216+ 
217+     @property  
218+     def  inputs (self ) ->  List [InputParam ]:
219+         return  [
220+             InputParam (name = "num_inference_steps" , default = 50 ),
221+             InputParam (name = "sigmas" ),
222+         ]
223+ 
224+     @property  
225+     def  intermediate_inputs (self ) ->  List [InputParam ]:
226+         return  [
227+             InputParam (name = "latents" , required = True , type_hint = torch .Tensor , description = "The latents to use for the denoising process" ),
228+         ]
229+ 
230+     @property  
231+     def  intermediate_outputs (self ) ->  List [OutputParam ]:
232+         return  [
233+             OutputParam (name = "timesteps" , type_hint = torch .Tensor , description = "The timesteps to use for the denoising process" ),
234+             OutputParam (name = "num_inference_steps" , type_hint = int , description = "The number of inference steps to use for the denoising process" ),
235+         ]
236+     
237+     def  __call__ (self , components : QwenImageModularPipeline , state : PipelineState ) ->  PipelineState :
238+         block_state  =  self .get_block_state (state )
239+ 
240+         device  =  components ._execution_device 
241+ 
242+         sigmas  =  np .linspace (1.0 , 1  /  block_state .num_inference_steps , block_state .num_inference_steps ) if  block_state .sigmas  is  None  else  block_state .sigmas 
243+         
244+         mu  =  calculate_shift (
245+             image_seq_len = block_state .latents .shape [1 ],
246+             base_seq_len =  components .scheduler .config .get ("base_image_seq_len" , 256 ),
247+             max_seq_len =  components .scheduler .config .get ("max_image_seq_len" , 4096 ),
248+             base_shift =  components .scheduler .config .get ("base_shift" , 0.5 ),
249+             max_shift =  components .scheduler .config .get ("max_shift" , 1.15 ),
250+         )
251+         timesteps , num_inference_steps  =  retrieve_timesteps (
252+             scheduler = components .scheduler ,
253+             num_inference_steps = block_state .num_inference_steps ,
254+             device ,
255+             sigmas = sigmas ,
256+             mu = mu ,
257+         )
258+ 
259+         self .set_block_state (state , block_state )
260+ 
261+         return  components , state 
262+         
263+ 
264+ class  QwenImagePrepareAdditionalConditioningStep (PipelineBlock ):
265+ 
266+     model_name  =  "qwenimage" 
267+ 
268+     @property  
269+     def  description (self ) ->  str :
270+         return  "Step that prepares the additional conditioning for the text-to-image generation process" 
271+     
272+     def  __call__ (self , components : QwenImageModularPipeline , state : PipelineState ) ->  PipelineState :
273+ 
274+         block_state  =  self .get_block_state (state )
275+ 
276+         height  =  block_state .height  or  components .default_height 
277+         width  =  block_state .width  or  components .default_width 
278+ 
279+         block_state .img_shapes  =  [(1 , height  //  components .vae_scale_factor  //  2 , width  //  components .vae_scale_factor  //  2 )] *  block_state .final_batch_size 
280+         image_seq_len  =  block_state .latents .shape [1 ]
281+         txt_seq_lens  =  block_state .prompt_embeds_mask .sum (dim = 1 ).tolist () if  block_state .prompt_embeds_mask  is  not None  else  None 
282+         negative_txt_seq_lens  =  (
283+             block_state .negative_prompt_embeds_mask .sum (dim = 1 ).tolist () if  block_state .negative_prompt_embeds_mask  is  not None  else  None 
284+         )
285+ 
286+ 
287+         self .set_block_state (state , block_state )
288+ 
289+         return  components , state 
0 commit comments