Skip to content

Commit 66547b9

Browse files
Add more karras schedulers (#6695)
## Summary Add karras variants of `deis`, `unipc`, `kdpm2` and `kdpm_2_a` schedulers. Also added `dpmpp_3` schedulers, but `dpmpp_3s` currently bugged, so added only 3m: huggingface/diffusers#9007 ## Related Issues / Discussions \- ## QA Instructions \- ## Merge Plan ~@psychedelicious We need to decide what to do with schedulers order, as it looks a bit broken:~ ![image](https://github.com/user-attachments/assets/e41674af-d87c-4432-8014-c90bd86965a6) ## Checklist - [x] _The PR has a short but descriptive title, suitable for a changelog_ - [ ] _Tests added / updated (if applicable)_ - [ ] _Documentation added / updated (if applicable)_
2 parents 2755316 + 328e58b commit 66547b9

File tree

4 files changed

+63
-30
lines changed

4 files changed

+63
-30
lines changed

invokeai/backend/stable_diffusion/schedulers/schedulers.py

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,14 @@
2020
)
2121
from diffusers.schedulers.scheduling_utils import SchedulerMixin
2222

23+
# TODO: add dpmpp_3s/dpmpp_3s_k when fix released
24+
# https://github.com/huggingface/diffusers/issues/9007
25+
2326
SCHEDULER_NAME_VALUES = Literal[
2427
"ddim",
2528
"ddpm",
2629
"deis",
30+
"deis_k",
2731
"lms",
2832
"lms_k",
2933
"pndm",
@@ -33,24 +37,30 @@
3337
"euler_k",
3438
"euler_a",
3539
"kdpm_2",
40+
"kdpm_2_k",
3641
"kdpm_2_a",
42+
"kdpm_2_a_k",
3743
"dpmpp_2s",
3844
"dpmpp_2s_k",
3945
"dpmpp_2m",
4046
"dpmpp_2m_k",
4147
"dpmpp_2m_sde",
4248
"dpmpp_2m_sde_k",
49+
"dpmpp_3m",
50+
"dpmpp_3m_k",
4351
"dpmpp_sde",
4452
"dpmpp_sde_k",
4553
"unipc",
54+
"unipc_k",
4655
"lcm",
4756
"tcd",
4857
]
4958

5059
SCHEDULER_MAP: dict[SCHEDULER_NAME_VALUES, tuple[Type[SchedulerMixin], dict[str, Any]]] = {
5160
"ddim": (DDIMScheduler, {}),
5261
"ddpm": (DDPMScheduler, {}),
53-
"deis": (DEISMultistepScheduler, {}),
62+
"deis": (DEISMultistepScheduler, {"use_karras_sigmas": False}),
63+
"deis_k": (DEISMultistepScheduler, {"use_karras_sigmas": True}),
5464
"lms": (LMSDiscreteScheduler, {"use_karras_sigmas": False}),
5565
"lms_k": (LMSDiscreteScheduler, {"use_karras_sigmas": True}),
5666
"pndm": (PNDMScheduler, {}),
@@ -59,17 +69,28 @@
5969
"euler": (EulerDiscreteScheduler, {"use_karras_sigmas": False}),
6070
"euler_k": (EulerDiscreteScheduler, {"use_karras_sigmas": True}),
6171
"euler_a": (EulerAncestralDiscreteScheduler, {}),
62-
"kdpm_2": (KDPM2DiscreteScheduler, {}),
63-
"kdpm_2_a": (KDPM2AncestralDiscreteScheduler, {}),
64-
"dpmpp_2s": (DPMSolverSinglestepScheduler, {"use_karras_sigmas": False}),
65-
"dpmpp_2s_k": (DPMSolverSinglestepScheduler, {"use_karras_sigmas": True}),
66-
"dpmpp_2m": (DPMSolverMultistepScheduler, {"use_karras_sigmas": False}),
67-
"dpmpp_2m_k": (DPMSolverMultistepScheduler, {"use_karras_sigmas": True}),
68-
"dpmpp_2m_sde": (DPMSolverMultistepScheduler, {"use_karras_sigmas": False, "algorithm_type": "sde-dpmsolver++"}),
69-
"dpmpp_2m_sde_k": (DPMSolverMultistepScheduler, {"use_karras_sigmas": True, "algorithm_type": "sde-dpmsolver++"}),
72+
"kdpm_2": (KDPM2DiscreteScheduler, {"use_karras_sigmas": False}),
73+
"kdpm_2_k": (KDPM2DiscreteScheduler, {"use_karras_sigmas": True}),
74+
"kdpm_2_a": (KDPM2AncestralDiscreteScheduler, {"use_karras_sigmas": False}),
75+
"kdpm_2_a_k": (KDPM2AncestralDiscreteScheduler, {"use_karras_sigmas": True}),
76+
"dpmpp_2s": (DPMSolverSinglestepScheduler, {"use_karras_sigmas": False, "solver_order": 2}),
77+
"dpmpp_2s_k": (DPMSolverSinglestepScheduler, {"use_karras_sigmas": True, "solver_order": 2}),
78+
"dpmpp_2m": (DPMSolverMultistepScheduler, {"use_karras_sigmas": False, "solver_order": 2}),
79+
"dpmpp_2m_k": (DPMSolverMultistepScheduler, {"use_karras_sigmas": True, "solver_order": 2}),
80+
"dpmpp_2m_sde": (
81+
DPMSolverMultistepScheduler,
82+
{"use_karras_sigmas": False, "solver_order": 2, "algorithm_type": "sde-dpmsolver++"},
83+
),
84+
"dpmpp_2m_sde_k": (
85+
DPMSolverMultistepScheduler,
86+
{"use_karras_sigmas": True, "solver_order": 2, "algorithm_type": "sde-dpmsolver++"},
87+
),
88+
"dpmpp_3m": (DPMSolverMultistepScheduler, {"use_karras_sigmas": False, "solver_order": 3}),
89+
"dpmpp_3m_k": (DPMSolverMultistepScheduler, {"use_karras_sigmas": True, "solver_order": 3}),
7090
"dpmpp_sde": (DPMSolverSDEScheduler, {"use_karras_sigmas": False, "noise_sampler_seed": 0}),
7191
"dpmpp_sde_k": (DPMSolverSDEScheduler, {"use_karras_sigmas": True, "noise_sampler_seed": 0}),
72-
"unipc": (UniPCMultistepScheduler, {"cpu_only": True}),
92+
"unipc": (UniPCMultistepScheduler, {"use_karras_sigmas": False, "cpu_only": True}),
93+
"unipc_k": (UniPCMultistepScheduler, {"use_karras_sigmas": True, "cpu_only": True}),
7394
"lcm": (LCMScheduler, {}),
7495
"tcd": (TCDScheduler, {}),
7596
}

