Skip to content

Commit 4585ca9

Browse files
authored
Avoid the exception branch in expand (#518)
1 parent e5ef261 commit 4585ca9

File tree

5 files changed

+72
-33
lines changed

5 files changed

+72
-33
lines changed

src/KernelAbstractions.jl

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ synchronize(backend)
5050
```
5151
"""
5252
macro kernel(expr)
53-
__kernel(expr, #=generate_cpu=# true, #=force_inbounds=# false)
53+
return __kernel(expr, #=generate_cpu=# true, #=force_inbounds=# false)
5454
end
5555

5656
"""
@@ -68,7 +68,7 @@ This allows for two different configurations:
6868
"""
6969
macro kernel(ex...)
7070
if length(ex) == 1
71-
__kernel(ex[1], true, false)
71+
return __kernel(ex[1], true, false)
7272
else
7373
generate_cpu = true
7474
force_inbounds = false
@@ -88,7 +88,7 @@ macro kernel(ex...)
8888
)
8989
end
9090
end
91-
__kernel(ex[end], generate_cpu, force_inbounds)
91+
return __kernel(ex[end], generate_cpu, force_inbounds)
9292
end
9393
end
9494

@@ -206,7 +206,7 @@ a tuple corresponding to kernel configuration. In order to get
206206
the total size you can use `prod(@groupsize())`.
207207
"""
208208
macro groupsize()
209-
quote
209+
return quote
210210
$groupsize($(esc(:__ctx__)))
211211
end
212212
end
@@ -218,7 +218,7 @@ Query the ndrange on the backend. This function returns
218218
a tuple corresponding to kernel configuration.
219219
"""
220220
macro ndrange()
221-
quote
221+
return quote
222222
$size($ndrange($(esc(:__ctx__))))
223223
end
224224
end
@@ -232,7 +232,7 @@ macro localmem(T, dims)
232232
# Stay in sync with CUDAnative
233233
id = gensym("static_shmem")
234234

235-
quote
235+
return quote
236236
$SharedMemory($(esc(T)), Val($(esc(dims))), Val($(QuoteNode(id))))
237237
end
238238
end
@@ -253,7 +253,7 @@ macro private(T, dims)
253253
if dims isa Integer
254254
dims = (dims,)
255255
end
256-
quote
256+
return quote
257257
$Scratchpad($(esc(:__ctx__)), $(esc(T)), Val($(esc(dims))))
258258
end
259259
end
@@ -265,7 +265,7 @@ Creates a private local of `mem` per item in the workgroup. This can be safely u
265265
across [`@synchronize`](@ref) statements.
266266
"""
267267
macro private(expr)
268-
esc(expr)
268+
return esc(expr)
269269
end
270270

271271
"""
@@ -275,7 +275,7 @@ end
275275
that span workitems, or are reused across `@synchronize` statements.
276276
"""
277277
macro uniform(value)
278-
esc(value)
278+
return esc(value)
279279
end
280280

281281
"""
@@ -286,7 +286,7 @@ from each thread in the workgroup are visible in from all other threads in the
286286
workgroup.
287287
"""
288288
macro synchronize()
289-
quote
289+
return quote
290290
$__synchronize()
291291
end
292292
end
@@ -303,7 +303,7 @@ workgroup. `cond` is not allowed to have any visible sideffects.
303303
- `CPU`: This synchronization will always occur.
304304
"""
305305
macro synchronize(cond)
306-
quote
306+
return quote
307307
$(esc(cond)) && $__synchronize()
308308
end
309309
end
@@ -328,7 +328,7 @@ end
328328
```
329329
"""
330330
macro context()
331-
esc(:(__ctx__))
331+
return esc(:(__ctx__))
332332
end
333333

334334
"""
@@ -368,7 +368,7 @@ macro print(items...)
368368
end
369369
end
370370

371-
quote
371+
return quote
372372
$__print($(map(esc, args)...))
373373
end
374374
end
@@ -424,7 +424,7 @@ macro index(locale, args...)
424424
end
425425

426426
index_function = Symbol(:__index_, locale, :_, indexkind)
427-
Expr(:call, GlobalRef(KernelAbstractions, index_function), esc(:__ctx__), map(esc, args)...)
427+
return Expr(:call, GlobalRef(KernelAbstractions, index_function), esc(:__ctx__), map(esc, args)...)
428428
end
429429

430430
###
@@ -662,7 +662,7 @@ struct Kernel{Backend, WorkgroupSize <: _Size, NDRange <: _Size, Fun}
662662
end
663663

664664
function Base.similar(kernel::Kernel{D, WS, ND}, f::F) where {D, WS, ND, F}
665-
Kernel{D, WS, ND, F}(kernel.backend, f)
665+
return Kernel{D, WS, ND, F}(kernel.backend, f)
666666
end
667667

668668
workgroupsize(::Kernel{D, WorkgroupSize}) where {D, WorkgroupSize} = WorkgroupSize
@@ -772,7 +772,7 @@ end
772772
push!(args, item)
773773
end
774774

775-
quote
775+
return quote
776776
print($(args...))
777777
end
778778
end

src/cpu.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ function (obj::Kernel{CPU})(args...; ndrange = nothing, workgroupsize = nothing)
4343
end
4444

4545
__run(obj, ndrange, iterspace, args, dynamic, obj.backend.static)
46+
return nothing
4647
end
4748

4849
const CPU_GRAINSIZE = 1024 # Vectorization, 4x unrolling, minimal grain size
@@ -161,15 +162,15 @@ end
161162

162163
@inline function __index_Global_Linear(ctx, idx::CartesianIndex)
163164
I = @inbounds expand(__iterspace(ctx), __groupindex(ctx), idx)
164-
@inbounds LinearIndices(__ndrange(ctx))[I]
165+
return @inbounds LinearIndices(__ndrange(ctx))[I]
165166
end
166167

167168
@inline function __index_Local_Cartesian(_, idx::CartesianIndex)
168169
return idx
169170
end
170171

171172
@inline function __index_Group_Cartesian(ctx, ::CartesianIndex)
172-
__groupindex(ctx)
173+
return __groupindex(ctx)
173174
end
174175

175176
@inline function __index_Global_Cartesian(ctx, idx::CartesianIndex)
@@ -190,7 +191,7 @@ end
190191
# CPU implementation of shared memory
191192
###
192193
@inline function SharedMemory(::Type{T}, ::Val{Dims}, ::Val) where {T, Dims}
193-
MArray{__size(Dims), T}(undef)
194+
return MArray{__size(Dims), T}(undef)
194195
end
195196

196197
###
@@ -211,7 +212,7 @@ end
211212
# https://github.com/JuliaLang/julia/issues/39308
212213
@inline function aview(A, I::Vararg{Any, N}) where {N}
213214
J = Base.to_indices(A, I)
214-
Base.unsafe_view(Base._maybe_reshape_parent(A, Base.index_ndims(J...)), J...)
215+
return Base.unsafe_view(Base._maybe_reshape_parent(A, Base.index_ndims(J...)), J...)
215216
end
216217

217218
@inline function Base.getindex(A::ScratchArray{N}, idx) where {N}

src/macros.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ function find_return(stmt)
66
result |= @capture(expr, return x_)
77
expr
88
end
9-
result
9+
return result
1010
end
1111

1212
# XXX: Proper errors
@@ -103,6 +103,7 @@ function transform_gpu!(def, constargs, force_inbounds)
103103
Expr(:block, let_constargs...),
104104
body,
105105
)
106+
return nothing
106107
end
107108

