Skip to content

Commit 53733c7

Browse files
committed
Add community class StableDiffusionXL_T5Pipeline
Will be used with base model opendiffusionai/stablediffusionxl_t5
1 parent f161e27 commit 53733c7

File tree

1 file changed

+186
-0
lines changed

1 file changed

+186
-0
lines changed
Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
# Copyright Philip Brown, ppbrown@github
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# Note: At this time, the intent is to use the T5 encoder mentioned
16+
# below, with zero changes.
17+
# Therefore, the model deliberately does not store the T5 encoder model bytes,
18+
# (Since they are not unique!)
19+
# but instead takes advantage of huggingface hub cache loading
20+
21+
T5_NAME = "mcmonkey/google_t5-v1_1-xxl_encoderonly"
22+
23+
24+
# Caller is expected to load this, or equivalent, as model name for now
25+
# eg: pipe = StableDiffusionXL_T5Pipeline(SDXL_NAME)
26+
SDXL_NAME = "stabilityai/stable-diffusion-xl-base-1.0"
27+
28+
29+
30+
from diffusers import StableDiffusionXLPipeline, DiffusionPipeline
31+
from transformers import T5Tokenizer, T5EncoderModel
32+
from transformers import (
33+
CLIPImageProcessor,
34+
CLIPTextModel,
35+
CLIPTextModelWithProjection,
36+
CLIPTokenizer,
37+
CLIPVisionModelWithProjection,
38+
)
39+
40+
from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel
41+
from diffusers.schedulers import KarrasDiffusionSchedulers
42+
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
43+
44+
45+
from typing import Optional
46+
47+
import torch.nn as nn, torch, types
48+
49+
import torch.nn as nn
50+
51+
class LinearWithDtype(nn.Linear):
52+
@property
53+
def dtype(self):
54+
return self.weight.dtype
55+
56+
57+
class StableDiffusionXL_T5Pipeline(StableDiffusionXLPipeline):
58+
_expected_modules = [
59+
"vae", "unet", "scheduler", "tokenizer",
60+
"image_encoder", "feature_extractor",
61+
"t5_encoder", "t5_projection",
62+
]
63+
64+
_optional_components = [
65+
"image_encoder", "feature_extractor",
66+
"t5_encoder", "t5_projection",
67+
]
68+
69+
def __init__(
70+
self,
71+
vae: AutoencoderKL,
72+
unet: UNet2DConditionModel,
73+
scheduler: KarrasDiffusionSchedulers,
74+
tokenizer: CLIPTokenizer,
75+
t5_encoder=None,
76+
t5_projection=None,
77+
image_encoder: CLIPVisionModelWithProjection = None,
78+
feature_extractor: CLIPImageProcessor = None,
79+
force_zeros_for_empty_prompt: bool = True,
80+
add_watermarker: Optional[bool] = None,
81+
):
82+
DiffusionPipeline.__init__(self)
83+
84+
if t5_encoder is None:
85+
self.t5_encoder = T5EncoderModel.from_pretrained(T5_NAME,
86+
torch_dtype=unet.dtype)
87+
else:
88+
self.t5_encoder = t5_encoder
89+
90+
# ----- build T5 4096 => 2048 dim projection -----
91+
if t5_projection is None:
92+
self.t5_projection = LinearWithDtype(4096, 2048) # trainable
93+
else:
94+
self.t5_projection = t5_projection
95+
self.t5_projection.to(dtype=unet.dtype)
96+
97+
print("dtype of Linear is ",self.t5_projection.dtype)
98+
99+
self.register_modules(
100+
vae=vae,
101+
unet=unet,
102+
scheduler=scheduler,
103+
tokenizer=tokenizer,
104+
t5_encoder=self.t5_encoder,
105+
t5_projection=self.t5_projection,
106+
image_encoder=image_encoder,
107+
feature_extractor=feature_extractor,
108+
)
109+
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
110+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
111+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
112+
113+
self.default_sample_size = (
114+
self.unet.config.sample_size
115+
if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size")
116+
else 128
117+
)
118+
119+
self.watermark = None
120+
121+
# Parts of original SDXL class complain if these attributes are not
122+
# at least PRESENT
123+
self.text_encoder = self.text_encoder_2 = None
124+
125+
# ------------------------------------------------------------------
126+
# Encode a text prompt (T5-XXL + 4096→2048 projection)
127+
# Returns exactly four tensors in the order SDXL’s __call__ expects.
128+
# ------------------------------------------------------------------
129+
def encode_prompt(
130+
self,
131+
prompt,
132+
num_images_per_prompt: int = 1,
133+
do_classifier_free_guidance: bool = True,
134+
negative_prompt: str | None = None,
135+
**_,
136+
):
137+
"""
138+
Returns
139+
-------
140+
prompt_embeds : Tensor [B, T, 2048]
141+
negative_prompt_embeds : Tensor [B, T, 2048] | None
142+
pooled_prompt_embeds : Tensor [B, 1280]
143+
negative_pooled_prompt_embeds: Tensor [B, 1280] | None
144+
where B = batch * num_images_per_prompt
145+
"""
146+
147+
# --- helper to tokenize on the pipeline’s device ----------------
148+
def _tok(text: str):
149+
tok_out = self.tokenizer(
150+
text,
151+
return_tensors="pt",
152+
padding="max_length",
153+
max_length=self.tokenizer.model_max_length,
154+
truncation=True,
155+
).to(self.device)
156+
return tok_out.input_ids, tok_out.attention_mask
157+
158+
# ---------- positive stream -------------------------------------
159+
ids, mask = _tok(prompt)
160+
h_pos = self.t5_encoder(ids, attention_mask=mask).last_hidden_state # [b, T, 4096]
161+
tok_pos = self.t5_projection(h_pos) # [b, T, 2048]
162+
pool_pos = tok_pos.mean(dim=1)[:, :1280] # [b, 1280]
163+
164+
# expand for multiple images per prompt
165+
tok_pos = tok_pos.repeat_interleave(num_images_per_prompt, 0)
166+
pool_pos = pool_pos.repeat_interleave(num_images_per_prompt, 0)
167+
168+
# ---------- negative / CFG stream --------------------------------
169+
if do_classifier_free_guidance:
170+
neg_text = "" if negative_prompt is None else negative_prompt
171+
ids_n, mask_n = _tok(neg_text)
172+
h_neg = self.t5_encoder(ids_n, attention_mask=mask_n).last_hidden_state
173+
tok_neg = self.t5_projection(h_neg)
174+
pool_neg = tok_neg.mean(dim=1)[:, :1280]
175+
176+
tok_neg = tok_neg.repeat_interleave(num_images_per_prompt, 0)
177+
pool_neg = pool_neg.repeat_interleave(num_images_per_prompt, 0)
178+
else:
179+
tok_neg = pool_neg = None
180+
181+
# ----------------- final ordered return --------------------------
182+
# 1) positive token embeddings
183+
# 2) negative token embeddings (or None)
184+
# 3) positive pooled embeddings
185+
# 4) negative pooled embeddings (or None)
186+
return tok_pos, tok_neg, pool_pos, pool_neg

0 commit comments

Comments
 (0)