-
Notifications
You must be signed in to change notification settings - Fork 80
Open
Description
MWE below @vchuravy
Works in 1.10, but not in 1.11/1.12
"""
Minimal working example to demonstrate the difference between:
1. Using `array[I] +=` in a loop (BROKEN)
2. Using a local accumulator then `array[I] =` (WORKS)
This appears to be a bug in KernelAbstractions.jl or me misunderstanding how it handles
compound assignment operators on array elements within loops.
"""
using KernelAbstractions
using Test
# Version 1: Using += directly on array element (BROKEN)
@kernel function accumulate_broken!(output, input, @Const(n))
I = @index(Global, Cartesian)
i = I[1]
j = I[2]
# This should accumulate input[i, k] for k in j:n
# But it produces WRONG results!
for k in j:n
output[I] += input[i, k]
end
end
# Version 2: Using local accumulator (WORKS)
@kernel function accumulate_fixed!(output, input, @Const(n))
I = @index(Global, Cartesian)
i = I[1]
j = I[2]
# Use local accumulator
sum_val = zero(eltype(output))
for k in j:n
sum_val += input[i, k]
end
output[I] = sum_val
end
# Test function
function test_kernel_bug()
println("="^70)
println("Minimal Kernel Bug Demonstration")
println("="^70)
# Setup test data
n = 8
m = 5
input = Float32[i + k for i in 1:m, k in 1:n]
output_broken = zeros(Float32, m, n)
output_fixed = zeros(Float32, m, n)
output_cpu = zeros(Float32, m, n)
println("\nInput matrix ($(m)×$(n)):")
display(input)
println("\n")
# CPU reference implementation
println("Computing CPU reference...")
for i in 1:m
for j in 1:n
for k in j:n
output_cpu[i, j] += input[i, k]
end
end
end
# Version 1: Broken (using +=)
println("Running BROKEN kernel (using output[I] += ...)...")
backend = CPU()
kernel_broken! = accumulate_broken!(backend)
kernel_broken!(output_broken, input, n, ndrange=size(output_broken))
KernelAbstractions.synchronize(backend)
# Version 2: Fixed (using local accumulator)
println("Running FIXED kernel (using local accumulator)...")
kernel_fixed! = accumulate_fixed!(backend)
kernel_fixed!(output_fixed, input, n, ndrange=size(output_fixed))
KernelAbstractions.synchronize(backend)
# Compare results
println("\n" * "="^70)
println("RESULTS")
println("="^70)
println("\nCPU Reference:")
display(output_cpu)
println("\n")
println("\nBROKEN Kernel (output[I] +=):")
display(output_broken)
println("\n")
println("\nFIXED Kernel (local accumulator):")
display(output_fixed)
println("\n")
# Check if broken version is actually broken
is_broken_wrong = !(output_broken ≈ output_cpu)
is_fixed_correct = output_fixed ≈ output_cpu
return (is_broken_wrong, is_fixed_correct)
end
# Run the test
if abspath(PROGRAM_FILE) == @__FILE__
test_kernel_bug()
end
Metadata
Metadata
Assignees
Labels
No labels