Skip to content

Commit 0d3ba2f

Browse files
committed
BNNS Random extension
1 parent 02c1225 commit 0d3ba2f

File tree

6 files changed

+336
-3
lines changed

6 files changed

+336
-3
lines changed

Project.toml

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,24 @@ uuid = "13e28ba4-7ad8-5781-acae-3021b1ed3924"
33
version = "0.4.1"
44

55
[deps]
6+
BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b"
7+
CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82"
68
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
79
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
810

11+
[weakdeps]
12+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
13+
14+
[extensions]
15+
RandomExt = "Random"
16+
917
[compat]
18+
BFloat16s = "0.5.0"
19+
CEnum = "0.5.0"
1020
julia = "1.9"
1121

1222
[extras]
23+
DSP = "717857b8-e6f2-59f4-9121-6e50c889abd2"
1324
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
1425
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
1526
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
@@ -18,7 +29,6 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1829
Sockets = "6462fe0b-24de-5631-8697-dd941f90decc"
1930
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
2031
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
21-
DSP = "717857b8-e6f2-59f4-9121-6e50c889abd2"
2232
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
2333
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2434

ext/RandomExt.jl

Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
module RandomExt
2+
using BFloat16s
3+
using AppleAccelerate: BNNS
4+
using .BNNS: BNNSFilterParameters,
5+
BNNSRandomGeneratorMethodAES_CTR,
6+
BNNSCreateRandomGenerator,
7+
BNNSCreateRandomGeneratorWithSeed,
8+
BNNSRandomGeneratorStateSize,
9+
BNNSRandomGeneratorSetState,
10+
BNNSRandomGeneratorGetState,
11+
BNNSNDArrayDescriptor,
12+
BNNSRandomFillNormalFloat,
13+
BNNSRandomFillUniformFloat,
14+
BNNSRandomFillUniformInt
15+
using Random: Random, AbstractRNG
16+
17+
"""
18+
RNG()
19+
20+
A random number generator using AppleAccelerate's BNNS functionality.
21+
"""
22+
mutable struct RNG <: AbstractRNG
23+
ptr::Ptr{Nothing}
24+
function RNG(filter_parameters::Union{Nothing, BNNSFilterParameters}=nothing)
25+
params = isnothing(filter_parameters) ? Ptr{BNNSFilterParameters}(0) : [filter_parameters]
26+
res = new(BNNSCreateRandomGenerator(BNNSRandomGeneratorMethodAES_CTR, params))
27+
# finalizer(res) do
28+
# BNNSDestroyRandomGenerator(res.ptr)
29+
# end
30+
return res
31+
end
32+
function RNG(seed::Integer, filter_parameters::Union{Nothing, BNNSFilterParameters}=nothing)
33+
seed = seed%UInt64
34+
params = isnothing(filter_parameters) ? Ptr{BNNSFilterParameters}(0) : [filter_parameters]
35+
res = new(BNNSCreateRandomGeneratorWithSeed(BNNSRandomGeneratorMethodAES_CTR, seed, params))
36+
# finalizer(res) do
37+
# BNNSDestroyRandomGenerator(res.ptr)
38+
# end
39+
return res
40+
end
41+
end
42+
43+
BNNS.bnns_rng() = RNG()
44+
BNNS.bnns_rng(seed::Integer) = RNG(seed)
45+
46+
function _get_rng_state(rng::RNG)
47+
stateSize = BNNSRandomGeneratorStateSize(rng.ptr)
48+
state = Memory{UInt8}(undef, Int64(stateSize))
49+
BNNSRandomGeneratorGetState(rng.ptr, stateSize, state)
50+
return state
51+
end
52+
53+
function Base.copy!(dest::RNG, src::RNG)
54+
state = _get_rng_state(src)
55+
BNNSRandomGeneratorSetState(dest.ptr, length(state), state)
56+
return dest
57+
end
58+
59+
function Base.copy(rng::RNG)
60+
newrng = RNG()
61+
return copy!(newrng, rng)
62+
end
63+
64+
Base.:(==)(rng1::RNG, rng2::RNG) = _get_rng_state(rng1) == _get_rng_state(rng2)
65+
66+
function Random.seed!(rng::RNG, seed::Integer)
67+
return copy!(rng, RNG(seed))
68+
end
69+
70+
function Random.seed!(rng::RNG)
71+
return copy!(rng, RNG())
72+
end
73+
74+
const GLOBAL_RNG = Ref{RNG}()
75+
function BNNS.default_rng()
76+
if !isassigned(GLOBAL_RNG)
77+
GLOBAL_RNG[] = BNNS.bnns_rng()
78+
end
79+
return GLOBAL_RNG[]
80+
end
81+
82+
const BNNSInt = Union{UInt8,Int8,UInt16,Int16,UInt32,Int32,UInt64,Int64}
83+
const BNNSFloat = Union{Float16, Float32, BFloat16}
84+
85+
const BNNSUniform = Union{<:BNNSInt,<:BNNSFloat}
86+
const BNNSNormal = BNNSFloat
87+
88+
function Random.rand!(rng::RNG, A::DenseArray{T}) where {T<:BNNSInt}
89+
isempty(A) && return A
90+
desc = Ref(BNNSNDArrayDescriptor(A))
91+
res = BNNSRandomFillUniformInt(rng.ptr, desc, typemin(signed(T)), typemax(signed(T)))
92+
@assert res == 0
93+
return A
94+
end
95+
function Random.rand!(rng::RNG, A::DenseArray{T}) where {T<:BNNSFloat}
96+
isempty(A) && return A
97+
desc = Ref(BNNSNDArrayDescriptor(A))
98+
res = BNNSRandomFillUniformFloat(rng.ptr, desc, T(0), T(1))
99+
@assert res == 0
100+
return A
101+
end
102+
function Random.randn!(rng::RNG, A::DenseArray{T}) where {T<:BNNSFloat}
103+
isempty(A) && return A
104+
desc = Ref(BNNSNDArrayDescriptor(A))
105+
res = BNNSRandomFillNormalFloat(rng.ptr, desc, Float32(0), Float32(1))
106+
@assert res == 0
107+
return A
108+
end
109+
110+
# Out of place
111+
Random.rand(rng::RNG, ::Type{T}, dims::Dims) where T <: BNNSUniform =
112+
Random.rand!(rng, Array{T,length(dims)}(undef, dims...))
113+
Random.randn(rng::RNG, ::Type{T}, dims::Dims) where T <: BNNSNormal =
114+
Random.randn!(rng, Array{T,length(dims)}(undef, dims...))
115+
116+
# support all dimension specifications
117+
Random.rand(rng::RNG, ::Type{T}, dim1::Integer, dims::Integer...) where T <: BNNSUniform =
118+
Random.rand!(rng, Array{T,length(dims) + 1}(undef, dim1, dims...))
119+
Random.randn(rng::RNG, ::Type{T}, dim1::Integer, dims::Integer...) where T <: BNNSNormal =
120+
Random.randn!(rng, Array{T,length(dims) + 1}(undef, dim1, dims...))
121+
122+
# untyped out-of-place
123+
Random.rand(rng::RNG, dim1::Integer, dims::Integer...) =
124+
Random.rand!(rng, Array{Float32,length(dims) + 1}(undef, dim1, dims...))
125+
Random.randn(rng::RNG, dim1::Integer, dims::Integer...) =
126+
Random.randn!(rng, Array{Float32,length(dims) + 1}(undef, dim1, dims...))
127+
128+
# scalars
129+
Random.rand(rng::RNG, T::Union{Type{Float16}, Type{Float32}, Type{BFloat16},
130+
Type{Int8}, Type{UInt8},
131+
Type{Int16}, Type{UInt16},
132+
Type{Int32}, Type{UInt32},
133+
Type{Int64}, Type{UInt64}}=Float32) = Random.rand(rng, T, 1)[1]
134+
135+
# This is the only way I could fix method ambiguity
136+
Random.randn(rng::RNG, T::Type{BFloat16}) = Random.randn(rng, T, 1)[1]
137+
Random.randn(rng::RNG, T::Type{Float16}) = Random.randn(rng, T, 1)[1]
138+
Random.randn(rng::RNG, T::Type{Float32}) = Random.randn(rng, T, 1)[1]
139+
Random.randn(rng::RNG) = Random.randn(rng, Float32)
140+
141+
142+
# GPUArrays out-of-place
143+
function BNNS.rand(::Type{T}, dims::Dims) where T <: BNNSUniform
144+
return Random.rand!(BNNS.default_rng(), Array{T,length(dims)}(undef, dims...))
145+
end
146+
function BNNS.randn(::Type{T}, dims::Dims) where T <: BNNSNormal
147+
return Random.randn!(BNNS.default_rng(), Array{T,length(dims)}(undef, dims...))
148+
end
149+
150+
# support all dimension specifications
151+
function BNNS.rand(::Type{T}, dim1::Integer, dims::Integer...) where T <: BNNSUniform
152+
return Random.rand!(BNNS.default_rng(), Array{T,length(dims) + 1}(undef, dim1, dims...))
153+
end
154+
function BNNS.randn(::Type{T}, dim1::Integer, dims::Integer...) where T <: BNNSNormal
155+
return Random.randn!(BNNS.default_rng(), Array{T,length(dims) + 1}(undef, dim1, dims...))
156+
end
157+
158+
# untyped out-of-place
159+
BNNS.rand(dim1::Integer, dims::Integer...) =
160+
Random.rand!(BNNS.default_rng(), Array{Float32,length(dims) + 1}(undef, dim1, dims...))
161+
BNNS.randn(dim1::Integer, dims::Integer...) =
162+
Random.randn!(BNNS.default_rng(), Array{Float32,length(dims) + 1}(undef, dim1, dims...))
163+
164+
# scalars
165+
BNNS.rand(T::Type=Float32) = BNNS.rand(T, 1)[1]
166+
BNNS.randn(T::Type=Float32) = BNNS.randn(T, 1)[1]
167+
168+
# seeding
169+
function BNNS.seed!(seed=Base.rand(UInt64))
170+
Random.seed!(BNNS.default_rng(), seed)
171+
end
172+
173+
174+
end # module

