Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions fast_sum.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
using OpenCL, pocl_jll, BenchmarkTools

using SPIRVIntrinsics, Atomix

function sum_columns_subgroup(X, result, M, N)
col = get_global_id(1)
row_thread = get_global_id(2)
row_stride = get_global_size(2)

if col > N
return
end

partial = 0.0f0
for row = row_thread:row_stride:M
idx = (col - 1) * M + row # column-major layout
partial += X[idx]
end

# Subgroup shuffle-based warp reduction
lane = get_sub_group_local_id()
width = get_sub_group_size()

offset = 1
while offset < width
if lane >= offset
other = sub_group_shuffle(partial, lane - offset)
partial += other
end
offset <<= 1
end

# Only one thread writes result
if lane == 1
Atomix.@atomic result[col] += partial
end
nothing
end


X = OpenCL.rand(Float32, 1000, 1000)
out = OpenCL.zeros(Float32, 1000)
@benchmark begin
@opencl local_size = (1, 64) global_size = (1000, 64) extensions = ["SPV_EXT_shader_atomic_float_add"] sum_columns_subgroup(X, out, 1000, 1000)
OpenCL.synchronize(out)
end
9 changes: 9 additions & 0 deletions lib/intrinsics/src/atomic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,15 @@ for gentype in atomic_integer_types, as in atomic_memory_types
end
end

for gentype in [Float32, Float64], as in atomic_memory_types
@eval begin

@device_function atomic_add!(p::LLVMPtr{$gentype,$as}, val::$gentype) =
@builtin_ccall("atomic_add", $gentype,
(LLVMPtr{$gentype,$as}, $gentype), p, val)
end
end


# specifically typed

Expand Down
44 changes: 44 additions & 0 deletions lib/intrinsics/src/work_item.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,50 @@ for (julia_name, (spirv_name, julia_type, offset)) in [
end
end


# Sub-group shuffle intrinsics using a loop and @eval, matching the style of the 1D/3D value loops above
export sub_group_shuffle, sub_group_shuffle_xor

for (jltype, llvmtype, julia_type_str) in [
(Int8, "i8", :Int8),
(UInt8, "i8", :UInt8),
(Int16, "i16", :Int16),
(UInt16, "i16", :UInt16),
(Int32, "i32", :Int32),
(UInt32, "i32", :UInt32),
(Int64, "i64", :Int64),
(UInt64, "i64", :UInt64),
(Float16, "half", :Float16),
(Float32, "float", :Float32),
(Float64, "double",:Float64)
]
@eval begin
export sub_group_shuffle, sub_group_shuffle_xor
function sub_group_shuffle(x::$jltype, idx::Integer)
Base.llvmcall(
$("""
declare $llvmtype @__spirv_GroupNonUniformShuffle(i32, $llvmtype, i32)
define $llvmtype @entry($llvmtype %val, i32 %idx) #0 {
%res = call $llvmtype @__spirv_GroupNonUniformShuffle(i32 3, $llvmtype %val, i32 %idx)
ret $llvmtype %res
}
attributes #0 = { alwaysinline }
""", "entry"), $julia_type_str, Tuple{$julia_type_str, Int32}, x, Int32(idx))
end
function sub_group_shuffle_xor(x::$jltype, mask::Integer)
Base.llvmcall(
$("""
declare $llvmtype @__spirv_GroupNonUniformShuffleXor(i32, $llvmtype, i32)
define $llvmtype @entry($llvmtype %val, i32 %mask) #0 {
%res = call $llvmtype @__spirv_GroupNonUniformShuffleXor(i32 3, $llvmtype %val, i32 %mask)
ret $llvmtype %res
}
attributes #0 = { alwaysinline }
""", "entry"), $julia_type_str, Tuple{$julia_type_str, Int32}, x, Int32(mask))
end
end
end

# 3D values
for (julia_name, (spirv_name, offset)) in [
# indices
Expand Down
2 changes: 1 addition & 1 deletion src/compiler/execution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ export @opencl, clfunction
## high-level @opencl interface

const MACRO_KWARGS = [:launch]
const COMPILER_KWARGS = [:kernel, :name, :always_inline]
const COMPILER_KWARGS = [:kernel, :name, :always_inline, :extensions]
const LAUNCH_KWARGS = [:global_size, :local_size, :queue]

macro opencl(ex...)
Expand Down
Loading