Skip to content

Commit fc7bfa5

Browse files
authored
Merge pull request #69 from christiangnrd/add_functions
Add some more functions and improve broadcasting support
2 parents 9355d11 + f5c2d71 commit fc7bfa5

File tree

4 files changed

+226
-7
lines changed

4 files changed

+226
-7
lines changed

src/AppleAccelerate.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,9 @@ function __init__()
9292
end
9393

9494
if Sys.isapple()
95+
include("Util.jl")
9596
include("Array.jl")
9697
include("DSP.jl")
97-
include("Util.jl")
9898
end
9999

100100
end # module

src/Array.jl

Lines changed: 135 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,9 @@ for (T, suff) in ((Float64, ""), (Float32, "f"))
116116
Base.copy(bc::Base.Broadcast.Broadcasted{Style, Axes, typeof($f), Tuple{Array{$T, N}}}) where {Style, Axes, N} = ($f)(bc.args...)
117117
Base.copyto!(dest::Array{$T, N}, bc::Base.Broadcast.Broadcasted{Style, Axes, typeof($f), Tuple{Array{$T, N}}}) where {Style, Axes, N} = ($f!)(dest, bc.args...)
118118
end
119+
if T == Float32
120+
@eval Base.broadcasted(::typeof($f), arg::Union{Array{F,N},Base.Broadcast.Broadcasted}) where {N,F<:Union{Float32,Float64}} = ($f)(maybecopy(arg))
121+
end
119122
end
120123
for (f, fa) in (twoarg_funcs...,(:pow,:pow))
121124
f! = Symbol("$(f)!")
@@ -124,14 +127,19 @@ for (T, suff) in ((Float64, ""), (Float32, "f"))
124127
Base.copy(bc::Base.Broadcast.Broadcasted{Style, Axes, typeof($f), Tuple{Array{$T, N},Array{$T,N}}}) where {Style, Axes, N} = ($f)(bc.args...)
125128
Base.copyto!(dest::Array{$T, N}, bc::Base.Broadcast.Broadcasted{Style, Axes, typeof($f), Tuple{Array{$T,N},Array{$T,N}}}) where {Style, Axes, N} = ($f!)(dest, bc.args...)
126129
end
130+
if T == Float32
131+
@eval Base.broadcasted(::typeof($f), arg1::Union{Array{F, N},Base.Broadcast.Broadcasted}, arg2::Union{Array{F, N},Base.Broadcast.Broadcasted}) where {N,F<:Union{Float32,Float64}} = ($f)(maybecopy(arg1), maybecopy(arg2))
132+
end
127133
end
128134
end
129135

130136
# Functions over single vectors that return scalars/tuples
131137
for (T, suff) in ((Float32, ""), (Float64, "D"))
132138

133139
for (f, fa) in ((:maximum, :maxv), (:minimum, :minv), (:mean, :meanv),
134-
(:meansqr, :measqv), (:meanmag, :meamgv), (:sum, :sve))
140+
(:meanmag, :meamgv), (:meansqr, :measqv), (:meanssqr, :mvessq),
141+
(:sum, :sve), (:summag, :svemg), (:sumsqr, :svesq),
142+
(:sumssqr, :svs))
135143
@eval begin
136144
function ($f)(X::Vector{$T})
137145
val = Ref{$T}(0.0)
@@ -192,7 +200,133 @@ for (T, suff) in ((Float32, ""), (Float64, "D"))
192200
return result
193201
end
194202
end
203+
204+
@eval begin
205+
# Broadcasting override such that f.(X) turns into f(X)
206+
Base.copy(bc::Base.Broadcast.Broadcasted{Style, Axes, typeof($f), Tuple{Array{$T, N},Array{$T,N}}}) where {Style, Axes, N} = ($f)(bc.args...)
207+
Base.copyto!(dest::Array{$T, N}, bc::Base.Broadcast.Broadcasted{Style, Axes, typeof($f), Tuple{Array{$T,N},Array{$T,N}}}) where {Style, Axes, N} = ($f!)(dest, bc.args...)
208+
Base.broadcasted(::typeof($f), arg1::Union{Array{$T, N},Base.Broadcast.Broadcasted}, arg2::Union{Array{$T, N},Base.Broadcast.Broadcasted}) where {N} = ($f)(maybecopy(arg1), maybecopy(arg2))
209+
end
195210
end
196211
end
197212

