diff --git a/Project.toml b/Project.toml index d5a81e2..7d09e1c 100644 --- a/Project.toml +++ b/Project.toml @@ -3,13 +3,24 @@ uuid = "13e28ba4-7ad8-5781-acae-3021b1ed3924" version = "0.4.1" [deps] +BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" +CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +[weakdeps] +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" + +[extensions] +RandomExt = "Random" + [compat] +BFloat16s = "0.5.0" +CEnum = "0.5.0" julia = "1.9" [extras] +DSP = "717857b8-e6f2-59f4-9121-6e50c889abd2" Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" @@ -18,7 +29,6 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Sockets = "6462fe0b-24de-5631-8697-dd941f90decc" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" -DSP = "717857b8-e6f2-59f4-9121-6e50c889abd2" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/ext/RandomExt.jl b/ext/RandomExt.jl new file mode 100644 index 0000000..61549e2 --- /dev/null +++ b/ext/RandomExt.jl @@ -0,0 +1,186 @@ +module RandomExt + +@static if Sys.isapple() + +using BFloat16s +using AppleAccelerate: BNNS +using .BNNS: BNNSFilterParameters, + BNNSRandomGeneratorMethodAES_CTR, + BNNSCreateRandomGenerator, + BNNSCreateRandomGeneratorWithSeed, + BNNSRandomGeneratorStateSize, + BNNSRandomGeneratorSetState, + BNNSRandomGeneratorGetState, + BNNSNDArrayDescriptor, + BNNSRandomFillNormalFloat, + BNNSRandomFillUniformFloat, + BNNSRandomFillUniformInt +using Random: Random, AbstractRNG + +""" + RNG() + +A random number generator using AppleAccelerate's BNNS functionality. +""" +mutable struct RNG <: AbstractRNG + ptr::Ptr{Nothing} + function RNG(filter_parameters::Union{Nothing, BNNSFilterParameters}=nothing) + params = isnothing(filter_parameters) ? Ptr{BNNSFilterParameters}(0) : [filter_parameters] + res = new(BNNSCreateRandomGenerator(BNNSRandomGeneratorMethodAES_CTR, params)) + # finalizer(res) do + # BNNSDestroyRandomGenerator(res.ptr) + # end + return res + end + function RNG(seed::Integer, filter_parameters::Union{Nothing, BNNSFilterParameters}=nothing) + seed = seed%UInt64 + params = isnothing(filter_parameters) ? Ptr{BNNSFilterParameters}(0) : [filter_parameters] + res = new(BNNSCreateRandomGeneratorWithSeed(BNNSRandomGeneratorMethodAES_CTR, seed, params)) + # finalizer(res) do + # BNNSDestroyRandomGenerator(res.ptr) + # end + return res + end +end + +BNNS.bnns_rng() = RNG() +BNNS.bnns_rng(seed::Integer) = RNG(seed) + +@static if isdefined(Base, :Memory) #VERSION >= v"1.11" + function _get_rng_state(rng::RNG) + stateSize = BNNSRandomGeneratorStateSize(rng.ptr) + state = Memory{UInt8}(undef, Int64(stateSize)) + BNNSRandomGeneratorGetState(rng.ptr, stateSize, state) + return state + end +else + function _get_rng_state(rng::RNG) + stateSize = BNNSRandomGeneratorStateSize(rng.ptr) + state = Vector{UInt8}(undef, Int64(stateSize)) + BNNSRandomGeneratorGetState(rng.ptr, stateSize, state) + return state + end +end + +function Base.copy!(dest::RNG, src::RNG) + state = _get_rng_state(src) + BNNSRandomGeneratorSetState(dest.ptr, length(state), state) + return dest +end + +function Base.copy(rng::RNG) + newrng = RNG() + return copy!(newrng, rng) +end + +Base.:(==)(rng1::RNG, rng2::RNG) = _get_rng_state(rng1) == _get_rng_state(rng2) + +function Random.seed!(rng::RNG, seed::Integer) + return copy!(rng, RNG(seed)) +end + +function Random.seed!(rng::RNG) + return copy!(rng, RNG()) +end + +const GLOBAL_RNG = Ref{RNG}() +function BNNS.default_rng() + if !isassigned(GLOBAL_RNG) + GLOBAL_RNG[] = BNNS.bnns_rng() + end + return GLOBAL_RNG[] +end + +const BNNSInt = Union{UInt8,Int8,UInt16,Int16,UInt32,Int32,UInt64,Int64} +const BNNSFloat = Union{Float16, Float32, BFloat16} + +const BNNSUniform = Union{<:BNNSInt,<:BNNSFloat} +const BNNSNormal = BNNSFloat + +function Random.rand!(rng::RNG, A::DenseArray{T}) where {T<:BNNSInt} + isempty(A) && return A + desc = Ref(BNNSNDArrayDescriptor(A)) + res = BNNSRandomFillUniformInt(rng.ptr, desc, typemin(signed(T)), typemax(signed(T))) + @assert res == 0 + return A +end +function Random.rand!(rng::RNG, A::DenseArray{T}) where {T<:BNNSFloat} + isempty(A) && return A + desc = Ref(BNNSNDArrayDescriptor(A)) + res = BNNSRandomFillUniformFloat(rng.ptr, desc, T(0), T(1)) + @assert res == 0 + return A +end +function Random.randn!(rng::RNG, A::DenseArray{T}) where {T<:BNNSFloat} + isempty(A) && return A + desc = Ref(BNNSNDArrayDescriptor(A)) + res = BNNSRandomFillNormalFloat(rng.ptr, desc, Float32(0), Float32(1)) + @assert res == 0 + return A +end + +# Out of place +Random.rand(rng::RNG, ::Type{T}, dims::Dims) where T <: BNNSUniform = + Random.rand!(rng, Array{T,length(dims)}(undef, dims...)) +Random.randn(rng::RNG, ::Type{T}, dims::Dims) where T <: BNNSNormal = + Random.randn!(rng, Array{T,length(dims)}(undef, dims...)) + +# support all dimension specifications +Random.rand(rng::RNG, ::Type{T}, dim1::Integer, dims::Integer...) where T <: BNNSUniform = + Random.rand!(rng, Array{T,length(dims) + 1}(undef, dim1, dims...)) +Random.randn(rng::RNG, ::Type{T}, dim1::Integer, dims::Integer...) where T <: BNNSNormal = + Random.randn!(rng, Array{T,length(dims) + 1}(undef, dim1, dims...)) + +# untyped out-of-place +Random.rand(rng::RNG, dim1::Integer, dims::Integer...) = + Random.rand!(rng, Array{Float32,length(dims) + 1}(undef, dim1, dims...)) +Random.randn(rng::RNG, dim1::Integer, dims::Integer...) = + Random.randn!(rng, Array{Float32,length(dims) + 1}(undef, dim1, dims...)) + +# scalars +Random.rand(rng::RNG, T::Union{Type{Float16}, Type{Float32}, Type{BFloat16}, +Type{Int8}, Type{UInt8}, +Type{Int16}, Type{UInt16}, +Type{Int32}, Type{UInt32}, +Type{Int64}, Type{UInt64}}=Float32) = Random.rand(rng, T, 1)[1] + +# This is the only way I could fix method ambiguity +Random.randn(rng::RNG, T::Type{BFloat16}) = Random.randn(rng, T, 1)[1] +Random.randn(rng::RNG, T::Type{Float16}) = Random.randn(rng, T, 1)[1] +Random.randn(rng::RNG, T::Type{Float32}) = Random.randn(rng, T, 1)[1] +Random.randn(rng::RNG) = Random.randn(rng, Float32) + + +# GPUArrays out-of-place +function BNNS.rand(::Type{T}, dims::Dims) where T <: BNNSUniform + return Random.rand!(BNNS.default_rng(), Array{T,length(dims)}(undef, dims...)) +end +function BNNS.randn(::Type{T}, dims::Dims) where T <: BNNSNormal + return Random.randn!(BNNS.default_rng(), Array{T,length(dims)}(undef, dims...)) +end + +# support all dimension specifications +function BNNS.rand(::Type{T}, dim1::Integer, dims::Integer...) where T <: BNNSUniform + return Random.rand!(BNNS.default_rng(), Array{T,length(dims) + 1}(undef, dim1, dims...)) +end +function BNNS.randn(::Type{T}, dim1::Integer, dims::Integer...) where T <: BNNSNormal + return Random.randn!(BNNS.default_rng(), Array{T,length(dims) + 1}(undef, dim1, dims...)) +end + +# untyped out-of-place +BNNS.rand(dim1::Integer, dims::Integer...) = + Random.rand!(BNNS.default_rng(), Array{Float32,length(dims) + 1}(undef, dim1, dims...)) +BNNS.randn(dim1::Integer, dims::Integer...) = + Random.randn!(BNNS.default_rng(), Array{Float32,length(dims) + 1}(undef, dim1, dims...)) + +# scalars +BNNS.rand(T::Type=Float32) = BNNS.rand(T, 1)[1] +BNNS.randn(T::Type=Float32) = BNNS.randn(T, 1)[1] + +# seeding +function BNNS.seed!(seed=Base.rand(UInt64)) + Random.seed!(BNNS.default_rng(), seed) +end + +end +end # module diff --git a/lib/BNNS/BNNS.jl b/lib/BNNS/BNNS.jl new file mode 100644 index 0000000..4f87498 --- /dev/null +++ b/lib/BNNS/BNNS.jl @@ -0,0 +1,42 @@ +using BFloat16s + +include("libBNNS.jl") + +bnnsdatatype_modifier(::Type{T}) where {T <: Union{AbstractFloat, Bool}} = BNNSDataTypeFloatBit +bnnsdatatype_modifier(::Type{T}) where {T <: Signed} = BNNSDataTypeIntBit +bnnsdatatype_modifier(::Type{T}) where {T <: Unsigned} = BNNSDataTypeUIntBit +bnnsdatatype_modifier(::Type{Bool}) = BNNSDataTypeMiscellaneousBit +bnnsdatatype_modifier(::Type{BFloat16}) = 0x18000 + +Base.convert(::Type{BNNSDataType}, T) = BNNSDataType(bnnsdatatype_modifier(T) | UInt32(sizeof(T)*8)) + +function BNNSNDArrayDescriptor(arr::AbstractArray{T, N}) where {T,N} + N > 8 && throw(ArgumentError("BNNSNDArrays do not support more than 8 dimensions.")) + + + layout = BNNSDataLayout(UInt32(N) * UInt32(BNNSDataLayoutVector) | 0x8000) + # layout = datalayout[N] + sz = ntuple(Val(8)) do i + Csize_t(get(size(arr), i, 0)) + end + stride = ntuple(_ -> Csize_t(0), Val(8)) + return GC.@preserve arr BNNSNDArrayDescriptor(BNNSNDArrayFlagBackpropSet, + layout, + sz, + stride, + Ptr{Nothing}(pointer(arr)), + T, + 0, + T, + 1, + 0) +end + +# Definitions for the Random extension +function bnns_rng end +function default_rng end +function rand end +function randn end +function rand! end +function randn! end +function seed! end diff --git a/lib/BNNS/libBNNS.jl b/lib/BNNS/libBNNS.jl new file mode 100644 index 0000000..22aff9d --- /dev/null +++ b/lib/BNNS/libBNNS.jl @@ -0,0 +1,2173 @@ +# This file is automatically generated. Do not edit! +# To re-generate, execute res/wrap/wrap.jl + +using CEnum: CEnum, @cenum + +@cenum BNNSDataType::UInt32 begin + BNNSDataTypeFloatBit = 0x0000000000010000 + BNNSDataTypeFloat16 = 0x0000000000010010 + BNNSDataTypeFloat32 = 0x0000000000010020 + BNNSDataTypeBFloat16 = 0x0000000000018010 + BNNSDataTypeIntBit = 0x0000000000020000 + BNNSDataTypeInt1 = 0x0000000000020001 + BNNSDataTypeInt2 = 0x0000000000020002 + BNNSDataTypeInt4 = 0x0000000000020004 + BNNSDataTypeInt8 = 0x0000000000020008 + BNNSDataTypeInt16 = 0x0000000000020010 + BNNSDataTypeInt32 = 0x0000000000020020 + BNNSDataTypeInt64 = 0x0000000000020040 + BNNSDataTypeUIntBit = 0x0000000000040000 + BNNSDataTypeUInt1 = 0x0000000000040001 + BNNSDataTypeUInt2 = 0x0000000000040002 + BNNSDataTypeUInt3 = 0x0000000000040003 + BNNSDataTypeUInt4 = 0x0000000000040004 + BNNSDataTypeUInt6 = 0x0000000000040006 + BNNSDataTypeUInt8 = 0x0000000000040008 + BNNSDataTypeUInt16 = 0x0000000000040010 + BNNSDataTypeUInt32 = 0x0000000000040020 + BNNSDataTypeUInt64 = 0x0000000000040040 + BNNSDataTypeIndexedBit = 0x0000000000080000 + BNNSDataTypeIndexed1 = 0x0000000000080001 + BNNSDataTypeIndexed2 = 0x0000000000080002 + BNNSDataTypeIndexed4 = 0x0000000000080004 + BNNSDataTypeIndexed8 = 0x0000000000080008 + BNNSDataTypeMiscellaneousBit = 0x0000000000100000 + BNNSDataTypeBoolean = 0x0000000000100008 +end + +@cenum BNNSPoolingFunction::UInt32 begin + BNNSPoolingFunctionMax = 0x0000000000000000 + BNNSPoolingFunctionAverageCountIncludePadding = 0x0000000000000001 + BNNSPoolingFunctionAverageCountExcludePadding = 0x0000000000000002 + BNNSPoolingFunctionUnMax = 0x0000000000000003 + BNNSPoolingFunctionL2Norm = 0x0000000000000004 + BNNSPoolingFunctionAverage = 0x0000000000000001 +end + +@cenum BNNSActivationFunction::UInt32 begin + BNNSActivationFunctionIdentity = 0x0000000000000000 + BNNSActivationFunctionRectifiedLinear = 0x0000000000000001 + BNNSActivationFunctionLeakyRectifiedLinear = 0x0000000000000002 + BNNSActivationFunctionSigmoid = 0x0000000000000003 + BNNSActivationFunctionTanh = 0x0000000000000004 + BNNSActivationFunctionScaledTanh = 0x0000000000000005 + BNNSActivationFunctionAbs = 0x0000000000000006 + BNNSActivationFunctionLinear = 0x0000000000000007 + BNNSActivationFunctionClamp = 0x0000000000000008 + BNNSActivationFunctionIntegerLinearSaturate = 0x0000000000000009 + BNNSActivationFunctionIntegerLinearSaturatePerChannel = 0x000000000000000a + BNNSActivationFunctionSoftmax = 0x000000000000000b + BNNSActivationFunctionGELUApproximation = 0x000000000000000c + BNNSActivationFunctionGumbel = 0x000000000000000d + BNNSActivationFunctionGumbelMax = 0x000000000000000e + BNNSActivationFunctionHardSigmoid = 0x000000000000000f + BNNSActivationFunctionSoftplus = 0x0000000000000010 + BNNSActivationFunctionSoftsign = 0x0000000000000011 + BNNSActivationFunctionELU = 0x0000000000000012 + BNNSActivationFunctionClampedLeakyRectifiedLinear = 0x0000000000000013 + BNNSActivationFunctionLinearWithBias = 0x0000000000000014 + BNNSActivationFunctionLogSoftmax = 0x0000000000000015 + BNNSActivationFunctionLogSigmoid = 0x0000000000000016 + BNNSActivationFunctionSELU = 0x0000000000000017 + BNNSActivationFunctionCELU = 0x0000000000000018 + BNNSActivationFunctionHardShrink = 0x0000000000000019 + BNNSActivationFunctionSoftShrink = 0x000000000000001a + BNNSActivationFunctionTanhShrink = 0x000000000000001b + BNNSActivationFunctionThreshold = 0x000000000000001c + BNNSActivationFunctionPReLUPerChannel = 0x000000000000001d + BNNSActivationFunctionGELUApproximation2 = 0x000000000000001e + BNNSActivationFunctionHardSwish = 0x000000000000001e + BNNSActivationFunctionSiLU = 0x000000000000001f + BNNSActivationFunctionReLU6 = 0x0000000000000020 + BNNSActivationFunctionErf = 0x0000000000000021 + BNNSActivationFunctionGELU = 0x0000000000000022 + BNNSActivationFunctionGELUApproximationSigmoid = 0x0000000000000023 +end + +@cenum BNNSFlags::UInt32 begin + BNNSFlagsUseClientPtr = 0x0000000000000001 +end + +@cenum BNNSLossFunction::UInt32 begin + BNNSLossFunctionSoftmaxCrossEntropy = 0x0000000000000001 + BNNSLossFunctionSigmoidCrossEntropy = 0x0000000000000002 + BNNSLossFunctionMeanSquareError = 0x0000000000000003 + BNNSLossFunctionHuber = 0x0000000000000004 + BNNSLossFunctionYolo = 0x0000000000000005 + BNNSLossFunctionLog = 0x0000000000000006 + BNNSLossFunctionCosineDistance = 0x0000000000000007 + BNNSLossFunctionHinge = 0x0000000000000008 + BNNSLossFunctionMeanAbsoluteError = 0x0000000000000009 + BNNSLossFunctionCategoricalCrossEntropy = 0x000000000000000a +end + +@cenum BNNSLossReductionFunction::UInt32 begin + BNNSLossReductionNone = 0x0000000000000000 + BNNSLossReductionSum = 0x0000000000000001 + BNNSLossReductionWeightedMean = 0x0000000000000002 + BNNSLossReductionMean = 0x0000000000000003 + BNNSLossReductionNonZeroWeightMean = 0x0000000000000004 +end + +@cenum BNNSArithmeticFunction::UInt32 begin + BNNSArithmeticAdd = 0x0000000000000000 + BNNSArithmeticSubtract = 0x0000000000000001 + BNNSArithmeticMultiply = 0x0000000000000002 + BNNSArithmeticDivide = 0x0000000000000003 + BNNSArithmeticSquareRoot = 0x0000000000000004 + BNNSArithmeticReciprocalSquareRoot = 0x0000000000000005 + BNNSArithmeticCeil = 0x0000000000000006 + BNNSArithmeticFloor = 0x0000000000000007 + BNNSArithmeticRound = 0x0000000000000008 + BNNSArithmeticSin = 0x0000000000000009 + BNNSArithmeticCos = 0x000000000000000a + BNNSArithmeticTan = 0x000000000000000b + BNNSArithmeticAsin = 0x000000000000000c + BNNSArithmeticAcos = 0x000000000000000d + BNNSArithmeticAtan = 0x000000000000000e + BNNSArithmeticSinh = 0x000000000000000f + BNNSArithmeticCosh = 0x0000000000000010 + BNNSArithmeticTanh = 0x0000000000000011 + BNNSArithmeticAsinh = 0x0000000000000012 + BNNSArithmeticAcosh = 0x0000000000000013 + BNNSArithmeticAtanh = 0x0000000000000014 + BNNSArithmeticPow = 0x0000000000000015 + BNNSArithmeticExp = 0x0000000000000016 + BNNSArithmeticExp2 = 0x0000000000000017 + BNNSArithmeticLog = 0x0000000000000018 + BNNSArithmeticLog2 = 0x0000000000000019 + BNNSArithmeticMultiplyNoNaN = 0x000000000000001a + BNNSArithmeticDivideNoNaN = 0x000000000000001b + BNNSArithmeticMultiplyAdd = 0x000000000000001c + BNNSArithmeticMinimum = 0x000000000000001d + BNNSArithmeticMaximum = 0x000000000000001e + BNNSArithmeticSelect = 0x000000000000001f + BNNSArithmeticAbs = 0x0000000000000020 + BNNSArithmeticSign = 0x0000000000000021 + BNNSArithmeticNegate = 0x0000000000000022 + BNNSArithmeticReciprocal = 0x0000000000000023 + BNNSArithmeticSquare = 0x0000000000000024 + BNNSArithmeticFloorDivide = 0x0000000000000025 + BNNSArithmeticTruncDivide = 0x0000000000000026 + BNNSArithmeticTruncRemainder = 0x0000000000000027 + BNNSArithmeticErf = 0x0000000000000028 +end + +@cenum BNNSDescriptorType::UInt32 begin + BNNSConstant = 0x0000000000000000 + BNNSSample = 0x0000000000000001 + BNNSParameter = 0x0000000000000002 +end + +@cenum BNNSOptimizerFunction::UInt32 begin + BNNSOptimizerFunctionSGDMomentum = 0x0000000000000001 + BNNSOptimizerFunctionAdam = 0x0000000000000002 + BNNSOptimizerFunctionRMSProp = 0x0000000000000003 + BNNSOptimizerFunctionAdamW = 0x0000000000000004 + BNNSOptimizerFunctionAdamAMSGrad = 0x0000000000000005 + BNNSOptimizerFunctionAdamWAMSGrad = 0x0000000000000006 + BNNSOptimizerFunctionSGDMomentumWithClipping = 0x0000000000000007 + BNNSOptimizerFunctionAdamWithClipping = 0x0000000000000008 + BNNSOptimizerFunctionRMSPropWithClipping = 0x0000000000000009 + BNNSOptimizerFunctionAdamWWithClipping = 0x000000000000000a + BNNSOptimizerFunctionAdamAMSGradWithClipping = 0x000000000000000b + BNNSOptimizerFunctionAdamWAMSGradWithClipping = 0x000000000000000c +end + +@cenum BNNSOptimizerRegularizationFunction::UInt32 begin + BNNSOptimizerRegularizationNone = 0x0000000000000000 + BNNSOptimizerRegularizationL1 = 0x0000000000000001 + BNNSOptimizerRegularizationL2 = 0x0000000000000002 +end + +@cenum BNNSOptimizerSGDMomentumVariant::UInt32 begin + BNNSSGDMomentumVariant0 = 0x0000000000000000 + BNNSSGDMomentumVariant1 = 0x0000000000000001 + BNNSSGDMomentumVariant2 = 0x0000000000000002 +end + +@cenum BNNSOptimizerClippingFunction::UInt32 begin + BNNSOptimizerClippingNone = 0x0000000000000000 + BNNSOptimizerClippingByValue = 0x0000000000000001 + BNNSOptimizerClippingByNorm = 0x0000000000000002 + BNNSOptimizerClippingByGlobalNorm = 0x0000000000000003 +end + +@cenum BNNSNormType::UInt32 begin + BNNSL2Norm = 0x0000000000000001 +end + +@cenum BNNSFilterType::UInt32 begin + BNNSConvolution = 0x0000000000000000 + BNNSFullyConnected = 0x0000000000000001 + BNNSBatchNorm = 0x0000000000000002 + BNNSInstanceNorm = 0x0000000000000003 + BNNSLayerNorm = 0x0000000000000004 + BNNSGroupNorm = 0x0000000000000005 + BNNSTransposedConvolution = 0x0000000000000006 + BNNSQuantization = 0x0000000000000007 + BNNSArithmetic = 0x0000000000000008 +end + +@cenum BNNSReduceFunction::UInt32 begin + BNNSReduceFunctionMax = 0x0000000000000000 + BNNSReduceFunctionMin = 0x0000000000000001 + BNNSReduceFunctionArgMax = 0x0000000000000002 + BNNSReduceFunctionArgMin = 0x0000000000000003 + BNNSReduceFunctionMean = 0x0000000000000004 + BNNSReduceFunctionMeanNonZero = 0x0000000000000005 + BNNSReduceFunctionSum = 0x0000000000000006 + BNNSReduceFunctionSumSquare = 0x0000000000000007 + BNNSReduceFunctionSumLog = 0x0000000000000008 + BNNSReduceFunctionL1Norm = 0x0000000000000009 + BNNSReduceFunctionLogicalOr = 0x000000000000000a + BNNSReduceFunctionLogicalAnd = 0x000000000000000b + BNNSReduceFunctionL2Norm = 0x000000000000000c + BNNSReduceFunctionLogSumExp = 0x000000000000000d + BNNSReduceFunctionProduct = 0x000000000000000e + BNNSReduceFunctionNone = 0x000000000000000f + BNNSReduceFunctionLogSum = 0x0000000000000010 + BNNSReduceFunctionAny = 0x000000000000000a + BNNSReduceFunctionAll = 0x000000000000000b +end + +@cenum BNNSLayerFlags::UInt32 begin + BNNSLayerFlagsLSTMBidirectional = 0x0000000000000001 + BNNSLayerFlagsLSTMDefaultActivations = 0x0000000000000002 +end + +@cenum BNNSDataLayout::UInt32 begin + BNNSDataLayoutVector = 0x0000000000010000 + BNNSDataLayout1DLastMajor = 0x0000000000018000 + BNNSDataLayout1DFirstMajor = 0x0000000000018001 + BNNSDataLayoutRowMajorMatrix = 0x0000000000020000 + BNNSDataLayoutColumnMajorMatrix = 0x0000000000020001 + BNNSDataLayout2DLastMajor = 0x0000000000028000 + BNNSDataLayout2DFirstMajor = 0x0000000000028001 + BNNSDataLayoutFullyConnectedSparse = 0x0000000000021001 + BNNSDataLayoutImageCHW = 0x0000000000030000 + BNNSDataLayoutSNE = 0x0000000000030001 + BNNSDataLayoutNSE = 0x0000000000030002 + BNNSDataLayoutMHA_DHK = 0x0000000000030003 + BNNSDataLayout3DLastMajor = 0x0000000000038000 + BNNSDataLayout3DFirstMajor = 0x0000000000038001 + BNNSDataLayoutConvolutionWeightsOIHW = 0x0000000000040000 + BNNSDataLayoutConvolutionWeightsOIHrWr = 0x0000000000040001 + BNNSDataLayoutConvolutionWeightsIOHrWr = 0x0000000000040002 + BNNSDataLayoutConvolutionWeightsOIHW_Pack32 = 0x0000000000040010 + BNNSDataLayout4DLastMajor = 0x0000000000048000 + BNNSDataLayout4DFirstMajor = 0x0000000000048001 + BNNSDataLayout5DLastMajor = 0x0000000000058000 + BNNSDataLayout5DFirstMajor = 0x0000000000058001 + BNNSDataLayout6DLastMajor = 0x0000000000068000 + BNNSDataLayout6DFirstMajor = 0x0000000000068001 + BNNSDataLayout7DLastMajor = 0x0000000000078000 + BNNSDataLayout7DFirstMajor = 0x0000000000078001 + BNNSDataLayout8DLastMajor = 0x0000000000088000 + BNNSDataLayout8DFirstMajor = 0x0000000000088001 +end + +@cenum BNNSInterpolationMethod::UInt32 begin + BNNSInterpolationMethodNearest = 0x0000000000000000 + BNNSInterpolationMethodLinear = 0x0000000000000001 +end + +@cenum BNNSLinearSamplingMode::UInt32 begin + BNNSLinearSamplingDefault = 0x0000000000000000 + BNNSLinearSamplingAlignCorners = 0x0000000000000001 + BNNSLinearSamplingUnalignCorners = 0x0000000000000002 + BNNSLinearSamplingStrictAlignCorners = 0x0000000000000003 + BNNSLinearSamplingOffsetCorners = 0x0000000000000004 +end + +@cenum BNNSBoxCoordinateMode::UInt32 begin + BNNSCornersHeightFirst = 0x0000000000000000 + BNNSCornersWidthFirst = 0x0000000000000001 + BNNSCenterSizeHeightFirst = 0x0000000000000002 + BNNSCenterSizeWidthFirst = 0x0000000000000003 +end + +@cenum BNNSPaddingMode::UInt32 begin + BNNSPaddingModeConstant = 0x0000000000000000 + BNNSPaddingModeReflect = 0x0000000000000001 + BNNSPaddingModeSymmetric = 0x0000000000000002 +end + +@cenum BNNSRelationalOperator::UInt32 begin + BNNSRelationalOperatorEqual = 0x0000000000000000 + BNNSRelationalOperatorLess = 0x0000000000000001 + BNNSRelationalOperatorLessEqual = 0x0000000000000002 + BNNSRelationalOperatorGreater = 0x0000000000000003 + BNNSRelationalOperatorGreaterEqual = 0x0000000000000004 + BNNSRelationalOperatorNotEqual = 0x0000000000000005 + BNNSRelationalOperatorLogicalAND = 0x0000000000000006 + BNNSRelationalOperatorLogicalOR = 0x0000000000000007 + BNNSRelationalOperatorLogicalNOT = 0x0000000000000008 + BNNSRelationalOperatorLogicalNAND = 0x0000000000000009 + BNNSRelationalOperatorLogicalNOR = 0x000000000000000a + BNNSRelationalOperatorLogicalXOR = 0x000000000000000b +end + +@cenum BNNSPointerSpecifier::UInt32 begin + BNNSPointerSpecifierAlpha = 0x0000000000000000 + BNNSPointerSpecifierBeta = 0x0000000000000001 +end + +@cenum BNNSNDArrayFlags::UInt32 begin + BNNSNDArrayFlagBackpropSet = 0x0000000000000000 + BNNSNDArrayFlagBackpropAccumulate = 0x0000000000000001 +end + +@cenum BNNSEmbeddingFlags::UInt32 begin + BNNSEmbeddingFlagScaleGradientByFrequency = 0x0000000000000001 +end + +@cenum BNNSQuantizerFunction::UInt32 begin + BNNSQuantizerFunctionQuantize = 0x0000000000000000 + BNNSQuantizerFunctionDequantize = 0x0000000000000001 +end + +@cenum BNNSRandomGeneratorMethod::UInt32 begin + BNNSRandomGeneratorMethodAES_CTR = 0x0000000000000000 +end + +@cenum BNNSSparsityType::UInt32 begin + BNNSSparsityTypeUnstructured = 0x0000000000000000 +end + +@cenum BNNSTargetSystem::UInt32 begin + BNNSTargetSystemGeneric = 0x0000000000000000 +end + +@cenum BNNSShuffleType::UInt32 begin + BNNSShuffleTypePixelShuffleNCHW = 0x0000000000000000 + BNNSShuffleTypePixelUnshuffleNCHW = 0x0000000000000001 + BNNSShuffleTypeDepthToSpaceNCHW = 0x0000000000000002 + BNNSShuffleTypeSpaceToDepthNCHW = 0x0000000000000003 +end + +# typedef int ( * BNNSAlloc ) ( void * _Nullable * _Nullable memptr , size_t alignment , size_t size ) +const BNNSAlloc = Ptr{Cvoid} + +# typedef void ( * BNNSFree ) ( void * _Null_unspecified ptr ) +const BNNSFree = Ptr{Cvoid} + +struct BNNSActivation + _function::BNNSActivationFunction + alpha::Cfloat + beta::Cfloat + iscale::Int32 + ioffset::Int32 + ishift::Int32 + iscale_per_channel::Ptr{Int32} + ioffset_per_channel::Ptr{Int32} + ishift_per_channel::Ptr{Int32} +end + +struct BNNSNDArrayDescriptor + flags::BNNSNDArrayFlags + layout::BNNSDataLayout + size::NTuple{8,Csize_t} + stride::NTuple{8,Csize_t} + data::Ptr{Cvoid} + data_type::BNNSDataType + table_data::Ptr{Cvoid} + table_data_type::BNNSDataType + data_scale::Cfloat + data_bias::Cfloat +end + +struct BNNSTensor + data_type::BNNSDataType + rank::UInt8 + shape::NTuple{8,Cssize_t} + stride::NTuple{8,Cssize_t} + data::Ptr{Cvoid} + data_size_in_bytes::Csize_t + name::Cstring +end + +struct BNNSLSTMGateDescriptor + iw_desc::NTuple{2,BNNSNDArrayDescriptor} + hw_desc::BNNSNDArrayDescriptor + cw_desc::BNNSNDArrayDescriptor + b_desc::BNNSNDArrayDescriptor + activation::BNNSActivation +end + +struct BNNSLSTMDataDescriptor + data_desc::BNNSNDArrayDescriptor + hidden_desc::BNNSNDArrayDescriptor + cell_state_desc::BNNSNDArrayDescriptor +end + +struct BNNSArithmeticUnary + in::BNNSNDArrayDescriptor + in_type::BNNSDescriptorType + out::BNNSNDArrayDescriptor + out_type::BNNSDescriptorType +end + +struct BNNSArithmeticBinary + in1::BNNSNDArrayDescriptor + in1_type::BNNSDescriptorType + in2::BNNSNDArrayDescriptor + in2_type::BNNSDescriptorType + out::BNNSNDArrayDescriptor + out_type::BNNSDescriptorType +end + +struct BNNSArithmeticTernary + in1::BNNSNDArrayDescriptor + in1_type::BNNSDescriptorType + in2::BNNSNDArrayDescriptor + in2_type::BNNSDescriptorType + in3::BNNSNDArrayDescriptor + in3_type::BNNSDescriptorType + out::BNNSNDArrayDescriptor + out_type::BNNSDescriptorType +end + +struct BNNSMHAProjectionParameters + target_desc::BNNSNDArrayDescriptor + weights::BNNSNDArrayDescriptor + bias::BNNSNDArrayDescriptor +end + +struct BNNSLayerParametersConvolution + i_desc::BNNSNDArrayDescriptor + w_desc::BNNSNDArrayDescriptor + o_desc::BNNSNDArrayDescriptor + bias::BNNSNDArrayDescriptor + activation::BNNSActivation + x_stride::Csize_t + y_stride::Csize_t + x_dilation_stride::Csize_t + y_dilation_stride::Csize_t + x_padding::Csize_t + y_padding::Csize_t + groups::Csize_t + pad::NTuple{4,Csize_t} +end + +struct BNNSLayerParametersFullyConnected + i_desc::BNNSNDArrayDescriptor + w_desc::BNNSNDArrayDescriptor + o_desc::BNNSNDArrayDescriptor + bias::BNNSNDArrayDescriptor + activation::BNNSActivation +end + +struct BNNSLayerParametersPooling + i_desc::BNNSNDArrayDescriptor + o_desc::BNNSNDArrayDescriptor + bias::BNNSNDArrayDescriptor + activation::BNNSActivation + pooling_function::BNNSPoolingFunction + k_width::Csize_t + k_height::Csize_t + x_stride::Csize_t + y_stride::Csize_t + x_dilation_stride::Csize_t + y_dilation_stride::Csize_t + x_padding::Csize_t + y_padding::Csize_t + pad::NTuple{4,Csize_t} +end + +struct BNNSLayerParametersActivation + i_desc::BNNSNDArrayDescriptor + o_desc::BNNSNDArrayDescriptor + activation::BNNSActivation + axis_flags::UInt32 +end + +struct BNNSLayerParametersLossBase + _function::BNNSLossFunction + i_desc::BNNSNDArrayDescriptor + o_desc::BNNSNDArrayDescriptor + reduction::BNNSLossReductionFunction +end + +struct BNNSLayerParametersLossSoftmaxCrossEntropy + _function::BNNSLossFunction + i_desc::BNNSNDArrayDescriptor + o_desc::BNNSNDArrayDescriptor + reduction::BNNSLossReductionFunction + label_smooth::Cfloat +end + +struct BNNSLayerParametersLossSigmoidCrossEntropy + _function::BNNSLossFunction + i_desc::BNNSNDArrayDescriptor + o_desc::BNNSNDArrayDescriptor + reduction::BNNSLossReductionFunction + label_smooth::Cfloat +end + +struct BNNSLayerParametersLossHuber + _function::BNNSLossFunction + i_desc::BNNSNDArrayDescriptor + o_desc::BNNSNDArrayDescriptor + reduction::BNNSLossReductionFunction + huber_delta::Cfloat +end + +struct BNNSLayerParametersLossYolo + _function::BNNSLossFunction + i_desc::BNNSNDArrayDescriptor + o_desc::BNNSNDArrayDescriptor + reduction::BNNSLossReductionFunction + huber_delta::Cfloat + number_of_grid_columns::Csize_t + number_of_grid_rows::Csize_t + number_of_anchor_boxes::Csize_t + anchor_box_size::Csize_t + rescore::Bool + scale_xy::Cfloat + scale_wh::Cfloat + scale_object::Cfloat + scale_no_object::Cfloat + scale_classification::Cfloat + object_minimum_iou::Cfloat + no_object_maximum_iou::Cfloat + anchors_data::Ptr{Cfloat} +end + +struct BNNSOptimizerSGDMomentumFields + learning_rate::Cfloat + momentum::Cfloat + gradient_scale::Cfloat + regularization_scale::Cfloat + clip_gradients::Bool + clip_gradients_min::Cfloat + clip_gradients_max::Cfloat + nesterov::Bool + regularization_func::BNNSOptimizerRegularizationFunction + sgd_momentum_variant::BNNSOptimizerSGDMomentumVariant +end + +struct BNNSOptimizerSGDMomentumWithClippingFields + learning_rate::Cfloat + momentum::Cfloat + gradient_scale::Cfloat + regularization_scale::Cfloat + nesterov::Bool + regularization_func::BNNSOptimizerRegularizationFunction + sgd_momentum_variant::BNNSOptimizerSGDMomentumVariant + clipping_func::BNNSOptimizerClippingFunction + clip_gradients_min::Cfloat + clip_gradients_max::Cfloat + clip_gradients_max_norm::Cfloat + clip_gradients_use_norm::Cfloat +end + +struct BNNSOptimizerAdamFields + learning_rate::Cfloat + beta1::Cfloat + beta2::Cfloat + time_step::Cfloat + epsilon::Cfloat + gradient_scale::Cfloat + regularization_scale::Cfloat + clip_gradients::Bool + clip_gradients_min::Cfloat + clip_gradients_max::Cfloat + regularization_func::BNNSOptimizerRegularizationFunction +end + +struct BNNSOptimizerAdamWithClippingFields + learning_rate::Cfloat + beta1::Cfloat + beta2::Cfloat + time_step::Cfloat + epsilon::Cfloat + gradient_scale::Cfloat + regularization_scale::Cfloat + regularization_func::BNNSOptimizerRegularizationFunction + clipping_func::BNNSOptimizerClippingFunction + clip_gradients_min::Cfloat + clip_gradients_max::Cfloat + clip_gradients_max_norm::Cfloat + clip_gradients_use_norm::Cfloat +end + +struct BNNSOptimizerRMSPropFields + learning_rate::Cfloat + alpha::Cfloat + epsilon::Cfloat + centered::Bool + momentum::Cfloat + gradient_scale::Cfloat + regularization_scale::Cfloat + clip_gradients::Bool + clip_gradients_min::Cfloat + clip_gradients_max::Cfloat + regularization_func::BNNSOptimizerRegularizationFunction +end + +struct BNNSOptimizerRMSPropWithClippingFields + learning_rate::Cfloat + alpha::Cfloat + epsilon::Cfloat + centered::Bool + momentum::Cfloat + gradient_scale::Cfloat + regularization_scale::Cfloat + regularization_func::BNNSOptimizerRegularizationFunction + clipping_func::BNNSOptimizerClippingFunction + clip_gradients_min::Cfloat + clip_gradients_max::Cfloat + clip_gradients_max_norm::Cfloat + clip_gradients_use_norm::Cfloat +end + +struct BNNSLayerParametersNormalization + i_desc::BNNSNDArrayDescriptor + o_desc::BNNSNDArrayDescriptor + beta_desc::BNNSNDArrayDescriptor + gamma_desc::BNNSNDArrayDescriptor + moving_mean_desc::BNNSNDArrayDescriptor + moving_variance_desc::BNNSNDArrayDescriptor + momentum::Cfloat + epsilon::Cfloat + activation::BNNSActivation + num_groups::Csize_t + normalization_axis::Csize_t +end + +struct BNNSLayerParametersDropout + i_desc::BNNSNDArrayDescriptor + o_desc::BNNSNDArrayDescriptor + rate::Cfloat + seed::UInt32 + control::UInt8 +end + +struct BNNSLayerParametersLSTM + data::NTuple{5040,UInt8} +end + +function Base.getproperty(x::Ptr{BNNSLayerParametersLSTM}, f::Symbol) + f === :input_size && return Ptr{Csize_t}(x + 0) + f === :hidden_size && return Ptr{Csize_t}(x + 8) + f === :batch_size && return Ptr{Csize_t}(x + 16) + f === :num_layers && return Ptr{Csize_t}(x + 24) + f === :seq_len && return Ptr{Csize_t}(x + 32) + f === :dropout && return Ptr{Cfloat}(x + 40) + f === :lstm_flags && return Ptr{UInt32}(x + 44) + f === :sequence_descriptor && return Ptr{BNNSNDArrayDescriptor}(x + 48) + f === :input_descriptor && return Ptr{BNNSLSTMDataDescriptor}(x + 224) + f === :output_descriptor && return Ptr{BNNSLSTMDataDescriptor}(x + 752) + f === :input_gate && return Ptr{BNNSLSTMGateDescriptor}(x + 1280) + f === :forget_gate && return Ptr{BNNSLSTMGateDescriptor}(x + 2208) + f === :candidate_gate && return Ptr{BNNSLSTMGateDescriptor}(x + 3136) + f === :output_gate && return Ptr{BNNSLSTMGateDescriptor}(x + 4064) + f === :hidden_activation && return Ptr{BNNSActivation}(x + 4992) + return getfield(x, f) +end + +function Base.getproperty(x::BNNSLayerParametersLSTM, f::Symbol) + r = Ref{BNNSLayerParametersLSTM}(x) + ptr = Base.unsafe_convert(Ptr{BNNSLayerParametersLSTM}, r) + fptr = getproperty(ptr, f) + GC.@preserve r unsafe_load(fptr) +end + +function Base.setproperty!(x::Ptr{BNNSLayerParametersLSTM}, f::Symbol, v) + return unsafe_store!(getproperty(x, f), v) +end + +struct BNNSLayerParametersArithmetic + arithmetic_function::BNNSArithmeticFunction + arithmetic_function_fields::Ptr{Cvoid} + activation::BNNSActivation +end + +struct BNNSLayerParametersPermute + i_desc::BNNSNDArrayDescriptor + o_desc::BNNSNDArrayDescriptor + permutation::NTuple{8,Csize_t} +end + +struct BNNSLayerParametersTensorContraction + operation::Cstring + alpha::Cfloat + beta::Cfloat + iA_desc::BNNSNDArrayDescriptor + iB_desc::BNNSNDArrayDescriptor + o_desc::BNNSNDArrayDescriptor +end + +struct BNNSLayerParametersGram + alpha::Cfloat + i_desc::BNNSNDArrayDescriptor + o_desc::BNNSNDArrayDescriptor +end + +struct BNNSLayerParametersResize + method::BNNSInterpolationMethod + i_desc::BNNSNDArrayDescriptor + o_desc::BNNSNDArrayDescriptor + align_corners::Bool +end + +struct BNNSLayerParametersCropResize + normalized_coordinates::Bool + spatial_scale::Cfloat + extrapolation_value::Cfloat + sampling_mode::BNNSLinearSamplingMode + box_coordinate_mode::BNNSBoxCoordinateMode + method::BNNSInterpolationMethod +end + +struct BNNSLayerParametersBroadcastMatMul + alpha::Cfloat + beta::Cfloat + transA::Bool + transB::Bool + quadratic::Bool + a_is_weights::Bool + b_is_weights::Bool + iA_desc::BNNSNDArrayDescriptor + iB_desc::BNNSNDArrayDescriptor + o_desc::BNNSNDArrayDescriptor +end + +struct BNNSLayerParametersMultiheadAttention + query::BNNSMHAProjectionParameters + key::BNNSMHAProjectionParameters + value::BNNSMHAProjectionParameters + add_zero_attn::Bool + key_attn_bias::BNNSNDArrayDescriptor + value_attn_bias::BNNSNDArrayDescriptor + output::BNNSMHAProjectionParameters + dropout::Cfloat + seed::UInt32 +end + +struct BNNSLayerParametersReduction + i_desc::BNNSNDArrayDescriptor + o_desc::BNNSNDArrayDescriptor + w_desc::BNNSNDArrayDescriptor + reduce_func::BNNSReduceFunction + epsilon::Cfloat +end + +struct BNNSLayerParametersPadding + i_desc::BNNSNDArrayDescriptor + o_desc::BNNSNDArrayDescriptor + padding_size::NTuple{8,NTuple{2,Csize_t}} + padding_mode::BNNSPaddingMode + padding_value::UInt32 +end + +struct BNNSLayerParametersEmbedding + flags::BNNSEmbeddingFlags + i_desc::BNNSNDArrayDescriptor + o_desc::BNNSNDArrayDescriptor + dictionary::BNNSNDArrayDescriptor + padding_idx::Csize_t + max_norm::Cfloat + norm_type::Cfloat +end + +struct BNNSLayerParametersQuantization + axis_mask::Csize_t + _function::BNNSQuantizerFunction + i_desc::BNNSNDArrayDescriptor + o_desc::BNNSNDArrayDescriptor + scale::BNNSNDArrayDescriptor + bias::BNNSNDArrayDescriptor +end + +struct BNNSSparsityParameters + flags::UInt64 + sparsity_ratio::NTuple{2,UInt32} + sparsity_type::BNNSSparsityType + target_system::BNNSTargetSystem +end + +struct BNNSImageStackDescriptor + width::Csize_t + height::Csize_t + channels::Csize_t + row_stride::Csize_t + image_stride::Csize_t + data_type::BNNSDataType + data_scale::Cfloat + data_bias::Cfloat +end + +struct BNNSVectorDescriptor + size::Csize_t + data_type::BNNSDataType + data_scale::Cfloat + data_bias::Cfloat +end + +struct BNNSLayerData + data::Ptr{Cvoid} + data_type::BNNSDataType + data_scale::Cfloat + data_bias::Cfloat + data_table::Ptr{Cfloat} +end + +struct BNNSConvolutionLayerParameters + data::NTuple{176,UInt8} +end + +function Base.getproperty(x::Ptr{BNNSConvolutionLayerParameters}, f::Symbol) + f === :x_stride && return Ptr{Csize_t}(x + 0) + f === :y_stride && return Ptr{Csize_t}(x + 8) + f === :x_padding && return Ptr{Csize_t}(x + 16) + f === :y_padding && return Ptr{Csize_t}(x + 24) + f === :k_width && return Ptr{Csize_t}(x + 32) + f === :k_height && return Ptr{Csize_t}(x + 40) + f === :in_channels && return Ptr{Csize_t}(x + 48) + f === :out_channels && return Ptr{Csize_t}(x + 56) + f === :weights && return Ptr{BNNSLayerData}(x + 64) + f === :bias && return Ptr{BNNSLayerData}(x + 96) + f === :activation && return Ptr{BNNSActivation}(x + 128) + return getfield(x, f) +end + +function Base.getproperty(x::BNNSConvolutionLayerParameters, f::Symbol) + r = Ref{BNNSConvolutionLayerParameters}(x) + ptr = Base.unsafe_convert(Ptr{BNNSConvolutionLayerParameters}, r) + fptr = getproperty(ptr, f) + GC.@preserve r unsafe_load(fptr) +end + +function Base.setproperty!(x::Ptr{BNNSConvolutionLayerParameters}, f::Symbol, v) + return unsafe_store!(getproperty(x, f), v) +end + +struct BNNSFullyConnectedLayerParameters + data::NTuple{128,UInt8} +end + +function Base.getproperty(x::Ptr{BNNSFullyConnectedLayerParameters}, f::Symbol) + f === :in_size && return Ptr{Csize_t}(x + 0) + f === :out_size && return Ptr{Csize_t}(x + 8) + f === :weights && return Ptr{BNNSLayerData}(x + 16) + f === :bias && return Ptr{BNNSLayerData}(x + 48) + f === :activation && return Ptr{BNNSActivation}(x + 80) + return getfield(x, f) +end + +function Base.getproperty(x::BNNSFullyConnectedLayerParameters, f::Symbol) + r = Ref{BNNSFullyConnectedLayerParameters}(x) + ptr = Base.unsafe_convert(Ptr{BNNSFullyConnectedLayerParameters}, r) + fptr = getproperty(ptr, f) + GC.@preserve r unsafe_load(fptr) +end + +function Base.setproperty!(x::Ptr{BNNSFullyConnectedLayerParameters}, f::Symbol, v) + return unsafe_store!(getproperty(x, f), v) +end + +struct BNNSPoolingLayerParameters + data::NTuple{152,UInt8} +end + +function Base.getproperty(x::Ptr{BNNSPoolingLayerParameters}, f::Symbol) + f === :x_stride && return Ptr{Csize_t}(x + 0) + f === :y_stride && return Ptr{Csize_t}(x + 8) + f === :x_padding && return Ptr{Csize_t}(x + 16) + f === :y_padding && return Ptr{Csize_t}(x + 24) + f === :k_width && return Ptr{Csize_t}(x + 32) + f === :k_height && return Ptr{Csize_t}(x + 40) + f === :in_channels && return Ptr{Csize_t}(x + 48) + f === :out_channels && return Ptr{Csize_t}(x + 56) + f === :pooling_function && return Ptr{BNNSPoolingFunction}(x + 64) + f === :bias && return Ptr{BNNSLayerData}(x + 72) + f === :activation && return Ptr{BNNSActivation}(x + 104) + return getfield(x, f) +end + +function Base.getproperty(x::BNNSPoolingLayerParameters, f::Symbol) + r = Ref{BNNSPoolingLayerParameters}(x) + ptr = Base.unsafe_convert(Ptr{BNNSPoolingLayerParameters}, r) + fptr = getproperty(ptr, f) + GC.@preserve r unsafe_load(fptr) +end + +function Base.setproperty!(x::Ptr{BNNSPoolingLayerParameters}, f::Symbol, v) + return unsafe_store!(getproperty(x, f), v) +end + +struct BNNSFilterParameters + flags::UInt32 + n_threads::Csize_t + alloc_memory::BNNSAlloc + free_memory::BNNSFree +end + +struct bnns_graph_t + data::Ptr{Cvoid} + size::Csize_t +end + +struct bnns_graph_context_t + data::Ptr{Cvoid} + size::Csize_t +end + +struct bnns_graph_compile_options_t + data::Ptr{Cvoid} + size::Csize_t +end + +struct bnns_graph_shape_t + rank::Csize_t + shape::Ptr{UInt64} +end + +# typedef int ( * bnns_graph_realloc_fn_t ) ( void * _Nullable user_memory_context , size_t user_memory_context_size , void * _Nullable * _Nonnull memptr , size_t alignment , size_t size ) +const bnns_graph_realloc_fn_t = Ptr{Cvoid} + +# typedef void ( * bnns_graph_free_all_fn_t ) ( void * _Nullable user_memory_context , size_t user_memory_context_size ) +const bnns_graph_free_all_fn_t = Ptr{Cvoid} + +@cenum BNNSGraphMessageLevel::UInt32 begin + BNNSGraphMessageLevelInfo = 0x0000000000000001 + BNNSGraphMessageLevelUnsupported = 0x0000000000000002 + BNNSGraphMessageLevelWarning = 0x0000000000000004 + BNNSGraphMessageLevelError = 0x0000000000000008 +end + +struct bnns_user_message_data_t + size::Csize_t + data::Ptr{Cvoid} +end + +# typedef void ( * bnns_graph_execute_message_fn_t ) ( BNNSGraphMessageLevel msg_level , char const * _Nonnull error_msg , char const * _Nullable op_info , bnns_user_message_data_t * _Nullable additional_logging_arguments ) +const bnns_graph_execute_message_fn_t = Ptr{Cvoid} + +# typedef void ( * bnns_graph_compile_message_fn_t ) ( BNNSGraphMessageLevel msg_level , char const * _Nonnull error_msg , char const * _Nullable source_location , bnns_user_message_data_t * _Nullable additional_logging_arguments ) +const bnns_graph_compile_message_fn_t = Ptr{Cvoid} + +function BNNSGraphCompileOptionsMakeDefault() + return ccall((:BNNSGraphCompileOptionsMakeDefault, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + bnns_graph_compile_options_t, ()) +end + +function BNNSGraphCompileOptionsDestroy(options) + return ccall((:BNNSGraphCompileOptionsDestroy, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Cvoid, (bnns_graph_compile_options_t,), options) +end + +function BNNSGraphCompileOptionsSetTargetSingleThread(options, value) + return ccall((:BNNSGraphCompileOptionsSetTargetSingleThread, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Cvoid, (bnns_graph_compile_options_t, Bool), options, value) +end + +function BNNSGraphCompileOptionsGetTargetSingleThread(options) + return ccall((:BNNSGraphCompileOptionsGetTargetSingleThread, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Bool, (bnns_graph_compile_options_t,), options) +end + +function BNNSGraphCompileOptionsSetGenerateDebugInfo(options, value) + return ccall((:BNNSGraphCompileOptionsSetGenerateDebugInfo, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Cvoid, (bnns_graph_compile_options_t, Bool), options, value) +end + +function BNNSGraphCompileOptionsGetGenerateDebugInfo(options) + return ccall((:BNNSGraphCompileOptionsGetGenerateDebugInfo, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Bool, (bnns_graph_compile_options_t,), options) +end + +@cenum BNNSGraphOptimizationPreference::UInt32 begin + BNNSGraphOptimizationPreferencePerformance = 0x0000000000000000 + BNNSGraphOptimizationPreferenceIRSize = 0x0000000000000001 +end + +function BNNSGraphCompileOptionsSetOptimizationPreference(options, preference) + return ccall((:BNNSGraphCompileOptionsSetOptimizationPreference, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Cvoid, (bnns_graph_compile_options_t, BNNSGraphOptimizationPreference), + options, preference) +end + +function BNNSGraphCompileOptionsGetOptimizationPreference(options) + return ccall((:BNNSGraphCompileOptionsGetOptimizationPreference, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + BNNSGraphOptimizationPreference, (bnns_graph_compile_options_t,), options) +end + +function BNNSGraphCompileOptionsSetMessageLogCallback(options, log_callback, + additional_logging_arguments) + return ccall((:BNNSGraphCompileOptionsSetMessageLogCallback, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Cvoid, + (bnns_graph_compile_options_t, bnns_graph_compile_message_fn_t, + Ptr{bnns_user_message_data_t}), options, log_callback, + additional_logging_arguments) +end + +function BNNSGraphCompileOptionsSetMessageLogMask(options, log_level_mask) + return ccall((:BNNSGraphCompileOptionsSetMessageLogMask, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Cvoid, (bnns_graph_compile_options_t, UInt32), options, log_level_mask) +end + +function BNNSGraphCompileOptionsSetOutputPath(options, path) + return ccall((:BNNSGraphCompileOptionsSetOutputPath, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Cvoid, (bnns_graph_compile_options_t, Cstring), options, path) +end + +function BNNSGraphCompileOptionsGetOutputPath(options) + return ccall((:BNNSGraphCompileOptionsGetOutputPath, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Cstring, (bnns_graph_compile_options_t,), options) +end + +function BNNSGraphCompileOptionsSetOutputFD(options, fd) + return ccall((:BNNSGraphCompileOptionsSetOutputFD, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Cvoid, (bnns_graph_compile_options_t, Cint), options, fd) +end + +function BNNSGraphCompileOptionsGetOutputFD(options) + return ccall((:BNNSGraphCompileOptionsGetOutputFD, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Cint, (bnns_graph_compile_options_t,), options) +end + +function BNNSGraphCompileFromFile(filename, _function, options) + return ccall((:BNNSGraphCompileFromFile, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + bnns_graph_t, (Cstring, Cstring, bnns_graph_compile_options_t), filename, + _function, options) +end + +function BNNSGraphGetInputCount(graph, _function) + return ccall((:BNNSGraphGetInputCount, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Csize_t, (bnns_graph_t, Cstring), graph, _function) +end + +function BNNSGraphGetOutputCount(graph, _function) + return ccall((:BNNSGraphGetOutputCount, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Csize_t, (bnns_graph_t, Cstring), graph, _function) +end + +function BNNSGraphGetArgumentCount(graph, _function) + return ccall((:BNNSGraphGetArgumentCount, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Csize_t, (bnns_graph_t, Cstring), graph, _function) +end + +function BNNSGraphGetFunctionCount(graph) + return ccall((:BNNSGraphGetFunctionCount, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Csize_t, (bnns_graph_t,), graph) +end + +function BNNSGraphGetInputNames(graph, _function, input_names_count, input_names) + return ccall((:BNNSGraphGetInputNames, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Cint, (bnns_graph_t, Cstring, Csize_t, Ptr{Cstring}), graph, _function, + input_names_count, input_names) +end + +function BNNSGraphGetOutputNames(graph, _function, output_names_count, output_names) + return ccall((:BNNSGraphGetOutputNames, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Cint, (bnns_graph_t, Cstring, Csize_t, Ptr{Cstring}), graph, _function, + output_names_count, output_names) +end + +function BNNSGraphGetArgumentNames(graph, _function, argument_names_count, argument_names) + return ccall((:BNNSGraphGetArgumentNames, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Cint, (bnns_graph_t, Cstring, Csize_t, Ptr{Cstring}), graph, _function, + argument_names_count, argument_names) +end + +function BNNSGraphGetFunctionNames(graph, function_name_count, function_names) + return ccall((:BNNSGraphGetFunctionNames, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Cint, (bnns_graph_t, Csize_t, Ptr{Cstring}), graph, function_name_count, + function_names) +end + +@cenum BNNSGraphArgumentIntent::UInt32 begin + BNNSGraphArgumentIntentIn = 0x0000000000000001 + BNNSGraphArgumentIntentOut = 0x0000000000000002 + BNNSGraphArgumentIntentInOut = 0x0000000000000003 +end + +function BNNSGraphGetArgumentIntents(graph, _function, argument_intents_count, + argument_intents) + return ccall((:BNNSGraphGetArgumentIntents, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Cint, (bnns_graph_t, Cstring, Csize_t, Ptr{BNNSGraphArgumentIntent}), + graph, _function, argument_intents_count, argument_intents) +end + +function BNNSGraphGetArgumentPosition(graph, _function, argument) + return ccall((:BNNSGraphGetArgumentPosition, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Csize_t, (bnns_graph_t, Cstring, Cstring), graph, _function, argument) +end + +function BNNSGraphGetArgumentInterleaveFactors(graph, _function, argument_count, + argument_interleave, + argument_interleave_counts) + return ccall((:BNNSGraphGetArgumentInterleaveFactors, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Cint, (bnns_graph_t, Cstring, Csize_t, Ptr{Ptr{UInt16}}, Ptr{Csize_t}), + graph, _function, argument_count, argument_interleave, + argument_interleave_counts) +end + +@cenum BNNSGraphArgumentType::UInt32 begin + BNNSGraphArgumentTypePointer = 0x0000000000000000 + BNNSGraphArgumentTypeTensor = 0x0000000000000002 +end + +struct bnns_graph_argument_t + data::NTuple{16,UInt8} +end + +function Base.getproperty(x::Ptr{bnns_graph_argument_t}, f::Symbol) + f === :tensor && return Ptr{Ptr{BNNSTensor}}(x + 0) + f === :descriptor && return Ptr{Ptr{BNNSNDArrayDescriptor}}(x + 0) + f === :data_ptr && return Ptr{Ptr{Cvoid}}(x + 0) + f === :data_ptr_size && return Ptr{Csize_t}(x + 8) + return getfield(x, f) +end + +function Base.getproperty(x::bnns_graph_argument_t, f::Symbol) + r = Ref{bnns_graph_argument_t}(x) + ptr = Base.unsafe_convert(Ptr{bnns_graph_argument_t}, r) + fptr = getproperty(ptr, f) + GC.@preserve r unsafe_load(fptr) +end + +function Base.setproperty!(x::Ptr{bnns_graph_argument_t}, f::Symbol, v) + return unsafe_store!(getproperty(x, f), v) +end + +function BNNSGraphContextMake(graph) + return ccall((:BNNSGraphContextMake, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + bnns_graph_context_t, (bnns_graph_t,), graph) +end + +function BNNSGraphContextMakeStreaming(graph, _function, initial_states_count, + initial_states) + return ccall((:BNNSGraphContextMakeStreaming, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + bnns_graph_context_t, (bnns_graph_t, Cstring, Csize_t, Ptr{BNNSTensor}), + graph, _function, initial_states_count, initial_states) +end + +function BNNSGraphContextDestroy(context) + return ccall((:BNNSGraphContextDestroy, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Cvoid, (bnns_graph_context_t,), context) +end + +function BNNSGraphContextSetDynamicShapes(context, _function, shapes_count, shapes) + return ccall((:BNNSGraphContextSetDynamicShapes, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Cint, (bnns_graph_context_t, Cstring, Csize_t, Ptr{bnns_graph_shape_t}), + context, _function, shapes_count, shapes) +end + +function BNNSGraphContextSetBatchSize(context, _function, batch_size) + return ccall((:BNNSGraphContextSetBatchSize, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Cint, (bnns_graph_context_t, Cstring, UInt64), context, _function, + batch_size) +end + +function BNNSGraphContextSetArgumentType(context, argument_type) + return ccall((:BNNSGraphContextSetArgumentType, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Cint, (bnns_graph_context_t, BNNSGraphArgumentType), context, + argument_type) +end + +function BNNSGraphContextEnableNanAndInfChecks(context, enable_check_for_nans_inf) + return ccall((:BNNSGraphContextEnableNanAndInfChecks, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Cvoid, (bnns_graph_context_t, Bool), context, enable_check_for_nans_inf) +end + +function BNNSGraphContextExecute(context, _function, argument_count, arguments, + workspace_size, workspace) + return ccall((:BNNSGraphContextExecute, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Cint, + (bnns_graph_context_t, Cstring, Csize_t, Ptr{bnns_graph_argument_t}, + Csize_t, Cstring), context, _function, argument_count, arguments, + workspace_size, workspace) +end + +function BNNSGraphContextGetWorkspaceSize(context, _function) + return ccall((:BNNSGraphContextGetWorkspaceSize, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Csize_t, (bnns_graph_context_t, Cstring), context, _function) +end + +function BNNSGraphContextGetTensor(context, _function, argument, fill_known_dynamic_shapes, + tensor) + return ccall((:BNNSGraphContextGetTensor, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Cint, (bnns_graph_context_t, Cstring, Cstring, Bool, Ptr{BNNSTensor}), + context, _function, argument, fill_known_dynamic_shapes, tensor) +end + +function BNNSGraphTensorFillStrides(graph, _function, argument, tensor) + return ccall((:BNNSGraphTensorFillStrides, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Cint, (bnns_graph_t, Cstring, Cstring, Ptr{BNNSTensor}), graph, _function, + argument, tensor) +end + +function BNNSGraphContextSetWorkspaceAllocationCallback(context, realloc, free, + user_memory_context_size, + user_memory_context) + return ccall((:BNNSGraphContextSetWorkspaceAllocationCallback, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Cint, + (bnns_graph_context_t, bnns_graph_realloc_fn_t, bnns_graph_free_all_fn_t, + Csize_t, Ptr{Cvoid}), context, realloc, free, user_memory_context_size, + user_memory_context) +end + +function BNNSGraphContextSetOutputAllocationCallback(context, realloc, free, + user_memory_context_size, + user_memory_context) + return ccall((:BNNSGraphContextSetOutputAllocationCallback, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Cint, + (bnns_graph_context_t, bnns_graph_realloc_fn_t, bnns_graph_free_all_fn_t, + Csize_t, Ptr{Cvoid}), context, realloc, free, user_memory_context_size, + user_memory_context) +end + +function BNNSGraphContextSetMessageLogCallback(context, log_callback_fn, + additional_logging_arguments) + return ccall((:BNNSGraphContextSetMessageLogCallback, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Cint, + (bnns_graph_context_t, bnns_graph_execute_message_fn_t, + Ptr{bnns_user_message_data_t}), context, log_callback_fn, + additional_logging_arguments) +end + +function BNNSGraphContextSetMessageLogMask(context, log_level_mask) + return ccall((:BNNSGraphContextSetMessageLogMask, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Cint, (bnns_graph_context_t, UInt32), context, log_level_mask) +end + +const BNNSFilter = Ptr{Cvoid} + +function BNNSFilterCreateLayerConvolution(layer_params, filter_params) + return ccall((:BNNSFilterCreateLayerConvolution, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Ptr{Cvoid}, + (Ptr{BNNSLayerParametersConvolution}, Ptr{BNNSFilterParameters}), + layer_params, filter_params) +end + +function BNNSFilterCreateLayerTransposedConvolution(layer_params, filter_params) + return ccall((:BNNSFilterCreateLayerTransposedConvolution, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Ptr{Cvoid}, + (Ptr{BNNSLayerParametersConvolution}, Ptr{BNNSFilterParameters}), + layer_params, filter_params) +end + +function BNNSFilterCreateLayerFullyConnected(layer_params, filter_params) + return ccall((:BNNSFilterCreateLayerFullyConnected, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Ptr{Cvoid}, + (Ptr{BNNSLayerParametersFullyConnected}, Ptr{BNNSFilterParameters}), + layer_params, filter_params) +end + +function BNNSFilterCreateLayerPooling(layer_params, filter_params) + return ccall((:BNNSFilterCreateLayerPooling, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Ptr{Cvoid}, (Ptr{BNNSLayerParametersPooling}, Ptr{BNNSFilterParameters}), + layer_params, filter_params) +end + +function BNNSFilterCreateLayerActivation(layer_params, filter_params) + return ccall((:BNNSFilterCreateLayerActivation, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Ptr{Cvoid}, + (Ptr{BNNSLayerParametersActivation}, Ptr{BNNSFilterParameters}), + layer_params, filter_params) +end + +function BNNSFilterCreateLayerLoss(layer_params, filter_params) + return ccall((:BNNSFilterCreateLayerLoss, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Ptr{Cvoid}, (Ptr{Cvoid}, Ptr{BNNSFilterParameters}), layer_params, + filter_params) +end + +function BNNSFilterCreateLayerNormalization(normType, layer_params, filter_params) + return ccall((:BNNSFilterCreateLayerNormalization, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Ptr{Cvoid}, + (BNNSFilterType, Ptr{BNNSLayerParametersNormalization}, + Ptr{BNNSFilterParameters}), normType, layer_params, filter_params) +end + +function BNNSFilterCreateLayerArithmetic(layer_params, filter_params) + return ccall((:BNNSFilterCreateLayerArithmetic, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Ptr{Cvoid}, + (Ptr{BNNSLayerParametersArithmetic}, Ptr{BNNSFilterParameters}), + layer_params, filter_params) +end + +function BNNSFilterCreateLayerPermute(layer_params, filter_params) + return ccall((:BNNSFilterCreateLayerPermute, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Ptr{Cvoid}, (Ptr{BNNSLayerParametersPermute}, Ptr{BNNSFilterParameters}), + layer_params, filter_params) +end + +function BNNSFilterCreateLayerDropout(layer_params, filter_params) + return ccall((:BNNSFilterCreateLayerDropout, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Ptr{Cvoid}, (Ptr{BNNSLayerParametersDropout}, Ptr{BNNSFilterParameters}), + layer_params, filter_params) +end + +function BNNSFilterCreateLayerPadding(layer_params, filter_params) + return ccall((:BNNSFilterCreateLayerPadding, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Ptr{Cvoid}, (Ptr{BNNSLayerParametersPadding}, Ptr{BNNSFilterParameters}), + layer_params, filter_params) +end + +function BNNSFilterCreateLayerBroadcastMatMul(layer_params, filter_params) + return ccall((:BNNSFilterCreateLayerBroadcastMatMul, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Ptr{Cvoid}, + (Ptr{BNNSLayerParametersBroadcastMatMul}, Ptr{BNNSFilterParameters}), + layer_params, filter_params) +end + +function BNNSFilterCreateLayerTensorContraction(layer_params, filter_params) + return ccall((:BNNSFilterCreateLayerTensorContraction, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Ptr{Cvoid}, + (Ptr{BNNSLayerParametersTensorContraction}, Ptr{BNNSFilterParameters}), + layer_params, filter_params) +end + +function BNNSFilterCreateLayerGram(layer_params, filter_params) + return ccall((:BNNSFilterCreateLayerGram, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Ptr{Cvoid}, (Ptr{BNNSLayerParametersGram}, Ptr{BNNSFilterParameters}), + layer_params, filter_params) +end + +function BNNSFilterCreateLayerResize(layer_params, filter_params) + return ccall((:BNNSFilterCreateLayerResize, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Ptr{Cvoid}, (Ptr{BNNSLayerParametersResize}, Ptr{BNNSFilterParameters}), + layer_params, filter_params) +end + +function BNNSFilterCreateLayerMultiheadAttention(layer_params, filter_params) + return ccall((:BNNSFilterCreateLayerMultiheadAttention, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Ptr{Cvoid}, + (Ptr{BNNSLayerParametersMultiheadAttention}, Ptr{BNNSFilterParameters}), + layer_params, filter_params) +end + +function BNNSFilterCreateLayerReduction(layer_params, filter_params) + return ccall((:BNNSFilterCreateLayerReduction, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Ptr{Cvoid}, (Ptr{BNNSLayerParametersReduction}, Ptr{BNNSFilterParameters}), + layer_params, filter_params) +end + +function BNNSFilterCreateFusedLayer(number_of_fused_filters, filter_type, layer_params, + filter_params) + return ccall((:BNNSFilterCreateFusedLayer, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Ptr{Cvoid}, + (Csize_t, Ptr{BNNSFilterType}, Ptr{Ptr{Cvoid}}, Ptr{BNNSFilterParameters}), + number_of_fused_filters, filter_type, layer_params, filter_params) +end + +function BNNSFilterCreateLayerEmbedding(layer_params, filter_params) + return ccall((:BNNSFilterCreateLayerEmbedding, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Ptr{Cvoid}, (Ptr{BNNSLayerParametersEmbedding}, Ptr{BNNSFilterParameters}), + layer_params, filter_params) +end + +function BNNSFilterApply(filter, in, out) + return ccall((:BNNSFilterApply, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Cint, (Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cvoid}), filter, in, out) +end + +function BNNSFilterApplyBatch(filter, batch_size, in, in_stride, out, out_stride) + return ccall((:BNNSFilterApplyBatch, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Cint, (Ptr{Cvoid}, Csize_t, Ptr{Cvoid}, Csize_t, Ptr{Cvoid}, Csize_t), + filter, batch_size, in, in_stride, out, out_stride) +end + +function BNNSPoolingFilterApplyBatch(filter, batch_size, in, in_stride, out, out_stride, + indices, idx_stride) + return ccall((:BNNSPoolingFilterApplyBatch, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Cint, + (Ptr{Cvoid}, Csize_t, Ptr{Cvoid}, Csize_t, Ptr{Cvoid}, Csize_t, + Ptr{Csize_t}, Csize_t), filter, batch_size, in, in_stride, out, + out_stride, indices, idx_stride) +end + +function BNNSPoolingFilterApplyBatchEx(filter, batch_size, in, in_stride, out, out_stride, + indices_data_type, indices, idx_stride) + return ccall((:BNNSPoolingFilterApplyBatchEx, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Cint, + (Ptr{Cvoid}, Csize_t, Ptr{Cvoid}, Csize_t, Ptr{Cvoid}, Csize_t, + BNNSDataType, Ptr{Cvoid}, Csize_t), filter, batch_size, in, in_stride, + out, out_stride, indices_data_type, indices, idx_stride) +end + +function BNNSFilterApplyTwoInput(filter, inA, inB, out) + return ccall((:BNNSFilterApplyTwoInput, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Cint, (Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cvoid}), filter, inA, inB, + out) +end + +function BNNSFilterApplyTwoInputBatch(filter, batch_size, inA, inA_stride, inB, inB_stride, + out, out_stride) + return ccall((:BNNSFilterApplyTwoInputBatch, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Cint, + (Ptr{Cvoid}, Csize_t, Ptr{Cvoid}, Csize_t, Ptr{Cvoid}, Csize_t, Ptr{Cvoid}, + Csize_t), filter, batch_size, inA, inA_stride, inB, inB_stride, out, + out_stride) +end + +function BNNSNormalizationFilterApplyBatch(filter, batch_size, in, in_stride, out, + out_stride, training) + return ccall((:BNNSNormalizationFilterApplyBatch, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Cint, + (Ptr{Cvoid}, Csize_t, Ptr{Cvoid}, Csize_t, Ptr{Cvoid}, Csize_t, Bool), + filter, batch_size, in, in_stride, out, out_stride, training) +end + +function BNNSFusedFilterApplyBatch(filter, batch_size, in, in_stride, out, out_stride, + training) + return ccall((:BNNSFusedFilterApplyBatch, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Cint, + (Ptr{Cvoid}, Csize_t, Ptr{Cvoid}, Csize_t, Ptr{Cvoid}, Csize_t, Bool), + filter, batch_size, in, in_stride, out, out_stride, training) +end + +function BNNSFusedFilterApplyMultiInputBatch(filter, batch_size, number_of_inputs, in, + in_stride, out, out_stride, training) + return ccall((:BNNSFusedFilterApplyMultiInputBatch, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Cint, + (Ptr{Cvoid}, Csize_t, Csize_t, Ptr{Ptr{Cvoid}}, Ptr{Csize_t}, Ptr{Cvoid}, + Csize_t, Bool), filter, batch_size, number_of_inputs, in, in_stride, out, + out_stride, training) +end + +function BNNSArithmeticFilterApplyBatch(filter, batch_size, number_of_inputs, in, in_stride, + out, out_stride) + return ccall((:BNNSArithmeticFilterApplyBatch, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Cint, + (Ptr{Cvoid}, Csize_t, Csize_t, Ptr{Ptr{Cvoid}}, Ptr{Csize_t}, Ptr{Cvoid}, + Csize_t), filter, batch_size, number_of_inputs, in, in_stride, out, + out_stride) +end + +function BNNSApplyMultiheadAttention(F, batch_size, query, query_stride, key, key_stride, + key_mask, key_mask_stride, value, value_stride, output, + output_stride, add_to_attention, backprop_cache_size, + backprop_cache, workspace_size, workspace) + return ccall((:BNNSApplyMultiheadAttention, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Cint, + (Ptr{Cvoid}, Csize_t, Ptr{Cvoid}, Csize_t, Ptr{Cvoid}, Csize_t, + Ptr{BNNSNDArrayDescriptor}, Csize_t, Ptr{Cvoid}, Csize_t, Ptr{Cvoid}, + Csize_t, Ptr{BNNSNDArrayDescriptor}, Ptr{Csize_t}, Ptr{Cvoid}, + Ptr{Csize_t}, Ptr{Cvoid}), F, batch_size, query, query_stride, key, + key_stride, key_mask, key_mask_stride, value, value_stride, output, + output_stride, add_to_attention, backprop_cache_size, backprop_cache, + workspace_size, workspace) +end + +function BNNSDirectApplyQuantizer(layer_params, filter_params, batch_size, input_stride, + output_stride) + return ccall((:BNNSDirectApplyQuantizer, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Cint, + (Ptr{BNNSLayerParametersQuantization}, Ptr{BNNSFilterParameters}, Csize_t, + Csize_t, Csize_t), layer_params, filter_params, batch_size, input_stride, + output_stride) +end + +function BNNSFilterDestroy(filter) + return ccall((:BNNSFilterDestroy, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Cvoid, (Ptr{Cvoid},), filter) +end + +function BNNSOptimizerStep(_function, OptimizerAlgFields, number_of_parameters, parameters, + gradients, accumulators, filter_params) + return ccall((:BNNSOptimizerStep, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Cint, + (BNNSOptimizerFunction, Ptr{Cvoid}, Csize_t, + Ptr{Ptr{BNNSNDArrayDescriptor}}, Ptr{Ptr{BNNSNDArrayDescriptor}}, + Ptr{Ptr{BNNSNDArrayDescriptor}}, Ptr{BNNSFilterParameters}), _function, + OptimizerAlgFields, number_of_parameters, parameters, gradients, + accumulators, filter_params) +end + +function BNNSClipByValue(dest, src, min_val, max_val) + return ccall((:BNNSClipByValue, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Cint, + (Ptr{BNNSNDArrayDescriptor}, Ptr{BNNSNDArrayDescriptor}, Cfloat, Cfloat), + dest, src, min_val, max_val) +end + +function BNNSClipByNorm(dest, src, max_norm, axis_flags) + return ccall((:BNNSClipByNorm, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Cint, + (Ptr{BNNSNDArrayDescriptor}, Ptr{BNNSNDArrayDescriptor}, Cfloat, UInt32), + dest, src, max_norm, axis_flags) +end + +function BNNSClipByGlobalNorm(dest, src, count, max_norm, use_norm) + return ccall((:BNNSClipByGlobalNorm, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Cint, + (Ptr{Ptr{BNNSNDArrayDescriptor}}, Ptr{Ptr{BNNSNDArrayDescriptor}}, Csize_t, + Cfloat, Cfloat), dest, src, count, max_norm, use_norm) +end + +function BNNSComputeNorm(dest, src, norm_type, axis_flags) + return ccall((:BNNSComputeNorm, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Cint, + (Ptr{BNNSNDArrayDescriptor}, Ptr{BNNSNDArrayDescriptor}, BNNSNormType, + UInt32), dest, src, norm_type, axis_flags) +end + +function BNNSComputeNormBackward(in, in_delta, out, out_delta, norm_type, axis_flags) + return ccall((:BNNSComputeNormBackward, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Cint, + (Ptr{Cvoid}, Ptr{BNNSNDArrayDescriptor}, Ptr{Cvoid}, + Ptr{BNNSNDArrayDescriptor}, BNNSNormType, UInt32), in, in_delta, out, + out_delta, norm_type, axis_flags) +end + +function BNNSFilterApplyBackwardBatch(filter, batch_size, in, in_stride, in_delta, + in_delta_stride, out, out_stride, out_delta, + out_delta_stride, weights_delta, bias_delta) + return ccall((:BNNSFilterApplyBackwardBatch, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Cint, + (Ptr{Cvoid}, Csize_t, Ptr{Cvoid}, Csize_t, Ptr{BNNSNDArrayDescriptor}, + Csize_t, Ptr{Cvoid}, Csize_t, Ptr{BNNSNDArrayDescriptor}, Csize_t, + Ptr{BNNSNDArrayDescriptor}, Ptr{BNNSNDArrayDescriptor}), filter, + batch_size, in, in_stride, in_delta, in_delta_stride, out, out_stride, + out_delta, out_delta_stride, weights_delta, bias_delta) +end + +function BNNSPoolingFilterApplyBackwardBatch(filter, batch_size, in, in_stride, in_delta, + in_delta_stride, out, out_stride, out_delta, + out_delta_stride, bias_delta, indices, + idx_stride) + return ccall((:BNNSPoolingFilterApplyBackwardBatch, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Cint, + (Ptr{Cvoid}, Csize_t, Ptr{Cvoid}, Csize_t, Ptr{BNNSNDArrayDescriptor}, + Csize_t, Ptr{Cvoid}, Csize_t, Ptr{BNNSNDArrayDescriptor}, Csize_t, + Ptr{BNNSNDArrayDescriptor}, Ptr{Csize_t}, Csize_t), filter, batch_size, + in, in_stride, in_delta, in_delta_stride, out, out_stride, out_delta, + out_delta_stride, bias_delta, indices, idx_stride) +end + +function BNNSPoolingFilterApplyBackwardBatchEx(filter, batch_size, in, in_stride, in_delta, + in_delta_stride, out, out_stride, out_delta, + out_delta_stride, bias_delta, + indices_data_type, indices, idx_stride) + return ccall((:BNNSPoolingFilterApplyBackwardBatchEx, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Cint, + (Ptr{Cvoid}, Csize_t, Ptr{Cvoid}, Csize_t, Ptr{BNNSNDArrayDescriptor}, + Csize_t, Ptr{Cvoid}, Csize_t, Ptr{BNNSNDArrayDescriptor}, Csize_t, + Ptr{BNNSNDArrayDescriptor}, BNNSDataType, Ptr{Cvoid}, Csize_t), filter, + batch_size, in, in_stride, in_delta, in_delta_stride, out, out_stride, + out_delta, out_delta_stride, bias_delta, indices_data_type, indices, + idx_stride) +end + +function BNNSFilterApplyBackwardTwoInputBatch(filter, batch_size, inA, inA_stride, + inA_delta, inA_delta_stride, inB, inB_stride, + inB_delta, inB_delta_stride, out, out_stride, + out_delta, out_delta_stride, weights_delta, + bias_delta) + return ccall((:BNNSFilterApplyBackwardTwoInputBatch, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Cint, + (Ptr{Cvoid}, Csize_t, Ptr{Cvoid}, Csize_t, Ptr{BNNSNDArrayDescriptor}, + Csize_t, Ptr{Cvoid}, Csize_t, Ptr{BNNSNDArrayDescriptor}, Csize_t, + Ptr{Cvoid}, Csize_t, Ptr{BNNSNDArrayDescriptor}, Csize_t, + Ptr{BNNSNDArrayDescriptor}, Ptr{BNNSNDArrayDescriptor}), filter, + batch_size, inA, inA_stride, inA_delta, inA_delta_stride, inB, inB_stride, + inB_delta, inB_delta_stride, out, out_stride, out_delta, out_delta_stride, + weights_delta, bias_delta) +end + +function BNNSNormalizationFilterApplyBackwardBatch(filter, batch_size, in_delta, + in_delta_stride, out, out_stride, + out_delta, out_delta_stride, beta_delta, + gamma_delta) + return ccall((:BNNSNormalizationFilterApplyBackwardBatch, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Cint, + (Ptr{Cvoid}, Csize_t, Ptr{BNNSNDArrayDescriptor}, Csize_t, Ptr{Cvoid}, + Csize_t, Ptr{BNNSNDArrayDescriptor}, Csize_t, Ptr{BNNSNDArrayDescriptor}, + Ptr{BNNSNDArrayDescriptor}), filter, batch_size, in_delta, + in_delta_stride, out, out_stride, out_delta, out_delta_stride, beta_delta, + gamma_delta) +end + +function BNNSFusedFilterApplyBackwardBatch(filter, batch_size, in, in_stride, in_delta, + in_delta_stride, out, out_stride, out_delta, + out_delta_stride, delta_parameters) + return ccall((:BNNSFusedFilterApplyBackwardBatch, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Cint, + (Ptr{Cvoid}, Csize_t, Ptr{Cvoid}, Csize_t, Ptr{BNNSNDArrayDescriptor}, + Csize_t, Ptr{Cvoid}, Csize_t, Ptr{BNNSNDArrayDescriptor}, Csize_t, + Ptr{Ptr{BNNSNDArrayDescriptor}}), filter, batch_size, in, in_stride, + in_delta, in_delta_stride, out, out_stride, out_delta, out_delta_stride, + delta_parameters) +end + +function BNNSFusedFilterApplyBackwardMultiInputBatch(filter, batch_size, number_of_inputs, + in, in_stride, in_delta, + in_delta_stride, out, out_stride, + out_delta, out_delta_stride, + delta_parameters) + return ccall((:BNNSFusedFilterApplyBackwardMultiInputBatch, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Cint, + (Ptr{Cvoid}, Csize_t, Csize_t, Ptr{Ptr{Cvoid}}, Ptr{Csize_t}, + Ptr{Ptr{BNNSNDArrayDescriptor}}, Ptr{Csize_t}, Ptr{Cvoid}, Csize_t, + Ptr{BNNSNDArrayDescriptor}, Csize_t, Ptr{Ptr{BNNSNDArrayDescriptor}}), + filter, batch_size, number_of_inputs, in, in_stride, in_delta, + in_delta_stride, out, out_stride, out_delta, out_delta_stride, + delta_parameters) +end + +function BNNSArithmeticFilterApplyBackwardBatch(filter, batch_size, number_of_inputs, in, + in_stride, in_delta, in_delta_stride, out, + out_stride, out_delta, out_delta_stride) + return ccall((:BNNSArithmeticFilterApplyBackwardBatch, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Cint, + (Ptr{Cvoid}, Csize_t, Csize_t, Ptr{Ptr{Cvoid}}, Ptr{Csize_t}, + Ptr{Ptr{BNNSNDArrayDescriptor}}, Ptr{Csize_t}, Ptr{Cvoid}, Csize_t, + Ptr{BNNSNDArrayDescriptor}, Csize_t), filter, batch_size, + number_of_inputs, in, in_stride, in_delta, in_delta_stride, out, + out_stride, out_delta, out_delta_stride) +end + +function BNNSPermuteFilterApplyBackwardBatch(filter, batch_size, in_delta, in_delta_stride, + out_delta, out_delta_stride) + return ccall((:BNNSPermuteFilterApplyBackwardBatch, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Cint, + (Ptr{Cvoid}, Csize_t, Ptr{BNNSNDArrayDescriptor}, Csize_t, + Ptr{BNNSNDArrayDescriptor}, Csize_t), filter, batch_size, in_delta, + in_delta_stride, out_delta, out_delta_stride) +end + +function BNNSLossFilterApplyBatch(filter, batch_size, in, in_stride, labels, labels_stride, + weights, weights_size, out, in_delta, in_delta_stride) + return ccall((:BNNSLossFilterApplyBatch, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Cint, + (Ptr{Cvoid}, Csize_t, Ptr{Cvoid}, Csize_t, Ptr{Cvoid}, Csize_t, Ptr{Cvoid}, + Csize_t, Ptr{Cvoid}, Ptr{BNNSNDArrayDescriptor}, Csize_t), filter, + batch_size, in, in_stride, labels, labels_stride, weights, weights_size, + out, in_delta, in_delta_stride) +end + +function BNNSLossFilterApplyBackwardBatch(filter, batch_size, in, in_stride, in_delta, + in_delta_stride, labels, labels_stride, weights, + weights_size, out_delta, out_delta_stride) + return ccall((:BNNSLossFilterApplyBackwardBatch, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Cint, + (Ptr{Cvoid}, Csize_t, Ptr{Cvoid}, Csize_t, Ptr{BNNSNDArrayDescriptor}, + Csize_t, Ptr{Cvoid}, Csize_t, Ptr{Cvoid}, Csize_t, + Ptr{BNNSNDArrayDescriptor}, Csize_t), filter, batch_size, in, in_stride, + in_delta, in_delta_stride, labels, labels_stride, weights, weights_size, + out_delta, out_delta_stride) +end + +function BNNSApplyMultiheadAttentionBackward(F, batch_size, query, query_stride, + query_param_delta, key, key_stride, key_mask, + key_mask_stride, key_param_delta, value, + value_stride, value_param_delta, + add_to_attention, key_attn_bias_delta, + value_attn_bias_delta, output, output_stride, + output_param_delta, backprop_cache_size, + backprop_cache, workspace_size, workspace) + return ccall((:BNNSApplyMultiheadAttentionBackward, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Cint, + (Ptr{Cvoid}, Csize_t, Ptr{Cvoid}, Csize_t, + Ptr{BNNSMHAProjectionParameters}, Ptr{Cvoid}, Csize_t, + Ptr{BNNSNDArrayDescriptor}, Csize_t, Ptr{BNNSMHAProjectionParameters}, + Ptr{Cvoid}, Csize_t, Ptr{BNNSMHAProjectionParameters}, + Ptr{BNNSNDArrayDescriptor}, Ptr{BNNSNDArrayDescriptor}, + Ptr{BNNSNDArrayDescriptor}, Ptr{Cvoid}, Csize_t, + Ptr{BNNSMHAProjectionParameters}, Csize_t, Ptr{Cvoid}, Ptr{Csize_t}, + Ptr{Cvoid}), F, batch_size, query, query_stride, query_param_delta, key, + key_stride, key_mask, key_mask_stride, key_param_delta, value, + value_stride, value_param_delta, add_to_attention, key_attn_bias_delta, + value_attn_bias_delta, output, output_stride, output_param_delta, + backprop_cache_size, backprop_cache, workspace_size, workspace) +end + +function BNNSComputeLSTMTrainingCacheCapacity(layer_params) + return ccall((:BNNSComputeLSTMTrainingCacheCapacity, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Csize_t, (Ptr{BNNSLayerParametersLSTM},), layer_params) +end + +function BNNSDirectApplyLSTMBatchTrainingCaching(layer_params, filter_params, + training_cache_ptr, + training_cache_capacity) + return ccall((:BNNSDirectApplyLSTMBatchTrainingCaching, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Cint, + (Ptr{BNNSLayerParametersLSTM}, Ptr{BNNSFilterParameters}, Ptr{Cvoid}, + Csize_t), layer_params, filter_params, training_cache_ptr, + training_cache_capacity) +end + +function BNNSDirectApplyActivationBatch(layer_params, filter_params, batch_size, in_stride, + out_stride) + return ccall((:BNNSDirectApplyActivationBatch, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Cint, + (Ptr{BNNSLayerParametersActivation}, Ptr{BNNSFilterParameters}, Csize_t, + Csize_t, Csize_t), layer_params, filter_params, batch_size, in_stride, + out_stride) +end + +function BNNSCopy(dest, src, filter_params) + return ccall((:BNNSCopy, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Cint, + (Ptr{BNNSNDArrayDescriptor}, Ptr{BNNSNDArrayDescriptor}, + Ptr{BNNSFilterParameters}), dest, src, filter_params) +end + +function BNNSMatMulWorkspaceSize(transA, transB, alpha, inputA, inputB, output, + filter_params) + return ccall((:BNNSMatMulWorkspaceSize, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Cssize_t, + (Bool, Bool, Cfloat, Ptr{BNNSNDArrayDescriptor}, + Ptr{BNNSNDArrayDescriptor}, Ptr{BNNSNDArrayDescriptor}, + Ptr{BNNSFilterParameters}), transA, transB, alpha, inputA, inputB, output, + filter_params) +end + +function BNNSMatMul(transA, transB, alpha, inputA, inputB, output, workspace, filter_params) + return ccall((:BNNSMatMul, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Cint, + (Bool, Bool, Cfloat, Ptr{BNNSNDArrayDescriptor}, + Ptr{BNNSNDArrayDescriptor}, Ptr{BNNSNDArrayDescriptor}, Ptr{Cvoid}, + Ptr{BNNSFilterParameters}), transA, transB, alpha, inputA, inputB, output, + workspace, filter_params) +end + +function BNNSDirectApplyBroadcastMatMul(transA, transB, alpha, inputA, inputB, output, + filter_params) + return ccall((:BNNSDirectApplyBroadcastMatMul, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Cvoid, + (Bool, Bool, Cfloat, Ptr{BNNSNDArrayDescriptor}, + Ptr{BNNSNDArrayDescriptor}, Ptr{BNNSNDArrayDescriptor}, + Ptr{BNNSFilterParameters}), transA, transB, alpha, inputA, inputB, output, + filter_params) +end + +function BNNSTranspose(dest, src, axis0, axis1, filter_params) + return ccall((:BNNSTranspose, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Cint, + (Ptr{BNNSNDArrayDescriptor}, Ptr{BNNSNDArrayDescriptor}, Csize_t, Csize_t, + Ptr{BNNSFilterParameters}), dest, src, axis0, axis1, filter_params) +end + +function BNNSDirectApplyReduction(layer_params, filter_params) + return ccall((:BNNSDirectApplyReduction, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Cint, (Ptr{BNNSLayerParametersReduction}, Ptr{BNNSFilterParameters}), + layer_params, filter_params) +end + +function BNNSCompareTensor(in0, in1, op, out) + return ccall((:BNNSCompareTensor, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Cint, + (Ptr{BNNSNDArrayDescriptor}, Ptr{BNNSNDArrayDescriptor}, + BNNSRelationalOperator, Ptr{BNNSNDArrayDescriptor}), in0, in1, op, out) +end + +function BNNSTile(input, output, filter_params) + return ccall((:BNNSTile, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Cint, + (Ptr{BNNSNDArrayDescriptor}, Ptr{BNNSNDArrayDescriptor}, + Ptr{BNNSFilterParameters}), input, output, filter_params) +end + +function BNNSDirectApplyTopK(K, axis, batch_size, input, input_batch_stride, best_values, + best_values_batch_stride, best_indices, + best_indices_batch_stride, filter_params) + return ccall((:BNNSDirectApplyTopK, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Cint, + (Csize_t, Csize_t, Csize_t, Ptr{BNNSNDArrayDescriptor}, Csize_t, + Ptr{BNNSNDArrayDescriptor}, Csize_t, Ptr{BNNSNDArrayDescriptor}, Csize_t, + Ptr{BNNSFilterParameters}), K, axis, batch_size, input, + input_batch_stride, best_values, best_values_batch_stride, best_indices, + best_indices_batch_stride, filter_params) +end + +function BNNSDirectApplyInTopK(K, axis, batch_size, input, input_batch_stride, test_indices, + test_indices_batch_stride, output, output_batch_stride, + filter_params) + return ccall((:BNNSDirectApplyInTopK, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Cint, + (Csize_t, Csize_t, Csize_t, Ptr{BNNSNDArrayDescriptor}, Csize_t, + Ptr{BNNSNDArrayDescriptor}, Csize_t, Ptr{BNNSNDArrayDescriptor}, Csize_t, + Ptr{BNNSFilterParameters}), K, axis, batch_size, input, + input_batch_stride, test_indices, test_indices_batch_stride, output, + output_batch_stride, filter_params) +end + +function BNNSGather(axis, input, indices, output, filter_params) + return ccall((:BNNSGather, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Cint, + (Csize_t, Ptr{BNNSNDArrayDescriptor}, Ptr{BNNSNDArrayDescriptor}, + Ptr{BNNSNDArrayDescriptor}, Ptr{BNNSFilterParameters}), axis, input, + indices, output, filter_params) +end + +function BNNSScatter(axis, op, input, indices, output, filter_params) + return ccall((:BNNSScatter, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Cint, + (Csize_t, BNNSReduceFunction, Ptr{BNNSNDArrayDescriptor}, + Ptr{BNNSNDArrayDescriptor}, Ptr{BNNSNDArrayDescriptor}, + Ptr{BNNSFilterParameters}), axis, op, input, indices, output, + filter_params) +end + +function BNNSGatherND(input, indices, output, filter_params) + return ccall((:BNNSGatherND, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Cint, + (Ptr{BNNSNDArrayDescriptor}, Ptr{BNNSNDArrayDescriptor}, + Ptr{BNNSNDArrayDescriptor}, Ptr{BNNSFilterParameters}), input, indices, + output, filter_params) +end + +function BNNSScatterND(op, input, indices, output, filter_params) + return ccall((:BNNSScatterND, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Cint, + (BNNSReduceFunction, Ptr{BNNSNDArrayDescriptor}, + Ptr{BNNSNDArrayDescriptor}, Ptr{BNNSNDArrayDescriptor}, + Ptr{BNNSFilterParameters}), op, input, indices, output, filter_params) +end + +function BNNSShuffle(type, input, output, filter_params) + return ccall((:BNNSShuffle, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Cint, + (BNNSShuffleType, Ptr{BNNSNDArrayDescriptor}, Ptr{BNNSNDArrayDescriptor}, + Ptr{BNNSFilterParameters}), type, input, output, filter_params) +end + +function BNNSBandPart(num_lower, num_upper, input, output, filter_params) + return ccall((:BNNSBandPart, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Cint, + (Cint, Cint, Ptr{BNNSNDArrayDescriptor}, Ptr{BNNSNDArrayDescriptor}, + Ptr{BNNSFilterParameters}), num_lower, num_upper, input, output, + filter_params) +end + +function BNNSCropResize(layer_params, input, roi, output, filter_params) + return ccall((:BNNSCropResize, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Cint, + (Ptr{BNNSLayerParametersCropResize}, Ptr{BNNSNDArrayDescriptor}, + Ptr{BNNSNDArrayDescriptor}, Ptr{BNNSNDArrayDescriptor}, + Ptr{BNNSFilterParameters}), layer_params, input, roi, output, + filter_params) +end + +function BNNSDirectApplyLSTMBatchBackward(layer_params, layer_delta_params, filter_params, + training_cache_ptr, training_cache_capacity) + return ccall((:BNNSDirectApplyLSTMBatchBackward, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Cint, + (Ptr{BNNSLayerParametersLSTM}, Ptr{BNNSLayerParametersLSTM}, + Ptr{BNNSFilterParameters}, Ptr{Cvoid}, Csize_t), layer_params, + layer_delta_params, filter_params, training_cache_ptr, + training_cache_capacity) +end + +function BNNSTileBackward(in_delta, out_delta, filter_params) + return ccall((:BNNSTileBackward, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Cint, + (Ptr{BNNSNDArrayDescriptor}, Ptr{BNNSNDArrayDescriptor}, + Ptr{BNNSFilterParameters}), in_delta, out_delta, filter_params) +end + +function BNNSCropResizeBackward(layer_params, in_delta, roi, out_delta, filter_params) + return ccall((:BNNSCropResizeBackward, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Cint, + (Ptr{BNNSLayerParametersCropResize}, Ptr{BNNSNDArrayDescriptor}, + Ptr{BNNSNDArrayDescriptor}, Ptr{BNNSNDArrayDescriptor}, + Ptr{BNNSFilterParameters}), layer_params, in_delta, roi, out_delta, + filter_params) +end + +function BNNSGetPointer(filter, target) + return ccall((:BNNSGetPointer, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + BNNSNDArrayDescriptor, (Ptr{Cvoid}, BNNSPointerSpecifier), filter, target) +end + +function BNNSNDArrayGetDataSize(array) + return ccall((:BNNSNDArrayGetDataSize, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Csize_t, (Ptr{BNNSNDArrayDescriptor},), array) +end + +function BNNSTensorGetAllocationSize(tensor) + return ccall((:BNNSTensorGetAllocationSize, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Csize_t, (Ptr{BNNSTensor},), tensor) +end + +function BNNSDataLayoutGetRank(layout) + return ccall((:BNNSDataLayoutGetRank, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Csize_t, (BNNSDataLayout,), layout) +end + +function BNNSNDArrayFullyConnectedSparsifySparseCOO(in_dense_shape, in_indices, in_values, + out, sparse_params, batch_size, + workspace, workspace_size, + filter_params) + return ccall((:BNNSNDArrayFullyConnectedSparsifySparseCOO, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Cint, + (Ptr{BNNSNDArrayDescriptor}, Ptr{BNNSNDArrayDescriptor}, + Ptr{BNNSNDArrayDescriptor}, Ptr{BNNSNDArrayDescriptor}, + Ptr{BNNSSparsityParameters}, Csize_t, Ptr{Cvoid}, Csize_t, + Ptr{BNNSFilterParameters}), in_dense_shape, in_indices, in_values, out, + sparse_params, batch_size, workspace, workspace_size, filter_params) +end + +function BNNSNDArrayFullyConnectedSparsifySparseCSR(in_dense_shape, in_column_indices, + in_row_starts, in_values, out, + sparse_params, batch_size, workspace, + workspace_size, filter_params) + return ccall((:BNNSNDArrayFullyConnectedSparsifySparseCSR, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Cint, + (Ptr{BNNSNDArrayDescriptor}, Ptr{BNNSNDArrayDescriptor}, + Ptr{BNNSNDArrayDescriptor}, Ptr{BNNSNDArrayDescriptor}, + Ptr{BNNSNDArrayDescriptor}, Ptr{BNNSSparsityParameters}, Csize_t, + Ptr{Cvoid}, Csize_t, Ptr{BNNSFilterParameters}), in_dense_shape, + in_column_indices, in_row_starts, in_values, out, sparse_params, + batch_size, workspace, workspace_size, filter_params) +end + +const BNNSRandomGenerator = Ptr{Cvoid} + +function BNNSCreateRandomGenerator(method, filter_params) + return ccall((:BNNSCreateRandomGenerator, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Ptr{Cvoid}, (BNNSRandomGeneratorMethod, Ptr{BNNSFilterParameters}), method, + filter_params) +end + +function BNNSCreateRandomGeneratorWithSeed(method, seed, filter_params) + return ccall((:BNNSCreateRandomGeneratorWithSeed, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Ptr{Cvoid}, (BNNSRandomGeneratorMethod, UInt64, Ptr{BNNSFilterParameters}), + method, seed, filter_params) +end + +function BNNSDestroyRandomGenerator(generator) + return ccall((:BNNSDestroyRandomGenerator, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Cvoid, (Ptr{Cvoid},), generator) +end + +function BNNSRandomGeneratorStateSize(generator) + return ccall((:BNNSRandomGeneratorStateSize, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Csize_t, (Ptr{Cvoid},), generator) +end + +function BNNSRandomGeneratorGetState(generator, state_size, state) + return ccall((:BNNSRandomGeneratorGetState, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Cint, (Ptr{Cvoid}, Csize_t, Ptr{Cvoid}), generator, state_size, state) +end + +function BNNSRandomGeneratorSetState(generator, state_size, state) + return ccall((:BNNSRandomGeneratorSetState, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Cint, (Ptr{Cvoid}, Csize_t, Ptr{Cvoid}), generator, state_size, state) +end + +function BNNSRandomFillUniformFloat(generator, desc, a, b) + return ccall((:BNNSRandomFillUniformFloat, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Cint, (Ptr{Cvoid}, Ptr{BNNSNDArrayDescriptor}, Cfloat, Cfloat), generator, + desc, a, b) +end + +function BNNSRandomFillUniformInt(generator, desc, a, b) + return ccall((:BNNSRandomFillUniformInt, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Cint, (Ptr{Cvoid}, Ptr{BNNSNDArrayDescriptor}, Int64, Int64), generator, + desc, a, b) +end + +function BNNSRandomFillNormalFloat(generator, desc, mean, stddev) + return ccall((:BNNSRandomFillNormalFloat, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Cint, (Ptr{Cvoid}, Ptr{BNNSNDArrayDescriptor}, Cfloat, Cfloat), generator, + desc, mean, stddev) +end + +function BNNSRandomFillCategoricalFloat(generator, desc, probabilities, log_probabilities) + return ccall((:BNNSRandomFillCategoricalFloat, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Cint, + (Ptr{Cvoid}, Ptr{BNNSNDArrayDescriptor}, Ptr{BNNSNDArrayDescriptor}, Bool), + generator, desc, probabilities, log_probabilities) +end + +const BNNSNearestNeighbors = Ptr{Cvoid} + +function BNNSCreateNearestNeighbors(max_n_samples, n_features, n_neighbors, data_type, + filter_params) + return ccall((:BNNSCreateNearestNeighbors, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Ptr{Cvoid}, (Cuint, Cuint, Cuint, BNNSDataType, Ptr{BNNSFilterParameters}), + max_n_samples, n_features, n_neighbors, data_type, filter_params) +end + +function BNNSDestroyNearestNeighbors(knn) + return ccall((:BNNSDestroyNearestNeighbors, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Cvoid, (Ptr{Cvoid},), knn) +end + +function BNNSNearestNeighborsLoad(knn, n_new_samples, data_ptr) + return ccall((:BNNSNearestNeighborsLoad, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Cint, (Ptr{Cvoid}, Cuint, Ptr{Cvoid}), knn, n_new_samples, data_ptr) +end + +function BNNSNearestNeighborsGetInfo(knn, sample_number, indices, distances) + return ccall((:BNNSNearestNeighborsGetInfo, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Cint, (Ptr{Cvoid}, Cint, Ptr{Cint}, Ptr{Cvoid}), knn, sample_number, + indices, distances) +end + +function BNNSFilterCreateConvolutionLayer(in_desc, out_desc, layer_params, filter_params) + return ccall((:BNNSFilterCreateConvolutionLayer, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Ptr{Cvoid}, + (Ptr{BNNSImageStackDescriptor}, Ptr{BNNSImageStackDescriptor}, + Ptr{BNNSConvolutionLayerParameters}, Ptr{BNNSFilterParameters}), in_desc, + out_desc, layer_params, filter_params) +end + +function BNNSFilterCreateFullyConnectedLayer(in_desc, out_desc, layer_params, filter_params) + return ccall((:BNNSFilterCreateFullyConnectedLayer, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Ptr{Cvoid}, + (Ptr{BNNSVectorDescriptor}, Ptr{BNNSVectorDescriptor}, + Ptr{BNNSFullyConnectedLayerParameters}, Ptr{BNNSFilterParameters}), + in_desc, out_desc, layer_params, filter_params) +end + +function BNNSFilterCreatePoolingLayer(in_desc, out_desc, layer_params, filter_params) + return ccall((:BNNSFilterCreatePoolingLayer, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Ptr{Cvoid}, + (Ptr{BNNSImageStackDescriptor}, Ptr{BNNSImageStackDescriptor}, + Ptr{BNNSPoolingLayerParameters}, Ptr{BNNSFilterParameters}), in_desc, + out_desc, layer_params, filter_params) +end + +function BNNSFilterCreateVectorActivationLayer(in_desc, out_desc, activation, filter_params) + return ccall((:BNNSFilterCreateVectorActivationLayer, + Symbol("/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib")), + Ptr{Cvoid}, + (Ptr{BNNSVectorDescriptor}, Ptr{BNNSVectorDescriptor}, Ptr{BNNSActivation}, + Ptr{BNNSFilterParameters}), in_desc, out_desc, activation, filter_params) +end diff --git a/src/AppleAccelerate.jl b/src/AppleAccelerate.jl index ecb09d7..e440be0 100644 --- a/src/AppleAccelerate.jl +++ b/src/AppleAccelerate.jl @@ -96,10 +96,16 @@ function __init__() load_accelerate(; load_ilp64=true, use_external_lapack=false) end -if Sys.isapple() +@static if Sys.isapple() include("Util.jl") include("Array.jl") include("DSP.jl") end +module BNNS +@static if Sys.isapple() + include("../lib/BNNS/BNNS.jl") +end +end # module BNNS + end # module diff --git a/test/BNNS.jl b/test/BNNS.jl new file mode 100644 index 0000000..e6ce6cc --- /dev/null +++ b/test/BNNS.jl @@ -0,0 +1,99 @@ +const RAND_TYPES = [BFloat16, Float16, Float32, Int8, UInt8, Int16, UInt16, Int32, UInt32, Int64, + UInt64] +const RANDN_TYPES = [BFloat16, Float16, Float32] +const INPLACE_TUPLES = [[(rand!, T) for T in RAND_TYPES]; + [(randn!, T) for T in RANDN_TYPES]] +const OOPLACE_TUPLES = [[(BNNS.rand, rand, T) for T in RAND_TYPES]; + [(BNNS.randn, rand, T) for T in RANDN_TYPES]] + +@testset "random" begin + # in-place + @testset "in-place" begin + rng = BNNS.bnns_rng() + + @testset "$f with $T" for (f, T) in INPLACE_TUPLES + # d == 2 and d == 3 are to hit the test cases where sizeof(A) <= 4 + @testset "$d" for d in (2, 3, (3, 3), (3, 3, 3), 16, (16, 16), (16, 16, 16), (1000,), (1000,1000)) + A = Array{T}(undef, d) + + # specifie BNNS rng + fill!(A, T(0)) + f(rng, A) + @test !iszero(collect(A)) + end + + @testset "0" begin + A = Array{T}(undef, 0) + + # specified BNNS rng + fill!(A, T(0)) + f(rng, A) + @test Array(A) == fill(1, 0) + end + end + end + # out-of-place + @testset "out-of-place" begin + @testset "$fr with implicit type" for (fm, fr, T) in + ((BNNS.rand, Random.rand, Float32), (BNNS.randn, Random.randn, Float32)) + rng = BNNS.bnns_rng() + @testset "args" for args in ((0,), (1,), (3,), (3, 3), (16,), (16, 16), (1000,), (1000,1000)) + # default_rng + A = fm(args...) + @test eltype(A) == T + + # specified MPS rng + B = fr(rng, args...) + @test eltype(B) == T + end + + @testset "scalar" begin + a = fm() + @test typeof(a) == T + b = fr(rng) + @test typeof(b) == T + end + end + + # out-of-place, with type specified + @testset "$fr with $T" for (fm, fr, T) in OOPLACE_TUPLES + rng = BNNS.bnns_rng() + @testset "$args" for args in ((T, 0), + (T, 1), + (T, 3), + (T, 3, 3), + (T, (3, 3)), + (T, 16), + (T, 16, 16), + (T, (16, 16)), + (T, 1000), + (T, 1000, 1000),) + # default_rng + A = fm(args...) + @test eltype(A) == T + + # specified RNG rng + B = fr(rng, args...) + @test eltype(B) == T + end + + @testset "scalar" begin + a = fm(T) + @test typeof(a) == T + b = fr(rng, T) + @test typeof(b) == T + end + end + end + + ## seeding + @testset "Seeding" begin + @testset "$d" for d in (1, 3, (3, 3, 3), 16, (16, 16), (16, 16, 16), (1000,), (1000,1000), (3,3,3,3), (3,3,3,3,3), (3,3,3,3,3,3)) + rng = BNNS.bnns_rng(1) + a = rand(rng, Float32, d) + Random.seed!(rng, 1) + b = rand(rng, Float32, d) + @test a == b + end + end +end # testset diff --git a/test/runtests.jl b/test/runtests.jl index b15415e..01a2c42 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,6 +1,6 @@ using LinearAlgebra -using AppleAccelerate -using DSP, Test, Random, Statistics +using AppleAccelerate: AppleAccelerate, BNNS +using DSP, Test, Random, Statistics, BFloat16s if !Sys.isapple() @info("AppleAccelerate.jl will be tested only on macOS. Exiting.") @@ -165,6 +165,10 @@ for T in (Float32, Float64) end end +@testset "BNNS" begin + include("BNNS.jl") +end + @testset "DCT::Float32" begin r=rand(Float32,2^16) diff --git a/wrap/.gitignore b/wrap/.gitignore new file mode 100644 index 0000000..d33fed7 --- /dev/null +++ b/wrap/.gitignore @@ -0,0 +1 @@ +*.JLD2 diff --git a/wrap/Project.toml b/wrap/Project.toml new file mode 100644 index 0000000..2daf101 --- /dev/null +++ b/wrap/Project.toml @@ -0,0 +1,9 @@ +[deps] +Clang = "40e3b903-d033-50b4-a0cc-940c62c95e31" +Clang_jll = "0ee61d77-7f21-5576-8119-9fcc46b10100" +Glob = "c27321d9-0574-5035-807b-f59d2c89b15c" +JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819" +JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899" +Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" +Revise = "295af30f-e4ad-537b-8983-00126c2a3abe" +Tokenize = "0796e94c-ce3b-5d07-9a54-7f471281c624" diff --git a/wrap/libBNNS.toml b/wrap/libBNNS.toml new file mode 100644 index 0000000..6f2fe33 --- /dev/null +++ b/wrap/libBNNS.toml @@ -0,0 +1,21 @@ +[general] +library_name = "Symbol(\"/System/Library/Frameworks/Accelerate.framework/Frameworks/vecLib.framework/vecLib\")" +output_file_path = "../lib/BNNS/libBNNS.jl" + +minimum_macos_supported = "13" + +generate_isystem_symbols = false + +output_ignorelist = ["^(?i:(?!BNNS).)+$"] + +[codegen] +# use_ccall_macro = true +always_NUL_terminated_string = true + +version_function = "AppleAccelerate.get_macos_version()" + +[codegen.macro] +# it's highly recommended to set this entry to "basic". +# if you'd like to skip all of the macros, please set this entry to "disable". +# if you'd like to translate function-like macros to Julia, please set this entry to "aggressive". +macro_mode = "disable" diff --git a/wrap/wrap.jl b/wrap/wrap.jl new file mode 100644 index 0000000..3277fab --- /dev/null +++ b/wrap/wrap.jl @@ -0,0 +1,195 @@ +# Not necessary, but removes some errors that don't seem to affect the output +using Clang_jll +Clang_jll.libclang = "/Applications/Xcode.app/Contents/Frameworks/libclang.dylib" + +using Clang.Generators +using Clang +using Glob +using JLD2 +using JuliaFormatter +using Logging + +# Use system SDK +SDK_PATH = `xcrun --show-sdk-path` |> open |> readchomp |> String + +main(name::AbstractString; kwargs...) = main([name]; kwargs...) +function main(names::AbstractVector=["all"]; sdk_path=SDK_PATH) + path_to_framework(framework) = joinpath(sdk_path, "System/Library/Frameworks/","Accelerate.framework","Frameworks",framework*".framework","Headers") + path_to_mps_framework(framework) = joinpath(sdk_path, "System/Library/Frameworks/","Accelerate.framework","Frameworks",framework*".framework","Headers") + + defines = [] + + ctxs = [] + + + if "all" in names || "libBNNS" in names + fwpath = path_to_framework("vecLib") + tctx = wrap("libBNNS", joinpath(fwpath, "vecLib.h"); defines) + push!(ctxs, tctx) + end + + # if "all" in names || "veclib" in names + # fwpath = path_to_framework("vecLib") + # tctx = wrap("vecLib", joinpath(fwpath, "vecLib.h"); defines) + # push!(ctxs, tctx) + # end + + # if "all" in names || "vimage" in names + # fwpath = path_to_framework("vImage") + # tctx = wrap("vImage", joinpath(fwpath, "vImage.h"); defines) + # push!(ctxs, tctx) + # end + + return ctxs +end + +function wrap(name, headers; defines=[]) + @info "Wrapping $name" + + options = load_options(joinpath(@__DIR__, "$(name).toml")) + + args = [ + "-x","objective-c", + "-isysroot", SDK_PATH, + "-fblocks", + "-fregister-global-dtors-with-atexit", + "-fgnuc-version=4.2.1", + "-fobjc-runtime=macosx-15.0.0", + "-fobjc-exceptions", + "-fexceptions", + "-fmax-type-align=16", + "-fcommon", + "-DNS_FORMAT_ARGUMENT(A)=", + "-D__GCC_HAVE_DWARF2_CFI_ASM=1", + ] + + for define in defines + if isa(define, Pair) + append!(args, ["-D", "$(first(define))=$(last(define))"]) + else + append!(args, ["-D", "$define"]) + end + end + + @info "Creating context" + ctx = create_objc_context(headers, args, options) + + @info "Building no printing" + build!(ctx, BUILDSTAGE_NO_PRINTING) + + rewriter!(ctx, options) + + @info "Building only printing" + build!(ctx, BUILDSTAGE_PRINTING_ONLY) + + output_file = options["general"]["output_file_path"] + + # prepend "autogenerated, do not edit!" comment + output_data = read(output_file, String) + open(output_file, "w") do io + println(io, """# This file is automatically generated. Do not edit! + # To re-generate, execute res/wrap/wrap.jl""") + println(io) + print(io, output_data) + end + + format_file(output_file, YASStyle()) + + return ctx +end + +# Uses the same passes as with C, but with some other changes +create_objc_context(header::AbstractString, args=String[], ops=Dict()) = create_objc_context([header], args, ops) +function create_objc_context(headers::Vector, args::Vector=String[], options::Dict=Dict()) + system_dirs = [ + SDK_PATH, + "/Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain" + ] + + regen = if haskey(options, "general") && haskey(options["general"], "regenerate_dependent_headers") + options["general"]["regenerate_dependent_headers"] + else + false + end + + # Since the framework we're wrapping is a system header, + # find all dependent headers, then remove all but the relevant ones + # also temporarily disable logging + dep_headers_fname = if haskey(options, "general") && haskey(options["general"], "library_name") + splitext(splitpath(options["general"]["output_file_path"])[end])[1]*".JLD2" + else + nothing + end + Base.CoreLogging._min_enabled_level[] = Logging.Info+1 + dependent_headers = if !regen && !isnothing(dep_headers_fname) && isfile(dep_headers_fname) + JLD2.load(dep_headers_fname, "dep_headers") + else + all_headers = find_dependent_headers(headers,args,[]) + dep_headers = Vector{eltype(all_headers)}(undef, 0) + for header in headers + target_framework = "/"*joinpath(Sys.splitpath(header)[end-2:end-1]) + dep_headers = append!(dep_headers, filter(s -> occursin(target_framework, s), all_headers)) + end + if haskey(options, "general") && haskey(options["general"], "extra_target_headers") + append!(dep_headers, options["general"]["extra_target_headers"]) + end + regen || JLD2.@save dep_headers_fname dep_headers + dep_headers + end + Base.CoreLogging._min_enabled_level[] = Logging.Debug + + ctx = Context(; options) + + @info "Parsing headers..." + parse_headers!(ctx, headers, args) + + Generators.add_default_passes!(ctx, options, system_dirs, dependent_headers) +end + +function rewriter!(ctx, options) + for node in get_nodes(ctx.dag) + if haskey(options, "api") + nodetype = typeof(node) + if nodetype <: Generators.ExprNode{<:Generators.AbstractStructNodeType} + expr = node.exprs[1] + structName = string(node.id) + + if haskey(options["api"], structName) && haskey(options["api"][structName], "constructor") + expr = node.exprs[1] + con = options["api"][structName]["constructor"] |> Meta.parse + + if con.head == :(=) && con.args[2] isa Expr && con.args[2].head == :block && + con.args[2].args[1] isa LineNumberNode && con.args[2].args[2].head == :call + con.args[2] = con.args[2].args[2] + end + push!(expr.args[3].args, con) + end + elseif nodetype <: Generators.ExprNode{<:Generators.AbstractObjCObjNodeType} + expr = node.exprs[1] + className = string(node.id) + if haskey(options["api"], className) + if haskey(options["api"][className], "immutable") + expr = node.exprs[1] + con = options["api"][className]["immutable"] + + expr.args[3].args[2] = con + end + if haskey(options["api"][className], "override_supertype") + expr2 = if expr.head == :macrocall && first(expr.args) == Symbol("@static") + expr.args[3].args[2].args[1].args[4] + else + expr.args[4] + end + typ = options["api"][className]["override_supertype"] |> Meta.parse + + expr2.args[2] = typ + end + end + end + end + end +end + +if abspath(PROGRAM_FILE) == @__FILE__ + main() +end