Skip to content

Commit c1f51a7

Browse files
committed
fixup! datadeps: Add at-stencil helper
1 parent ea01b52 commit c1f51a7

File tree

2 files changed

+82
-19
lines changed

2 files changed

+82
-19
lines changed

src/stencil.jl

Lines changed: 40 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,20 @@ const Read = In
33
const Write = Out
44
const ReadWrite = InOut
55

6+
function validate_neigh_dist(neigh_dist, size)
7+
if !(neigh_dist isa Integer)
8+
throw(ArgumentError("Neighborhood distance ($neigh_dist) must be an Integer"))
9+
end
10+
if neigh_dist <= 0
11+
throw(ArgumentError("Neighborhood distance ($neigh_dist) must be greater than 0"))
12+
end
13+
if any(size .< neigh_dist)
14+
throw(ArgumentError("Neighborhood distance ($neigh_dist) must not be larger than the chunk size ($size)"))
15+
end
16+
end
17+
618
function load_neighbor_edge(arr, dim, dir, neigh_dist)
19+
validate_neigh_dist(neigh_dist, size(arr))
720
if dir == -1
821
start_idx = CartesianIndex(ntuple(i -> i == dim ? (lastindex(arr, i) - neigh_dist + 1) : firstindex(arr, i), ndims(arr)))
922
stop_idx = CartesianIndex(ntuple(i -> i == dim ? lastindex(arr, i) : lastindex(arr, i), ndims(arr)))
@@ -15,6 +28,7 @@ function load_neighbor_edge(arr, dim, dir, neigh_dist)
1528
return move(task_processor(), collect(@view arr[start_idx:stop_idx]))
1629
end
1730
function load_neighbor_corner(arr, corner_side, neigh_dist)
31+
validate_neigh_dist(neigh_dist, size(arr))
1832
start_idx = CartesianIndex(ntuple(i -> corner_side[i] == 0 ? (lastindex(arr, i) - neigh_dist + 1) : firstindex(arr, i), ndims(arr)))
1933
stop_idx = CartesianIndex(ntuple(i -> corner_side[i] == 0 ? lastindex(arr, i) : (firstindex(arr, i) + neigh_dist - 1), ndims(arr)))
2034
return move(task_processor(), collect(@view arr[start_idx:stop_idx]))
@@ -134,9 +148,9 @@ end
134148
"""
135149
@stencil begin body end
136150
137-
Allows the specification of stencil operations within a `spawn_datadeps`
138-
region. The `idx` variable is used to iterate over `range`, which must be a
139-
`DArray`. An example usage may look like:
151+
Allows the execution of stencil operations within a `spawn_datadeps` region.
152+
The `idx` variable is used to iterate over one or more `DArray`s. An example
153+
usage may look like:
140154
141155
```julia
142156
import Dagger: @stencil, Wrap
@@ -146,19 +160,19 @@ A[5, 5] = 1
146160
B = zeros(Blocks(3, 3), Int, 9, 9)
147161
Dagger.spawn_datadeps() do
148162
@stencil begin
149-
# Sum values of all neighbors with self
150-
A[idx] = sum(@neighbors(A[idx], 1, Wrap()))
151-
# Decrement all values by 1
152-
A[idx] -= 1
153-
# Copy A to B
154-
B[idx] = A[idx]
163+
# Increment all values by 1
164+
A[idx] = A[idx] + 1
165+
# Sum values of all neighbors with self and write to B
166+
B[idx] = sum(@neighbors(A[idx], 1, Wrap()))
167+
# Copy B back to A
168+
A[idx] = B[idx]
155169
end
156170
end
157171
```
158172
159173
Each expression within an `@stencil` region that performs an in-place indexing
160174
expression like `A[idx] = ...` is transformed into a set of tasks that operate
161-
on each chunk of `A` or any other arrays specified as `A[idx]`, and within each
175+
on each chunk of `A` or any other arrays specified as `A[idx]`; within each
162176
task, elements of that chunk of `A` can be accessed. Elements of multiple
163177
`DArray`s can be accessed, such as `B[idx]`, so long as `B` has the same size,
164178
shape, and chunk layout as `A`.
@@ -168,18 +182,20 @@ values around `A[idx]`, at a configurable distance (in this case, 1 element
168182
distance) and with various kinds of boundary conditions (in this case, `Wrap()`
169183
specifies wrapping behavior on the boundaries). Neighborhoods are computed with
170184
respect to neighboring chunks as well - if a neighborhood would overflow from
171-
the current chunk into one or more neighboring chunks, values from those
172-
neighboring chunks will be included in the neighborhood.
185+
the current chunk into a neighboring chunk, values from that neighboring chunk
186+
will be included in the neighborhood.
173187
174188
Note that, while `@stencil` may look like a `for` loop, it does not follow the
175189
same semantics; in particular, an expression within `@stencil` occurs "all at
176190
once" (across all indices) before the next expression occurs. This means that
177-
`A[idx] = sum(@neighbors(A[idx], 1, Wrap()))` will write the sum of
178-
neighbors for all `idx` values into `A[idx]` before `A[idx] -= 1` decrements
179-
the values `A` by 1, and that occurs before any of the values are copied to `B`
180-
in `B[idx] = A[idx]`. Of course, pipelining and other optimizations may still
181-
occur, so long as they respect the sequential nature of `@stencil` (just like
182-
with other operations in `spawn_datadeps`).
191+
`A[idx] = A[idx] + 1` increments the values `A` by 1, which occurs before
192+
`B[idx] = sum(@neighbors(A[idx], 1, Wrap()))` writes the sum of neighbors for
193+
all `idx` values into `B[idx]`, and that occurs before any of the values are
194+
copied to `A` in `A[idx] = B[idx]`. Of course, pipelining and other optimizations
195+
may still occur, so long as they respect the sequential nature of `@stencil`
196+
(just like with other operations in `spawn_datadeps`). Due to this behavior,
197+
expressions like `A[idx] = sum(@neighbors(A[idx], 1, Wrap()))` are not valid,
198+
as that would currently cause race conditions and lead to undefined behavior.
183199
"""
184200
macro stencil(orig_ex)
185201
@assert Meta.isexpr(orig_ex, :block) "Invalid stencil block: $orig_ex"
@@ -200,7 +216,12 @@ macro stencil(orig_ex)
200216
push!(accessed_vars, read_var)
201217
push!(read_vars, read_var)
202218
elseif @capture(read_inner_ex, @neighbors(read_var_[read_idx_], neigh_dist_, boundary_))
203-
@assert read_idx == write_idx "Neighborhood access must be at the same index as the write: $read_inner_ex"
219+
if read_idx != write_idx
220+
throw(ArgumentError("Neighborhood access must be at the same index as the write: $read_inner_ex"))
221+
end
222+
if write_var == read_var
223+
throw(ArgumentError("Cannot write to the same variable as the neighborhood access: $read_inner_ex"))
224+
end
204225
push!(accessed_vars, read_var)
205226
push!(read_vars, read_var)
206227
neighborhoods[read_var] = (neigh_dist, boundary)

