1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15- # Note: At this time, the intent is to use the T5 encoder mentioned
16- # below, with zero changes.
17- # Therefore, the model deliberately does not store the T5 encoder model bytes,
18- # (Since they are not unique!)
19- # but instead takes advantage of huggingface hub cache loading
20-
21- T5_NAME = "mcmonkey/google_t5-v1_1-xxl_encoderonly"
22-
23-
24- # Caller is expected to load this, or equivalent, as model name for now
25- # eg: pipe = StableDiffusionXL_T5Pipeline(SDXL_NAME)
26- SDXL_NAME = "stabilityai/stable-diffusion-xl-base-1.0"
27-
2815
16+ from typing import Optional
2917
30- from diffusers import StableDiffusionXLPipeline , DiffusionPipeline
31- from transformers import T5Tokenizer , T5EncoderModel
18+ import torch .nn as nn
3219from transformers import (
3320 CLIPImageProcessor ,
34- CLIPTextModel ,
35- CLIPTextModelWithProjection ,
3621 CLIPTokenizer ,
3722 CLIPVisionModelWithProjection ,
23+ T5EncoderModel ,
3824)
3925
40- from diffusers .models import AutoencoderKL , ImageProjection , UNet2DConditionModel
26+ from diffusers import DiffusionPipeline , StableDiffusionXLPipeline
27+ from diffusers .image_processor import VaeImageProcessor
28+ from diffusers .models import AutoencoderKL , UNet2DConditionModel
4129from diffusers .schedulers import KarrasDiffusionSchedulers
42- from diffusers .image_processor import PipelineImageInput , VaeImageProcessor
4330
4431
45- from typing import Optional
32+ # Note: At this time, the intent is to use the T5 encoder mentioned
33+ # below, with zero changes.
34+ # Therefore, the model deliberately does not store the T5 encoder model bytes,
35+ # (Since they are not unique!)
36+ # but instead takes advantage of huggingface hub cache loading
4637
47- import torch .nn as nn , torch , types
38+ T5_NAME = "mcmonkey/google_t5-v1_1-xxl_encoderonly"
39+
40+ # Caller is expected to load this, or equivalent, as model name for now
41+ # eg: pipe = StableDiffusionXL_T5Pipeline(SDXL_NAME)
42+ SDXL_NAME = "stabilityai/stable-diffusion-xl-base-1.0"
4843
49- import torch .nn as nn
5044
5145class LinearWithDtype (nn .Linear ):
5246 @property
@@ -56,14 +50,23 @@ def dtype(self):
5650
5751class StableDiffusionXL_T5Pipeline (StableDiffusionXLPipeline ):
5852 _expected_modules = [
59- "vae" , "unet" , "scheduler" , "tokenizer" ,
60- "image_encoder" , "feature_extractor" ,
61- "t5_encoder" , "t5_projection" , "t5_pooled_projection" ,
53+ "vae" ,
54+ "unet" ,
55+ "scheduler" ,
56+ "tokenizer" ,
57+ "image_encoder" ,
58+ "feature_extractor" ,
59+ "t5_encoder" ,
60+ "t5_projection" ,
61+ "t5_pooled_projection" ,
6262 ]
6363
6464 _optional_components = [
65- "image_encoder" , "feature_extractor" ,
66- "t5_encoder" , "t5_projection" , "t5_pooled_projection" ,
65+ "image_encoder" ,
66+ "feature_extractor" ,
67+ "t5_encoder" ,
68+ "t5_projection" ,
69+ "t5_pooled_projection" ,
6770 ]
6871
6972 def __init__ (
@@ -83,25 +86,24 @@ def __init__(
8386 DiffusionPipeline .__init__ (self )
8487
8588 if t5_encoder is None :
86- self .t5_encoder = T5EncoderModel .from_pretrained (T5_NAME ,
87- torch_dtype = unet .dtype )
89+ self .t5_encoder = T5EncoderModel .from_pretrained (T5_NAME , torch_dtype = unet .dtype )
8890 else :
89- self .t5_encoder = t5_encoder
91+ self .t5_encoder = t5_encoder
9092
9193 # ----- build T5 4096 => 2048 dim projection -----
9294 if t5_projection is None :
93- self .t5_projection = LinearWithDtype (4096 , 2048 ) # trainable
95+ self .t5_projection = LinearWithDtype (4096 , 2048 ) # trainable
9496 else :
95- self .t5_projection = t5_projection
97+ self .t5_projection = t5_projection
9698 self .t5_projection .to (dtype = unet .dtype )
9799 # ----- build T5 4096 => 1280 dim projection -----
98100 if t5_pooled_projection is None :
99- self .t5_pooled_projection = LinearWithDtype (4096 , 1280 ) # trainable
101+ self .t5_pooled_projection = LinearWithDtype (4096 , 1280 ) # trainable
100102 else :
101- self .t5_pooled_projection = t5_pooled_projection
103+ self .t5_pooled_projection = t5_pooled_projection
102104 self .t5_pooled_projection .to (dtype = unet .dtype )
103105
104- print ("dtype of Linear is " ,self .t5_projection .dtype )
106+ print ("dtype of Linear is " , self .t5_projection .dtype )
105107
106108 self .register_modules (
107109 vae = vae ,
@@ -165,13 +167,13 @@ def _tok(text: str):
165167
166168 # ---------- positive stream -------------------------------------
167169 ids , mask = _tok (prompt )
168- h_pos = self .t5_encoder (ids , attention_mask = mask ).last_hidden_state # [b, T, 4096]
169- tok_pos = self .t5_projection (h_pos ) # [b, T, 2048]
170- pool_pos = self .t5_pooled_projection (h_pos .mean (dim = 1 )) # [b, 1280]
170+ h_pos = self .t5_encoder (ids , attention_mask = mask ).last_hidden_state # [b, T, 4096]
171+ tok_pos = self .t5_projection (h_pos ) # [b, T, 2048]
172+ pool_pos = self .t5_pooled_projection (h_pos .mean (dim = 1 )) # [b, 1280]
171173
172174 # expand for multiple images per prompt
173- tok_pos = tok_pos .repeat_interleave (num_images_per_prompt , 0 )
174- pool_pos = pool_pos .repeat_interleave (num_images_per_prompt , 0 )
175+ tok_pos = tok_pos .repeat_interleave (num_images_per_prompt , 0 )
176+ pool_pos = pool_pos .repeat_interleave (num_images_per_prompt , 0 )
175177
176178 # ---------- negative / CFG stream --------------------------------
177179 if do_classifier_free_guidance :
@@ -181,7 +183,7 @@ def _tok(text: str):
181183 tok_neg = self .t5_projection (h_neg )
182184 pool_neg = self .t5_pooled_projection (h_neg .mean (dim = 1 ))
183185
184- tok_neg = tok_neg .repeat_interleave (num_images_per_prompt , 0 )
186+ tok_neg = tok_neg .repeat_interleave (num_images_per_prompt , 0 )
185187 pool_neg = pool_neg .repeat_interleave (num_images_per_prompt , 0 )
186188 else :
187189 tok_neg = pool_neg = None
0 commit comments