-
Notifications
You must be signed in to change notification settings - Fork 44
[WIP] faster sum #356
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
[WIP] faster sum #356
Conversation
Your PR requires formatting changes to meet the project's style guidelines. Click here to view the suggested changes.diff --git a/fast_sum.jl b/fast_sum.jl
index e7cba52..fec6952 100644
--- a/fast_sum.jl
+++ b/fast_sum.jl
@@ -12,7 +12,7 @@ function sum_columns_subgroup(X, result, M, N)
end
partial = 0.0f0
- for row = row_thread:row_stride:M
+ for row in row_thread:row_stride:M
idx = (col - 1) * M + row # column-major layout
partial += X[idx]
end
@@ -32,9 +32,9 @@ function sum_columns_subgroup(X, result, M, N)
# Only one thread writes result
if lane == 1
- Atomix.@atomic result[col] += partial
+ Atomix.@atomic result[col] += partial
end
- nothing
+ return nothing
end
diff --git a/lib/intrinsics/src/atomic.jl b/lib/intrinsics/src/atomic.jl
index 08e71e8..6d46c2e 100644
--- a/lib/intrinsics/src/atomic.jl
+++ b/lib/intrinsics/src/atomic.jl
@@ -58,12 +58,14 @@ end
end
for gentype in [Float32, Float64], as in atomic_memory_types
-@eval begin
+ @eval begin
-@device_function atomic_add!(p::LLVMPtr{$gentype,$as}, val::$gentype) =
- @builtin_ccall("atomic_add", $gentype,
- (LLVMPtr{$gentype,$as}, $gentype), p, val)
-end
+ @device_function atomic_add!(p::LLVMPtr{$gentype, $as}, val::$gentype) =
+ @builtin_ccall(
+ "atomic_add", $gentype,
+ (LLVMPtr{$gentype, $as}, $gentype), p, val
+ )
+ end
end
diff --git a/lib/intrinsics/src/work_item.jl b/lib/intrinsics/src/work_item.jl
index d9919f2..3f6446b 100644
--- a/lib/intrinsics/src/work_item.jl
+++ b/lib/intrinsics/src/work_item.jl
@@ -39,41 +39,47 @@ end
export sub_group_shuffle, sub_group_shuffle_xor
for (jltype, llvmtype, julia_type_str) in [
- (Int8, "i8", :Int8),
- (UInt8, "i8", :UInt8),
- (Int16, "i16", :Int16),
- (UInt16, "i16", :UInt16),
- (Int32, "i32", :Int32),
- (UInt32, "i32", :UInt32),
- (Int64, "i64", :Int64),
- (UInt64, "i64", :UInt64),
- (Float16, "half", :Float16),
+ (Int8, "i8", :Int8),
+ (UInt8, "i8", :UInt8),
+ (Int16, "i16", :Int16),
+ (UInt16, "i16", :UInt16),
+ (Int32, "i32", :Int32),
+ (UInt32, "i32", :UInt32),
+ (Int64, "i64", :Int64),
+ (UInt64, "i64", :UInt64),
+ (Float16, "half", :Float16),
(Float32, "float", :Float32),
- (Float64, "double",:Float64)
+ (Float64, "double", :Float64),
]
@eval begin
export sub_group_shuffle, sub_group_shuffle_xor
function sub_group_shuffle(x::$jltype, idx::Integer)
- Base.llvmcall(
- $("""
- declare $llvmtype @__spirv_GroupNonUniformShuffle(i32, $llvmtype, i32)
- define $llvmtype @entry($llvmtype %val, i32 %idx) #0 {
- %res = call $llvmtype @__spirv_GroupNonUniformShuffle(i32 3, $llvmtype %val, i32 %idx)
- ret $llvmtype %res
- }
- attributes #0 = { alwaysinline }
- """, "entry"), $julia_type_str, Tuple{$julia_type_str, Int32}, x, Int32(idx))
+ return Base.llvmcall(
+ $(
+ """
+ declare $llvmtype @__spirv_GroupNonUniformShuffle(i32, $llvmtype, i32)
+ define $llvmtype @entry($llvmtype %val, i32 %idx) #0 {
+ %res = call $llvmtype @__spirv_GroupNonUniformShuffle(i32 3, $llvmtype %val, i32 %idx)
+ ret $llvmtype %res
+ }
+ attributes #0 = { alwaysinline }
+ """, "entry",
+ ), $julia_type_str, Tuple{$julia_type_str, Int32}, x, Int32(idx)
+ )
end
function sub_group_shuffle_xor(x::$jltype, mask::Integer)
- Base.llvmcall(
- $("""
- declare $llvmtype @__spirv_GroupNonUniformShuffleXor(i32, $llvmtype, i32)
- define $llvmtype @entry($llvmtype %val, i32 %mask) #0 {
- %res = call $llvmtype @__spirv_GroupNonUniformShuffleXor(i32 3, $llvmtype %val, i32 %mask)
- ret $llvmtype %res
- }
- attributes #0 = { alwaysinline }
- """, "entry"), $julia_type_str, Tuple{$julia_type_str, Int32}, x, Int32(mask))
+ return Base.llvmcall(
+ $(
+ """
+ declare $llvmtype @__spirv_GroupNonUniformShuffleXor(i32, $llvmtype, i32)
+ define $llvmtype @entry($llvmtype %val, i32 %mask) #0 {
+ %res = call $llvmtype @__spirv_GroupNonUniformShuffleXor(i32 3, $llvmtype %val, i32 %mask)
+ ret $llvmtype %res
+ }
+ attributes #0 = { alwaysinline }
+ """, "entry",
+ ), $julia_type_str, Tuple{$julia_type_str, Int32}, x, Int32(mask)
+ )
end
end
end |
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## master #356 +/- ##
=======================================
Coverage 78.86% 78.86%
=======================================
Files 12 12
Lines 672 672
=======================================
Hits 530 530
Misses 142 142 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
That's great to hear, at least. I presume that the cartesian indexing introduced by the more complicated |
ref #352
This matches the speed of the C implementation, so there seems to be no inherent overhead compared to OpenCL C:
cc @maleadt