Skip to content

Commit 685ad1b

Browse files
committed
"make style" tweaks
1 parent 7a2b028 commit 685ad1b

File tree

1 file changed

+43
-41
lines changed

1 file changed

+43
-41
lines changed

examples/community/pipeline_stable_diffusion_xl_t5.py

Lines changed: 43 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -12,41 +12,35 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
# Note: At this time, the intent is to use the T5 encoder mentioned
16-
# below, with zero changes.
17-
# Therefore, the model deliberately does not store the T5 encoder model bytes,
18-
# (Since they are not unique!)
19-
# but instead takes advantage of huggingface hub cache loading
20-
21-
T5_NAME = "mcmonkey/google_t5-v1_1-xxl_encoderonly"
22-
23-
24-
# Caller is expected to load this, or equivalent, as model name for now
25-
# eg: pipe = StableDiffusionXL_T5Pipeline(SDXL_NAME)
26-
SDXL_NAME = "stabilityai/stable-diffusion-xl-base-1.0"
27-
2815

16+
from typing import Optional
2917

30-
from diffusers import StableDiffusionXLPipeline, DiffusionPipeline
31-
from transformers import T5Tokenizer, T5EncoderModel
18+
import torch.nn as nn
3219
from transformers import (
3320
CLIPImageProcessor,
34-
CLIPTextModel,
35-
CLIPTextModelWithProjection,
3621
CLIPTokenizer,
3722
CLIPVisionModelWithProjection,
23+
T5EncoderModel,
3824
)
3925

40-
from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel
26+
from diffusers import DiffusionPipeline, StableDiffusionXLPipeline
27+
from diffusers.image_processor import VaeImageProcessor
28+
from diffusers.models import AutoencoderKL, UNet2DConditionModel
4129
from diffusers.schedulers import KarrasDiffusionSchedulers
42-
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
4330

4431

45-
from typing import Optional
32+
# Note: At this time, the intent is to use the T5 encoder mentioned
33+
# below, with zero changes.
34+
# Therefore, the model deliberately does not store the T5 encoder model bytes,
35+
# (Since they are not unique!)
36+
# but instead takes advantage of huggingface hub cache loading
4637

47-
import torch.nn as nn, torch, types
38+
T5_NAME = "mcmonkey/google_t5-v1_1-xxl_encoderonly"
39+
40+
# Caller is expected to load this, or equivalent, as model name for now
41+
# eg: pipe = StableDiffusionXL_T5Pipeline(SDXL_NAME)
42+
SDXL_NAME = "stabilityai/stable-diffusion-xl-base-1.0"
4843

49-
import torch.nn as nn
5044

5145
class LinearWithDtype(nn.Linear):
5246
@property
@@ -56,14 +50,23 @@ def dtype(self):
5650

5751
class StableDiffusionXL_T5Pipeline(StableDiffusionXLPipeline):
5852
_expected_modules = [
59-
"vae", "unet", "scheduler", "tokenizer",
60-
"image_encoder", "feature_extractor",
61-
"t5_encoder", "t5_projection", "t5_pooled_projection",
53+
"vae",
54+
"unet",
55+
"scheduler",
56+
"tokenizer",
57+
"image_encoder",
58+
"feature_extractor",
59+
"t5_encoder",
60+
"t5_projection",
61+
"t5_pooled_projection",
6262
]
6363

6464
_optional_components = [
65-
"image_encoder", "feature_extractor",
66-
"t5_encoder", "t5_projection", "t5_pooled_projection",
65+
"image_encoder",
66+
"feature_extractor",
67+
"t5_encoder",
68+
"t5_projection",
69+
"t5_pooled_projection",
6770
]
6871

6972
def __init__(
@@ -83,25 +86,24 @@ def __init__(
8386
DiffusionPipeline.__init__(self)
8487

8588
if t5_encoder is None:
86-
self.t5_encoder = T5EncoderModel.from_pretrained(T5_NAME,
87-
torch_dtype=unet.dtype)
89+
self.t5_encoder = T5EncoderModel.from_pretrained(T5_NAME, torch_dtype=unet.dtype)
8890
else:
89-
self.t5_encoder = t5_encoder
91+
self.t5_encoder = t5_encoder
9092

9193
# ----- build T5 4096 => 2048 dim projection -----
9294
if t5_projection is None:
93-
self.t5_projection = LinearWithDtype(4096, 2048) # trainable
95+
self.t5_projection = LinearWithDtype(4096, 2048) # trainable
9496
else:
95-
self.t5_projection = t5_projection
97+
self.t5_projection = t5_projection
9698
self.t5_projection.to(dtype=unet.dtype)
9799
# ----- build T5 4096 => 1280 dim projection -----
98100
if t5_pooled_projection is None:
99-
self.t5_pooled_projection = LinearWithDtype(4096, 1280) # trainable
101+
self.t5_pooled_projection = LinearWithDtype(4096, 1280) # trainable
100102
else:
101-
self.t5_pooled_projection = t5_pooled_projection
103+
self.t5_pooled_projection = t5_pooled_projection
102104
self.t5_pooled_projection.to(dtype=unet.dtype)
103105

104-
print("dtype of Linear is ",self.t5_projection.dtype)
106+
print("dtype of Linear is ", self.t5_projection.dtype)
105107

106108
self.register_modules(
107109
vae=vae,
@@ -165,13 +167,13 @@ def _tok(text: str):
165167

166168
# ---------- positive stream -------------------------------------
167169
ids, mask = _tok(prompt)
168-
h_pos = self.t5_encoder(ids, attention_mask=mask).last_hidden_state # [b, T, 4096]
169-
tok_pos = self.t5_projection(h_pos) # [b, T, 2048]
170-
pool_pos = self.t5_pooled_projection(h_pos.mean(dim=1)) # [b, 1280]
170+
h_pos = self.t5_encoder(ids, attention_mask=mask).last_hidden_state # [b, T, 4096]
171+
tok_pos = self.t5_projection(h_pos) # [b, T, 2048]
172+
pool_pos = self.t5_pooled_projection(h_pos.mean(dim=1)) # [b, 1280]
171173

172174
# expand for multiple images per prompt
173-
tok_pos = tok_pos.repeat_interleave(num_images_per_prompt, 0)
174-
pool_pos = pool_pos.repeat_interleave(num_images_per_prompt, 0)
175+
tok_pos = tok_pos.repeat_interleave(num_images_per_prompt, 0)
176+
pool_pos = pool_pos.repeat_interleave(num_images_per_prompt, 0)
175177

176178
# ---------- negative / CFG stream --------------------------------
177179
if do_classifier_free_guidance:
@@ -181,7 +183,7 @@ def _tok(text: str):
181183
tok_neg = self.t5_projection(h_neg)
182184
pool_neg = self.t5_pooled_projection(h_neg.mean(dim=1))
183185

184-
tok_neg = tok_neg.repeat_interleave(num_images_per_prompt, 0)
186+
tok_neg = tok_neg.repeat_interleave(num_images_per_prompt, 0)
185187
pool_neg = pool_neg.repeat_interleave(num_images_per_prompt, 0)
186188
else:
187189
tok_neg = pool_neg = None

0 commit comments

Comments
 (0)