Skip to content

Commit d0081e7

Browse files
authored
Merge pull request #674 from JuliaParallel/jps/stencil-tuple-neighborhood
DArray/stencils: Support Tuple for neighborhood distance
2 parents 31632a2 + 3944b1b commit d0081e7

File tree

2 files changed

+96
-26
lines changed

2 files changed

+96
-26
lines changed

src/stencil.jl

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@ const Write = Out
44
const ReadWrite = InOut
55

66
function validate_neigh_dist(neigh_dist)
7-
if !(neigh_dist isa Integer)
8-
throw(ArgumentError("Neighborhood distance ($neigh_dist) must be an Integer"))
7+
if !(neigh_dist isa Integer) && !(neigh_dist isa Tuple)
8+
throw(ArgumentError("Neighborhood distance ($neigh_dist) must be an Integer or Tuple"))
99
end
10-
if neigh_dist <= 0
10+
if any(neigh_dist .<= 0)
1111
throw(ArgumentError("Neighborhood distance ($neigh_dist) must be greater than 0"))
1212
end
1313
end
@@ -18,6 +18,10 @@ function validate_neigh_dist(neigh_dist, size)
1818
end
1919
end
2020

21+
get_neigh_dist(neigh_dist::Integer, i::Int) = neigh_dist
22+
get_neigh_dist(neigh_dist::Tuple, i::Int) = neigh_dist[i]
23+
24+
2125
# Load a halo region from a neighboring chunk
2226
# region_code: N-tuple where each element is -1 (low), 0 (full extent), or +1 (high)
2327
# For dimensions with code 0, we take the full extent of the array
@@ -27,14 +31,14 @@ function load_neighbor_region(arr, region_code::NTuple{N,Int}, neigh_dist) where
2731
validate_neigh_dist(neigh_dist, size(arr))
2832
start_idx = CartesianIndex(ntuple(N) do i
2933
if region_code[i] == -1
30-
lastindex(arr, i) - neigh_dist + 1
34+
lastindex(arr, i) - get_neigh_dist(neigh_dist, i) + 1
3135
else
3236
firstindex(arr, i)
3337
end
3438
end)
3539
stop_idx = CartesianIndex(ntuple(N) do i
3640
if region_code[i] == +1
37-
firstindex(arr, i) + neigh_dist - 1
41+
firstindex(arr, i) + get_neigh_dist(neigh_dist, i) - 1
3842
else
3943
lastindex(arr, i)
4044
end
@@ -88,13 +92,11 @@ function build_halo(neigh_dist, boundary, center, all_halos...)
8892
N = ndims(center)
8993
expected_halos = 3^N - 1
9094
@assert length(all_halos) == expected_halos "Halo mismatch: N=$N expected $expected_halos halos, got $(length(all_halos))"
91-
return HaloArray(center, (all_halos...,), ntuple(_->neigh_dist, N))
95+
return HaloArray(center, (all_halos...,), ntuple(i->get_neigh_dist(neigh_dist, i), N))
9296
end
9397
function load_neighborhood(arr::HaloArray{T,N}, idx) where {T,N}
94-
@assert all(arr.halo_width .== arr.halo_width[1])
95-
neigh_dist = arr.halo_width[1]
96-
start_idx = idx - CartesianIndex(ntuple(_->neigh_dist, ndims(arr)))
97-
stop_idx = idx + CartesianIndex(ntuple(_->neigh_dist, ndims(arr)))
98+
start_idx = idx - CartesianIndex(ntuple(i->arr.halo_width[i], ndims(arr)))
99+
stop_idx = idx + CartesianIndex(ntuple(i->arr.halo_width[i], ndims(arr)))
98100
return @view arr[start_idx:stop_idx]
99101
end
100102
function inner_stencil!(f, output, read_vars)
@@ -126,7 +128,7 @@ function load_boundary_region(pad::Pad, arr, region_code::NTuple{N,Int}, neigh_d
126128
# For dimensions with code 0, use full array size
127129
# For dimensions with code -1 or +1, use neigh_dist
128130
region_size = ntuple(N) do i
129-
region_code[i] == 0 ? size(arr, i) : neigh_dist
131+
region_code[i] == 0 ? size(arr, i) : get_neigh_dist(neigh_dist, i)
130132
end
131133
# FIXME: return Fill(pad.padval, region_size)
132134
return move(task_processor(), fill(pad.padval, region_size))

test/array/stencil.jl

Lines changed: 83 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -137,27 +137,94 @@ function test_stencil()
137137
end
138138
end
139139

140-
@testset "Invalid neighborhood distance" begin
141-
A = ones(Blocks(1, 1), Int, 2, 2)
142-
B = zeros(Blocks(1, 1), Int, 2, 2)
143-
@test_throws_unwrap ArgumentError Dagger.spawn_datadeps() do
144-
@stencil begin
145-
B[idx] = sum(@neighbors(A[idx], 0, Wrap()))
140+
@testset "Tuple neighborhood distance" begin
141+
# 1D case: distance (2,)
142+
@testset "1D with distance (2,)" begin
143+
A = DArray([1, 2, 3, 4, 5, 6], Blocks(2,))
144+
B = zeros(Blocks(2,), Int, 6)
145+
Dagger.spawn_datadeps() do
146+
@stencil begin
147+
B[idx] = sum(@neighbors(A[idx], (2,), Wrap()))
148+
end
146149
end
150+
# For each element, neighbors at distance 2 in 1D: [-2, -1, 0, 1, 2]
151+
# B[1] neighbors: A[5], A[6], A[1], A[2], A[3] (wrapping) = 5+6+1+2+3 = 17
152+
# B[2] neighbors: A[6], A[1], A[2], A[3], A[4] = 6+1+2+3+4 = 16
153+
# B[3] neighbors: A[1], A[2], A[3], A[4], A[5] = 1+2+3+4+5 = 15
154+
# B[4] neighbors: A[2], A[3], A[4], A[5], A[6] = 2+3+4+5+6 = 20
155+
# B[5] neighbors: A[3], A[4], A[5], A[6], A[1] = 3+4+5+6+1 = 19
156+
# B[6] neighbors: A[4], A[5], A[6], A[1], A[2] = 4+5+6+1+2 = 18
157+
expected_B_1d = [17, 16, 15, 20, 19, 18]
158+
@test collect(B) == expected_B_1d
147159
end
148-
@test_throws_unwrap ArgumentError Dagger.spawn_datadeps() do
149-
@stencil begin
150-
B[idx] = sum(@neighbors(A[idx], -1, Wrap()))
160+
161+
# 2D case: distance (1, 2) - different per dimension
162+
@testset "2D with distance (1, 2)" begin
163+
A = DArray(reshape(1:12, 3, 4), Blocks(1, 2))
164+
B = zeros(Blocks(1, 2), Int, 3, 4)
165+
Dagger.spawn_datadeps() do
166+
@stencil begin
167+
B[idx] = sum(@neighbors(A[idx], (1, 2), Wrap()))
168+
end
151169
end
170+
# Distance (1, 2) means:
171+
# - dimension 1 (rows): offsets -1, 0, 1
172+
# - dimension 2 (cols): offsets -2, -1, 0, 1, 2
173+
# Total neighborhood size: 3 * 5 = 15 elements
174+
expected_B_2d = zeros(Int, 3, 4)
175+
for i in 1:3, j in 1:4
176+
sum_val = 0
177+
for di in -1:1, dj in -2:2
178+
row = mod1(i+di, 3)
179+
col = mod1(j+dj, 4)
180+
sum_val += A[row, col]
181+
end
182+
expected_B_2d[i, j] = sum_val
183+
end
184+
@test collect(B) == expected_B_2d
152185
end
153-
@test_throws_unwrap ArgumentError Dagger.spawn_datadeps() do
154-
@stencil begin
155-
B[idx] = sum(@neighbors(A[idx], 1.5, Wrap()))
186+
187+
# 3D case: distance (1, 2, 1) - different per dimension
188+
@testset "3D with distance (1, 2, 1)" begin
189+
# Need chunk sizes >= 2*distance+1 for each dimension
190+
# distance (1, 2, 1) requires chunks >= (3, 5, 3)
191+
A = DArray(reshape(1:120, 4, 5, 6), Blocks(4, 5, 3))
192+
B = zeros(Blocks(4, 5, 3), Int, 4, 5, 6)
193+
Dagger.spawn_datadeps() do
194+
@stencil begin
195+
B[idx] = sum(@neighbors(A[idx], (1, 2, 1), Wrap()))
196+
end
156197
end
198+
# Distance (1, 2, 1) means:
199+
# - dimension 1: offsets -1, 0, 1 (3 elements)
200+
# - dimension 2: offsets -2, -1, 0, 1, 2 (5 elements)
201+
# - dimension 3: offsets -1, 0, 1 (3 elements)
202+
# Total neighborhood size: 3 * 5 * 3 = 45 elements
203+
expected_B_3d = zeros(Int, 4, 5, 6)
204+
for i in 1:4, j in 1:5, k in 1:6
205+
sum_val = 0
206+
for di in -1:1, dj in -2:2, dk in -1:1
207+
row = mod1(i+di, 4)
208+
col = mod1(j+dj, 5)
209+
depth = mod1(k+dk, 6)
210+
sum_val += A[row, col, depth]
211+
end
212+
expected_B_3d[i, j, k] = sum_val
213+
end
214+
@test collect(B) == expected_B_3d
157215
end
158-
@test_throws_unwrap ArgumentError Dagger.spawn_datadeps() do
159-
@stencil begin
160-
B[idx] = sum(@neighbors(A[idx], 2, Wrap()))
216+
end
217+
218+
@testset "Invalid neighborhood distance" begin
219+
A = ones(Blocks(1, 1), Int, 2, 2)
220+
B = zeros(Blocks(1, 1), Int, 2, 2)
221+
for value in [0, -1, 1.5, 2]
222+
for dist in [value, (value,)]
223+
@test_throws_unwrap ArgumentError Dagger.spawn_datadeps() do
224+
@stencil begin
225+
B[idx] = sum(@neighbors(A[idx], value, Wrap()))
226+
end
227+
end
161228
end
162229
end
163230
end
@@ -196,3 +263,4 @@ end
196263
end
197264
end
198265
end
266+

0 commit comments

Comments
 (0)