Skip to content

Commit 78c657f

Browse files
authored
Merge pull request #677 from JuliaParallel/jps/stencil-more-bcs
stencils: Support mixed boundaries, add Clamp and LinearExtrapolate
2 parents f39e908 + 49dbb60 commit 78c657f

File tree

4 files changed

+749
-70
lines changed

4 files changed

+749
-70
lines changed

docs/src/stencils.md

Lines changed: 80 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ The fundamental structure of a `@stencil` block involves iterating over an impli
88

99
```julia
1010
using Dagger
11-
import Dagger: @stencil, Wrap, Pad, Reflect
11+
import Dagger: @stencil, Wrap, Pad, Reflect, Clamp, LinearExtrapolate
1212

1313
# Initialize a DArray
1414
A = zeros(Blocks(2, 2), Int, 4, 4)
@@ -31,13 +31,16 @@ The true power of stencils comes from accessing neighboring elements. The `@neig
3131
`@neighbors(array[idx], distance, boundary_condition)`
3232

3333
- `array[idx]`: The array and current index from which to find neighbors.
34-
- `distance`: An integer specifying the extent of the neighborhood (e.g., `1` for a 3x3 neighborhood in 2D).
34+
- `distance`: An integer or `Tuple` of integers specifying the extent of the neighborhood (e.g., `1` for a 3x3 neighborhood in 2D).
3535
- `boundary_condition`: Defines how to handle accesses beyond the array boundaries. Available conditions are:
36-
- `Wrap()`: Wraps around to the other side of the array.
36+
- `Wrap()`: Wraps around to the other side of the array (periodic boundaries).
3737
- `Pad(value)`: Pads with a specified `value`.
3838
- `Reflect(symmetric)`: Reflects values back into the array at boundaries. The `symmetric` boolean controls whether the edge element is included in the reflection:
3939
- `Reflect(true)` (symmetric): Edge element IS repeated. For array `[a,b,c,d]`, extends as `[...,c,b,a,a,b,c,d,d,c,b,...]`.
4040
- `Reflect(false)` (mirror): Edge element NOT repeated. For array `[a,b,c,d]`, extends as `[...,d,c,b,a,b,c,d,c,b,a,...]`.
41+
- `Clamp()`: Clamps to the boundary value (repeats edge elements). For array `[a,b,c,d]`, extends as `[...,a,a,a,a,b,c,d,d,d,d,...]`.
42+
- `LinearExtrapolate()`: Linearly extrapolates using the slope at the boundary. Only works with `Real` element types. For array `[2,4,6,8]`, the slope at the low boundary is `4-2=2`, so index 0 would be `2-2=0`.
43+
- **Mixed BCs (Tuple)**: You can specify different boundary conditions per dimension using a tuple. For example, `(Wrap(), Pad(0))` uses `Wrap` for dimension 1 and `Pad(0)` for dimension 2.
4144

4245
### Example: Averaging Neighbors with `Wrap`
4346

@@ -149,6 +152,80 @@ end
149152
@assert collect(B) == [5, 6, 9, 10]
150153
```
151154

155+
### Example: Edge Detection with `Clamp`
156+
157+
The `Clamp` boundary condition repeats edge values, which is useful when you want boundary elements to have a neutral effect:
158+
159+
```julia
160+
import Dagger: Clamp
161+
162+
# Array [1, 2, 3, 4] extends as [..., 1, 1, 1, 1, 2, 3, 4, 4, 4, 4, ...]
163+
A = DArray([1, 2, 3, 4], Blocks(2))
164+
B = zeros(Blocks(2), Int, 4)
165+
166+
Dagger.spawn_datadeps() do
167+
@stencil begin
168+
B[idx] = sum(@neighbors(A[idx], 1, Clamp()))
169+
end
170+
end
171+
172+
# B[1]: indices 0,1,2 -> 0 clamps to 1, so [1,1,2] = 4
173+
# B[2]: indices 1,2,3 -> all in bounds, [1,2,3] = 6
174+
# B[3]: indices 2,3,4 -> all in bounds, [2,3,4] = 9
175+
# B[4]: indices 3,4,5 -> 5 clamps to 4, so [3,4,4] = 11
176+
@assert collect(B) == [4, 6, 9, 11]
177+
```
178+
179+
### Example: Smooth Extrapolation with `LinearExtrapolate`
180+
181+
The `LinearExtrapolate` boundary condition extrapolates linearly based on the slope at the boundary. This is useful for maintaining trends at edges:
182+
183+
```julia
184+
import Dagger: LinearExtrapolate
185+
186+
# Array [2.0, 4.0, 6.0, 8.0] has slope 2.0 at both boundaries
187+
# At low boundary: index 0 -> 2.0 + 2.0*(-1) = 0.0
188+
# At high boundary: index 5 -> 8.0 + 2.0*(1) = 10.0
189+
A = DArray([2.0, 4.0, 6.0, 8.0], Blocks(2))
190+
B = zeros(Blocks(2), Float64, 4)
191+
192+
Dagger.spawn_datadeps() do
193+
@stencil begin
194+
B[idx] = sum(@neighbors(A[idx], 1, LinearExtrapolate()))
195+
end
196+
end
197+
198+
# B[1]: indices 0,1,2 -> [0.0, 2.0, 4.0] = 6.0
199+
# B[2]: indices 1,2,3 -> [2.0, 4.0, 6.0] = 12.0
200+
# B[3]: indices 2,3,4 -> [4.0, 6.0, 8.0] = 18.0
201+
# B[4]: indices 3,4,5 -> [6.0, 8.0, 10.0] = 24.0
202+
@assert collect(B) [6.0, 12.0, 18.0, 24.0]
203+
```
204+
205+
### Example: Mixed Boundary Conditions
206+
207+
You can specify different boundary conditions for each dimension using a tuple. This is useful when different boundaries have different physical meanings:
208+
209+
```julia
210+
import Dagger: Wrap, Pad
211+
212+
# 2D array with Wrap in dimension 1 (rows) and Pad(0) in dimension 2 (columns)
213+
A = DArray(reshape(1:16, 4, 4), Blocks(2, 2))
214+
B = zeros(Blocks(2, 2), Int, 4, 4)
215+
216+
Dagger.spawn_datadeps() do
217+
@stencil begin
218+
B[idx] = sum(@neighbors(A[idx], 1, (Wrap(), Pad(0))))
219+
end
220+
end
221+
222+
# For each element:
223+
# - Row neighbors wrap around (periodic in rows)
224+
# - Column neighbors are padded with 0 (zero-flux at column boundaries)
225+
```
226+
227+
This is particularly useful for physical simulations where, for example, you might have periodic boundaries in one direction and fixed boundaries in another.
228+
152229
## Sequential Semantics
153230

154231
Expressions within a `@stencil` block are executed sequentially in terms of their effect on the data. This means that the result of one statement is visible to the subsequent statements, as if they were applied "all at once" across all indices before the next statement begins.

src/Dagger.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,10 @@ end
4646

4747
import MacroTools: @capture, prewalk
4848

49+
import KernelAbstractions
50+
import KernelAbstractions: @kernel, @index
51+
import Adapt
52+
4953
include("lib/util.jl")
5054
include("utils/dagdebug.jl")
5155

@@ -123,8 +127,6 @@ include("array/mul.jl")
123127
include("array/cholesky.jl")
124128
include("array/lu.jl")
125129

126-
import KernelAbstractions, Adapt
127-
128130
# GPU
129131
include("gpu.jl")
130132

0 commit comments

Comments
 (0)