213+
# Element-wise operations over a vector and a scalar
214+
for (T, suff) in ((Float32, ""), (Float64, "D"))
215+
216+
for (f, name) in ((:vsadd, "addition"), (:vsdiv, "division"), (:vsmul, "multiplication"))
217+
f! = Symbol("$(f)!")
218+
219+
@eval begin
220+
@doc """
221+
`$($f!)(result::Vector{$($T)}, X::Vector{$($T)}, c::$($T))`
222+
223+
Implements vector-scalar **$($name)** over **Vector{$($T)}** and $($T) and overwrites
224+
the result vector with computed value. *Returns:* **Vector{$($T)}** `result`
225+
""" ->
226+
function ($f!)(result::Vector{$T}, X::Vector{$T}, c::$T)
227+
ccall(($(string("vDSP_", f, suff), libacc)), Cvoid,
228+
(Ptr{$T}, Int64, Ptr{$T}, Ptr{$T}, Int64, UInt64),
229+
X, 1, Ref(c), result, 1, length(result))
230+
return result
231+
end
232+
end
233+
234+
@eval begin
235+
@doc """
236+
`$($f)(X::Vector{$($T)}, c::$($T))`
237+
238+
Implements vector-scalar **$($name)** over **Vector{$($T)}** and $($T). Allocates
239+
memory to store result. *Returns:* **Vector{$($T)}**
240+
""" ->
241+
function ($f)(X::Vector{$T}, c::$T)
242+
result = similar(X)
243+
($f!)(result, X, c)
244+
return result
245+
end
246+
end
247+
end
248+
f = :vssub
249+
f! = Symbol("$(f)!")
250+
251+
@eval begin
252+
@doc """
253+
`$($f!)(result::Vector{$($T)}, X::Vector{$($T)}, c::$($T))`
254+
255+
Implements vector-scalar **subtraction** over **Vector{$($T)}** and $($T) and overwrites
256+
the result vector with computed value. *Returns:* **Vector{$($T)}** `result`
257+
""" ->
258+
function ($f!)(result::Vector{$T}, X::Vector{$T}, c::$T)
259+
ccall(($(string("vDSP_vsadd", suff), libacc)), Cvoid,
260+
(Ptr{$T}, Int64, Ptr{$T}, Ptr{$T}, Int64, UInt64),
261+
X, 1, Ref(-c), result, 1, length(result))
262+
return result
263+
end
264+
end
265+
266+
@eval begin
267+
@doc """
268+
`$($f)(X::Vector{$($T)}, c::$($T))`
269+
270+
Implements vector-scalar **subtraction** over **Vector{$($T)}** and $($T). Allocates
271+
memory to store result. *Returns:* **Vector{$($T)}**
272+
""" ->
273+
function ($f)(X::Vector{$T}, c::$T)
274+
result = similar(X)
275+
($f!)(result, X, c)
276+
return result
277+
end
278+
end
279+
280+
f = :svsub
281+
f! = Symbol("$(f)!")
282+
283+
@eval begin
284+
@doc """
285+
`$($f!)(result::Vector{$($T)}, X::Vector{$($T)}, c::$($T))`
286+
287+
Implements vector-scalar **subtraction** over $($T) and **Vector{$($T)}** and overwrites
288+
the result vector with computed value. *Returns:* **Vector{$($T)}** `result`
289+
""" ->
290+
function ($f!)(result::Vector{$T}, X::Vector{$T}, c::$T)
291+
ccall(($(string("vDSP_vsadd", suff), libacc)), Cvoid,
292+
(Ptr{$T}, Int64, Ptr{$T}, Ptr{$T}, Int64, UInt64),
293+
-X, 1, Ref(c), result, 1, length(result))
294+
return result
295+
end
296+
end
297+
298+
@eval begin
299+
@doc """
300+
`$($f)(X::Vector{$($T), c::$($T)})`
301+
302+
Implements vector-scalar **subtraction** over $($T) and **Vector{$($T)}**. Allocates
303+
memory to store result. *Returns:* **Vector{$($T)}**
304+
""" ->
305+
function ($f)(X::Vector{$T}, c::$T)
306+
result = similar(X)
307+
($f!)(result, X, c)
308+
return result
309+
end
310+
end
311+
312+
for f in (:vsadd, :vssub, :vsdiv, :vsmul)
313+
f! = Symbol("$(f)!")
314+
@eval begin
315+
# Broadcasting override such that f.(X) turns into f(X)
316+
Base.copy(bc::Base.Broadcast.Broadcasted{Style, Axes, typeof($f), Tuple{Array{$T, N},$T}}) where {Style, Axes, N} = ($f)(bc.args...)
317+
Base.copyto!(dest::Array{$T, N}, bc::Base.Broadcast.Broadcasted{Style, Axes, typeof($f), Tuple{Array{$T,N},$T}}) where {Style, Axes, N} = ($f!)(dest, bc.args...)
318+
Base.broadcasted(::typeof($f), arg1::Union{Array{$T, N},Base.Broadcast.Broadcasted}, arg2::$T) where {N} = ($f)(maybecopy(arg1), arg2)
319+
end
320+
end
321+
322+
f = :svsub
323+
f! = Symbol("$(f)!")
324+
325+
@eval begin
326+
# Broadcasting override such that f.(X) turns into f(X)
327+
Base.copy(bc::Base.Broadcast.Broadcasted{Style, Axes, typeof($f), Tuple{Array{$T, N}, $T}}) where {Style, Axes, N} = ($f)(bc.args...)
328+
Base.copyto!(dest::Array{$T, N}, bc::Base.Broadcast.Broadcasted{Style, Axes, typeof($f), Tuple{Array{$T,N}, $T}}) where {Style, Axes, N} = ($f!)(dest, bc.args...)
329+
Base.broadcasted(::typeof($f), arg1::Union{Array{$T, N},Base.Broadcast.Broadcasted}, arg2::$T) where {N} = ($f)(maybecopy(arg1), arg2)
330+
end
331+
end
198332