108109
# The hard case, transform the function for CPU execution
@@ -137,6 +138,7 @@ function transform_cpu!(def, constargs, force_inbounds)
137138
Expr(:block, let_constargs...),
138139
Expr(:block, new_stmts...),
139140
)
141+
return nothing
140142
end
141143

142144
struct WorkgroupLoop
@@ -150,7 +152,7 @@ end
150152
is_sync(expr) = @capture(expr, @synchronize() | @synchronize(a_))
151153

152154
function is_scope_construct(expr::Expr)
153-
expr.head === :block # ||
155+
return expr.head === :block # ||
154156
# expr.head === :let
155157
end
156158

@@ -160,7 +162,7 @@ function find_sync(stmt)
160162
result |= is_sync(expr)
161163
expr
162164
end
163-
result
165+
return result
164166
end
165167

166168
# TODO proper handling of LineInfo

src/nditeration.jl

Lines changed: 43 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ abstract type _Size end
1313
struct DynamicSize <: _Size end
1414
struct StaticSize{S} <: _Size
1515
function StaticSize{S}() where {S}
16-
new{S::Tuple{Vararg{Int}}}()
16+
return new{S::Tuple{Vararg{Int}}}()
1717
end
1818
end
1919

@@ -51,11 +51,11 @@ struct NDRange{N, StaticBlocks, StaticWorkitems, DynamicBlock, DynamicWorkitems}
5151
workitems::DynamicWorkitems
5252

