Skip to content

Commit 5f0fe3c

Browse files
Suggested changes
Co-Authored-By: Ryan Dick <[email protected]>
1 parent 1748848 commit 5f0fe3c

File tree

3 files changed

+14
-21
lines changed

3 files changed

+14
-21
lines changed

invokeai/backend/stable_diffusion/diffusion_backend.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,8 @@ def apply_cfg(ctx: DenoiseContext) -> torch.Tensor:
100100
if isinstance(guidance_scale, list):
101101
guidance_scale = guidance_scale[ctx.step_index]
102102

103-
# Note: Although logically it same, it seams that precision errors differs.
104-
# This sometimes results in slightly different output.
103+
# Note: Although this `torch.lerp(...)` line is logically equivalent to the current CFG line, it seems to result
104+
# in slightly different outputs. It is suspected that this is caused by small precision differences.
105105
# return torch.lerp(ctx.negative_noise_pred, ctx.positive_noise_pred, guidance_scale)
106106
return ctx.negative_noise_pred + guidance_scale * (ctx.positive_noise_pred - ctx.negative_noise_pred)
107107

invokeai/backend/stable_diffusion/extensions/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from contextlib import contextmanager
44
from dataclasses import dataclass
5-
from typing import TYPE_CHECKING, Callable, Dict, List
5+
from typing import TYPE_CHECKING, Callable, Dict, List, Optional
66

77
import torch
88
from diffusers import UNet2DConditionModel
@@ -56,5 +56,5 @@ def patch_extension(self, context: DenoiseContext):
5656
yield None
5757

5858
@contextmanager
59-
def patch_unet(self, state_dict: Dict[str, torch.Tensor], unet: UNet2DConditionModel):
59+
def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None):
6060
yield None

invokeai/backend/stable_diffusion/extensions/freeu.py

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -15,28 +15,21 @@
1515
class FreeUExt(ExtensionBase):
1616
def __init__(
1717
self,
18-
freeu_config: Optional[FreeUConfig],
18+
freeu_config: FreeUConfig,
1919
):
2020
super().__init__()
21-
self.freeu_config = freeu_config
21+
self._freeu_config = freeu_config
2222

2323
@contextmanager
2424
def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None):
25-
did_apply_freeu = False
26-
try:
27-
assert hasattr(unet, "enable_freeu") # mypy doesn't pick up this attribute?
28-
if self.freeu_config is not None:
29-
unet.enable_freeu(
30-
b1=self.freeu_config.b1,
31-
b2=self.freeu_config.b2,
32-
s1=self.freeu_config.s1,
33-
s2=self.freeu_config.s2,
34-
)
35-
did_apply_freeu = True
25+
unet.enable_freeu(
26+
b1=self._freeu_config.b1,
27+
b2=self._freeu_config.b2,
28+
s1=self._freeu_config.s1,
29+
s2=self._freeu_config.s2,
30+
)
3631

32+
try:
3733
yield
38-
3934
finally:
40-
assert hasattr(unet, "disable_freeu") # mypy doesn't pick up this attribute?
41-
if did_apply_freeu:
42-
unet.disable_freeu()
35+
unet.disable_freeu()

0 commit comments

Comments
 (0)