src/Util.jl

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,13 @@
22

33
tupletypelength(a)=length(a.parameters)
44

5+
@inline maybecopy(x::T) where {T <: Base.Broadcast.Broadcasted} = copy(x)
6+
@inline maybecopy(x::T) where {T <: Array} = x
7+
8+
const OPS = Dict{Symbol,Tuple{Symbol, Symbol, Symbol}}(:+ => (:vadd, :vsadd, :vsadd),
9+
:- => (:vsub, :vssub, :svsub),
10+
:* => (:vadd, :vsmul, :vsmul),
11+
:/ => (:vadd, :vsdiv, :vsdiv),)
512

613
macro replaceBase(fs...)
714
b = Expr(:block)
@@ -28,16 +35,29 @@ macro replaceBase(fs...)
2835
e = quote
2936
(Base.$f)(X::Array{T}) where {T <: Union{Float64,Float32}} = ($fa)(X)
3037
(Base.$f)(X::Union{Float64,Float32}) = ($fa)([X])[1]
38+
Base.broadcasted(::typeof(Base.$f), arg::Union{Array{F,N},Base.Broadcast.Broadcasted}) where {N,F<:Union{Float32,Float64}} = ($fa)(maybecopy(arg))
3139
end
3240
arg_consumed = true
3341
end
34-
if fa in (:copysign,:atan,:pow,:rem,:div_float, :vadd, :vsub, :vmul)
42+
if fa in (:copysign,:atan,:pow,:rem)
3543
e = quote
3644
(Base.$f)(X::Array{T},Y::Array{T}) where {T <: Union{Float32,Float64}} = ($fa)(X,Y)
3745
(Base.$f)(X::T,Y::T) where {T <: Union{Float32,Float64}} = ($fa)([X],[Y])[1]
46+
Base.broadcasted(::typeof(Base.$f), arg1::Union{Array{F, N},Base.Broadcast.Broadcasted}, arg2::Union{Array{F, N},Base.Broadcast.Broadcasted}) where {N,F<:Union{Float32,Float64}} = ($fa)(maybecopy(arg1), maybecopy(arg2))
3847
end
3948
arg_consumed = true
4049
end
50+
if f in (:+,:-,:*,:/)
51+
e = quote
52+
(Base.$f)(X::Array{T},Y::Array{T}) where {T <: Union{Float32,Float64}} = ($(OPS[f][1]))(X,Y)
53+
(Base.$f)(X::T,Y::T) where {T <: Union{Float32,Float64}} = ($(OPS[f][1]))([X],[Y])[1]
54+
Base.broadcasted(::typeof(Base.$f), arg1::Union{Array{F, N},Base.Broadcast.Broadcasted}, arg2::Union{Array{F, N},Base.Broadcast.Broadcasted}) where {N,F<:Union{Float32,Float64}} = ($(OPS[f][1]))(maybecopy(arg1), maybecopy(arg2))
55+
56+
Base.broadcasted(::typeof(Base.$f), arg1::Union{Array{T, N},Base.Broadcast.Broadcasted}, arg2::T) where {N, T <: Union{Float32,Float64}} = ($(OPS[f][2]))(maybecopy(arg1), arg2)
57+
Base.broadcasted(::typeof(Base.$f), arg1::T, arg2::Union{Array{T, N},Base.Broadcast.Broadcasted}) where {N, T <: Union{Float32,Float64}} = ($(OPS[f][3]))(maybecopy(arg2), arg1)
58+
end
59+
arg_consumed = true
60+
end
4161
if !arg_consumed
4262
error("Function $f not defined by AppleAccelerate.jl")
4363
end

test/runtests.jl

Lines changed: 69 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,62 @@ end
1010
Random.seed!(7)
1111
N = 1_000
1212

