Skip to content

Commit 60c753e

Browse files
committed
Merge branch 'sd/tuple_brc'
2 parents 683087e + aba90b6 commit 60c753e

File tree

3 files changed

+43
-50
lines changed

3 files changed

+43
-50
lines changed

src/broadcast.jl

Lines changed: 32 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ function broadcast_t(f::Any, ::Type{Any}, ::Any, ::Any, A::GPUArrays.GPUArray, a
7373
end
7474

7575
deref(x) = x
76-
deref(x::RefValue) = (x[],) # RefValue doesn't work with CUDAnative
76+
deref(x::RefValue) = (x[],) # RefValue doesn't work with CUDAnative so we use Tuple
7777

7878
function _broadcast!(
7979
func, out::GPUArray,
@@ -84,9 +84,9 @@ function _broadcast!(
8484
shape = UInt32.(size(out))
8585
args = (A, Bs...)
8686
descriptor_tuple = ntuple(length(args)) do i
87-
BroadcastDescriptor(args[i], keeps[i], Idefaults[i])
87+
BInfo(args[i], keeps[i], Idefaults[i])
8888
end
89-
gpu_call(broadcast_kernel!, out, (func, out, shape, UInt32(length(out)), descriptor_tuple, A, deref.(Bs)...))
89+
gpu_call(broadcast_kernel!, out, (func, out, shape, descriptor_tuple, (A, deref.(Bs)...)))
9090
out
9191
end
9292

@@ -97,9 +97,9 @@ function Base.foreach(func, over::GPUArray, Bs...)
9797
keeps, Idefaults = map_newindexer(shape, over, Bs)
9898
args = (over, Bs...)
9999
descriptor_tuple = ntuple(length(args)) do i
100-
BroadcastDescriptor(args[i], keeps[i], Idefaults[i])
100+
BInfo(args[i], keeps[i], Idefaults[i])
101101
end
102-
gpu_call(foreach_kernel, over, (func, shape, UInt32.(length(over)), descriptor_tuple, over, deref.(Bs)...))
102+
gpu_call(foreach_kernel, over, (func, shape, descriptor_tuple, (over, deref.(Bs)...)))
103103
return
104104
end
105105

@@ -116,51 +116,60 @@ arg_length(x::Tuple) = (UInt32(length(x)),)
116116
arg_length(x::GPUArray) = UInt32.(size(x))
117117
arg_length(x) = () # Scalar
118118

119-
abstract type BroadcastDescriptor{Typ} end
120-
121-
struct BroadcastDescriptorN{Typ, N} <: BroadcastDescriptor{Typ}
119+
struct BInfo{Typ, N}
122120
size::NTuple{N, UInt32}
123121
keep::NTuple{N, UInt32}
124122
idefault::NTuple{N, UInt32}
125123
end
126-
function BroadcastDescriptor(val::RefValue, keep, idefault)
127-
BroadcastDescriptorN{Tuple, 1}((UInt32(1),), (UInt32(0),), (UInt32(1),))
128-
end
129124

130-
function BroadcastDescriptor(val, keep, idefault)
125+
function BInfo(val, keep, idefault)
131126
N = length(keep)
132127
typ = Broadcast.containertype(val)
133-
BroadcastDescriptorN{typ, N}(arg_length(val), UInt32.(keep), UInt32.(idefault))
128+
BInfo{typ, N}(arg_length(val), UInt32.(keep), UInt32.(idefault))
134129
end
135130

136131
@propagate_inbounds @inline function _broadcast_getindex(
137-
::BroadcastDescriptor{Array}, A, I
132+
::BInfo{Array}, A, I
138133
)
139134
A[I]
140135
end
141136
@propagate_inbounds @inline function _broadcast_getindex(
142-
::BroadcastDescriptor{Tuple}, A, I
137+
::BInfo{Tuple}, A, I
143138
)
144139
A[I]
145140
end
146141
@propagate_inbounds @inline function _broadcast_getindex(
147-
::BroadcastDescriptor{Array}, A::Ref, I
142+
::BInfo{Array}, A::Ref, I
148143
)
149144
A[]
150145
end
151146

152147
@inline _broadcast_getindex(any, A, I) = A
153148

154-
for N = 0:10
149+
@inline function broadcast_kernel!(state, func, out, shape, descriptor, args)
150+
ilin = @linearidx(out, state)
151+
@inbounds out[ilin] = apply_broadcast(ilin, func, shape, descriptor, args)
152+
return
153+
end
154+
function foreach_kernel(state, func, shape, descriptor, args)
155+
ilin = @linearidx(args[1], state)
156+
apply_broadcast(ilin, func, shape, descriptor, args)
157+
return
158+
end
159+
160+
function mapidx_kernel(state, f, A, args)
161+
ilin = @linearidx(A, state)
162+
f(ilin, A, args...)
163+
return
164+
end
165+
166+
for N = 0:15
155167
nargs = N + 1
156168
inner_expr = []
157-
args = []
158169
valargs = []
159170
for i = 1:N
160-
Ai = Symbol("A_", i);
161171
val_i = Symbol("val_", i); I_i = Symbol("I_", i);
162172
desi = Symbol("deref_", i)
163-
push!(args, Ai)
164173
inner = quote
165174
# destructure the keeps and As tuples
166175
$desi = descriptor[$i]
@@ -172,51 +181,25 @@ for N = 0:10
172181
$desi.size
173182
)
174183
# extract array values
175-
@inbounds $val_i = _broadcast_getindex($desi, $Ai, $I_i)
184+
@inbounds $val_i = _broadcast_getindex($desi, args[$i], $I_i)
176185
end
177186
push!(inner_expr, inner)
178187
push!(valargs, val_i)
179188
end
180189
@eval begin
181-
182-
@inline function apply_broadcast(ilin, state, func, shape, len, descriptor, $(args...))
190+
@inline function apply_broadcast(ilin, func, shape, descriptor, args::NTuple{$N, Any})
183191
# this will hopefully get dead code removed,
184192
# if only arrays with linear index are involved, because I should be unused in that case
185193
I = gpu_ind2sub(shape, ilin)
186194
$(inner_expr...)
187195
# call the function and store the result
188196
func($(valargs...))
189197
end
190-
191-
@inline function broadcast_kernel!(state, func, B, shape, len, descriptor, $(args...))
192-
ilin = linear_index(state)
193-
if ilin <= len
194-
@inbounds B[ilin] = apply_broadcast(ilin, state, func, shape, len, descriptor, $(args...))
195-
end
196-
return
197-
end
198-
199-
function foreach_kernel(state, func, shape, len, descriptor, A, $(args...))
200-
ilin = linear_index(state)
201-
if ilin <= len
202-
apply_broadcast(ilin, state, func, shape, len, descriptor, A, $(args...))
203-
end
204-
return
205-
end
206-
207-
function mapidx_kernel(state, f, A, len, $(args...))
208-
i = linear_index(state)
209-
@inbounds if i <= len
210-
f(i, A, $(args...))
211-
end
212-
return
213-
end
214198
end
215-
216199
end
217200

218201
function mapidx(f, A::GPUArray, args::NTuple{N, Any}) where N
219-
gpu_call(mapidx_kernel, A, (f, A, UInt32(length(A)), args...))
202+
gpu_call(mapidx_kernel, A, (f, A, args))
220203
end
221204

222205
# don't do anything for empty tuples

src/testsuite/blas.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ function run_blas(Typ)
77
T = Typ{Float32}
88
@testset "matmul" begin
99
against_base(*, T, (5, 5), (5, 5))
10-
against_base(*, T, (5, 5), (5))
10+
against_base(*, T, (5, 5), (5,))
1111
against_base(A_mul_Bt, T, (5, 5), (5, 5))
1212
against_base(A_mul_Bt!, T, (10, 32), (10, 60), (32, 60))
1313
against_base(At_mul_B, T, (5, 5), (5, 5))

src/testsuite/broadcasting.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ function test_broadcast(Typ)
4848
gres = Typ(cres)
4949
gres .= test_idx.(gidx, Base.RefValue(gy))
5050
cres .= test_idx.(cidx, Base.RefValue(cy))
51+
@test Array(gres) == cres
5152
end
5253
@testset "Tuple" begin
5354
against_base(T, (3, N), (3, N), (N,), (N,), (N,)) do out, arr, a, b, c
@@ -88,6 +89,15 @@ function test_broadcast(Typ)
8889
@. u = uprev + dt*duprev + dt2*(fract*ku)
8990
end
9091
against_base((x) -> (-).(x), T, (2, 3))
92+
93+
against_base(T, dim, dim, dim, dim, dim, dim) do utilde, gA, k1, k2, k3, k4
94+
btilde1 = ET(1)
95+
btilde2 = ET(1)
96+
btilde3 = ET(1)
97+
btilde4 = ET(1)
98+
dt = ET(1)
99+
@. utilde = dt*(btilde1*k1 + btilde2*k2 + btilde3*k3 + btilde4*k4)
100+
end
91101
end
92102

93103
against_base((x) -> fill!(x, 1), T, (3,3))

0 commit comments

Comments
 (0)