5050 ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
5151 ... )
5252 >>> prompt = "A bird in space"
53- >>> image = pipe(
54- ... prompt, control_image=control_image, height=1024, width=1024, guidance_scale=3.5)
55- ... ).images[0]
53+ >>> image = pipe(prompt, control_image=control_image, height=1024, width=1024, guidance_scale=3.5).images[0]
5654 >>> image.save("cogview4-control.png")
5755 ```
5856"""
5957
58+
6059# Copied from diffusers.pipelines.cogview4.pipeline_cogview4.calculate_shift
6160def calculate_shift (
6261 image_seq_len ,
@@ -101,19 +100,10 @@ def retrieve_timesteps(
101100 `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
102101 second element is the number of inference steps.
103102 """
104- accepts_timesteps = "timesteps" in set (inspect .signature (scheduler .set_timesteps ).parameters .keys ())
105- accepts_sigmas = "sigmas" in set (inspect .signature (scheduler .set_timesteps ).parameters .keys ())
106-
107103 if timesteps is not None and sigmas is not None :
108- if not accepts_timesteps and not accepts_sigmas :
109- raise ValueError (
110- f"The current scheduler class { scheduler .__class__ } 's `set_timesteps` does not support custom"
111- f" timestep or sigma schedules. Please check whether you are using the correct scheduler."
112- )
113- scheduler .set_timesteps (timesteps = timesteps , sigmas = sigmas , device = device , ** kwargs )
114- timesteps = scheduler .timesteps
115- num_inference_steps = len (timesteps )
116- elif timesteps is not None and sigmas is None :
104+ raise ValueError ("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values" )
105+ if timesteps is not None :
106+ accepts_timesteps = "timesteps" in set (inspect .signature (scheduler .set_timesteps ).parameters .keys ())
117107 if not accepts_timesteps :
118108 raise ValueError (
119109 f"The current scheduler class { scheduler .__class__ } 's `set_timesteps` does not support custom"
@@ -122,8 +112,9 @@ def retrieve_timesteps(
122112 scheduler .set_timesteps (timesteps = timesteps , device = device , ** kwargs )
123113 timesteps = scheduler .timesteps
124114 num_inference_steps = len (timesteps )
125- elif timesteps is None and sigmas is not None :
126- if not accepts_sigmas :
115+ elif sigmas is not None :
116+ accept_sigmas = "sigmas" in set (inspect .signature (scheduler .set_timesteps ).parameters .keys ())
117+ if not accept_sigmas :
127118 raise ValueError (
128119 f"The current scheduler class { scheduler .__class__ } 's `set_timesteps` does not support custom"
129120 f" sigmas schedules. Please check whether you are using the correct scheduler."
@@ -182,7 +173,6 @@ def __init__(
182173 def _get_glm_embeds (
183174 self ,
184175 prompt : Union [str , List [str ]] = None ,
185- num_images_per_prompt : int = 1 ,
186176 max_sequence_length : int = 1024 ,
187177 device : Optional [torch .device ] = None ,
188178 dtype : Optional [torch .dtype ] = None ,
@@ -191,7 +181,6 @@ def _get_glm_embeds(
191181 dtype = dtype or self .text_encoder .dtype
192182
193183 prompt = [prompt ] if isinstance (prompt , str ) else prompt
194- batch_size = len (prompt )
195184
196185 text_inputs = self .tokenizer (
197186 prompt ,
@@ -224,9 +213,6 @@ def _get_glm_embeds(
224213 ).hidden_states [- 2 ]
225214
226215 prompt_embeds = prompt_embeds .to (dtype = dtype , device = device )
227- _ , seq_len , _ = prompt_embeds .shape
228- prompt_embeds = prompt_embeds .repeat (1 , num_images_per_prompt , 1 )
229- prompt_embeds = prompt_embeds .view (batch_size * num_images_per_prompt , seq_len , - 1 )
230216 return prompt_embeds
231217
232218 # Copied from diffusers.pipelines.cogview4.pipeline_cogview4.CogView4Pipeline.encode_prompt
@@ -277,8 +263,13 @@ def encode_prompt(
277263 batch_size = len (prompt )
278264 else :
279265 batch_size = prompt_embeds .shape [0 ]
266+
280267 if prompt_embeds is None :
281- prompt_embeds = self ._get_glm_embeds (prompt , num_images_per_prompt , max_sequence_length , device , dtype )
268+ prompt_embeds = self ._get_glm_embeds (prompt , max_sequence_length , device , dtype )
269+
270+ seq_len = prompt_embeds .size (1 )
271+ prompt_embeds = prompt_embeds .repeat (1 , num_images_per_prompt , 1 )
272+ prompt_embeds = prompt_embeds .view (batch_size * num_images_per_prompt , seq_len , - 1 )
282273
283274 if do_classifier_free_guidance and negative_prompt_embeds is None :
284275 negative_prompt = negative_prompt or ""
@@ -296,9 +287,11 @@ def encode_prompt(
296287 " the batch size of `prompt`."
297288 )
298289
299- negative_prompt_embeds = self ._get_glm_embeds (
300- negative_prompt , num_images_per_prompt , max_sequence_length , device , dtype
301- )
290+ negative_prompt_embeds = self ._get_glm_embeds (negative_prompt , max_sequence_length , device , dtype )
291+
292+ seq_len = negative_prompt_embeds .size (1 )
293+ negative_prompt_embeds = negative_prompt_embeds .repeat (1 , num_images_per_prompt , 1 )
294+ negative_prompt_embeds = negative_prompt_embeds .view (batch_size * num_images_per_prompt , seq_len , - 1 )
302295
303296 return prompt_embeds , negative_prompt_embeds
304297
0 commit comments