From 6eac710b8977507ffbd2f576c9ea29d5c65fd14f Mon Sep 17 00:00:00 2001 From: Simeon David Schaub Date: Thu, 7 Aug 2025 13:42:16 +0200 Subject: [PATCH] [WIP] faster sum --- fast_sum.jl | 46 +++++++++++++++++++++++++++++++++ lib/intrinsics/src/atomic.jl | 9 +++++++ lib/intrinsics/src/work_item.jl | 44 +++++++++++++++++++++++++++++++ src/compiler/execution.jl | 2 +- 4 files changed, 100 insertions(+), 1 deletion(-) create mode 100644 fast_sum.jl diff --git a/fast_sum.jl b/fast_sum.jl new file mode 100644 index 00000000..e7cba52a --- /dev/null +++ b/fast_sum.jl @@ -0,0 +1,46 @@ +using OpenCL, pocl_jll, BenchmarkTools + +using SPIRVIntrinsics, Atomix + +function sum_columns_subgroup(X, result, M, N) + col = get_global_id(1) + row_thread = get_global_id(2) + row_stride = get_global_size(2) + + if col > N + return + end + + partial = 0.0f0 + for row = row_thread:row_stride:M + idx = (col - 1) * M + row # column-major layout + partial += X[idx] + end + + # Subgroup shuffle-based warp reduction + lane = get_sub_group_local_id() + width = get_sub_group_size() + + offset = 1 + while offset < width + if lane >= offset + other = sub_group_shuffle(partial, lane - offset) + partial += other + end + offset <<= 1 + end + + # Only one thread writes result + if lane == 1 + Atomix.@atomic result[col] += partial + end + nothing +end + + +X = OpenCL.rand(Float32, 1000, 1000) +out = OpenCL.zeros(Float32, 1000) +@benchmark begin + @opencl local_size = (1, 64) global_size = (1000, 64) extensions = ["SPV_EXT_shader_atomic_float_add"] sum_columns_subgroup(X, out, 1000, 1000) + OpenCL.synchronize(out) +end diff --git a/lib/intrinsics/src/atomic.jl b/lib/intrinsics/src/atomic.jl index 9bbbdbe6..08e71e88 100644 --- a/lib/intrinsics/src/atomic.jl +++ b/lib/intrinsics/src/atomic.jl @@ -57,6 +57,15 @@ for gentype in atomic_integer_types, as in atomic_memory_types end end +for gentype in [Float32, Float64], as in atomic_memory_types +@eval begin + +@device_function atomic_add!(p::LLVMPtr{$gentype,$as}, val::$gentype) = + @builtin_ccall("atomic_add", $gentype, + (LLVMPtr{$gentype,$as}, $gentype), p, val) +end +end + # specifically typed diff --git a/lib/intrinsics/src/work_item.jl b/lib/intrinsics/src/work_item.jl index bbe85adb..d9919f2e 100644 --- a/lib/intrinsics/src/work_item.jl +++ b/lib/intrinsics/src/work_item.jl @@ -34,6 +34,50 @@ for (julia_name, (spirv_name, julia_type, offset)) in [ end end + +# Sub-group shuffle intrinsics using a loop and @eval, matching the style of the 1D/3D value loops above +export sub_group_shuffle, sub_group_shuffle_xor + +for (jltype, llvmtype, julia_type_str) in [ + (Int8, "i8", :Int8), + (UInt8, "i8", :UInt8), + (Int16, "i16", :Int16), + (UInt16, "i16", :UInt16), + (Int32, "i32", :Int32), + (UInt32, "i32", :UInt32), + (Int64, "i64", :Int64), + (UInt64, "i64", :UInt64), + (Float16, "half", :Float16), + (Float32, "float", :Float32), + (Float64, "double",:Float64) + ] + @eval begin + export sub_group_shuffle, sub_group_shuffle_xor + function sub_group_shuffle(x::$jltype, idx::Integer) + Base.llvmcall( + $(""" + declare $llvmtype @__spirv_GroupNonUniformShuffle(i32, $llvmtype, i32) + define $llvmtype @entry($llvmtype %val, i32 %idx) #0 { + %res = call $llvmtype @__spirv_GroupNonUniformShuffle(i32 3, $llvmtype %val, i32 %idx) + ret $llvmtype %res + } + attributes #0 = { alwaysinline } + """, "entry"), $julia_type_str, Tuple{$julia_type_str, Int32}, x, Int32(idx)) + end + function sub_group_shuffle_xor(x::$jltype, mask::Integer) + Base.llvmcall( + $(""" + declare $llvmtype @__spirv_GroupNonUniformShuffleXor(i32, $llvmtype, i32) + define $llvmtype @entry($llvmtype %val, i32 %mask) #0 { + %res = call $llvmtype @__spirv_GroupNonUniformShuffleXor(i32 3, $llvmtype %val, i32 %mask) + ret $llvmtype %res + } + attributes #0 = { alwaysinline } + """, "entry"), $julia_type_str, Tuple{$julia_type_str, Int32}, x, Int32(mask)) + end + end +end + # 3D values for (julia_name, (spirv_name, offset)) in [ # indices diff --git a/src/compiler/execution.jl b/src/compiler/execution.jl index 881ea906..ad761bea 100644 --- a/src/compiler/execution.jl +++ b/src/compiler/execution.jl @@ -4,7 +4,7 @@ export @opencl, clfunction ## high-level @opencl interface const MACRO_KWARGS = [:launch] -const COMPILER_KWARGS = [:kernel, :name, :always_inline] +const COMPILER_KWARGS = [:kernel, :name, :always_inline, :extensions] const LAUNCH_KWARGS = [:global_size, :local_size, :queue] macro opencl(ex...)