@@ -454,21 +454,16 @@ def __call__(
454454 # pooled_prompt_embeds is 768, clip text encoder hidden size
455455 pooled_prompt_embeds = torch .zeros ((batch_size , 768 ), device = device , dtype = image_embeds .dtype )
456456
457- print ("1 prompt_embeds.shape" , prompt_embeds .shape )
458- prompt_embeds_scale = torch .tensor (prompt_embeds_scale , device = device , dtype = image_embeds .dtype )[:, None , None ]
459- pooled_prompt_embeds_scale = torch .tensor (pooled_prompt_embeds_scale , device = device , dtype = image_embeds .dtype )[:, None ]
457+ # scale & oncatenate image and text embeddings
458+ prompt_embeds = torch .cat ([prompt_embeds , image_embeds ], dim = 1 )
460459
460+ prompt_embeds *= torch .tensor (prompt_embeds_scale , device = device , dtype = image_embeds .dtype )[:, None , None ]
461+ pooled_prompt_embeds *= torch .tensor (pooled_prompt_embeds_scale , device = device , dtype = image_embeds .dtype )[:, None ]
462+
463+ # weighted sum
464+ prompt_embeds = torch .sum (prompt_embeds , dim = 0 , keepdim = True )
465+ pooled_prompt_embeds = torch .sum (pooled_prompt_embeds , dim = 0 , keepdim = True )
461466
462- # Concatenate image and text embeddings
463- prompt_embeds = torch .cat ([prompt_embeds , image_embeds ], dim = 1 )
464- print ("2 prompt_embeds.shape" , prompt_embeds .shape )
465- prompt_embeds *= prompt_embeds_scale
466- pooled_prompt_embeds *= pooled_prompt_embeds_scale
467- print ("3 prompt_embeds.shape" , prompt_embeds .shape )
468-
469- prompt_embeds = torch .sum (prompt_embeds , dim = 0 )
470- pooled_prompt_embeds = torch .sum (pooled_prompt_embeds , dim = 0 )
471- print ("4 prompt_embeds.shape" , prompt_embeds .shape )
472467 # Offload all models
473468 self .maybe_free_model_hooks ()
474469
0 commit comments