Skip to content

Commit d06c6bc

Browse files
author
toilaluan
committed
fix taylor precision
1 parent 309ce72 commit d06c6bc

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

src/diffusers/hooks/taylorseer_cache.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -137,8 +137,7 @@ def update(
137137
self.inactive_shapes = tuple(output.shape for output in outputs)
138138
else:
139139
self.taylor_factors = {}
140-
for i, output in enumerate(outputs):
141-
features = output.to(self.taylor_factors_dtype)
140+
for i, features in enumerate(outputs):
142141
new_factors: Dict[int, torch.Tensor] = {0: features}
143142
is_first_update = self.last_update_step is None
144143
if not is_first_update:
@@ -152,8 +151,8 @@ def update(
152151
prev = prev_factors.get(j)
153152
if prev is None:
154153
break
155-
new_factors[j + 1] = (new_factors[j] - prev.to(self.taylor_factors_dtype)) / delta_step
156-
self.taylor_factors[i] = new_factors
154+
new_factors[j + 1] = (new_factors[j] - prev.to(features.dtype)) / delta_step
155+
self.taylor_factors[i] = {order: factor.to(self.taylor_factors_dtype) for order, factor in new_factors.items()}
157156

158157
self.last_update_step = self.current_step
159158

@@ -179,14 +178,15 @@ def predict(self) -> List[torch.Tensor]:
179178
if not self.taylor_factors:
180179
raise ValueError("Taylor factors empty during prediction.")
181180
for i in range(len(self.module_dtypes)):
181+
output_dtype = self.module_dtypes[i]
182182
taylor_factors = self.taylor_factors[i]
183183
# Accumulate Taylor series: f(t0 + Δt) ≈ Σ f^{(n)}(t0) * (Δt^n / n!)
184-
output = torch.zeros_like(taylor_factors[0])
184+
output = torch.zeros_like(taylor_factors[0], dtype=output_dtype)
185185
for order, factor in taylor_factors.items():
186186
# Note: order starts at 0
187187
coeff = (step_offset**order) / math.factorial(order)
188-
output = output + factor * coeff
189-
outputs.append(output.to(self.module_dtypes[i]))
188+
output = output + factor.to(output_dtype) * coeff
189+
outputs.append(output)
190190

191191
return outputs
192192

0 commit comments

Comments
 (0)