Skip to content

Commit 3758abc

Browse files
authored
Implement DPM-Solver++ scheduler (#59)
1 parent d6a54fc commit 3758abc

File tree

4 files changed

+291
-51
lines changed

4 files changed

+291
-51
lines changed
Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
// For licensing see accompanying LICENSE.md file.
2+
// Copyright (C) 2022 Apple Inc. and The HuggingFace Team. All Rights Reserved.
3+
4+
import Accelerate
5+
import CoreML
6+
7+
/// A scheduler used to compute a de-noised image
8+
///
9+
/// This implementation matches:
10+
/// [Hugging Face Diffusers DPMSolverMultistepScheduler](https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py)
11+
///
12+
/// It uses the DPM-Solver++ algorithm: [code](https://github.com/LuChengTHU/dpm-solver) [paper](https://arxiv.org/abs/2211.01095).
13+
/// Limitations:
14+
/// - Only implemented for DPM-Solver++ algorithm (not DPM-Solver).
15+
/// - Second order only.
16+
/// - Assumes the model predicts epsilon.
17+
/// - No dynamic thresholding.
18+
/// - `midpoint` solver algorithm.
19+
@available(iOS 16.2, macOS 13.1, *)
20+
public final class DPMSolverMultistepScheduler: Scheduler {
21+
public let trainStepCount: Int
22+
public let inferenceStepCount: Int
23+
public let betas: [Float]
24+
public let alphas: [Float]
25+
public let alphasCumProd: [Float]
26+
public let timeSteps: [Int]
27+
28+
public let alpha_t: [Float]
29+
public let sigma_t: [Float]
30+
public let lambda_t: [Float]
31+
32+
public let solverOrder = 2
33+
private(set) var lowerOrderStepped = 0
34+
35+
/// Whether to use lower-order solvers in the final steps. Only valid for less than 15 inference steps.
36+
/// We empirically find this trick can stabilize the sampling of DPM-Solver, especially with 10 or fewer steps.
37+
public let useLowerOrderFinal = true
38+
39+
// Stores solverOrder (2) items
40+
private(set) var modelOutputs: [MLShapedArray<Float32>] = []
41+
42+
/// Create a scheduler that uses a second order DPM-Solver++ algorithm.
43+
///
44+
/// - Parameters:
45+
/// - stepCount: Number of inference steps to schedule
46+
/// - trainStepCount: Number of training diffusion steps
47+
/// - betaSchedule: Method to schedule betas from betaStart to betaEnd
48+
/// - betaStart: The starting value of beta for inference
49+
/// - betaEnd: The end value for beta for inference
50+
/// - Returns: A scheduler ready for its first step
51+
public init(
52+
stepCount: Int = 50,
53+
trainStepCount: Int = 1000,
54+
betaSchedule: BetaSchedule = .scaledLinear,
55+
betaStart: Float = 0.00085,
56+
betaEnd: Float = 0.012
57+
) {
58+
self.trainStepCount = trainStepCount
59+
self.inferenceStepCount = stepCount
60+
61+
switch betaSchedule {
62+
case .linear:
63+
self.betas = linspace(betaStart, betaEnd, trainStepCount)
64+
case .scaledLinear:
65+
self.betas = linspace(pow(betaStart, 0.5), pow(betaEnd, 0.5), trainStepCount).map({ $0 * $0 })
66+
}
67+
68+
self.alphas = betas.map({ 1.0 - $0 })
69+
var alphasCumProd = self.alphas
70+
for i in 1..<alphasCumProd.count {
71+
alphasCumProd[i] *= alphasCumProd[i - 1]
72+
}
73+
self.alphasCumProd = alphasCumProd
74+
75+
// Currently we only support VP-type noise shedule
76+
self.alpha_t = vForce.sqrt(self.alphasCumProd)
77+
self.sigma_t = vForce.sqrt(vDSP.subtract([Float](repeating: 1, count: self.alphasCumProd.count), self.alphasCumProd))
78+
self.lambda_t = zip(self.alpha_t, self.sigma_t).map { α, σ in log(α) - log(σ) }
79+
80+
self.timeSteps = linspace(0, Float(self.trainStepCount-1), stepCount).reversed().map { Int(round($0)) }
81+
}
82+
83+
/// Convert the model output to the corresponding type the algorithm needs.
84+
/// This implementation is for second-order DPM-Solver++ assuming epsilon prediction.
85+
func convertModelOutput(modelOutput: MLShapedArray<Float32>, timestep: Int, sample: MLShapedArray<Float32>) -> MLShapedArray<Float32> {
86+
assert(modelOutput.scalars.count == sample.scalars.count)
87+
let (alpha_t, sigma_t) = (self.alpha_t[timestep], self.sigma_t[timestep])
88+
89+
// This could be optimized with a Metal kernel if we find we need to
90+
let x0_scalars = zip(modelOutput.scalars, sample.scalars).map { m, s in
91+
(s - m * sigma_t) / alpha_t
92+
}
93+
return MLShapedArray(scalars: x0_scalars, shape: modelOutput.shape)
94+
}
95+
96+
/// One step for the first-order DPM-Solver (equivalent to DDIM).
97+
/// See https://arxiv.org/abs/2206.00927 for the detailed derivation.
98+
/// var names and code structure mostly follow https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
99+
func firstOrderUpdate(
100+
modelOutput: MLShapedArray<Float32>,
101+
timestep: Int,
102+
prevTimestep: Int,
103+
sample: MLShapedArray<Float32>
104+
) -> MLShapedArray<Float32> {
105+
let (p_lambda_t, lambda_s) = (Double(lambda_t[prevTimestep]), Double(lambda_t[timestep]))
106+
let p_alpha_t = Double(alpha_t[prevTimestep])
107+
let (p_sigma_t, sigma_s) = (Double(sigma_t[prevTimestep]), Double(sigma_t[timestep]))
108+
let h = p_lambda_t - lambda_s
109+
// x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output
110+
let x_t = weightedSum(
111+
[p_sigma_t / sigma_s, -p_alpha_t * (exp(-h) - 1)],
112+
[sample, modelOutput]
113+
)
114+
return x_t
115+
}
116+
117+
/// One step for the second-order multistep DPM-Solver++ algorithm, using the midpoint method.
118+
/// var names and code structure mostly follow https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
119+
func secondOrderUpdate(
120+
modelOutputs: [MLShapedArray<Float32>],
121+
timesteps: [Int],
122+
prevTimestep t: Int,
123+
sample: MLShapedArray<Float32>
124+
) -> MLShapedArray<Float32> {
125+
let (s0, s1) = (timesteps[back: 1], timesteps[back: 2])
126+
let (m0, m1) = (modelOutputs[back: 1], modelOutputs[back: 2])
127+
let (p_lambda_t, lambda_s0, lambda_s1) = (Double(lambda_t[t]), Double(lambda_t[s0]), Double(lambda_t[s1]))
128+
let p_alpha_t = Double(alpha_t[t])
129+
let (p_sigma_t, sigma_s0) = (Double(sigma_t[t]), Double(sigma_t[s0]))
130+
let (h, h_0) = (p_lambda_t - lambda_s0, lambda_s0 - lambda_s1)
131+
let r0 = h_0 / h
132+
let D0 = m0
133+
134+
// D1 = (1.0 / r0) * (m0 - m1)
135+
let D1 = weightedSum(
136+
[1/r0, -1/r0],
137+
[m0, m1]
138+
)
139+
140+
// See https://arxiv.org/abs/2211.01095 for detailed derivations
141+
// x_t = (
142+
// (sigma_t / sigma_s0) * sample
143+
// - (alpha_t * (torch.exp(-h) - 1.0)) * D0
144+
// - 0.5 * (alpha_t * (torch.exp(-h) - 1.0)) * D1
145+
// )
146+
let x_t = weightedSum(
147+
[p_sigma_t/sigma_s0, -p_alpha_t * (exp(-h) - 1), -0.5 * p_alpha_t * (exp(-h) - 1)],
148+
[sample, D0, D1]
149+
)
150+
return x_t
151+
}
152+
153+
public func step(output: MLShapedArray<Float32>, timeStep t: Int, sample: MLShapedArray<Float32>) -> MLShapedArray<Float32> {
154+
let stepIndex = timeSteps.firstIndex(of: t) ?? timeSteps.count - 1
155+
let prevTimestep = stepIndex == timeSteps.count - 1 ? 0 : timeSteps[stepIndex + 1]
156+
157+
let lowerOrderFinal = useLowerOrderFinal && stepIndex == timeSteps.count - 1 && timeSteps.count < 15
158+
let lowerOrderSecond = useLowerOrderFinal && stepIndex == timeSteps.count - 2 && timeSteps.count < 15
159+
let lowerOrder = lowerOrderStepped < 1 || lowerOrderFinal || lowerOrderSecond
160+
161+
let modelOutput = convertModelOutput(modelOutput: output, timestep: t, sample: sample)
162+
if modelOutputs.count == solverOrder { modelOutputs.removeFirst() }
163+
modelOutputs.append(modelOutput)
164+
165+
let prevSample: MLShapedArray<Float32>
166+
if lowerOrder {
167+
prevSample = firstOrderUpdate(modelOutput: modelOutput, timestep: t, prevTimestep: prevTimestep, sample: sample)
168+
} else {
169+
prevSample = secondOrderUpdate(
170+
modelOutputs: modelOutputs,
171+
timesteps: [timeSteps[stepIndex - 1], t],
172+
prevTimestep: prevTimestep,
173+
sample: sample
174+
)
175+
}
176+
if lowerOrderStepped < solverOrder {
177+
lowerOrderStepped += 1
178+
}
179+
180+
return prevSample
181+
}
182+
}

swift/StableDiffusion/pipeline/Scheduler.swift

Lines changed: 78 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -3,34 +3,98 @@
33

44
import CoreML
55

6-
/// A scheduler used to compute a de-noised image
7-
///
8-
/// This implementation matches:
9-
/// [Hugging Face Diffusers PNDMScheduler](https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_pndm.py)
10-
///
11-
/// It uses the pseudo linear multi-step (PLMS) method only, skipping pseudo Runge-Kutta (PRK) steps
126
@available(iOS 16.2, macOS 13.1, *)
13-
public final class Scheduler {
7+
public protocol Scheduler {
148
/// Number of diffusion steps performed during training
15-
public let trainStepCount: Int
9+
var trainStepCount: Int { get }
1610

1711
/// Number of inference steps to be performed
18-
public let inferenceStepCount: Int
12+
var inferenceStepCount: Int { get }
1913

2014
/// Training diffusion time steps index by inference time step
21-
public let timeSteps: [Int]
15+
var timeSteps: [Int] { get }
2216

2317
/// Schedule of betas which controls the amount of noise added at each timestep
24-
public let betas: [Float]
18+
var betas: [Float] { get }
2519

2620
/// 1 - betas
27-
let alphas: [Float]
21+
var alphas: [Float] { get }
2822

2923
/// Cached cumulative product of alphas
30-
let alphasCumProd: [Float]
24+
var alphasCumProd: [Float] { get }
3125

3226
/// Standard deviation of the initial noise distribution
33-
public let initNoiseSigma: Float
27+
var initNoiseSigma: Float { get }
28+
29+
/// Compute a de-noised image sample and step scheduler state
30+
///
31+
/// - Parameters:
32+
/// - output: The predicted residual noise output of learned diffusion model
33+
/// - timeStep: The current time step in the diffusion chain
34+
/// - sample: The current input sample to the diffusion model
35+
/// - Returns: Predicted de-noised sample at the previous time step
36+
/// - Postcondition: The scheduler state is updated.
37+
/// The state holds the current sample and history of model output noise residuals
38+
func step(
39+
output: MLShapedArray<Float32>,
40+
timeStep t: Int,
41+
sample s: MLShapedArray<Float32>
42+
) -> MLShapedArray<Float32>
43+
}
44+
45+
@available(iOS 16.2, macOS 13.1, *)
46+
public extension Scheduler {
47+
var initNoiseSigma: Float { 1 }
48+
}
49+
50+
@available(iOS 16.2, macOS 13.1, *)
51+
public extension Scheduler {
52+
/// Compute weighted sum of shaped arrays of equal shapes
53+
///
54+
/// - Parameters:
55+
/// - weights: The weights each array is multiplied by
56+
/// - values: The arrays to be weighted and summed
57+
/// - Returns: sum_i weights[i]*values[i]
58+
func weightedSum(_ weights: [Double], _ values: [MLShapedArray<Float32>]) -> MLShapedArray<Float32> {
59+
assert(weights.count > 1 && values.count == weights.count)
60+
assert(values.allSatisfy({ $0.scalarCount == values.first!.scalarCount }))
61+
var w = Float(weights.first!)
62+
var scalars = values.first!.scalars.map({ $0 * w })
63+
for next in 1 ..< values.count {
64+
w = Float(weights[next])
65+
let nextScalars = values[next].scalars
66+
for i in 0 ..< scalars.count {
67+
scalars[i] += w * nextScalars[i]
68+
}
69+
}
70+
return MLShapedArray(scalars: scalars, shape: values.first!.shape)
71+
}
72+
}
73+
74+
/// How to map a beta range to a sequence of betas to step over
75+
@available(iOS 16.2, macOS 13.1, *)
76+
public enum BetaSchedule {
77+
/// Linear stepping between start and end
78+
case linear
79+
/// Steps using linspace(sqrt(start),sqrt(end))^2
80+
case scaledLinear
81+
}
82+
83+
84+
/// A scheduler used to compute a de-noised image
85+
///
86+
/// This implementation matches:
87+
/// [Hugging Face Diffusers PNDMScheduler](https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_pndm.py)
88+
///
89+
/// This scheduler uses the pseudo linear multi-step (PLMS) method only, skipping pseudo Runge-Kutta (PRK) steps
90+
@available(iOS 16.2, macOS 13.1, *)
91+
public final class PNDMScheduler: Scheduler {
92+
public let trainStepCount: Int
93+
public let inferenceStepCount: Int
94+
public let betas: [Float]
95+
public let alphas: [Float]
96+
public let alphasCumProd: [Float]
97+
public let timeSteps: [Int]
3498

3599
// Internal state
36100
var counter: Int
@@ -62,15 +126,12 @@ public final class Scheduler {
62126
case .scaledLinear:
63127
self.betas = linspace(pow(betaStart, 0.5), pow(betaEnd, 0.5), trainStepCount).map({ $0 * $0 })
64128
}
65-
66129
self.alphas = betas.map({ 1.0 - $0 })
67-
self.initNoiseSigma = 1.0
68130
var alphasCumProd = self.alphas
69131
for i in 1..<alphasCumProd.count {
70132
alphasCumProd[i] *= alphasCumProd[i - 1]
71133
}
72134
self.alphasCumProd = alphasCumProd
73-
74135
let stepsOffset = 1 // For stable diffusion
75136
let stepRatio = Float(trainStepCount / stepCount )
76137
let forwardSteps = (0..<stepCount).map {
@@ -152,27 +213,6 @@ public final class Scheduler {
152213
return prevSample
153214
}
154215

155-
/// Compute weighted sum of shaped arrays of equal shapes
156-
///
157-
/// - Parameters:
158-
/// - weights: The weights each array is multiplied by
159-
/// - values: The arrays to be weighted and summed
160-
/// - Returns: sum_i weights[i]*values[i]
161-
func weightedSum(_ weights: [Double], _ values: [MLShapedArray<Float32>]) -> MLShapedArray<Float32> {
162-
assert(weights.count > 1 && values.count == weights.count)
163-
assert(values.allSatisfy({$0.scalarCount == values.first!.scalarCount}))
164-
var w = Float(weights.first!)
165-
var scalars = values.first!.scalars.map({ $0 * w })
166-
for next in 1 ..< values.count {
167-
w = Float(weights[next])
168-
let nextScalars = values[next].scalars
169-
for i in 0 ..< scalars.count {
170-
scalars[i] += w * nextScalars[i]
171-
}
172-
}
173-
return MLShapedArray(scalars: scalars, shape: values.first!.shape)
174-
}
175-
176216
/// Compute sample (denoised image) at previous step given a current time step
177217
///
178218
/// - Parameters:
@@ -225,17 +265,6 @@ public final class Scheduler {
225265
}
226266
}
227267

228-
@available(iOS 16.2, macOS 13.1, *)
229-
extension Scheduler {
230-
/// How to map a beta range to a sequence of betas to step over
231-
public enum BetaSchedule {
232-
/// Linear stepping between start and end
233-
case linear
234-
/// Steps using linspace(sqrt(start),sqrt(end))^2
235-
case scaledLinear
236-
}
237-
}
238-
239268
/// Evenly spaced floats between specified interval
240269
///
241270
/// - Parameters:

swift/StableDiffusion/pipeline/StableDiffusionPipeline.swift

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,14 @@ import CoreML
66
import Accelerate
77
import CoreGraphics
88

9+
/// Schedulers compatible with StableDiffusionPipeline
10+
public enum StableDiffusionScheduler {
11+
/// Scheduler that uses a pseudo-linear multi-step (PLMS) method
12+
case pndmScheduler
13+
/// Scheduler that uses a second order DPM-Solver++ algorithm
14+
case dpmSolverMultistepScheduler
15+
}
16+
917
/// A pipeline used to generate image samples from text input using stable diffusion
1018
///
1119
/// This implementation matches:
@@ -113,6 +121,7 @@ public struct StableDiffusionPipeline: ResourceManaging {
113121
stepCount: Int = 50,
114122
seed: Int = 0,
115123
disableSafety: Bool = false,
124+
scheduler: StableDiffusionScheduler = .pndmScheduler,
116125
progressHandler: (Progress) -> Bool = { _ in true }
117126
) throws -> [CGImage?] {
118127

@@ -133,7 +142,12 @@ public struct StableDiffusionPipeline: ResourceManaging {
133142
let hiddenStates = toHiddenStates(concatEmbedding)
134143

135144
/// Setup schedulers
136-
let scheduler = (0..<imageCount).map { _ in Scheduler(stepCount: stepCount) }
145+
let scheduler: [Scheduler] = (0..<imageCount).map { _ in
146+
switch scheduler {
147+
case .pndmScheduler: return PNDMScheduler(stepCount: stepCount)
148+
case .dpmSolverMultistepScheduler: return DPMSolverMultistepScheduler(stepCount: stepCount)
149+
}
150+
}
137151
let stdev = scheduler[0].initNoiseSigma
138152

139153
// Generate random latent samples from specified seed

0 commit comments

Comments
 (0)