Skip to content

Commit ddefb61

Browse files
authored
Add random source that matches PyTorch (#124)
* Add random source that matches PyTorch This added random source that matches PyTorch on CPU. In particular, it matches: `torch.randn([], dtype=torch.float)` result. PyTorch's RNG is a bit convoluted and not claimed to be version-stable (will open a separate issue in PyTorch repo on this). However, the current implementation on CPU is fairly straightforward^*. 1. If it is less than 16 elements, it uses Gaussian distribution sampled from MT19937 for double + Box-Muller transformation. 2. If it is more than 16 (16 included), it first do uniform sampling with whatever the resulting data type would be (in this case, torch.float), and then apply Box-Muller transformation over 16-element segment at a type, treating the first floating-point and the 8th as a pair, so on so forth. 3. If it is not a multiple of 16, trace back from the end for 16 elements and redo step 2. * Update with configuration available in SwiftDiffusionCLI * Fix the RNG is not passed into pipelineConfig.
1 parent 00390a6 commit ddefb61

File tree

6 files changed

+199
-8
lines changed

6 files changed

+199
-8
lines changed

swift/StableDiffusion/pipeline/Random.swift renamed to swift/StableDiffusion/pipeline/NumPyRandomSource.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ import CoreML
1010
/// [NumPy's older randomkit.c](https://github.com/numpy/numpy/blob/v1.0/numpy/random/mtrand/randomkit.c)
1111
///
1212
@available(iOS 16.2, macOS 13.1, *)
13-
struct NumPyRandomSource: RandomNumberGenerator {
13+
struct NumPyRandomSource: RandomNumberGenerator, RandomSource {
1414

1515
struct State {
1616
var key = [UInt32](repeating: 0, count: 624)
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
import CoreML
2+
3+
@available(iOS 16.2, macOS 13.1, *)
4+
public protocol RandomSource {
5+
mutating func normalShapedArray(_ shape: [Int], mean: Double, stdev: Double) -> MLShapedArray<Double>
6+
}

swift/StableDiffusion/pipeline/StableDiffusionPipeline.Configuration.swift

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ extension StableDiffusionPipeline {
3737
public var disableSafety: Bool = false
3838
/// The type of Scheduler to use.
3939
public var schedulerType: StableDiffusionScheduler = .pndmScheduler
40+
/// The type of RNG to use
41+
public var rngType: StableDiffusionRNG = .numpyRNG
4042

4143
/// Given the configuration, what mode will be used for generation
4244
public var mode: Mode {

swift/StableDiffusion/pipeline/StableDiffusionPipeline.swift

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,14 @@ public enum StableDiffusionScheduler {
1414
case dpmSolverMultistepScheduler
1515
}
1616

17+
/// RNG compatible with StableDiffusionPipeline
18+
public enum StableDiffusionRNG {
19+
/// RNG that matches numpy implementation
20+
case numpyRNG
21+
/// RNG that matches PyTorch CPU implementation.
22+
case torchRNG
23+
}
24+
1725
/// A pipeline used to generate image samples from text input using stable diffusion
1826
///
1927
/// This implementation matches:
@@ -157,7 +165,7 @@ public struct StableDiffusionPipeline: ResourceManaging {
157165
throw Error.startingImageProvidedWithoutEncoder
158166
}
159167

160-
let noiseTuples = generateImage2ImageLatentSamples(config.imageCount, stdev: 1, seed: config.seed)
168+
let noiseTuples = generateImage2ImageLatentSamples(config.imageCount, rng: config.rngType, stdev: 1, seed: config.seed)
161169
latents = try noiseTuples.map({
162170
try encoder.encode(
163171
image: startingImage,
@@ -168,7 +176,7 @@ public struct StableDiffusionPipeline: ResourceManaging {
168176
} else {
169177
timestepStrength = nil
170178
// Generate random latent samples from specified seed
171-
latents = generateLatentSamples(config.imageCount, stdev: stdev, seed: config.seed)
179+
latents = generateLatentSamples(config.imageCount, rng: config.rngType, stdev: stdev, seed: config.seed)
172180
}
173181

174182
// De-noising loop
@@ -224,11 +232,19 @@ public struct StableDiffusionPipeline: ResourceManaging {
224232
return try decodeToImages(latents, disableSafety: config.disableSafety)
225233
}
226234

227-
func generateLatentSamples(_ count: Int, stdev: Float, seed: UInt32) -> [MLShapedArray<Float32>] {
235+
private func randomSource(from rng: StableDiffusionRNG, seed: UInt32) -> RandomSource {
236+
switch rng {
237+
case .numpyRNG:
238+
return NumPyRandomSource(seed: seed)
239+
case .torchRNG:
240+
return TorchRandomSource(seed: seed)
241+
}
242+
}
243+
244+
func generateLatentSamples(_ count: Int, rng: StableDiffusionRNG, stdev: Float, seed: UInt32) -> [MLShapedArray<Float32>] {
228245
var sampleShape = unet.latentSampleShape
229246
sampleShape[0] = 1
230-
231-
var random = NumPyRandomSource(seed: seed)
247+
var random = randomSource(from: rng, seed: seed)
232248
let samples = (0..<count).map { _ in
233249
MLShapedArray<Float32>(
234250
converting: random.normalShapedArray(sampleShape, mean: 0.0, stdev: Double(stdev)))
@@ -245,11 +261,11 @@ public struct StableDiffusionPipeline: ResourceManaging {
245261
/// - diagonalAndLatentNoiseIsSame: Diffusions library does not seem to use the same noise for the `DiagonalGaussianDistribution` operation,
246262
/// but I have seen implementations of pipelines where it is the same.
247263
/// - Returns: An array of tuples of noise values with length of batch size.
248-
func generateImage2ImageLatentSamples(_ count: Int, stdev: Float, seed: UInt32, diagonalAndLatentNoiseIsSame: Bool = false) -> [(diagonal: MLShapedArray<Float32>, latentNoise: MLShapedArray<Float32>)] {
264+
func generateImage2ImageLatentSamples(_ count: Int, rng: StableDiffusionRNG, stdev: Float, seed: UInt32, diagonalAndLatentNoiseIsSame: Bool = false) -> [(diagonal: MLShapedArray<Float32>, latentNoise: MLShapedArray<Float32>)] {
249265
var sampleShape = unet.latentSampleShape
250266
sampleShape[0] = 1
251267

252-
var random = NumPyRandomSource(seed: UInt32(truncatingIfNeeded: seed))
268+
var random = randomSource(from: rng, seed: seed)
253269
let samples = (0..<count).map { _ in
254270
if diagonalAndLatentNoiseIsSame {
255271
let noise = MLShapedArray<Float32>(
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
// For licensing see accompanying LICENSE.md file.
2+
// Copyright (C) 2022 Apple Inc. All Rights Reserved.
3+
4+
import Foundation
5+
import CoreML
6+
7+
/// A random source consistent with PyTorch
8+
///
9+
/// This implementation matches:
10+
/// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/core/DistributionsHelper.h
11+
/// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cpu/DistributionTemplates.h
12+
/// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cpu/DistributionKernels.cpp
13+
/// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/core/TransformationHelper.h
14+
///
15+
@available(iOS 16.2, macOS 13.1, *)
16+
struct TorchRandomSource: RandomNumberGenerator, RandomSource {
17+
18+
struct State {
19+
var key = [UInt32](repeating: 0, count: 624)
20+
var pos: Int = 0
21+
var nextGauss: Double? = nil
22+
}
23+
24+
var state: State
25+
26+
/// Initialize with a random seed
27+
///
28+
/// - Parameters
29+
/// - seed: Seed for underlying Mersenne Twister 19937 generator
30+
/// - Returns random source
31+
init(seed: UInt32) {
32+
state = .init()
33+
var s = seed & 0xffff_ffff
34+
for i in 0..<state.key.count {
35+
state.key[i] = s
36+
s = UInt32((UInt64(1_812_433_253) * UInt64(s ^ (s >> 30)) + UInt64(i) + 1) & 0xffff_ffff)
37+
}
38+
state.pos = state.key.count
39+
state.nextGauss = nil
40+
}
41+
42+
/// Generate next UInt32 using fast 32bit Mersenne Twister
43+
mutating func nextUInt32() -> UInt32 {
44+
let n = 624
45+
let m = 397
46+
let matrixA: UInt64 = 0x9908_b0df
47+
let upperMask: UInt32 = 0x8000_0000
48+
let lowerMask: UInt32 = 0x7fff_ffff
49+
50+
var y: UInt32
51+
if state.pos == state.key.count {
52+
for i in 0..<(n - m) {
53+
y = (state.key[i] & upperMask) | (state.key[i + 1] & lowerMask)
54+
state.key[i] = state.key[i + m] ^ (y >> 1) ^ UInt32((UInt64(~(y & 1)) + 1) & matrixA)
55+
}
56+
for i in (n - m)..<(n - 1) {
57+
y = (state.key[i] & upperMask) | (state.key[i + 1] & lowerMask)
58+
state.key[i] = state.key[i + (m - n)] ^ (y >> 1) ^ UInt32((UInt64(~(y & 1)) + 1) & matrixA)
59+
}
60+
y = (state.key[n - 1] & upperMask) | (state.key[0] & lowerMask)
61+
state.key[n - 1] = state.key[m - 1] ^ (y >> 1) ^ UInt32((UInt64(~(y & 1)) + 1) & matrixA)
62+
state.pos = 0
63+
}
64+
y = state.key[state.pos]
65+
state.pos += 1
66+
67+
y ^= (y >> 11)
68+
y ^= (y << 7) & 0x9d2c_5680
69+
y ^= (y << 15) & 0xefc6_0000
70+
y ^= (y >> 18)
71+
72+
return y
73+
}
74+
75+
mutating func next() -> UInt64 {
76+
let high = nextUInt32()
77+
let low = nextUInt32()
78+
return (UInt64(high) << 32) | UInt64(low)
79+
}
80+
81+
/// Generate next random double value
82+
mutating func nextDouble() -> Double {
83+
let a = next()
84+
return Double(a & 9_007_199_254_740_991) * (1.0 / 9007199254740992.0)
85+
}
86+
87+
/// Generate next random float value
88+
mutating func nextFloat() -> Float {
89+
let a = nextUInt32()
90+
return Float(a & 16_777_215) * (1.0 / 16777216.0)
91+
}
92+
93+
/// Generate next random value from a standard normal
94+
mutating func nextGauss() -> Double {
95+
if let nextGauss = state.nextGauss {
96+
state.nextGauss = nil
97+
return nextGauss
98+
}
99+
// Box-Muller transform
100+
let u1: Double = nextDouble()
101+
let u2: Double = 1 - nextDouble()
102+
let radius = sqrt(-2.0 * log(u2))
103+
let theta = 2.0 * .pi * u1
104+
state.nextGauss = radius * sin(theta)
105+
return radius * cos(theta)
106+
}
107+
108+
/// Generates an array of random values from a normal distribution with given mean and standard deviation.
109+
/// This simulates torch.randn([1, 4, 64, 64], dtype=torch.float), note that for dtype=torch.double, it
110+
/// will be slightly different.
111+
mutating func normalArray(count: Int, mean: Double = 0.0, stdev: Double = 1.0) -> [Double] {
112+
// If it is smaller than 16 elements, Torch generates from Box-Muller transform directly.
113+
// Note that even if this is used to generate Float, it will use Double underneath.
114+
guard count >= 16 else {
115+
return (0..<count).map { _ in nextGauss() * stdev + mean }
116+
}
117+
// Otherwise, Torch first fill a uniform distribution array, then do Box-Muller
118+
// transformation over this array.
119+
var data = (0..<count).map { _ in Double(nextFloat()) }
120+
for i in stride(from: 0, to: count - 15, by: 16) {
121+
for j in 0..<8 {
122+
let u1 = 1 - data[i + j]
123+
let u2 = data[i + j + 8]
124+
let radius = sqrt(-2.0 * log(u1))
125+
let theta = 2.0 * .pi * u2
126+
data[i + j] = radius * cos(theta) * stdev + mean
127+
data[i + j + 8] = radius * sin(theta) * stdev + mean
128+
}
129+
}
130+
if count % 16 != 0 {
131+
for i in (count - 16)..<count {
132+
data[i] = nextDouble()
133+
}
134+
let i = count - 16
135+
for j in 0..<8 {
136+
let u1 = 1 - data[i + j]
137+
let u2 = data[i + j + 8]
138+
let radius = sqrt(-2.0 * log(u1))
139+
let theta = 2.0 * .pi * u2
140+
data[i + j] = radius * cos(theta) * stdev + mean
141+
data[i + j + 8] = radius * sin(theta) * stdev + mean
142+
}
143+
}
144+
return data
145+
}
146+
147+
/// Generate a shaped array with scalars from a normal distribution with given mean and standard deviation.
148+
mutating func normalShapedArray(_ shape: [Int], mean: Double = 0.0, stdev: Double = 1.0) -> MLShapedArray<Double> {
149+
let count = shape.reduce(1, *)
150+
return .init(scalars: normalArray(count: count, mean: mean, stdev: stdev), shape: shape)
151+
}
152+
}

swift/StableDiffusionCLI/main.swift

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,9 @@ struct StableDiffusionSample: ParsableCommand {
6969
@Option(help: "Scheduler to use, one of {pndm, dpmpp}")
7070
var scheduler: SchedulerOption = .pndm
7171

72+
@Option(help: "Random number generator to use, one of {numpy, torch}")
73+
var rng: RNGOption = .numpy
74+
7275
@Flag(help: "Disable safety checking")
7376
var disableSafety: Bool = false
7477

@@ -126,6 +129,7 @@ struct StableDiffusionSample: ParsableCommand {
126129
pipelineConfig.seed = seed
127130
pipelineConfig.guidanceScale = guidanceScale
128131
pipelineConfig.schedulerType = scheduler.stableDiffusionScheduler
132+
pipelineConfig.rngType = rng.stableDiffusionRNG
129133

130134
let images = try pipeline.generateImages(
131135
configuration: pipelineConfig,
@@ -250,6 +254,17 @@ enum SchedulerOption: String, ExpressibleByArgument {
250254
}
251255
}
252256

257+
@available(iOS 16.2, macOS 13.1, *)
258+
enum RNGOption: String, ExpressibleByArgument {
259+
case numpy, torch
260+
var stableDiffusionRNG: StableDiffusionRNG {
261+
switch self {
262+
case .numpy: return .numpyRNG
263+
case .torch: return .torchRNG
264+
}
265+
}
266+
}
267+
253268
if #available(iOS 16.2, macOS 13.1, *) {
254269
StableDiffusionSample.main()
255270
} else {

0 commit comments

Comments
 (0)