lib/BNNS/BNNS.jl

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
module BNNS
2+
using BFloat16s
3+
4+
include("libBNNS.jl")
5+
6+
bnnsdatatype_modifier(::Type{T}) where {T <: Union{AbstractFloat, Bool}} = BNNSDataTypeFloatBit
7+
bnnsdatatype_modifier(::Type{T}) where {T <: Signed} = BNNSDataTypeIntBit
8+
bnnsdatatype_modifier(::Type{T}) where {T <: Unsigned} = BNNSDataTypeUIntBit
9+
bnnsdatatype_modifier(::Type{Bool}) = BNNSDataTypeMiscellaneousBit
10+
bnnsdatatype_modifier(::Type{BFloat16}) = 0x18000
11+
12+
Base.convert(::Type{BNNSDataType}, T) = BNNSDataType(bnnsdatatype_modifier(T) | UInt32(sizeof(T)*8))
13+
14+
function BNNSNDArrayDescriptor(arr::AbstractArray{T, N}) where {T,N}
15+
N > 8 && throw(ArgumentError("BNNSNDArrays do not support more than 8 dimensions."))
16+
17+
18+
layout = BNNSDataLayout(UInt32(N) * UInt32(BNNSDataLayoutVector) | 0x8000)
19+
# layout = datalayout[N]
20+
sz = ntuple(Val(8)) do i
21+
Csize_t(get(size(arr), i, 0))
22+
end
23+
stride = ntuple(_ -> Csize_t(0), Val(8))
24+
return GC.@preserve arr BNNSNDArrayDescriptor(BNNSNDArrayFlagBackpropSet,
25+
layout,
26+
sz,
27+
stride,
28+
Ptr{Nothing}(pointer(arr)),
29+
T,
30+
0,
31+
T,
32+
1,
33+
0)
34+
end
35+
36+
# Definitions for the Random extension
37+
function bnns_rng end
38+
function default_rng end
39+
function rand end
40+
function randn end
41+
function rand! end
42+
function randn! end
43+
function seed! end
44+
45+
end

