Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 16 additions & 16 deletions src/KernelAbstractions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ synchronize(backend)
```
"""
macro kernel(expr)
__kernel(expr, #=generate_cpu=# true, #=force_inbounds=# false)
return __kernel(expr, #=generate_cpu=# true, #=force_inbounds=# false)
end

"""
Expand All @@ -68,7 +68,7 @@ This allows for two different configurations:
"""
macro kernel(ex...)
if length(ex) == 1
__kernel(ex[1], true, false)
return __kernel(ex[1], true, false)
else
generate_cpu = true
force_inbounds = false
Expand All @@ -88,7 +88,7 @@ macro kernel(ex...)
)
end
end
__kernel(ex[end], generate_cpu, force_inbounds)
return __kernel(ex[end], generate_cpu, force_inbounds)
end
end

Expand Down Expand Up @@ -206,7 +206,7 @@ a tuple corresponding to kernel configuration. In order to get
the total size you can use `prod(@groupsize())`.
"""
macro groupsize()
quote
return quote
$groupsize($(esc(:__ctx__)))
end
end
Expand All @@ -218,7 +218,7 @@ Query the ndrange on the backend. This function returns
a tuple corresponding to kernel configuration.
"""
macro ndrange()
quote
return quote
$size($ndrange($(esc(:__ctx__))))
end
end
Expand All @@ -232,7 +232,7 @@ macro localmem(T, dims)
# Stay in sync with CUDAnative
id = gensym("static_shmem")

quote
return quote
$SharedMemory($(esc(T)), Val($(esc(dims))), Val($(QuoteNode(id))))
end
end
Expand All @@ -253,7 +253,7 @@ macro private(T, dims)
if dims isa Integer
dims = (dims,)
end
quote
return quote
$Scratchpad($(esc(:__ctx__)), $(esc(T)), Val($(esc(dims))))
end
end
Expand All @@ -265,7 +265,7 @@ Creates a private local of `mem` per item in the workgroup. This can be safely u
across [`@synchronize`](@ref) statements.
"""
macro private(expr)
esc(expr)
return esc(expr)
end

"""
Expand All @@ -275,7 +275,7 @@ end
that span workitems, or are reused across `@synchronize` statements.
"""
macro uniform(value)
esc(value)
return esc(value)
end

"""
Expand All @@ -286,7 +286,7 @@ from each thread in the workgroup are visible in from all other threads in the
workgroup.
"""
macro synchronize()
quote
return quote
$__synchronize()
end
end
Expand All @@ -303,7 +303,7 @@ workgroup. `cond` is not allowed to have any visible sideffects.
- `CPU`: This synchronization will always occur.
"""
macro synchronize(cond)
quote
return quote
$(esc(cond)) && $__synchronize()
end
end
Expand All @@ -328,7 +328,7 @@ end
```
"""
macro context()
esc(:(__ctx__))
return esc(:(__ctx__))
end

"""
Expand Down Expand Up @@ -368,7 +368,7 @@ macro print(items...)
end
end

quote
return quote
$__print($(map(esc, args)...))
end
end
Expand Down Expand Up @@ -424,7 +424,7 @@ macro index(locale, args...)
end

index_function = Symbol(:__index_, locale, :_, indexkind)
Expr(:call, GlobalRef(KernelAbstractions, index_function), esc(:__ctx__), map(esc, args)...)
return Expr(:call, GlobalRef(KernelAbstractions, index_function), esc(:__ctx__), map(esc, args)...)
end

###
Expand Down Expand Up @@ -662,7 +662,7 @@ struct Kernel{Backend, WorkgroupSize <: _Size, NDRange <: _Size, Fun}
end

function Base.similar(kernel::Kernel{D, WS, ND}, f::F) where {D, WS, ND, F}
Kernel{D, WS, ND, F}(kernel.backend, f)
return Kernel{D, WS, ND, F}(kernel.backend, f)
end

workgroupsize(::Kernel{D, WorkgroupSize}) where {D, WorkgroupSize} = WorkgroupSize
Expand Down Expand Up @@ -772,7 +772,7 @@ end
push!(args, item)
end

quote
return quote
print($(args...))
end
end
Expand Down
9 changes: 5 additions & 4 deletions src/cpu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ function (obj::Kernel{CPU})(args...; ndrange = nothing, workgroupsize = nothing)
end

__run(obj, ndrange, iterspace, args, dynamic, obj.backend.static)
return nothing
end

const CPU_GRAINSIZE = 1024 # Vectorization, 4x unrolling, minimal grain size
Expand Down Expand Up @@ -161,15 +162,15 @@ end

@inline function __index_Global_Linear(ctx, idx::CartesianIndex)
I = @inbounds expand(__iterspace(ctx), __groupindex(ctx), idx)
@inbounds LinearIndices(__ndrange(ctx))[I]
return @inbounds LinearIndices(__ndrange(ctx))[I]
end

@inline function __index_Local_Cartesian(_, idx::CartesianIndex)
return idx
end

@inline function __index_Group_Cartesian(ctx, ::CartesianIndex)
__groupindex(ctx)
return __groupindex(ctx)
end

@inline function __index_Global_Cartesian(ctx, idx::CartesianIndex)
Expand All @@ -190,7 +191,7 @@ end
# CPU implementation of shared memory
###
@inline function SharedMemory(::Type{T}, ::Val{Dims}, ::Val) where {T, Dims}
MArray{__size(Dims), T}(undef)
return MArray{__size(Dims), T}(undef)
end

###
Expand All @@ -211,7 +212,7 @@ end
# https://github.com/JuliaLang/julia/issues/39308
@inline function aview(A, I::Vararg{Any, N}) where {N}
J = Base.to_indices(A, I)
Base.unsafe_view(Base._maybe_reshape_parent(A, Base.index_ndims(J...)), J...)
return Base.unsafe_view(Base._maybe_reshape_parent(A, Base.index_ndims(J...)), J...)
end