test/array/stencil.jl

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,48 @@ function test_stencil()
118118
expected_B_pad_val = fill(pad_value*5 + 1*4, 2, 2)
119119
@test collect(B) == expected_B_pad_val
120120
end
121+
122+
@testset "Invalid neighborhood distance" begin
123+
A = ones(Blocks(1, 1), Int, 2, 2)
124+
@test_throws ArgumentError Dagger.spawn_datadeps() do
125+
@stencil begin
126+
B[idx] = sum(@neighbors(A[idx], 0, Wrap()))
127+
end
128+
end
129+
@test_throws ArgumentError Dagger.spawn_datadeps() do
130+
@stencil begin
131+
B[idx] = sum(@neighbors(A[idx], -1, Wrap()))
132+
end
133+
end
134+
@test_throws ArgumentError Dagger.spawn_datadeps() do
135+
@stencil begin
136+
B[idx] = sum(@neighbors(A[idx], 1.5, Wrap()))
137+
end
138+
end
139+
@test_throws ArgumentError Dagger.spawn_datadeps() do
140+
@stencil begin
141+
B[idx] = sum(@neighbors(A[idx], 2, Wrap()))
142+
end
143+
end
144+
end
145+
146+
@testset "Invalid neighborhood access of written variable" begin
147+
A = ones(Blocks(1, 1), Int, 2, 2)
148+
@test_throws ArgumentError @eval Dagger.spawn_datadeps() do
149+
@stencil begin
150+
A[idx] = sum(@neighbors(A[idx], 1, Wrap()))
151+
end
152+
end
153+
end
154+
155+
@testset "Invalid update expression" begin
156+
A = ones(Blocks(1, 1), Int, 2, 2)
157+
@test_throws ArgumentError @eval Dagger.spawn_datadeps() do
158+
@stencil begin
159+
A[idx] += 1
160+
end
161+
end
162+
end
121163
end
122164

123165
@testset "CPU" begin

0 commit comments

Comments
 (0)