@@ -182,6 +182,7 @@ def _get_glm_embeds(
182182 prompt : Union [str , List [str ]] = None ,
183183 num_images_per_prompt : int = 1 ,
184184 max_sequence_length : int = 1024 ,
185+ padding_type : str = "longest" ,
185186 device : Optional [torch .device ] = None ,
186187 dtype : Optional [torch .dtype ] = None ,
187188 ):
@@ -193,7 +194,7 @@ def _get_glm_embeds(
193194
194195 text_inputs = self .tokenizer (
195196 prompt ,
196- padding = "longest" , # not use max length
197+ padding = padding_type ,
197198 max_length = max_sequence_length ,
198199 truncation = True ,
199200 add_special_tokens = True ,
@@ -239,6 +240,7 @@ def encode_prompt(
239240 device : Optional [torch .device ] = None ,
240241 dtype : Optional [torch .dtype ] = None ,
241242 max_sequence_length : int = 1024 ,
243+ padding_type : str = "longest" ,
242244 ):
243245 r"""
244246 Encodes the prompt into text encoder hidden states.
@@ -275,9 +277,8 @@ def encode_prompt(
275277 batch_size = len (prompt )
276278 else :
277279 batch_size = prompt_embeds .shape [0 ]
278-
279280 if prompt_embeds is None :
280- prompt_embeds = self ._get_glm_embeds (prompt , num_images_per_prompt , max_sequence_length , device , dtype )
281+ prompt_embeds = self ._get_glm_embeds (prompt , num_images_per_prompt , max_sequence_length , padding_type , device , dtype )
281282
282283 if do_classifier_free_guidance and negative_prompt_embeds is None :
283284 negative_prompt = negative_prompt or ""
@@ -296,7 +297,7 @@ def encode_prompt(
296297 )
297298
298299 negative_prompt_embeds = self ._get_glm_embeds (
299- negative_prompt , num_images_per_prompt , max_sequence_length , device , dtype
300+ negative_prompt , num_images_per_prompt , max_sequence_length , "longest" , device , dtype
300301 )
301302
302303 return prompt_embeds , negative_prompt_embeds
@@ -450,6 +451,7 @@ def __call__(
450451 ] = None ,
451452 callback_on_step_end_tensor_inputs : List [str ] = ["latents" ],
452453 max_sequence_length : int = 1024 ,
454+ padding_type : str = "longest" , # For downstream tasks, it can be modified to use max_length for implementation.
453455 ) -> Union [CogView4PipelineOutput , Tuple ]:
454456 """
455457 Function invoked when calling the pipeline for generation.
@@ -579,7 +581,8 @@ def __call__(
579581 prompt_embeds = prompt_embeds ,
580582 negative_prompt_embeds = negative_prompt_embeds ,
581583 max_sequence_length = max_sequence_length ,
582- device = device ,
584+ padding_type = padding_type ,
585+ device = device
583586 )
584587
585588 # Prepare latents
0 commit comments