@@ -369,7 +369,8 @@ def __call__(
369369 prompt_2 : Optional [Union [str , List [str ]]] = None ,
370370 prompt_embeds : Optional [torch .FloatTensor ] = None ,
371371 pooled_prompt_embeds : Optional [torch .FloatTensor ] = None ,
372- scales : Optional [Union [float , List [float ]]] = None ,
372+ prompt_embeds_scale : Optional [Union [float , List [float ]]] = 1. ,
373+ pooled_prompt_embeds_scale : Optional [Union [float , List [float ]]] = 1. ,
373374 return_dict : bool = True ,
374375 ):
375376 r"""
@@ -418,6 +419,10 @@ def __call__(
418419 batch_size = image .shape [0 ]
419420 if prompt is not None and isinstance (prompt , str ):
420421 prompt = batch_size * [prompt ]
422+ if isinstance (prompt_embeds_scale , float ):
423+ prompt_embeds_scale = batch_size * [prompt_embeds_scale ]
424+ if isinstance (pooled_prompt_embeds_scale , float ):
425+ pooled_prompt_embeds_scale = batch_size * [pooled_prompt_embeds_scale ]
421426
422427 device = self ._execution_device
423428
@@ -449,9 +454,21 @@ def __call__(
449454 # pooled_prompt_embeds is 768, clip text encoder hidden size
450455 pooled_prompt_embeds = torch .zeros ((batch_size , 768 ), device = device , dtype = image_embeds .dtype )
451456
457+ print ("1 prompt_embeds.shape" , prompt_embeds .shape )
458+ prompt_embeds_scale = torch .tensor (prompt_embeds_scale , device = device , dtype = image_embeds .dtype )[:, None , None ]
459+ pooled_prompt_embeds_scale = torch .tensor (pooled_prompt_embeds_scale , device = device , dtype = image_embeds .dtype )[:, None ]
460+
461+
452462 # Concatenate image and text embeddings
453463 prompt_embeds = torch .cat ([prompt_embeds , image_embeds ], dim = 1 )
454-
464+ print ("2 prompt_embeds.shape" , prompt_embeds .shape )
465+ prompt_embeds *= prompt_embeds_scale
466+ pooled_prompt_embeds *= pooled_prompt_embeds_scale
467+ print ("3 prompt_embeds.shape" , prompt_embeds .shape )
468+
469+ prompt_embeds = torch .sum (prompt_embeds , dim = 0 )
470+ pooled_prompt_embeds = torch .sum (pooled_prompt_embeds , dim = 0 )
471+ print ("4 prompt_embeds.shape" , prompt_embeds .shape )
455472 # Offload all models
456473 self .maybe_free_model_hooks ()
457474
0 commit comments