@@ -58,12 +58,12 @@ class StableDiffusionXL_T5Pipeline(StableDiffusionXLPipeline):
5858    _expected_modules  =  [
5959        "vae" , "unet" , "scheduler" , "tokenizer" ,
6060        "image_encoder" , "feature_extractor" ,
61-         "t5_encoder" , "t5_projection" ,
61+         "t5_encoder" , "t5_projection" ,  "t5_pooled_projection" , 
6262    ]
6363
6464    _optional_components  =  [
6565        "image_encoder" , "feature_extractor" ,
66-         "t5_encoder" , "t5_projection" ,
66+         "t5_encoder" , "t5_projection" ,  "t5_pooled_projection" , 
6767    ]
6868
6969    def  __init__ (
@@ -74,6 +74,7 @@ def __init__(
7474        tokenizer : CLIPTokenizer ,
7575        t5_encoder = None ,
7676        t5_projection = None ,
77+         t5_pooled_projection = None ,
7778        image_encoder : CLIPVisionModelWithProjection  =  None ,
7879        feature_extractor : CLIPImageProcessor  =  None ,
7980        force_zeros_for_empty_prompt : bool  =  True ,
@@ -93,6 +94,12 @@ def __init__(
9394        else :
9495            self .t5_projection   =  t5_projection 
9596        self .t5_projection .to (dtype = unet .dtype )
97+         # ----- build T5 4096 => 1280 dim projection ----- 
98+         if  t5_pooled_projection  is  None :
99+             self .t5_pooled_projection   =  LinearWithDtype (4096 , 1280 )   # trainable 
100+         else :
101+             self .t5_pooled_projection   =  t5_pooled_projection 
102+         self .t5_pooled_projection .to (dtype = unet .dtype )
96103
97104        print ("dtype of Linear is " ,self .t5_projection .dtype )
98105
@@ -103,6 +110,7 @@ def __init__(
103110            tokenizer = tokenizer ,
104111            t5_encoder = self .t5_encoder ,
105112            t5_projection = self .t5_projection ,
113+             t5_pooled_projection = self .t5_pooled_projection ,
106114            image_encoder = image_encoder ,
107115            feature_extractor = feature_extractor ,
108116        )
@@ -157,9 +165,9 @@ def _tok(text: str):
157165
158166        # ---------- positive stream ------------------------------------- 
159167        ids , mask  =  _tok (prompt )
160-         h_pos  =  self .t5_encoder (ids , attention_mask = mask ).last_hidden_state       # [b, T, 4096] 
161-         tok_pos  =  self .t5_projection (h_pos )                                      # [b, T, 2048] 
162-         pool_pos  =  tok_pos . mean (dim = 1 )[:, : 1280 ]                                 # [b, 1280] 
168+         h_pos  =  self .t5_encoder (ids , attention_mask = mask ).last_hidden_state    # [b, T, 4096] 
169+         tok_pos  =  self .t5_projection (h_pos )                                   # [b, T, 2048] 
170+         pool_pos  =  self . t5_pooled_projection ( h_pos . mean (dim = 1 ))                # [b, 1280] 
163171
164172        # expand for multiple images per prompt 
165173        tok_pos    =  tok_pos .repeat_interleave (num_images_per_prompt , 0 )
@@ -171,7 +179,7 @@ def _tok(text: str):
171179            ids_n , mask_n  =  _tok (neg_text )
172180            h_neg  =  self .t5_encoder (ids_n , attention_mask = mask_n ).last_hidden_state 
173181            tok_neg  =  self .t5_projection (h_neg )
174-             pool_neg  =  tok_neg . mean (dim = 1 )[:, : 1280 ] 
182+             pool_neg  =  self . t5_pooled_projection ( h_neg . mean (dim = 1 )) 
175183
176184            tok_neg   =  tok_neg .repeat_interleave (num_images_per_prompt , 0 )
177185            pool_neg  =  pool_neg .repeat_interleave (num_images_per_prompt , 0 )
0 commit comments