|  | 
|  | 1 | +# Copyright Philip Brown, ppbrown@github | 
|  | 2 | +# | 
|  | 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); | 
|  | 4 | +# you may not use this file except in compliance with the License. | 
|  | 5 | +# You may obtain a copy of the License at | 
|  | 6 | +# | 
|  | 7 | +#     http://www.apache.org/licenses/LICENSE-2.0 | 
|  | 8 | +# | 
|  | 9 | +# Unless required by applicable law or agreed to in writing, software | 
|  | 10 | +# distributed under the License is distributed on an "AS IS" BASIS, | 
|  | 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
|  | 12 | +# See the License for the specific language governing permissions and | 
|  | 13 | +# limitations under the License. | 
|  | 14 | + | 
|  | 15 | +########################################################################### | 
|  | 16 | +# This pipeline attempts to use a model that has SDXL vae, T5 text encoder, | 
|  | 17 | +# and SDXL unet. | 
|  | 18 | +# At the present time, there are no pretrained models that give pleasing | 
|  | 19 | +# output. So as yet, (2025/06/10) this pipeline is somewhat of a tech | 
|  | 20 | +# demo proving that the pieces can at least be put together. | 
|  | 21 | +# Hopefully, it will encourage someone with the hardware available to | 
|  | 22 | +# throw enough resources into training one up. | 
|  | 23 | + | 
|  | 24 | + | 
|  | 25 | +from typing import Optional | 
|  | 26 | + | 
|  | 27 | +import torch.nn as nn | 
|  | 28 | +from transformers import ( | 
|  | 29 | +    CLIPImageProcessor, | 
|  | 30 | +    CLIPTokenizer, | 
|  | 31 | +    CLIPVisionModelWithProjection, | 
|  | 32 | +    T5EncoderModel, | 
|  | 33 | +) | 
|  | 34 | + | 
|  | 35 | +from diffusers import DiffusionPipeline, StableDiffusionXLPipeline | 
|  | 36 | +from diffusers.image_processor import VaeImageProcessor | 
|  | 37 | +from diffusers.models import AutoencoderKL, UNet2DConditionModel | 
|  | 38 | +from diffusers.schedulers import KarrasDiffusionSchedulers | 
|  | 39 | + | 
|  | 40 | + | 
|  | 41 | +# Note: At this time, the intent is to use the T5 encoder mentioned | 
|  | 42 | +# below, with zero changes. | 
|  | 43 | +# Therefore, the model deliberately does not store the T5 encoder model bytes, | 
|  | 44 | +# (Since they are not unique!) | 
|  | 45 | +# but instead takes advantage of huggingface hub cache loading | 
|  | 46 | + | 
|  | 47 | +T5_NAME = "mcmonkey/google_t5-v1_1-xxl_encoderonly" | 
|  | 48 | + | 
|  | 49 | +# Caller is expected to load this, or equivalent, as model name for now | 
|  | 50 | +#   eg: pipe = StableDiffusionXL_T5Pipeline(SDXL_NAME) | 
|  | 51 | +SDXL_NAME = "stabilityai/stable-diffusion-xl-base-1.0" | 
|  | 52 | + | 
|  | 53 | + | 
|  | 54 | +class LinearWithDtype(nn.Linear): | 
|  | 55 | +    @property | 
|  | 56 | +    def dtype(self): | 
|  | 57 | +        return self.weight.dtype | 
|  | 58 | + | 
|  | 59 | + | 
|  | 60 | +class StableDiffusionXL_T5Pipeline(StableDiffusionXLPipeline): | 
|  | 61 | +    _expected_modules = [ | 
|  | 62 | +        "vae", | 
|  | 63 | +        "unet", | 
|  | 64 | +        "scheduler", | 
|  | 65 | +        "tokenizer", | 
|  | 66 | +        "image_encoder", | 
|  | 67 | +        "feature_extractor", | 
|  | 68 | +        "t5_encoder", | 
|  | 69 | +        "t5_projection", | 
|  | 70 | +        "t5_pooled_projection", | 
|  | 71 | +    ] | 
|  | 72 | + | 
|  | 73 | +    _optional_components = [ | 
|  | 74 | +        "image_encoder", | 
|  | 75 | +        "feature_extractor", | 
|  | 76 | +        "t5_encoder", | 
|  | 77 | +        "t5_projection", | 
|  | 78 | +        "t5_pooled_projection", | 
|  | 79 | +    ] | 
|  | 80 | + | 
|  | 81 | +    def __init__( | 
|  | 82 | +        self, | 
|  | 83 | +        vae: AutoencoderKL, | 
|  | 84 | +        unet: UNet2DConditionModel, | 
|  | 85 | +        scheduler: KarrasDiffusionSchedulers, | 
|  | 86 | +        tokenizer: CLIPTokenizer, | 
|  | 87 | +        t5_encoder=None, | 
|  | 88 | +        t5_projection=None, | 
|  | 89 | +        t5_pooled_projection=None, | 
|  | 90 | +        image_encoder: CLIPVisionModelWithProjection = None, | 
|  | 91 | +        feature_extractor: CLIPImageProcessor = None, | 
|  | 92 | +        force_zeros_for_empty_prompt: bool = True, | 
|  | 93 | +        add_watermarker: Optional[bool] = None, | 
|  | 94 | +    ): | 
|  | 95 | +        DiffusionPipeline.__init__(self) | 
|  | 96 | + | 
|  | 97 | +        if t5_encoder is None: | 
|  | 98 | +            self.t5_encoder = T5EncoderModel.from_pretrained(T5_NAME, torch_dtype=unet.dtype) | 
|  | 99 | +        else: | 
|  | 100 | +            self.t5_encoder = t5_encoder | 
|  | 101 | + | 
|  | 102 | +        # ----- build T5 4096 => 2048 dim projection ----- | 
|  | 103 | +        if t5_projection is None: | 
|  | 104 | +            self.t5_projection = LinearWithDtype(4096, 2048)  # trainable | 
|  | 105 | +        else: | 
|  | 106 | +            self.t5_projection = t5_projection | 
|  | 107 | +        self.t5_projection.to(dtype=unet.dtype) | 
|  | 108 | +        # ----- build T5 4096 => 1280 dim projection ----- | 
|  | 109 | +        if t5_pooled_projection is None: | 
|  | 110 | +            self.t5_pooled_projection = LinearWithDtype(4096, 1280)  # trainable | 
|  | 111 | +        else: | 
|  | 112 | +            self.t5_pooled_projection = t5_pooled_projection | 
|  | 113 | +        self.t5_pooled_projection.to(dtype=unet.dtype) | 
|  | 114 | + | 
|  | 115 | +        print("dtype of Linear is ", self.t5_projection.dtype) | 
|  | 116 | + | 
|  | 117 | +        self.register_modules( | 
|  | 118 | +            vae=vae, | 
|  | 119 | +            unet=unet, | 
|  | 120 | +            scheduler=scheduler, | 
|  | 121 | +            tokenizer=tokenizer, | 
|  | 122 | +            t5_encoder=self.t5_encoder, | 
|  | 123 | +            t5_projection=self.t5_projection, | 
|  | 124 | +            t5_pooled_projection=self.t5_pooled_projection, | 
|  | 125 | +            image_encoder=image_encoder, | 
|  | 126 | +            feature_extractor=feature_extractor, | 
|  | 127 | +        ) | 
|  | 128 | +        self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) | 
|  | 129 | +        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 | 
|  | 130 | +        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) | 
|  | 131 | + | 
|  | 132 | +        self.default_sample_size = ( | 
|  | 133 | +            self.unet.config.sample_size | 
|  | 134 | +            if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size") | 
|  | 135 | +            else 128 | 
|  | 136 | +        ) | 
|  | 137 | + | 
|  | 138 | +        self.watermark = None | 
|  | 139 | + | 
|  | 140 | +        # Parts of original SDXL class complain if these attributes are not | 
|  | 141 | +        # at least PRESENT | 
|  | 142 | +        self.text_encoder = self.text_encoder_2 = None | 
|  | 143 | + | 
|  | 144 | +    # ------------------------------------------------------------------ | 
|  | 145 | +    #  Encode a text prompt (T5-XXL + 4096→2048 projection) | 
|  | 146 | +    #  Returns exactly four tensors in the order SDXL’s __call__ expects. | 
|  | 147 | +    # ------------------------------------------------------------------ | 
|  | 148 | +    def encode_prompt( | 
|  | 149 | +        self, | 
|  | 150 | +        prompt, | 
|  | 151 | +        num_images_per_prompt: int = 1, | 
|  | 152 | +        do_classifier_free_guidance: bool = True, | 
|  | 153 | +        negative_prompt: str | None = None, | 
|  | 154 | +        **_, | 
|  | 155 | +    ): | 
|  | 156 | +        """ | 
|  | 157 | +        Returns | 
|  | 158 | +        ------- | 
|  | 159 | +        prompt_embeds                : Tensor [B, T, 2048] | 
|  | 160 | +        negative_prompt_embeds       : Tensor [B, T, 2048] | None | 
|  | 161 | +        pooled_prompt_embeds         : Tensor [B, 1280] | 
|  | 162 | +        negative_pooled_prompt_embeds: Tensor [B, 1280]    | None | 
|  | 163 | +        where B = batch * num_images_per_prompt | 
|  | 164 | +        """ | 
|  | 165 | + | 
|  | 166 | +        # --- helper to tokenize on the pipeline’s device ---------------- | 
|  | 167 | +        def _tok(text: str): | 
|  | 168 | +            tok_out = self.tokenizer( | 
|  | 169 | +                text, | 
|  | 170 | +                return_tensors="pt", | 
|  | 171 | +                padding="max_length", | 
|  | 172 | +                max_length=self.tokenizer.model_max_length, | 
|  | 173 | +                truncation=True, | 
|  | 174 | +            ).to(self.device) | 
|  | 175 | +            return tok_out.input_ids, tok_out.attention_mask | 
|  | 176 | + | 
|  | 177 | +        # ---------- positive stream ------------------------------------- | 
|  | 178 | +        ids, mask = _tok(prompt) | 
|  | 179 | +        h_pos = self.t5_encoder(ids, attention_mask=mask).last_hidden_state  # [b, T, 4096] | 
|  | 180 | +        tok_pos = self.t5_projection(h_pos)  # [b, T, 2048] | 
|  | 181 | +        pool_pos = self.t5_pooled_projection(h_pos.mean(dim=1))  # [b, 1280] | 
|  | 182 | + | 
|  | 183 | +        # expand for multiple images per prompt | 
|  | 184 | +        tok_pos = tok_pos.repeat_interleave(num_images_per_prompt, 0) | 
|  | 185 | +        pool_pos = pool_pos.repeat_interleave(num_images_per_prompt, 0) | 
|  | 186 | + | 
|  | 187 | +        # ---------- negative / CFG stream -------------------------------- | 
|  | 188 | +        if do_classifier_free_guidance: | 
|  | 189 | +            neg_text = "" if negative_prompt is None else negative_prompt | 
|  | 190 | +            ids_n, mask_n = _tok(neg_text) | 
|  | 191 | +            h_neg = self.t5_encoder(ids_n, attention_mask=mask_n).last_hidden_state | 
|  | 192 | +            tok_neg = self.t5_projection(h_neg) | 
|  | 193 | +            pool_neg = self.t5_pooled_projection(h_neg.mean(dim=1)) | 
|  | 194 | + | 
|  | 195 | +            tok_neg = tok_neg.repeat_interleave(num_images_per_prompt, 0) | 
|  | 196 | +            pool_neg = pool_neg.repeat_interleave(num_images_per_prompt, 0) | 
|  | 197 | +        else: | 
|  | 198 | +            tok_neg = pool_neg = None | 
|  | 199 | + | 
|  | 200 | +        # ----------------- final ordered return -------------------------- | 
|  | 201 | +        # 1) positive token embeddings | 
|  | 202 | +        # 2) negative token embeddings (or None) | 
|  | 203 | +        # 3) positive pooled embeddings | 
|  | 204 | +        # 4) negative pooled embeddings (or None) | 
|  | 205 | +        return tok_pos, tok_neg, pool_pos, pool_neg | 
0 commit comments