src/AppleAccelerate.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,10 +91,11 @@ function __init__()
9191
load_accelerate(; load_ilp64=true, use_external_lapack=false)
9292
end
9393

94-
if Sys.isapple()
94+
@static if Sys.isapple()
9595
include("Util.jl")
9696
include("Array.jl")
9797
include("DSP.jl")
98+
include("../lib/BNNS/BNNS.jl")
9899
end
99100

100101
end # module

test/BNNS.jl

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
const RAND_TYPES = [BFloat16, Float16, Float32, Int8, UInt8, Int16, UInt16, Int32, UInt32, Int64,
2+
UInt64]
3+
const RANDN_TYPES = [BFloat16, Float16, Float32]
4+
const INPLACE_TUPLES = [[(rand!, T) for T in RAND_TYPES];
5+
[(randn!, T) for T in RANDN_TYPES]]
6+
const OOPLACE_TUPLES = [[(BNNS.rand, rand, T) for T in RAND_TYPES];
7+
[(BNNS.randn, rand, T) for T in RANDN_TYPES]]
8+
9+
@testset "random" begin
10+
# in-place
11+
@testset "in-place" begin
12+
rng = BNNS.bnns_rng()
13+
14+
@testset "$f with $T" for (f, T) in INPLACE_TUPLES
15+
# d == 2 and d == 3 are to hit the test cases where sizeof(A) <= 4
16+
@testset "$d" for d in (2, 3, (3, 3), (3, 3, 3), 16, (16, 16), (16, 16, 16), (1000,), (1000,1000))
17+
A = Array{T}(undef, d)
18+
19+
# specifie BNNS rng
20+
fill!(A, T(0))
21+
f(rng, A)
22+
@test !iszero(collect(A))
23+
end
24+
25+
@testset "0" begin
26+
A = Array{T}(undef, 0)
27+
28+
# specified BNNS rng
29+
fill!(A, T(0))
30+
f(rng, A)
31+
@test Array(A) == fill(1, 0)
32+
end
33+
end
34+
end
35+
# out-of-place
36+
@testset "out-of-place" begin
37+
@testset "$fr with implicit type" for (fm, fr, T) in
38+
((BNNS.rand, Random.rand, Float32), (BNNS.randn, Random.randn, Float32))
39+
rng = BNNS.bnns_rng()
40+
@testset "args" for args in ((0,), (1,), (3,), (3, 3), (16,), (16, 16), (1000,), (1000,1000))
41+
# default_rng
42+
A = fm(args...)
43+
@test eltype(A) == T
44+
45+
# specified MPS rng
46+
B = fr(rng, args...)
47+
@test eltype(B) == T
48+
end
49+
50+
@testset "scalar" begin
51+
a = fm()
52+
@test typeof(a) == T
53+
b = fr(rng)
54+
@test typeof(b) == T
55+
end
56+
end
57+
58+
# out-of-place, with type specified
59+
@testset "$fr with $T" for (fm, fr, T) in OOPLACE_TUPLES
60+
rng = BNNS.bnns_rng()
61+
@testset "$args" for args in ((T, 0),
62+
(T, 1),
63+
(T, 3),
64+
(T, 3, 3),
65+
(T, (3, 3)),
66+
(T, 16),
67+
(T, 16, 16),
68+
(T, (16, 16)),
69+
(T, 1000),
70+
(T, 1000, 1000),)
71+
# default_rng
72+
A = fm(args...)
73+
@test eltype(A) == T
74+
75+
# specified RNG rng
76+
B = fr(rng, args...)
77+
@test eltype(B) == T
78+
end
79+
80+
@testset "scalar" begin
81+
a = fm(T)
82+
@test typeof(a) == T
83+
b = fr(rng, T)
84+
@test typeof(b) == T
85+
end
86+
end
87+
end
88+
89+
## seeding
90+
@testset "Seeding" begin
91+
@testset "$d" for d in (1, 3, (3, 3, 3), 16, (16, 16), (16, 16, 16), (1000,), (1000,1000), (3,3,3,3), (3,3,3,3,3), (3,3,3,3,3,3))
92+
rng = BNNS.bnns_rng(1)
93+
a = rand(rng, Float32, d)
94+
Random.seed!(rng, 1)
95+
b = rand(rng, Float32, d)
96+
@test a == b
97+
end
98+
end
99+
end # testset

test/runtests.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
using LinearAlgebra
22
using AppleAccelerate
3-
using DSP, Test, Random, Statistics
3+
using DSP, Test, Random, Statistics, BFloat16s
44

55
if !Sys.isapple()
66
@info("AppleAccelerate.jl will be tested only on macOS. Exiting.")
@@ -165,6 +165,10 @@ for T in (Float32, Float64)
165165
end
166166
end
167167

168+
@testset "BNNS" begin
169+
include("BNNS.jl")
170+
end
171+
168172

169173
@testset "DCT::Float32" begin
170174
r=rand(Float32,2^16)

0 commit comments

Comments
 (0)