invokeai/frontend/web/src/features/nodes/types/common.ts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ export const zSchedulerField = z.enum([
3232
'ddpm',
3333
'dpmpp_2s',
3434
'dpmpp_2m',
35+
'dpmpp_3m',
3536
'dpmpp_2m_sde',
3637
'dpmpp_sde',
3738
'heun',
@@ -40,12 +41,17 @@ export const zSchedulerField = z.enum([
4041
'pndm',
4142
'unipc',
4243
'euler_k',
44+
'deis_k',
4345
'dpmpp_2s_k',
4446
'dpmpp_2m_k',
47+
'dpmpp_3m_k',
4548
'dpmpp_2m_sde_k',
4649
'dpmpp_sde_k',
4750
'heun_k',
51+
'kdpm_2_k',
52+
'kdpm_2_a_k',
4853
'lms_k',
54+
'unipc_k',
4955
'euler_a',
5056
'kdpm_2_a',
5157
'lcm',

invokeai/frontend/web/src/features/parameters/types/constants.ts

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -52,28 +52,34 @@ export const CLIP_SKIP_MAP = {
5252
* Mapping of schedulers to human readable name
5353
*/
5454
export const SCHEDULER_OPTIONS: ComboboxOption[] = [
55-
{ value: 'euler', label: 'Euler' },
56-
{ value: 'deis', label: 'DEIS' },
5755
{ value: 'ddim', label: 'DDIM' },
5856
{ value: 'ddpm', label: 'DDPM' },
59-
{ value: 'dpmpp_sde', label: 'DPM++ SDE' },
57+
{ value: 'deis', label: 'DEIS' },
58+
{ value: 'deis_k', label: 'DEIS Karras' },
6059
{ value: 'dpmpp_2s', label: 'DPM++ 2S' },
61-
{ value: 'dpmpp_2m', label: 'DPM++ 2M' },
62-
{ value: 'dpmpp_2m_sde', label: 'DPM++ 2M SDE' },
63-
{ value: 'heun', label: 'Heun' },
64-
{ value: 'kdpm_2', label: 'KDPM 2' },
65-
{ value: 'lms', label: 'LMS' },
66-
{ value: 'pndm', label: 'PNDM' },
67-
{ value: 'unipc', label: 'UniPC' },
68-
{ value: 'euler_k', label: 'Euler Karras' },
69-
{ value: 'dpmpp_sde_k', label: 'DPM++ SDE Karras' },
7060
{ value: 'dpmpp_2s_k', label: 'DPM++ 2S Karras' },
61+
{ value: 'dpmpp_2m', label: 'DPM++ 2M' },
7162
{ value: 'dpmpp_2m_k', label: 'DPM++ 2M Karras' },
63+
{ value: 'dpmpp_2m_sde', label: 'DPM++ 2M SDE' },
7264
{ value: 'dpmpp_2m_sde_k', label: 'DPM++ 2M SDE Karras' },
73-
{ value: 'heun_k', label: 'Heun Karras' },
74-
{ value: 'lms_k', label: 'LMS Karras' },
65+
{ value: 'dpmpp_3m', label: 'DPM++ 3M' },
66+
{ value: 'dpmpp_3m_k', label: 'DPM++ 3M Karras' },
67+
{ value: 'dpmpp_sde', label: 'DPM++ SDE' },
68+
{ value: 'dpmpp_sde_k', label: 'DPM++ SDE Karras' },
69+
{ value: 'euler', label: 'Euler' },
70+
{ value: 'euler_k', label: 'Euler Karras' },
7571
{ value: 'euler_a', label: 'Euler Ancestral' },
72+
{ value: 'heun', label: 'Heun' },
73+
{ value: 'heun_k', label: 'Heun Karras' },
74+
{ value: 'kdpm_2', label: 'KDPM 2' },
75+
{ value: 'kdpm_2_k', label: 'KDPM 2 Karras' },
7676
{ value: 'kdpm_2_a', label: 'KDPM 2 Ancestral' },
77+
{ value: 'kdpm_2_a_k', label: 'KDPM 2 Ancestral Karras' },
7778
{ value: 'lcm', label: 'LCM' },
79+
{ value: 'lms', label: 'LMS' },
80+
{ value: 'lms_k', label: 'LMS Karras' },
81+
{ value: 'pndm', label: 'PNDM' },
7882
{ value: 'tcd', label: 'TCD' },
79-
].sort((a, b) => a.label.localeCompare(b.label));
83+
{ value: 'unipc', label: 'UniPC' },
84+
{ value: 'unipc_k', label: 'UniPC Karras' },
85+
];

invokeai/frontend/web/src/services/api/schema.ts

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3553,7 +3553,7 @@ export type components = {
35533553
* @default euler
35543554
* @enum {string}
35553555
*/
3556-
scheduler?: "ddim" | "ddpm" | "deis" | "lms" | "lms_k" | "pndm" | "heun" | "heun_k" | "euler" | "euler_k" | "euler_a" | "kdpm_2" | "kdpm_2_a" | "dpmpp_2s" | "dpmpp_2s_k" | "dpmpp_2m" | "dpmpp_2m_k" | "dpmpp_2m_sde" | "dpmpp_2m_sde_k" | "dpmpp_sde" | "dpmpp_sde_k" | "unipc" | "lcm" | "tcd";
3556+
scheduler?: "ddim" | "ddpm" | "deis" | "deis_k" | "lms" | "lms_k" | "pndm" | "heun" | "heun_k" | "euler" | "euler_k" | "euler_a" | "kdpm_2" | "kdpm_2_k" | "kdpm_2_a" | "kdpm_2_a_k" | "dpmpp_2s" | "dpmpp_2s_k" | "dpmpp_2m" | "dpmpp_2m_k" | "dpmpp_2m_sde" | "dpmpp_2m_sde_k" | "dpmpp_3m" | "dpmpp_3m_k" | "dpmpp_sde" | "dpmpp_sde_k" | "unipc" | "unipc_k" | "lcm" | "tcd";
35573557
/**
35583558
* UNet
35593559
* @description UNet (scheduler, LoRAs)
@@ -8553,7 +8553,7 @@ export type components = {
85538553
* Scheduler
85548554
* @description Default scheduler for this model
85558555
*/
8556-
scheduler?: ("ddim" | "ddpm" | "deis" | "lms" | "lms_k" | "pndm" | "heun" | "heun_k" | "euler" | "euler_k" | "euler_a" | "kdpm_2" | "kdpm_2_a" | "dpmpp_2s" | "dpmpp_2s_k" | "dpmpp_2m" | "dpmpp_2m_k" | "dpmpp_2m_sde" | "dpmpp_2m_sde_k" | "dpmpp_sde" | "dpmpp_sde_k" | "unipc" | "lcm" | "tcd") | null;
8556+
scheduler?: ("ddim" | "ddpm" | "deis" | "deis_k" | "lms" | "lms_k" | "pndm" | "heun" | "heun_k" | "euler" | "euler_k" | "euler_a" | "kdpm_2" | "kdpm_2_k" | "kdpm_2_a" | "kdpm_2_a_k" | "dpmpp_2s" | "dpmpp_2s_k" | "dpmpp_2m" | "dpmpp_2m_k" | "dpmpp_2m_sde" | "dpmpp_2m_sde_k" | "dpmpp_3m" | "dpmpp_3m_k" | "dpmpp_sde" | "dpmpp_sde_k" | "unipc" | "unipc_k" | "lcm" | "tcd") | null;
85578557
/**
85588558
* Steps
85598559
* @description Default number of steps for this model
@@ -11467,7 +11467,7 @@ export type components = {
1146711467
* @default euler
1146811468
* @enum {string}
1146911469
*/
11470-
scheduler?: "ddim" | "ddpm" | "deis" | "lms" | "lms_k" | "pndm" | "heun" | "heun_k" | "euler" | "euler_k" | "euler_a" | "kdpm_2" | "kdpm_2_a" | "dpmpp_2s" | "dpmpp_2s_k" | "dpmpp_2m" | "dpmpp_2m_k" | "dpmpp_2m_sde" | "dpmpp_2m_sde_k" | "dpmpp_sde" | "dpmpp_sde_k" | "unipc" | "lcm" | "tcd";
11470+
scheduler?: "ddim" | "ddpm" | "deis" | "deis_k" | "lms" | "lms_k" | "pndm" | "heun" | "heun_k" | "euler" | "euler_k" | "euler_a" | "kdpm_2" | "kdpm_2_k" | "kdpm_2_a" | "kdpm_2_a_k" | "dpmpp_2s" | "dpmpp_2s_k" | "dpmpp_2m" | "dpmpp_2m_k" | "dpmpp_2m_sde" | "dpmpp_2m_sde_k" | "dpmpp_3m" | "dpmpp_3m_k" | "dpmpp_sde" | "dpmpp_sde_k" | "unipc" | "unipc_k" | "lcm" | "tcd";
1147111471
/**
1147211472
* type
1147311473
* @default scheduler
@@ -11483,7 +11483,7 @@ export type components = {
1148311483
* @description Scheduler to use during inference
1148411484
* @enum {string}
1148511485
*/
11486-
scheduler: "ddim" | "ddpm" | "deis" | "lms" | "lms_k" | "pndm" | "heun" | "heun_k" | "euler" | "euler_k" | "euler_a" | "kdpm_2" | "kdpm_2_a" | "dpmpp_2s" | "dpmpp_2s_k" | "dpmpp_2m" | "dpmpp_2m_k" | "dpmpp_2m_sde" | "dpmpp_2m_sde_k" | "dpmpp_sde" | "dpmpp_sde_k" | "unipc" | "lcm" | "tcd";
11486+
scheduler: "ddim" | "ddpm" | "deis" | "deis_k" | "lms" | "lms_k" | "pndm" | "heun" | "heun_k" | "euler" | "euler_k" | "euler_a" | "kdpm_2" | "kdpm_2_k" | "kdpm_2_a" | "kdpm_2_a_k" | "dpmpp_2s" | "dpmpp_2s_k" | "dpmpp_2m" | "dpmpp_2m_k" | "dpmpp_2m_sde" | "dpmpp_2m_sde_k" | "dpmpp_3m" | "dpmpp_3m_k" | "dpmpp_sde" | "dpmpp_sde_k" | "unipc" | "unipc_k" | "lcm" | "tcd";
1148711487
/**
1148811488
* type
1148911489
* @default scheduler_output
@@ -13261,7 +13261,7 @@ export type components = {
1326113261
* @default euler
1326213262
* @enum {string}
1326313263
*/
13264-
scheduler?: "ddim" | "ddpm" | "deis" | "lms" | "lms_k" | "pndm" | "heun" | "heun_k" | "euler" | "euler_k" | "euler_a" | "kdpm_2" | "kdpm_2_a" | "dpmpp_2s" | "dpmpp_2s_k" | "dpmpp_2m" | "dpmpp_2m_k" | "dpmpp_2m_sde" | "dpmpp_2m_sde_k" | "dpmpp_sde" | "dpmpp_sde_k" | "unipc" | "lcm" | "tcd";
13264+
scheduler?: "ddim" | "ddpm" | "deis" | "deis_k" | "lms" | "lms_k" | "pndm" | "heun" | "heun_k" | "euler" | "euler_k" | "euler_a" | "kdpm_2" | "kdpm_2_k" | "kdpm_2_a" | "kdpm_2_a_k" | "dpmpp_2s" | "dpmpp_2s_k" | "dpmpp_2m" | "dpmpp_2m_k" | "dpmpp_2m_sde" | "dpmpp_2m_sde_k" | "dpmpp_3m" | "dpmpp_3m_k" | "dpmpp_sde" | "dpmpp_sde_k" | "unipc" | "unipc_k" | "lcm" | "tcd";
1326513265
/**
1326613266
* UNet
1326713267
* @description UNet (scheduler, LoRAs)

0 commit comments

Comments
 (0)