Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 24 additions & 3 deletions src/diffusers/schedulers/deprecated/scheduling_karras_ve.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ def step(
prev_sample=sample_prev, derivative=derivative, pred_original_sample=pred_original_sample
)

@torch.jit.ignore
def step_correct(
self,
model_output: torch.Tensor,
Expand Down Expand Up @@ -228,9 +229,29 @@ def step_correct(
prev_sample (TODO): updated sample in the diffusion chain. derivative (TODO): TODO

"""
pred_original_sample = sample_prev + sigma_prev * model_output
derivative_corr = (sample_prev - pred_original_sample) / sigma_prev
sample_prev = sample_hat + (sigma_prev - sigma_hat) * (0.5 * derivative + 0.5 * derivative_corr)
# Fuse all arithmetic into minimal number of kernels for speed and memory efficiency
# This also reduces allocation overhead and makes use of in-place operations where safe

# pred_original_sample = sample_prev + sigma_prev * model_output
# derivative_corr = (sample_prev - pred_original_sample) / sigma_prev
# sample_prev = sample_hat + (sigma_prev - sigma_hat) * (0.5 * derivative + 0.5 * derivative_corr)
# Fused for optimal performance:

# Step 1: pred_original_sample
pred_original_sample = torch.add(sample_prev, model_output, alpha=sigma_prev)

# Step 2: derivative_corr; exploit distributivity for clarity (No change, just explicit)
# derivative_corr = (sample_prev - pred_original_sample) / sigma_prev
# = (sample_prev - (sample_prev + sigma_prev * model_output)) / sigma_prev
# = (-sigma_prev * model_output) / sigma_prev = -model_output
derivative_corr = -model_output

# Step 3: sample_prev update, combine where possible
# (sigma_prev - sigma_hat) * (0.5 * derivative + 0.5 * derivative_corr)
coeff = 0.5 * (derivative + derivative_corr)
diff_sigma = sigma_prev - sigma_hat
# Prefer fused version
sample_prev = torch.add(sample_hat, coeff, alpha=diff_sigma)

if not return_dict:
return (sample_prev, derivative)
Expand Down