Skip to content

Commit 35f8781

Browse files
RyanJDickhipsterusername
authored andcommitted
Fix static type errors with SCHEDULER_NAME_VALUES. And, avoid bi-directional cross-directory imports, which contribute to circular import issues.
1 parent 3a24d70 commit 35f8781

File tree

8 files changed

+49
-11
lines changed

8 files changed

+49
-11
lines changed

invokeai/app/invocations/constants.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from typing import Literal
22

3-
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
43
from invokeai.backend.util.devices import TorchDevice
54

65
LATENT_SCALE_FACTOR = 8
@@ -11,9 +10,6 @@
1110
The ratio of image:latent dimensions is LATENT_SCALE_FACTOR:1, or 8:1.
1211
"""
1312

14-
SCHEDULER_NAME_VALUES = Literal[tuple(SCHEDULER_MAP.keys())]
15-
"""A literal type representing the valid scheduler names."""
16-
1713
IMAGE_MODES = Literal["L", "RGB", "RGBA", "CMYK", "YCbCr", "LAB", "HSV", "I", "F"]
1814
"""A literal type for PIL image modes supported by Invoke"""
1915

invokeai/app/invocations/denoise_latents.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from transformers import CLIPVisionModelWithProjection
1818

1919
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
20-
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR, SCHEDULER_NAME_VALUES
20+
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
2121
from invokeai.app.invocations.controlnet_image_processors import ControlField
2222
from invokeai.app.invocations.fields import (
2323
ConditioningField,
@@ -54,6 +54,7 @@
5454
TextConditioningRegions,
5555
)
5656
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
57+
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
5758
from invokeai.backend.util.devices import TorchDevice
5859
from invokeai.backend.util.hotfixes import ControlNetModel
5960
from invokeai.backend.util.mask import to_standard_float_mask

invokeai/app/invocations/scheduler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
2-
from invokeai.app.invocations.constants import SCHEDULER_NAME_VALUES
32
from invokeai.app.invocations.fields import (
43
FieldDescriptions,
54
InputField,
65
OutputField,
76
UIType,
87
)
98
from invokeai.app.services.shared.invocation_context import InvocationContext
9+
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
1010

1111

1212
@invocation_output("scheduler_output")

invokeai/app/invocations/tiled_multi_diffusion_denoise_latents.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from pydantic import field_validator
99

1010
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
11-
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR, SCHEDULER_NAME_VALUES
11+
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
1212
from invokeai.app.invocations.controlnet_image_processors import ControlField
1313
from invokeai.app.invocations.denoise_latents import DenoiseLatentsInvocation, get_scheduler
1414
from invokeai.app.invocations.fields import (
@@ -29,6 +29,7 @@
2929
MultiDiffusionPipeline,
3030
MultiDiffusionRegionConditioning,
3131
)
32+
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
3233
from invokeai.backend.tiles.tiles import (
3334
calc_tiles_min_overlap,
3435
)

invokeai/backend/model_manager/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,10 @@
3030
from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter
3131
from typing_extensions import Annotated, Any, Dict
3232

33-
from invokeai.app.invocations.constants import SCHEDULER_NAME_VALUES
3433
from invokeai.app.util.misc import uuid_string
3534
from invokeai.backend.model_hash.hash_validator import validate_hash
3635
from invokeai.backend.raw_model import RawModel
36+
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
3737

3838
# ModelMixin is the base class for all diffusers and transformers models
3939
# RawModel is the InvokeAI wrapper class for ip_adapters, loras, textual_inversion and onnx runtime

invokeai/backend/stable_diffusion/schedulers/schedulers.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Any, Literal, Type
2+
13
from diffusers import (
24
DDIMScheduler,
35
DDPMScheduler,
@@ -16,8 +18,36 @@
1618
TCDScheduler,
1719
UniPCMultistepScheduler,
1820
)
21+
from diffusers.schedulers.scheduling_utils import SchedulerMixin
22+
23+
SCHEDULER_NAME_VALUES = Literal[
24+
"ddim",
25+
"ddpm",
26+
"deis",
27+
"lms",
28+
"lms_k",
29+
"pndm",
30+
"heun",
31+
"heun_k",
32+
"euler",
33+
"euler_k",
34+
"euler_a",
35+
"kdpm_2",
36+
"kdpm_2_a",
37+
"dpmpp_2s",
38+
"dpmpp_2s_k",
39+
"dpmpp_2m",
40+
"dpmpp_2m_k",
41+
"dpmpp_2m_sde",
42+
"dpmpp_2m_sde_k",
43+
"dpmpp_sde",
44+
"dpmpp_sde_k",
45+
"unipc",
46+
"lcm",
47+
"tcd",
48+
]
1949

20-
SCHEDULER_MAP = {
50+
SCHEDULER_MAP: dict[SCHEDULER_NAME_VALUES, tuple[Type[SchedulerMixin], dict[str, Any]]] = {
2151
"ddim": (DDIMScheduler, {}),
2252
"ddpm": (DDPMScheduler, {}),
2353
"deis": (DEISMultistepScheduler, {}),

invokeai/invocation_api/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
invocation,
1212
invocation_output,
1313
)
14-
from invokeai.app.invocations.constants import SCHEDULER_NAME_VALUES
1514
from invokeai.app.invocations.fields import (
1615
BoardField,
1716
ColorField,
@@ -78,6 +77,7 @@
7877
ConditioningFieldData,
7978
SDXLConditioningInfo,
8079
)
80+
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
8181
from invokeai.backend.util.devices import CPU_DEVICE, CUDA_DEVICE, MPS_DEVICE, choose_precision, choose_torch_device
8282
from invokeai.version import __version__
8383

@@ -163,7 +163,7 @@
163163
"BaseModelType",
164164
"ModelType",
165165
"SubModelType",
166-
# invokeai.app.invocations.constants
166+
# invokeai.backend.stable_diffusion.schedulers.schedulers
167167
"SCHEDULER_NAME_VALUES",
168168
# invokeai.version
169169
"__version__",
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
from typing import get_args
2+
3+
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_MAP, SCHEDULER_NAME_VALUES
4+
5+
6+
def test_scheduler_map_has_all_keys():
7+
# Assert that SCHEDULER_MAP has all keys from SCHEDULER_NAME_VALUES.
8+
# TODO(ryand): This feels like it should be a type check, but I couldn't find a clean way to do this and didn't want
9+
# to spend more time on it.
10+
assert set(SCHEDULER_MAP.keys()) == set(get_args(SCHEDULER_NAME_VALUES))

0 commit comments

Comments
 (0)