Skip to content

Commit 7b59c2b

Browse files
authored
Add NvRandomSource (#257)
Draw Things 1.20230913.1 released with NVIDIA GPU Compatible mode. See example: https://twitter.com/drawthingsapp/status/1702387783166476689 This is the implementation that should be compatible with the ml-stable-diffusion project.
1 parent c506322 commit 7b59c2b

File tree

3 files changed

+98
-2
lines changed

3 files changed

+98
-2
lines changed
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
import Foundation
2+
import CoreML
3+
4+
/// A random source consistent with NVIDIA curandom
5+
///
6+
/// This implementation references to:
7+
/// https://github.com/dsnz/random/blob/master/philox.py for Philox_M4_32 configuration.
8+
///
9+
@available(iOS 16.2, macOS 13.1, *)
10+
struct NvRandomSource: RandomSource {
11+
public let seed: UInt64
12+
private var offset: UInt32
13+
14+
/// Initialize with a random seed
15+
///
16+
/// - Parameters
17+
/// - seed: Seed for underlying Philox M4 32 generator
18+
/// - Returns random source
19+
public init(seed: UInt32) {
20+
self.seed = UInt64(seed)
21+
offset = 0
22+
}
23+
24+
static private let PHILOX_M4_32: (UInt32, UInt32) = (0xD251_1F53, 0xCD9E_8D57)
25+
static private let PHILOX_W_32: (UInt32, UInt32) = (0x9E37_79B9, 0xBB67_AE85)
26+
27+
static private func philox4Round(counter: inout [[UInt32]], key: [[UInt32]]) {
28+
for i in 0..<counter[0].count {
29+
let v1: UInt64 = UInt64(counter[0][i]) * UInt64(PHILOX_M4_32.0)
30+
let v2: UInt64 = UInt64(counter[2][i]) * UInt64(PHILOX_M4_32.1)
31+
counter[0][i] = UInt32(v2 >> 32) ^ counter[1][i] ^ key[0][i]
32+
counter[1][i] = UInt32(v2 & 0xffff_ffff)
33+
counter[2][i] = UInt32(v1 >> 32) ^ counter[3][i] ^ key[1][i]
34+
counter[3][i] = UInt32(v1 & 0xffff_ffff)
35+
}
36+
}
37+
38+
static private func philox4Bumpkey(key: inout [[UInt32]]) {
39+
for (i, element) in key[0].enumerated() {
40+
key[0][i] = element &+ PHILOX_W_32.0
41+
}
42+
for (i, element) in key[1].enumerated() {
43+
key[1][i] = element &+ PHILOX_W_32.1
44+
}
45+
}
46+
47+
static private func philox4_32(counter: inout [[UInt32]], key: inout [[UInt32]], rounds: Int = 10) {
48+
for _ in 0..<(rounds - 1) {
49+
philox4Round(counter: &counter, key: key)
50+
philox4Bumpkey(key: &key)
51+
}
52+
philox4Round(counter: &counter, key: key)
53+
}
54+
55+
private func boxMuller(_ counter1: [UInt32], _ counter2: [UInt32], mean: Double, stdev: Double) -> [Double] {
56+
// Box-Muller transform
57+
return zip(counter1, counter2).map {
58+
let u: Double = Double($0) / 4294967296.0 + (1.0 / 8589934592.0)
59+
let v: Double = Double($1) * (.pi / 2147483648.0) + (.pi / 4294967296.0)
60+
let radius = stdev * sqrt(-2.0 * log(u))
61+
return radius * sin(v) + mean
62+
}
63+
}
64+
65+
private mutating func normalArray(count: Int, mean: Double, stdev: Double) -> [Double] {
66+
var counter: [[UInt32]] = [
67+
Array(repeating: offset, count: count),
68+
Array(repeating: 0, count: count),
69+
Array(0..<UInt32(count)),
70+
Array(repeating: 0, count: count),
71+
]
72+
offset += 1
73+
var key: [[UInt32]] = [
74+
Array(repeating: UInt32(seed & 0xffff_ffff), count: count),
75+
Array(repeating: UInt32(seed >> 32), count: count),
76+
]
77+
Self.philox4_32(counter: &counter, key: &key)
78+
return boxMuller(counter[0], counter[1], mean: mean, stdev: stdev)
79+
}
80+
81+
/// Generates a random value from a normal distribution with given mean and standard deviation.
82+
mutating func nextNormal(mean: Double = 0.0, stdev: Double = 1.0) -> Double {
83+
return normalArray(count: 1, mean: mean, stdev: stdev)[0]
84+
}
85+
86+
/// Generate a shaped array with scalars from a normal distribution with given mean and standard deviation.
87+
mutating func normalShapedArray(_ shape: [Int], mean: Double = 0.0, stdev: Double = 1.0) -> MLShapedArray<Double> {
88+
let count = shape.reduce(1, *)
89+
return .init(scalars: normalArray(count: count, mean: mean, stdev: stdev), shape: shape)
90+
}
91+
}

swift/StableDiffusion/pipeline/StableDiffusionPipeline.swift

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ public enum StableDiffusionRNG {
2121
case numpyRNG
2222
/// RNG that matches PyTorch CPU implementation.
2323
case torchRNG
24+
/// RNG that matches PyTorch CUDA implementation.
25+
case nvidiaRNG
2426
}
2527

2628
public enum PipelineError: String, Swift.Error {
@@ -398,6 +400,8 @@ extension StableDiffusionPipelineProtocol {
398400
return NumPyRandomSource(seed: seed)
399401
case .torchRNG:
400402
return TorchRandomSource(seed: seed)
403+
case .nvidiaRNG:
404+
return NvRandomSource(seed: seed)
401405
}
402406
}
403407

swift/StableDiffusionCLI/main.swift

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ struct StableDiffusionSample: ParsableCommand {
7474
@Option(help: "Scheduler to use, one of {pndm, dpmpp}")
7575
var scheduler: SchedulerOption = .pndm
7676

77-
@Option(help: "Random number generator to use, one of {numpy, torch}")
77+
@Option(help: "Random number generator to use, one of {numpy, torch, nvidia}")
7878
var rng: RNGOption = .numpy
7979

8080
@Option(
@@ -336,11 +336,12 @@ enum SchedulerOption: String, ExpressibleByArgument {
336336

337337
@available(iOS 16.2, macOS 13.1, *)
338338
enum RNGOption: String, ExpressibleByArgument {
339-
case numpy, torch
339+
case numpy, torch, nvidia
340340
var stableDiffusionRNG: StableDiffusionRNG {
341341
switch self {
342342
case .numpy: return .numpyRNG
343343
case .torch: return .torchRNG
344+
case .nvidia: return .nvidiaRNG
344345
}
345346
}
346347
}

0 commit comments

Comments
 (0)