|
| 1 | +import Foundation |
| 2 | +import CoreML |
| 3 | + |
| 4 | +/// A random source consistent with NVIDIA curandom |
| 5 | +/// |
| 6 | +/// This implementation references to: |
| 7 | +/// https://github.com/dsnz/random/blob/master/philox.py for Philox_M4_32 configuration. |
| 8 | +/// |
| 9 | +@available(iOS 16.2, macOS 13.1, *) |
| 10 | +struct NvRandomSource: RandomSource { |
| 11 | + public let seed: UInt64 |
| 12 | + private var offset: UInt32 |
| 13 | + |
| 14 | + /// Initialize with a random seed |
| 15 | + /// |
| 16 | + /// - Parameters |
| 17 | + /// - seed: Seed for underlying Philox M4 32 generator |
| 18 | + /// - Returns random source |
| 19 | + public init(seed: UInt32) { |
| 20 | + self.seed = UInt64(seed) |
| 21 | + offset = 0 |
| 22 | + } |
| 23 | + |
| 24 | + static private let PHILOX_M4_32: (UInt32, UInt32) = (0xD251_1F53, 0xCD9E_8D57) |
| 25 | + static private let PHILOX_W_32: (UInt32, UInt32) = (0x9E37_79B9, 0xBB67_AE85) |
| 26 | + |
| 27 | + static private func philox4Round(counter: inout [[UInt32]], key: [[UInt32]]) { |
| 28 | + for i in 0..<counter[0].count { |
| 29 | + let v1: UInt64 = UInt64(counter[0][i]) * UInt64(PHILOX_M4_32.0) |
| 30 | + let v2: UInt64 = UInt64(counter[2][i]) * UInt64(PHILOX_M4_32.1) |
| 31 | + counter[0][i] = UInt32(v2 >> 32) ^ counter[1][i] ^ key[0][i] |
| 32 | + counter[1][i] = UInt32(v2 & 0xffff_ffff) |
| 33 | + counter[2][i] = UInt32(v1 >> 32) ^ counter[3][i] ^ key[1][i] |
| 34 | + counter[3][i] = UInt32(v1 & 0xffff_ffff) |
| 35 | + } |
| 36 | + } |
| 37 | + |
| 38 | + static private func philox4Bumpkey(key: inout [[UInt32]]) { |
| 39 | + for (i, element) in key[0].enumerated() { |
| 40 | + key[0][i] = element &+ PHILOX_W_32.0 |
| 41 | + } |
| 42 | + for (i, element) in key[1].enumerated() { |
| 43 | + key[1][i] = element &+ PHILOX_W_32.1 |
| 44 | + } |
| 45 | + } |
| 46 | + |
| 47 | + static private func philox4_32(counter: inout [[UInt32]], key: inout [[UInt32]], rounds: Int = 10) { |
| 48 | + for _ in 0..<(rounds - 1) { |
| 49 | + philox4Round(counter: &counter, key: key) |
| 50 | + philox4Bumpkey(key: &key) |
| 51 | + } |
| 52 | + philox4Round(counter: &counter, key: key) |
| 53 | + } |
| 54 | + |
| 55 | + private func boxMuller(_ counter1: [UInt32], _ counter2: [UInt32], mean: Double, stdev: Double) -> [Double] { |
| 56 | + // Box-Muller transform |
| 57 | + return zip(counter1, counter2).map { |
| 58 | + let u: Double = Double($0) / 4294967296.0 + (1.0 / 8589934592.0) |
| 59 | + let v: Double = Double($1) * (.pi / 2147483648.0) + (.pi / 4294967296.0) |
| 60 | + let radius = stdev * sqrt(-2.0 * log(u)) |
| 61 | + return radius * sin(v) + mean |
| 62 | + } |
| 63 | + } |
| 64 | + |
| 65 | + private mutating func normalArray(count: Int, mean: Double, stdev: Double) -> [Double] { |
| 66 | + var counter: [[UInt32]] = [ |
| 67 | + Array(repeating: offset, count: count), |
| 68 | + Array(repeating: 0, count: count), |
| 69 | + Array(0..<UInt32(count)), |
| 70 | + Array(repeating: 0, count: count), |
| 71 | + ] |
| 72 | + offset += 1 |
| 73 | + var key: [[UInt32]] = [ |
| 74 | + Array(repeating: UInt32(seed & 0xffff_ffff), count: count), |
| 75 | + Array(repeating: UInt32(seed >> 32), count: count), |
| 76 | + ] |
| 77 | + Self.philox4_32(counter: &counter, key: &key) |
| 78 | + return boxMuller(counter[0], counter[1], mean: mean, stdev: stdev) |
| 79 | + } |
| 80 | + |
| 81 | + /// Generates a random value from a normal distribution with given mean and standard deviation. |
| 82 | + mutating func nextNormal(mean: Double = 0.0, stdev: Double = 1.0) -> Double { |
| 83 | + return normalArray(count: 1, mean: mean, stdev: stdev)[0] |
| 84 | + } |
| 85 | + |
| 86 | + /// Generate a shaped array with scalars from a normal distribution with given mean and standard deviation. |
| 87 | + mutating func normalShapedArray(_ shape: [Int], mean: Double = 0.0, stdev: Double = 1.0) -> MLShapedArray<Double> { |
| 88 | + let count = shape.reduce(1, *) |
| 89 | + return .init(scalars: normalArray(count: count, mean: mean, stdev: stdev), shape: shape) |
| 90 | + } |
| 91 | +} |
0 commit comments