5353
function NDRange{N, B, W}() where {N, B, W}
54-
new{N, B, W, Nothing, Nothing}(nothing, nothing)
54+
return new{N, B, W, Nothing, Nothing}(nothing, nothing)
5555
end
5656

5757
function NDRange{N, B, W}(blocks, workitems) where {N, B, W}
58-
new{N, B, W, typeof(blocks), typeof(workitems)}(blocks, workitems)
58+
return new{N, B, W, typeof(blocks), typeof(workitems)}(blocks, workitems)
5959
end
6060
end
6161

@@ -78,19 +78,55 @@ Base.length(range::NDRange) = length(blocks(range))
7878
gidx = groupidx.I[I]
7979
(gidx - 1) * stride + idx.I[I]
8080
end
81-
CartesianIndex(nI)
81+
return CartesianIndex(nI)
82+
end
83+
84+
85+
"""
86+
assume(cond::Bool)
87+
88+
Assume that the condition `cond` is true. This is a hint to the compiler, possibly enabling
89+
it to optimize more aggressively.
90+
"""
91+
@inline assume(cond::Bool) = Base.llvmcall(
92+
(
93+
"""
94+
declare void @llvm.assume(i1)
95+
96+
define void @entry(i8) #0 {
97+
%cond = icmp eq i8 %0, 1
98+
call void @llvm.assume(i1 %cond)
99+
ret void
100+
}
101+
102+
attributes #0 = { alwaysinline }""", "entry",
103+
),
104+
Nothing, Tuple{Bool}, cond
105+
)
106+
107+
@inline function assume_nonzero(CI::CartesianIndices)
108+
return ntuple(Val(ndims(CI))) do I
109+
Base.@_inline_meta
110+
indices = CI.indices[I]
111+
assume(indices.stop > 0)
112+
end
82113
end
83114

84115
Base.@propagate_inbounds function expand(ndrange::NDRange, groupidx::Integer, idx::Integer)
85-
expand(ndrange, blocks(ndrange)[groupidx], workitems(ndrange)[idx])
116+
# this causes a exception branch and a div
117+
B = blocks(ndrange)
118+
W = workitems(ndrange)
119+
assume_nonzero(B)
120+
assume_nonzero(W)
121+
return expand(ndrange, B[groupidx], workitems(ndrange)[idx])
86122
end
87123

88124
Base.@propagate_inbounds function expand(ndrange::NDRange{N}, groupidx::CartesianIndex{N}, idx::Integer) where {N}
89-
expand(ndrange, groupidx, workitems(ndrange)[idx])
125+
return expand(ndrange, groupidx, workitems(ndrange)[idx])
90126
end
91127

92128
Base.@propagate_inbounds function expand(ndrange::NDRange{N}, groupidx::Integer, idx::CartesianIndex{N}) where {N}
93-
expand(ndrange, blocks(ndrange)[groupidx], idx)
129+
return expand(ndrange, blocks(ndrange)[groupidx], idx)
94130
end
95131

96132
"""

src/reflection.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ end
3434

3535

3636
function ka_code_llvm(kernel, argtypes; ndrange = nothing, workgroupsize = nothing, kwargs...)
37-
ka_code_llvm(stdout, kernel, argtypes; ndrange = ndrange, workgroupsize = nothing, kwargs...)
37+
return ka_code_llvm(stdout, kernel, argtypes; ndrange = ndrange, workgroupsize = nothing, kwargs...)
3838
end
3939

4040
function ka_code_llvm(io::IO, kernel, argtypes; ndrange = nothing, workgroupsize = nothing, kwargs...)
@@ -119,7 +119,7 @@ macro ka_code_typed(ex0...)
119119

120120
thecall = InteractiveUtils.gen_call_with_extracted_types_and_kwargs(__module__, :ka_code_typed, ex)
121121

122-
quote
122+
return quote
123123
local $(esc(args)) = $(old_args)
124124
# e.g. translate CuArray to CuBackendArray
125125
$(esc(args)) = map(x -> argconvert($kern, x), $(esc(args)))
@@ -152,7 +152,7 @@ macro ka_code_llvm(ex0...)
152152

153153
thecall = InteractiveUtils.gen_call_with_extracted_types_and_kwargs(__module__, :ka_code_llvm, ex)
154154

155-
quote
155+
return quote
156156
local $(esc(args)) = $(old_args)
157157

158158
if isa($kern, Kernel{G} where {G <: GPU})

0 commit comments

Comments
 (0)