@inline function Base.getindex(A::ScratchArray{N}, idx) where {N}
Expand Down
8 changes: 5 additions & 3 deletions src/macros.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ function find_return(stmt)
result |= @capture(expr, return x_)
expr
end
result
return result
end

# XXX: Proper errors
Expand Down Expand Up @@ -103,6 +103,7 @@ function transform_gpu!(def, constargs, force_inbounds)
Expr(:block, let_constargs...),
body,
)
return nothing
end

# The hard case, transform the function for CPU execution
Expand Down Expand Up @@ -137,6 +138,7 @@ function transform_cpu!(def, constargs, force_inbounds)
Expr(:block, let_constargs...),
Expr(:block, new_stmts...),
)
return nothing
end

struct WorkgroupLoop
Expand All @@ -150,7 +152,7 @@ end
is_sync(expr) = @capture(expr, @synchronize() | @synchronize(a_))

function is_scope_construct(expr::Expr)
expr.head === :block # ||
return expr.head === :block # ||
# expr.head === :let
end

Expand All @@ -160,7 +162,7 @@ function find_sync(stmt)
result |= is_sync(expr)
expr
end
result
return result
end

# TODO proper handling of LineInfo
Expand Down
50 changes: 43 additions & 7 deletions src/nditeration.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ abstract type _Size end
struct DynamicSize <: _Size end
struct StaticSize{S} <: _Size
function StaticSize{S}() where {S}
new{S::Tuple{Vararg{Int}}}()
return new{S::Tuple{Vararg{Int}}}()
end
end

Expand Down Expand Up @@ -51,11 +51,11 @@ struct NDRange{N, StaticBlocks, StaticWorkitems, DynamicBlock, DynamicWorkitems}
workitems::DynamicWorkitems

function NDRange{N, B, W}() where {N, B, W}
new{N, B, W, Nothing, Nothing}(nothing, nothing)
return new{N, B, W, Nothing, Nothing}(nothing, nothing)
end

function NDRange{N, B, W}(blocks, workitems) where {N, B, W}
new{N, B, W, typeof(blocks), typeof(workitems)}(blocks, workitems)
return new{N, B, W, typeof(blocks), typeof(workitems)}(blocks, workitems)
end
end

Expand All @@ -78,19 +78,55 @@ Base.length(range::NDRange) = length(blocks(range))
gidx = groupidx.I[I]
(gidx - 1) * stride + idx.I[I]
end
CartesianIndex(nI)
return CartesianIndex(nI)
end


"""
assume(cond::Bool)

Assume that the condition `cond` is true. This is a hint to the compiler, possibly enabling
it to optimize more aggressively.
"""
@inline assume(cond::Bool) = Base.llvmcall(
(
"""
declare void @llvm.assume(i1)

define void @entry(i8) #0 {
%cond = icmp eq i8 %0, 1
call void @llvm.assume(i1 %cond)
ret void
}

attributes #0 = { alwaysinline }""", "entry",
),
Nothing, Tuple{Bool}, cond
)

@inline function assume_nonzero(CI::CartesianIndices)
return ntuple(Val(ndims(CI))) do I
Base.@_inline_meta
indices = CI.indices[I]
assume(indices.stop > 0)
end
end

Base.@propagate_inbounds function expand(ndrange::NDRange, groupidx::Integer, idx::Integer)
expand(ndrange, blocks(ndrange)[groupidx], workitems(ndrange)[idx])
# this causes a exception branch and a div
B = blocks(ndrange)
W = workitems(ndrange)
assume_nonzero(B)
assume_nonzero(W)
return expand(ndrange, B[groupidx], workitems(ndrange)[idx])
end

Base.@propagate_inbounds function expand(ndrange::NDRange{N}, groupidx::CartesianIndex{N}, idx::Integer) where {N}
expand(ndrange, groupidx, workitems(ndrange)[idx])
return expand(ndrange, groupidx, workitems(ndrange)[idx])
end

Base.@propagate_inbounds function expand(ndrange::NDRange{N}, groupidx::Integer, idx::CartesianIndex{N}) where {N}
expand(ndrange, blocks(ndrange)[groupidx], idx)
return expand(ndrange, blocks(ndrange)[groupidx], idx)
end

"""
Expand Down
6 changes: 3 additions & 3 deletions src/reflection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ end


function ka_code_llvm(kernel, argtypes; ndrange = nothing, workgroupsize = nothing, kwargs...)
ka_code_llvm(stdout, kernel, argtypes; ndrange = ndrange, workgroupsize = nothing, kwargs...)
return ka_code_llvm(stdout, kernel, argtypes; ndrange = ndrange, workgroupsize = nothing, kwargs...)
end

function ka_code_llvm(io::IO, kernel, argtypes; ndrange = nothing, workgroupsize = nothing, kwargs...)
Expand Down Expand Up @@ -119,7 +119,7 @@ macro ka_code_typed(ex0...)

thecall = InteractiveUtils.gen_call_with_extracted_types_and_kwargs(__module__, :ka_code_typed, ex)

quote
return quote
local $(esc(args)) = $(old_args)
# e.g. translate CuArray to CuBackendArray
$(esc(args)) = map(x -> argconvert($kern, x), $(esc(args)))
Expand Down Expand Up @@ -152,7 +152,7 @@ macro ka_code_llvm(ex0...)

thecall = InteractiveUtils.gen_call_with_extracted_types_and_kwargs(__module__, :ka_code_llvm, ex)

quote
return quote
local $(esc(args)) = $(old_args)

if isa($kern, Kernel{G} where {G <: GPU})
Expand Down
Loading