Skip to content

Commit 23daa81

Browse files
committed
Changed pooled_embeds to use projection instead of slice
1 parent 53733c7 commit 23daa81

File tree

1 file changed

+14
-6
lines changed

1 file changed

+14
-6
lines changed

examples/community/pipeline_stable_diffusion_xl_t5.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -58,12 +58,12 @@ class StableDiffusionXL_T5Pipeline(StableDiffusionXLPipeline):
5858
_expected_modules = [
5959
"vae", "unet", "scheduler", "tokenizer",
6060
"image_encoder", "feature_extractor",
61-
"t5_encoder", "t5_projection",
61+
"t5_encoder", "t5_projection", "t5_pooled_projection",
6262
]
6363

6464
_optional_components = [
6565
"image_encoder", "feature_extractor",
66-
"t5_encoder", "t5_projection",
66+
"t5_encoder", "t5_projection", "t5_pooled_projection",
6767
]
6868

6969
def __init__(
@@ -74,6 +74,7 @@ def __init__(
7474
tokenizer: CLIPTokenizer,
7575
t5_encoder=None,
7676
t5_projection=None,
77+
t5_pooled_projection=None,
7778
image_encoder: CLIPVisionModelWithProjection = None,
7879
feature_extractor: CLIPImageProcessor = None,
7980
force_zeros_for_empty_prompt: bool = True,
@@ -93,6 +94,12 @@ def __init__(
9394
else:
9495
self.t5_projection = t5_projection
9596
self.t5_projection.to(dtype=unet.dtype)
97+
# ----- build T5 4096 => 1280 dim projection -----
98+
if t5_pooled_projection is None:
99+
self.t5_pooled_projection = LinearWithDtype(4096, 1280) # trainable
100+
else:
101+
self.t5_pooled_projection = t5_pooled_projection
102+
self.t5_pooled_projection.to(dtype=unet.dtype)
96103

97104
print("dtype of Linear is ",self.t5_projection.dtype)
98105

@@ -103,6 +110,7 @@ def __init__(
103110
tokenizer=tokenizer,
104111
t5_encoder=self.t5_encoder,
105112
t5_projection=self.t5_projection,
113+
t5_pooled_projection=self.t5_pooled_projection,
106114
image_encoder=image_encoder,
107115
feature_extractor=feature_extractor,
108116
)
@@ -157,9 +165,9 @@ def _tok(text: str):
157165

158166
# ---------- positive stream -------------------------------------
159167
ids, mask = _tok(prompt)
160-
h_pos = self.t5_encoder(ids, attention_mask=mask).last_hidden_state # [b, T, 4096]
161-
tok_pos = self.t5_projection(h_pos) # [b, T, 2048]
162-
pool_pos = tok_pos.mean(dim=1)[:, :1280] # [b, 1280]
168+
h_pos = self.t5_encoder(ids, attention_mask=mask).last_hidden_state # [b, T, 4096]
169+
tok_pos = self.t5_projection(h_pos) # [b, T, 2048]
170+
pool_pos = self.t5_pooled_projection(h_pos.mean(dim=1)) # [b, 1280]
163171

164172
# expand for multiple images per prompt
165173
tok_pos = tok_pos.repeat_interleave(num_images_per_prompt, 0)
@@ -171,7 +179,7 @@ def _tok(text: str):
171179
ids_n, mask_n = _tok(neg_text)
172180
h_neg = self.t5_encoder(ids_n, attention_mask=mask_n).last_hidden_state
173181
tok_neg = self.t5_projection(h_neg)
174-
pool_neg = tok_neg.mean(dim=1)[:, :1280]
182+
pool_neg = self.t5_pooled_projection(h_neg.mean(dim=1))
175183

176184
tok_neg = tok_neg.repeat_interleave(num_images_per_prompt, 0)
177185
pool_neg = pool_neg.repeat_interleave(num_images_per_prompt, 0)

0 commit comments

Comments
 (0)