Skip to content

Commit b64ade5

Browse files
feature: support TAESD - Tiny Autoencoder for Stable Diffusion (#4316)
[TAESD - Tiny Autoencoder for Stable Diffusion](https://github.com/madebyollin/taesd) - is a tiny VAE that provides significantly better results than my single-multiplication hack but is still very fast. The entire TAESD model weights are under 10 MB! This PR requires diffusers 0.20: - [x] #4311 ## To Do Test with - [x] SD 1.x - [ ] SD 2.x: #4415 - [x] SDXL ## Have you discussed this change with the InvokeAI team? - See [TAESD Invocation API](https://discord.com/channels/1020123559063990373/1137857402453119166) ## Have you updated all relevant documentation? - [ ] No ## Related Tickets & Documents <!-- For pull requests that relate or close an issue, please include them below. For example having the text: "closes #1234" would connect the current pull request to issue 1234. And when we merge the pull request, Github will automatically close the issue. --> - Related Issue # - Closes # ## QA Instructions, Screenshots, Recordings Should be able to import these models: - [madebyollin/taesd](https://huggingface.co/madebyollin/taesd) - [madebyollin/taesdxl](https://huggingface.co/madebyollin/taesdxl) and use them as VAE. <!-- Please provide steps on how to test changes, any hardware or software specifications as well as any other pertinent information. --> ## Added/updated tests? - [x] Some. There are new tests for VaeFolderProbe based on VAE configurations, but no tests that require the full model weights.
2 parents 24d0901 + 3c44a74 commit b64ade5

File tree

7 files changed

+202
-8
lines changed

7 files changed

+202
-8
lines changed

invokeai/app/invocations/latent.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
22

33
from contextlib import ExitStack
4+
from functools import singledispatchmethod
45
from typing import List, Literal, Optional, Union
56

67
import einops
78
import numpy as np
89
import torch
910
import torchvision.transforms as T
11+
from diffusers import AutoencoderKL, AutoencoderTiny
1012
from diffusers.image_processor import VaeImageProcessor
1113
from diffusers.models import UNet2DConditionModel
1214
from diffusers.models.attention_processor import (
@@ -857,8 +859,7 @@ def vae_encode(vae_info, upcast, tiled, image_tensor):
857859
# non_noised_latents_from_image
858860
image_tensor = image_tensor.to(device=vae.device, dtype=vae.dtype)
859861
with torch.inference_mode():
860-
image_tensor_dist = vae.encode(image_tensor).latent_dist
861-
latents = image_tensor_dist.sample().to(dtype=vae.dtype) # FIXME: uses torch.randn. make reproducible!
862+
latents = ImageToLatentsInvocation._encode_to_tensor(vae, image_tensor)
862863

863864
latents = vae.config.scaling_factor * latents
864865
latents = latents.to(dtype=orig_dtype)
@@ -885,6 +886,18 @@ def invoke(self, context: InvocationContext) -> LatentsOutput:
885886
context.services.latents.save(name, latents)
886887
return build_latents_output(latents_name=name, latents=latents, seed=None)
887888

889+
@singledispatchmethod
890+
@staticmethod
891+
def _encode_to_tensor(vae: AutoencoderKL, image_tensor: torch.FloatTensor) -> torch.FloatTensor:
892+
image_tensor_dist = vae.encode(image_tensor).latent_dist
893+
latents = image_tensor_dist.sample().to(dtype=vae.dtype) # FIXME: uses torch.randn. make reproducible!
894+
return latents
895+
896+
@_encode_to_tensor.register
897+
@staticmethod
898+
def _(vae: AutoencoderTiny, image_tensor: torch.FloatTensor) -> torch.FloatTensor:
899+
return vae.encode(image_tensor).latents
900+
888901

889902
@invocation("lblend", title="Blend Latents", tags=["latents", "blend"], category="latents", version="1.0.0")
890903
class BlendLatentsInvocation(BaseInvocation):

invokeai/backend/model_management/model_probe.py

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import json
2+
import re
23
from dataclasses import dataclass
34
from pathlib import Path
45
from typing import Callable, Dict, Literal, Optional, Union
@@ -53,6 +54,7 @@ class ModelProbe(object):
5354
"StableDiffusionXLImg2ImgPipeline": ModelType.Main,
5455
"StableDiffusionXLInpaintPipeline": ModelType.Main,
5556
"AutoencoderKL": ModelType.Vae,
57+
"AutoencoderTiny": ModelType.Vae,
5658
"ControlNetModel": ModelType.ControlNet,
5759
"CLIPVisionModelWithProjection": ModelType.CLIPVision,
5860
}
@@ -177,6 +179,7 @@ def get_model_type_from_folder(cls, folder_path: Path, model: ModelMixin) -> Mod
177179
Get the model type of a hugging-face style folder.
178180
"""
179181
class_name = None
182+
error_hint = None
180183
if model:
181184
class_name = model.__class__.__name__
182185
else:
@@ -202,12 +205,18 @@ def get_model_type_from_folder(cls, folder_path: Path, model: ModelMixin) -> Mod
202205
class_name = conf["architectures"][0]
203206
else:
204207
class_name = None
208+
else:
209+
error_hint = f"No model_index.json or config.json found in {folder_path}."
205210

206211
if class_name and (type := cls.CLASS2TYPE.get(class_name)):
207212
return type
213+
else:
214+
error_hint = f"class {class_name} is not one of the supported classes [{', '.join(cls.CLASS2TYPE.keys())}]"
208215

209216
# give up
210-
raise InvalidModelException(f"Unable to determine model type for {folder_path}")
217+
raise InvalidModelException(
218+
f"Unable to determine model type for {folder_path}" + (f"; {error_hint}" if error_hint else "")
219+
)
211220

212221
@classmethod
213222
def _scan_and_load_checkpoint(cls, model_path: Path) -> dict:
@@ -461,16 +470,32 @@ def get_variant_type(self) -> ModelVariantType:
461470

462471
class VaeFolderProbe(FolderProbeBase):
463472
def get_base_type(self) -> BaseModelType:
473+
if self._config_looks_like_sdxl():
474+
return BaseModelType.StableDiffusionXL
475+
elif self._name_looks_like_sdxl():
476+
# but SD and SDXL VAE are the same shape (3-channel RGB to 4-channel float scaled down
477+
# by a factor of 8), we can't necessarily tell them apart by config hyperparameters.
478+
return BaseModelType.StableDiffusionXL
479+
else:
480+
return BaseModelType.StableDiffusion1
481+
482+
def _config_looks_like_sdxl(self) -> bool:
483+
# config values that distinguish Stability's SD 1.x VAE from their SDXL VAE.
464484
config_file = self.folder_path / "config.json"
465485
if not config_file.exists():
466486
raise InvalidModelException(f"Cannot determine base type for {self.folder_path}")
467487
with open(config_file, "r") as file:
468488
config = json.load(file)
469-
return (
470-
BaseModelType.StableDiffusionXL
471-
if config.get("scaling_factor", 0) == 0.13025 and config.get("sample_size") in [512, 1024]
472-
else BaseModelType.StableDiffusion1
473-
)
489+
return config.get("scaling_factor", 0) == 0.13025 and config.get("sample_size") in [512, 1024]
490+
491+
def _name_looks_like_sdxl(self) -> bool:
492+
return bool(re.search(r"xl\b", self._guess_name(), re.IGNORECASE))
493+
494+
def _guess_name(self) -> str:
495+
name = self.folder_path.name
496+
if name == "vae":
497+
name = self.folder_path.parent.name
498+
return name
474499

475500

476501
class TextualInversionFolderProbe(FolderProbeBase):

tests/test_model_probe.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
from pathlib import Path
2+
3+
import pytest
4+
5+
from invokeai.backend import BaseModelType
6+
from invokeai.backend.model_management.model_probe import VaeFolderProbe
7+
8+
9+
@pytest.mark.parametrize(
10+
"vae_path,expected_type",
11+
[
12+
("sd-vae-ft-mse", BaseModelType.StableDiffusion1),
13+
("sdxl-vae", BaseModelType.StableDiffusionXL),
14+
("taesd", BaseModelType.StableDiffusion1),
15+
("taesdxl", BaseModelType.StableDiffusionXL),
16+
],
17+
)
18+
def test_get_base_type(vae_path: str, expected_type: BaseModelType, datadir: Path):
19+
sd1_vae_path = datadir / "vae" / vae_path
20+
probe = VaeFolderProbe(sd1_vae_path)
21+
base_type = probe.get_base_type()
22+
assert base_type == expected_type
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
{
2+
"_class_name": "AutoencoderKL",
3+
"_diffusers_version": "0.4.2",
4+
"act_fn": "silu",
5+
"block_out_channels": [
6+
128,
7+
256,
8+
512,
9+
512
10+
],
11+
"down_block_types": [
12+
"DownEncoderBlock2D",
13+
"DownEncoderBlock2D",
14+
"DownEncoderBlock2D",
15+
"DownEncoderBlock2D"
16+
],
17+
"in_channels": 3,
18+
"latent_channels": 4,
19+
"layers_per_block": 2,
20+
"norm_num_groups": 32,
21+
"out_channels": 3,
22+
"sample_size": 256,
23+
"up_block_types": [
24+
"UpDecoderBlock2D",
25+
"UpDecoderBlock2D",
26+
"UpDecoderBlock2D",
27+
"UpDecoderBlock2D"
28+
]
29+
}
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
{
2+
"_class_name": "AutoencoderKL",
3+
"_diffusers_version": "0.18.0.dev0",
4+
"_name_or_path": ".",
5+
"act_fn": "silu",
6+
"block_out_channels": [
7+
128,
8+
256,
9+
512,
10+
512
11+
],
12+
"down_block_types": [
13+
"DownEncoderBlock2D",
14+
"DownEncoderBlock2D",
15+
"DownEncoderBlock2D",
16+
"DownEncoderBlock2D"
17+
],
18+
"in_channels": 3,
19+
"latent_channels": 4,
20+
"layers_per_block": 2,
21+
"norm_num_groups": 32,
22+
"out_channels": 3,
23+
"sample_size": 1024,
24+
"scaling_factor": 0.13025,
25+
"up_block_types": [
26+
"UpDecoderBlock2D",
27+
"UpDecoderBlock2D",
28+
"UpDecoderBlock2D",
29+
"UpDecoderBlock2D"
30+
]
31+
}
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
{
2+
"_class_name": "AutoencoderTiny",
3+
"_diffusers_version": "0.20.0.dev0",
4+
"act_fn": "relu",
5+
"decoder_block_out_channels": [
6+
64,
7+
64,
8+
64,
9+
64
10+
],
11+
"encoder_block_out_channels": [
12+
64,
13+
64,
14+
64,
15+
64
16+
],
17+
"force_upcast": false,
18+
"in_channels": 3,
19+
"latent_channels": 4,
20+
"latent_magnitude": 3,
21+
"latent_shift": 0.5,
22+
"num_decoder_blocks": [
23+
3,
24+
3,
25+
3,
26+
1
27+
],
28+
"num_encoder_blocks": [
29+
1,
30+
3,
31+
3,
32+
3
33+
],
34+
"out_channels": 3,
35+
"scaling_factor": 1.0,
36+
"upsampling_scaling_factor": 2
37+
}
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
{
2+
"_class_name": "AutoencoderTiny",
3+
"_diffusers_version": "0.20.0.dev0",
4+
"act_fn": "relu",
5+
"decoder_block_out_channels": [
6+
64,
7+
64,
8+
64,
9+
64
10+
],
11+
"encoder_block_out_channels": [
12+
64,
13+
64,
14+
64,
15+
64
16+
],
17+
"force_upcast": false,
18+
"in_channels": 3,
19+
"latent_channels": 4,
20+
"latent_magnitude": 3,
21+
"latent_shift": 0.5,
22+
"num_decoder_blocks": [
23+
3,
24+
3,
25+
3,
26+
1
27+
],
28+
"num_encoder_blocks": [
29+
1,
30+
3,
31+
3,
32+
3
33+
],
34+
"out_channels": 3,
35+
"scaling_factor": 1.0,
36+
"upsampling_scaling_factor": 2
37+
}

0 commit comments

Comments
 (0)