Skip to content

Commit 9016013

Browse files
committed
use tuple args, fix JuliaGPU/CLArrays.jl#5
1 parent 01458c6 commit 9016013

File tree

3 files changed

+43
-50
lines changed

3 files changed

+43
-50
lines changed

src/broadcast.jl

Lines changed: 32 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ function broadcast_t(f::Any, ::Type{Any}, ::Any, ::Any, A::GPUArrays.GPUArray, a
7373
end
7474

7575
deref(x) = x
76-
deref(x::RefValue) = (x[],) # RefValue doesn't work with CUDAnative
76+
deref(x::RefValue) = (x[],) # RefValue doesn't work with CUDAnative so we use Tuple
7777

7878
function _broadcast!(
7979
func, out::GPUArray,
@@ -84,9 +84,9 @@ function _broadcast!(
8484
shape = UInt32.(size(out))
8585
args = (A, Bs...)
8686
descriptor_tuple = ntuple(length(args)) do i
87-
BroadcastDescriptor(args[i], keeps[i], Idefaults[i])
87+
BInfo(args[i], keeps[i], Idefaults[i])
8888
end
89-
gpu_call(broadcast_kernel!, out, (func, out, shape, UInt32(length(out)), descriptor_tuple, A, deref.(Bs)...))
89+
gpu_call(broadcast_kernel!, out, (func, out, shape, descriptor_tuple, (A, deref.(Bs)...)))
9090
out
9191
end
9292

@@ -97,9 +97,9 @@ function Base.foreach(func, over::GPUArray, Bs...)
9797
keeps, Idefaults = map_newindexer(shape, over, Bs)
9898
args = (over, Bs...)
9999
descriptor_tuple = ntuple(length(args)) do i
100-
BroadcastDescriptor(args[i], keeps[i], Idefaults[i])
100+
BInfo(args[i], keeps[i], Idefaults[i])
101101
end
102-
gpu_call(foreach_kernel, over, (func, shape, UInt32.(length(over)), descriptor_tuple, over, deref.(Bs)...))
102+
gpu_call(foreach_kernel, over, (func, shape, descriptor_tuple, (over, deref.(Bs)...)))
103103
return
104104
end
105105

@@ -109,51 +109,60 @@ arg_length(x::Tuple) = (UInt32(length(x)),)
109109
arg_length(x::GPUArray) = UInt32.(size(x))
110110
arg_length(x) = ()
111111

112-
abstract type BroadcastDescriptor{Typ} end
113-
114-
struct BroadcastDescriptorN{Typ, N} <: BroadcastDescriptor{Typ}
112+
struct BInfo{Typ, N}
115113
size::NTuple{N, UInt32}
116114
keep::NTuple{N, UInt32}
117115
idefault::NTuple{N, UInt32}
118116
end
119-
function BroadcastDescriptor(val::RefValue, keep, idefault)
120-
BroadcastDescriptorN{Tuple, 1}((UInt32(1),), (UInt32(0),), (UInt32(1),))
121-
end
122117

123-
function BroadcastDescriptor(val, keep, idefault)
118+
function BInfo(val, keep, idefault)
124119
N = length(keep)
125120
typ = Broadcast.containertype(val)
126-
BroadcastDescriptorN{typ, N}(arg_length(val), UInt32.(keep), UInt32.(idefault))
121+
BInfo{typ, N}(arg_length(val), UInt32.(keep), UInt32.(idefault))
127122
end
128123

129124
@propagate_inbounds @inline function _broadcast_getindex(
130-
::BroadcastDescriptor{Array}, A, I
125+
::BInfo{Array}, A, I
131126
)
132127
A[I]
133128
end
134129
@propagate_inbounds @inline function _broadcast_getindex(
135-
::BroadcastDescriptor{Tuple}, A, I
130+
::BInfo{Tuple}, A, I
136131
)
137132
A[I]
138133
end
139134
@propagate_inbounds @inline function _broadcast_getindex(
140-
::BroadcastDescriptor{Array}, A::Ref, I
135+
::BInfo{Array}, A::Ref, I
141136
)
142137
A[]
143138
end
144139

145140
@inline _broadcast_getindex(any, A, I) = A
146141

147-
for N = 0:10
142+
@inline function broadcast_kernel!(state, func, out, shape, descriptor, args)
143+
ilin = @linearidx(out, state)
144+
@inbounds out[ilin] = apply_broadcast(ilin, func, shape, descriptor, args)
145+
return
146+
end
147+
function foreach_kernel(state, func, shape, descriptor, args)
148+
ilin = @linearidx(args[1], state)
149+
apply_broadcast(ilin, func, shape, descriptor, args)
150+
return
151+
end
152+
153+
function mapidx_kernel(state, f, A, args)
154+
ilin = @linearidx(A, state)
155+
f(ilin, A, args...)
156+
return
157+
end
158+
159+
for N = 0:15
148160
nargs = N + 1
149161
inner_expr = []
150-
args = []
151162
valargs = []
152163
for i = 1:N
153-
Ai = Symbol("A_", i);
154164
val_i = Symbol("val_", i); I_i = Symbol("I_", i);
155165
desi = Symbol("deref_", i)
156-
push!(args, Ai)
157166
inner = quote
158167
# destructure the keeps and As tuples
159168
$desi = descriptor[$i]
@@ -165,51 +174,25 @@ for N = 0:10
165174
$desi.size
166175
)
167176
# extract array values
168-
@inbounds $val_i = _broadcast_getindex($desi, $Ai, $I_i)
177+
@inbounds $val_i = _broadcast_getindex($desi, args[$i], $I_i)
169178
end
170179
push!(inner_expr, inner)
171180
push!(valargs, val_i)
172181
end
173182
@eval begin
174-
175-
@inline function apply_broadcast(ilin, state, func, shape, len, descriptor, $(args...))
183+
@inline function apply_broadcast(ilin, func, shape, descriptor, args::NTuple{$N, Any})
176184
# this will hopefully get dead code removed,
177185
# if only arrays with linear index are involved, because I should be unused in that case
178186
I = gpu_ind2sub(shape, ilin)
179187
$(inner_expr...)
180188
# call the function and store the result
181189
func($(valargs...))
182190
end
183-
184-
@inline function broadcast_kernel!(state, func, B, shape, len, descriptor, $(args...))
185-
ilin = linear_index(state)
186-
if ilin <= len
187-
@inbounds B[ilin] = apply_broadcast(ilin, state, func, shape, len, descriptor, $(args...))
188-
end
189-
return
190-
end
191-
192-
function foreach_kernel(state, func, shape, len, descriptor, A, $(args...))
193-
ilin = linear_index(state)
194-
if ilin <= len
195-
apply_broadcast(ilin, state, func, shape, len, descriptor, A, $(args...))
196-
end
197-
return
198-
end
199-
200-
function mapidx_kernel(state, f, A, len, $(args...))
201-
i = linear_index(state)
202-
@inbounds if i <= len
203-
f(i, A, $(args...))
204-
end
205-
return
206-
end
207191
end
208-
209192
end
210193

211194
function mapidx(f, A::GPUArray, args::NTuple{N, Any}) where N
212-
gpu_call(mapidx_kernel, A, (f, A, UInt32(length(A)), args...))
195+
gpu_call(mapidx_kernel, A, (f, A, args))
213196
end
214197

215198
# don't do anything for empty tuples

src/testsuite/blas.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ function run_blas(Typ)
77
T = Typ{Float32}
88
@testset "matmul" begin
99
against_base(*, T, (5, 5), (5, 5))
10-
against_base(*, T, (5, 5), (5))
10+
against_base(*, T, (5, 5), (5,))
1111
against_base(A_mul_Bt, T, (5, 5), (5, 5))
1212
against_base(At_mul_B, T, (5, 5), (5, 5))
1313
end

src/testsuite/broadcasting.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ function test_broadcast(Typ)
4848
gres = Typ(cres)
4949
gres .= test_idx.(gidx, Base.RefValue(gy))
5050
cres .= test_idx.(cidx, Base.RefValue(cy))
51+
@test Array(gres) == cres
5152
end
5253
@testset "Tuple" begin
5354
against_base(T, (3, N), (3, N), (N,), (N,), (N,)) do out, arr, a, b, c
@@ -88,6 +89,15 @@ function test_broadcast(Typ)
8889
@. u = uprev + dt*duprev + dt2*(fract*ku)
8990
end
9091
against_base((x) -> (-).(x), T, (2, 3))
92+
93+
against_base(T, dim, dim, dim, dim, dim, dim) do utilde, gA, k1, k2, k3, k4
94+
btilde1 = ET(1)
95+
btilde2 = ET(1)
96+
btilde3 = ET(1)
97+
btilde4 = ET(1)
98+
dt = ET(1)
99+
@. utilde = dt*(btilde1*k1 + btilde2*k2 + btilde3*k3 + btilde4*k4)
100+
end
91101
end
92102

93103
against_base((x) -> fill!(x, 1), T, (3,3))

0 commit comments

Comments
 (0)