Skip to content

Commit 259952a

Browse files
committed
Add config to control whether operations are upcast to fp64
1 parent 0a3f908 commit 259952a

File tree

1 file changed

+14
-5
lines changed

1 file changed

+14
-5
lines changed

src/diffusers/guiders/frequency_decoupled_guidance.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,21 +37,25 @@
3737
build_laplacian_pyramid_func = None
3838

3939

40-
def project(v0: torch.Tensor, v1: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
40+
def project(v0: torch.Tensor, v1: torch.Tensor, upcast_to_double: bool = True) -> Tuple[torch.Tensor, torch.Tensor]:
4141
"""
4242
Project vector v0 onto vector v1, returning the parallel and orthogonal components of v0. Implementation from
4343
paper (Algorithm 2).
4444
"""
4545
# v0 shape: [B, ...]
4646
# v1 shape: [B, ...]
47-
dtype = v0.dtype
4847
# Assume first dim is a batch dim and all other dims are channel or "spatial" dims
4948
all_dims_but_first = list(range(1, len(v0.shape)))
50-
v0, v1 = v0.double(), v1.double()
49+
if upcast_to_double:
50+
dtype = v0.dtype
51+
v0, v1 = v0.double(), v1.double()
5152
v1 = torch.nn.functional.normalize(v1, dim=all_dims_but_first)
5253
v0_parallel = (v0 * v1).sum(dim=all_dims_but_first, keepdim=True) * v1
5354
v0_orthogonal = v0 - v0_parallel
54-
return v0_parallel.to(dtype), v0_orthogonal.to(dtype)
55+
if upcast_to_double:
56+
v0_parallel = v0_parallel.to(dtype)
57+
v0_orthogonal = v0_orthogonal.to(dtype)
58+
return v0_parallel, v0_orthogonal
5559

5660

5761
def build_image_from_pyramid(pyramid: List[torch.Tensor]) -> torch.Tensor:
@@ -127,6 +131,9 @@ class FrequencyDecoupledGuidance(BaseGuidance):
127131
`"freq"` space (right after the CFG update, for each freq level). Note that frequency space rescaling is
128132
speculative and may not produce expected results. If `"data"` is set, the first `guidance_rescale` value
129133
will be used; otherwise, per-frequency-level guidance rescale values will be used if available.
134+
upcast_to_double (`bool`, defaults to `True`):
135+
Whether to upcast certain operations, such as the projection operation when using `parallel_weights`, to
136+
float64 when performing guidance. This may result in better performance at the cost of increased runtime.
130137
"""
131138

132139
_input_predictions = ["pred_cond", "pred_uncond"]
@@ -141,6 +148,7 @@ def __init__(
141148
start: Union[float, List[float], Tuple[float]] = 0.0,
142149
stop: Union[float, List[float], Tuple[float]] = 1.0,
143150
guidance_rescale_space: str = "data",
151+
upcast_to_double: bool = True,
144152
):
145153
if not _CAN_USE_KORNIA:
146154
raise ImportError(
@@ -188,6 +196,7 @@ def __init__(
188196
)
189197

190198
self.use_original_formulation = use_original_formulation
199+
self.upcast_to_double = upcast_to_double
191200

192201
if isinstance(start, float):
193202
self.guidance_start = [start] * self.levels
@@ -244,7 +253,7 @@ def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] =
244253

245254
# Apply parallel weights, if used (1.0 corresponds to using the normal CFG shift)
246255
if not math.isclose(parallel_weight, 1.0):
247-
shift_parallel, shift_orthogonal = project(shift, pred_cond_freq)
256+
shift_parallel, shift_orthogonal = project(shift, pred_cond_freq, self.upcast_to_double)
248257
shift = parallel_weight * shift_parallel + shift_orthogonal
249258

250259
# Apply CFG update for the current frequency level

0 commit comments

Comments
 (0)