13+
@testset "AppleAccelerate.jl" begin
1314
for T in (Float32, Float64)
1415
@testset "Element-wise Operators::$T" begin
1516
X::Vector{T} = randn(N)
1617
Y::Vector{T} = randn(N)
18+
Z::Vector{T} = similar(X)
19+
# Vector-vector
1720
@test (X .+ Y) AppleAccelerate.vadd(X, Y)
1821
@test (X .- Y) AppleAccelerate.vsub(X, Y)
1922
@test (X .* Y) AppleAccelerate.vmul(X, Y)
2023
@test (X ./ Y) AppleAccelerate.vdiv(X, Y)
24+
25+
# Vector-vector non-allocating
26+
AppleAccelerate.vadd!(Z, X, Y)
27+
@test (X .+ Y) Z
28+
AppleAccelerate.vsub!(Z, X, Y)
29+
@test (X .- Y) Z
30+
AppleAccelerate.vmul!(Z, X, Y)
31+
@test (X .* Y) Z
32+
AppleAccelerate.vdiv!(Z, X, Y)
33+
@test (X ./ Y) Z
34+
35+
# Vector-vector broadcasting
36+
@test (X .+ Y) AppleAccelerate.vadd.(X, Y)
37+
@test (X .- Y) AppleAccelerate.vsub.(X, Y)
38+
@test (X .* Y) AppleAccelerate.vmul.(X, Y)
39+
@test (X ./ Y) AppleAccelerate.vdiv.(X, Y)
40+
41+
#Vector-scalar
42+
c::T = randn()
43+
@test (X .+ c) AppleAccelerate.vsadd.(X, c)
44+
@test (X .- c) AppleAccelerate.vssub.(X, c)
45+
@test (c .- X) AppleAccelerate.svsub.(X, c)
46+
@test (X .* c) AppleAccelerate.vsmul.(X, c)
47+
@test (X ./ c) AppleAccelerate.vsdiv.(X, c)
48+
49+
#Vector-scalar non-allocating
50+
AppleAccelerate.vsadd!(Y, X, c)
51+
@test (X .+ c) Y
52+
AppleAccelerate.vssub!(Y, X, c)
53+
@test (X .- c) Y
54+
AppleAccelerate.svsub!(Y, X, c)
55+
@test (c .- X) Y
56+
AppleAccelerate.vsmul!(Y, X, c)
57+
@test (X .* c) Y
58+
AppleAccelerate.vsdiv!(Y, X, c)
59+
@test (X ./ c) Y
60+
61+
#Vector-scalar broadcasting
62+
@test (X .+ c) AppleAccelerate.vsadd.(X, c)
63+
@test (X .- c) AppleAccelerate.vssub.(X, c)
64+
@test (c .- X) AppleAccelerate.svsub.(X, c)
65+
@test (X .* c) AppleAccelerate.vsmul.(X, c)
66+
@test (X ./ c) AppleAccelerate.vsdiv.(X, c)
67+
68+
@test (X .+ Y .+ Y) AppleAccelerate.vadd.(X, Y .+ Y)
2169
end
2270
end
2371

@@ -198,12 +246,28 @@ for T in (Float32, Float64)
198246
@test fa(X)[2] fb(X)[2]
199247
end
200248

249+
@testset "Testing meanmag::$T" begin
250+
@test AppleAccelerate.meanmag(X) mean(abs, X)
251+
end
252+
201253
@testset "Testing meansqr::$T" begin
202-
@test AppleAccelerate.meansqr(X) mean(X .*X)
254+
@test AppleAccelerate.meansqr(X) mean(X .* X)
203255
end
204256

205-
@testset "Testing meanmag::$T" begin
206-
@test AppleAccelerate.meanmag(X) mean(abs.(X))
257+
@testset "Testing meanssqr::$T" begin
258+
@test AppleAccelerate.meanssqr(X) mean(X .* abs.(X))
259+
end
260+
261+
@testset "Testing summag::$T" begin
262+
@test AppleAccelerate.summag(X) sum(abs, X)
263+
end
264+
265+
@testset "Testing sumsqr::$T" begin
266+
@test AppleAccelerate.sumsqr(X) sum(abs2, X)
267+
end
268+
269+
@testset "Testing sumssqr::$T" begin
270+
@test AppleAccelerate.sumssqr(X) sum(X .* abs.(X))
207271
end
208272

209273
end
@@ -270,6 +334,7 @@ Y::Array{T} = abs.(randn(N))
270334
@test X ./ Y == AppleAccelerate.div_float(X, Y)
271335
end
272336
=#
337+
end
273338

274339
if AppleAccelerate.get_macos_version() < v"13.3"
275340
@info("AppleAccelerate.jl needs macOS >= 13.3 for BLAS forwarding. Not testing forwarding capabilities.")
@@ -338,6 +403,6 @@ end
338403
end
339404

340405
run(`$(Base.julia_cmd()) --project=$(Base.active_project()) $(dir)/runtests.jl LinearAlgebra/blas LinearAlgebra/lapack`)
341-
end;
406+
end;
342407
end
343408

0 commit comments

Comments
 (0)