@@ -12,85 +12,83 @@ Base.@propagate_inbounds function fd_operator_shmem(
1212 # allocate temp output
1313 RT = return_eltype (op, args... )
1414 Ju³ = CUDA. CuStaticSharedArray (RT, (Nvt,))
15- return Ju³
15+ lJu³ = CUDA. CuStaticSharedArray (RT, (1 ,))
16+ rJu³ = CUDA. CuStaticSharedArray (RT, (1 ,))
17+ return (Ju³, lJu³, rJu³)
1618end
1719
18- Base. @propagate_inbounds function fd_operator_fill_shmem_interior ! (
20+ Base. @propagate_inbounds function fd_operator_fill_shmem ! (
1921 op:: Operators.DivergenceF2C ,
20- Ju³,
21- loc, # can be any location
22- space,
23- idx:: Utilities.PlusHalf ,
24- hidx,
25- arg,
26- )
27- @inbounds begin
28- vt = threadIdx (). x
29- lg = Geometry. LocalGeometry (space, idx, hidx)
30- u³ = Operators. getidx (space, arg, loc, idx, hidx)
31- Ju³[vt] = Geometry. Jcontravariant3 (u³, lg)
32- end
33- return nothing
34- end
35-
36- Base. @propagate_inbounds function fd_operator_fill_shmem_left_boundary! (
37- op:: Operators.DivergenceF2C ,
38- bc:: Operators.SetValue ,
39- Ju³,
22+ (Ju³, lJu³, rJu³),
4023 loc,
24+ bc_bds,
25+ arg_space,
4126 space,
4227 idx:: Utilities.PlusHalf ,
4328 hidx,
4429 arg,
4530)
46- idx == Operators. left_face_boundary_idx (space) ||
47- error (" Incorrect left idx" )
4831 @inbounds begin
4932 vt = threadIdx (). x
5033 lg = Geometry. LocalGeometry (space, idx, hidx)
51- u³ = Operators. getidx (space, bc. val, loc, nothing , hidx)
52- Ju³[vt] = Geometry. Jcontravariant3 (u³, lg)
53- end
54- return nothing
55- end
56-
57- Base. @propagate_inbounds function fd_operator_fill_shmem_right_boundary! (
58- op:: Operators.DivergenceF2C ,
59- bc:: Operators.SetValue ,
60- Ju³,
61- loc,
62- space,
63- idx:: Utilities.PlusHalf ,
64- hidx,
65- arg,
66- )
67- # The right boundary is called at `idx + 1`, so we need to subtract 1 from idx (shmem is loaded at vt+1)
68- idx == Operators. right_face_boundary_idx (space) ||
69- error (" Incorrect right idx" )
70- @inbounds begin
71- vt = threadIdx (). x
72- lg = Geometry. LocalGeometry (space, idx, hidx)
73- u³ = Operators. getidx (space, bc. val, loc, nothing , hidx)
74- Ju³[vt] = Geometry. Jcontravariant3 (u³, lg)
34+ if ! on_boundary (space, op, loc, idx)
35+ u³ = Operators. getidx (space, arg, loc, idx, hidx)
36+ Ju³[vt] = Geometry. Jcontravariant3 (u³, lg)
37+ else
38+ bc = Operators. get_boundary (op, loc)
39+ ub = Operators. getidx (space, bc. val, loc, nothing , hidx)
40+ bJu³ = on_left_boundary (idx, space) ? lJu³ : rJu³
41+ if bc isa Operators. SetValue
42+ bJu³[1 ] = Geometry. Jcontravariant3 (ub, lg)
43+ elseif bc isa Operators. SetDivergence
44+ bJu³[1 ] = ub
45+ elseif bc isa Operators. Extrapolate # no shmem needed
46+ end
47+ end
7548 end
7649 return nothing
7750end
7851
7952Base. @propagate_inbounds function fd_operator_evaluate (
8053 op:: Operators.DivergenceF2C ,
81- Ju³,
54+ ( Ju³, lJu³, rJu³) ,
8255 loc,
8356 space,
8457 idx:: Integer ,
8558 hidx,
86- args ... ,
59+ arg ,
8760)
8861 @inbounds begin
8962 vt = threadIdx (). x
90- local_geometry = Geometry. LocalGeometry (space, idx, hidx)
91- Ju³₋ = Ju³[vt] # corresponds to idx - half
92- Ju³₊ = Ju³[vt + 1 ] # corresponds to idx + half
93- return (Ju³₊ ⊟ Ju³₋) ⊠ local_geometry. invJ
63+ lg = Geometry. LocalGeometry (space, idx, hidx)
64+ if ! on_boundary (space, op, loc, idx)
65+ Ju³₋ = Ju³[vt] # corresponds to idx - half
66+ Ju³₊ = Ju³[vt + 1 ] # corresponds to idx + half
67+ return (Ju³₊ ⊟ Ju³₋) ⊠ lg. invJ
68+ else
69+ bc = Operators. get_boundary (op, loc)
70+ @assert bc isa Operators. SetValue || bc isa Operators. SetDivergence
71+ if on_left_boundary (idx, space)
72+ if bc isa Operators. SetValue
73+ Ju³₋ = lJu³[1 ] # corresponds to idx - half
74+ Ju³₊ = Ju³[vt + 1 ] # corresponds to idx + half
75+ return (Ju³₊ ⊟ Ju³₋) ⊠ lg. invJ
76+ else
77+ # @assert bc isa Operators.SetDivergence
78+ return lJu³[1 ]
79+ end
80+ else
81+ @assert on_right_boundary (idx, space)
82+ if bc isa Operators. SetValue
83+ Ju³₋ = Ju³[vt] # corresponds to idx - half
84+ Ju³₊ = rJu³[1 ] # corresponds to idx + half
85+ return (Ju³₊ ⊟ Ju³₋) ⊠ lg. invJ
86+ else
87+ @assert bc isa Operators. SetDivergence
88+ return rJu³[1 ]
89+ end
90+ end
91+ end
9492 end
9593end
9694
@@ -108,10 +106,12 @@ Base.@propagate_inbounds function fd_operator_shmem(
108106 return (u, lb, rb)
109107end
110108
111- Base. @propagate_inbounds function fd_operator_fill_shmem_interior ! (
109+ Base. @propagate_inbounds function fd_operator_fill_shmem ! (
112110 op:: Operators.GradientC2F ,
113111 (u, lb, rb),
114112 loc, # can be any location
113+ bc_bds,
114+ arg_space,
115115 space,
116116 idx:: Integer ,
117117 hidx,
@@ -120,50 +120,33 @@ Base.@propagate_inbounds function fd_operator_fill_shmem_interior!(
120120 @inbounds begin
121121 vt = threadIdx (). x
122122 cov3 = Geometry. Covariant3Vector (1 )
123- u[vt] = cov3 ⊗ Operators. getidx (space, arg, loc, idx, hidx)
124- end
125- return nothing
126- end
127-
128- Base. @propagate_inbounds function fd_operator_fill_shmem_left_boundary! (
129- op:: Operators.GradientC2F ,
130- bc:: Operators.SetValue ,
131- (u, lb, rb),
132- loc,
133- space,
134- idx:: Integer ,
135- hidx,
136- arg,
137- )
138- idx == Operators. left_center_boundary_idx (space) ||
139- error (" Incorrect left idx" )
140- @inbounds begin
141- vt = threadIdx (). x
142- cov3 = Geometry. Covariant3Vector (1 )
143- u[vt] = cov3 ⊗ Operators. getidx (space, arg, loc, idx, hidx)
144- lb[1 ] = cov3 ⊗ Operators. getidx (space, bc. val, loc, nothing , hidx)
145- end
146- return nothing
147- end
148-
149- Base. @propagate_inbounds function fd_operator_fill_shmem_right_boundary! (
150- op:: Operators.GradientC2F ,
151- bc:: Operators.SetValue ,
152- (u, lb, rb),
153- loc,
154- space,
155- idx:: Integer ,
156- hidx,
157- arg,
158- )
159- # The right boundary is called at `idx + 1`, so we need to subtract 1 from idx (shmem is loaded at vt+1)
160- idx == Operators. right_center_boundary_idx (space) ||
161- error (" Incorrect right idx" )
162- @inbounds begin
163- vt = threadIdx (). x
164- cov3 = Geometry. Covariant3Vector (1 )
165- u[vt] = cov3 ⊗ Operators. getidx (space, arg, loc, idx, hidx)
166- rb[1 ] = cov3 ⊗ Operators. getidx (space, bc. val, loc, nothing , hidx)
123+ if in_domain (idx, arg_space)
124+ u[vt] = cov3 ⊗ Operators. getidx (space, arg, loc, idx, hidx)
125+ else # idx can be Spaces.nlevels(ᶜspace)+1 because threads must extend to faces
126+ ᶜspace = Spaces. center_space (arg_space)
127+ @assert idx == Spaces. nlevels (ᶜspace) + 1
128+ end
129+ if on_any_boundary (idx, space, op)
130+ lloc =
131+ Operators. LeftBoundaryWindow {Spaces.left_boundary_name(space)} ()
132+ rloc = Operators. RightBoundaryWindow{
133+ Spaces. right_boundary_name (space),
134+ }()
135+ bloc = on_left_boundary (idx, space, op) ? lloc : rloc
136+ @assert bloc isa typeof (lloc) && on_left_boundary (idx, space, op) ||
137+ bloc isa typeof (rloc) && on_right_boundary (idx, space, op)
138+ bc = Operators. get_boundary (op, bloc)
139+ @assert bc isa Operators. SetValue || bc isa Operators. SetGradient
140+ ub = Operators. getidx (space, bc. val, bloc, nothing , hidx)
141+ bu = on_left_boundary (idx, space) ? lb : rb
142+ if bc isa Operators. SetValue
143+ bu[1 ] = cov3 ⊗ ub
144+ elseif bc isa Operators. SetGradient
145+ lg = Geometry. LocalGeometry (space, idx, hidx)
146+ bu[1 ] = Geometry. project (Geometry. Covariant3Axis (), ub, lg)
147+ elseif bc isa Operators. Extrapolate # no shmem needed
148+ end
149+ end
167150 end
168151 return nothing
169152end
@@ -179,17 +162,28 @@ Base.@propagate_inbounds function fd_operator_evaluate(
179162)
180163 @inbounds begin
181164 vt = threadIdx (). x
182- # @assert idx.i == vt-1 # assertion passes, but commented to remove potential thrown exception in llvm output
183- if idx == Operators. right_face_boundary_idx (space)
184- u₋ = 2 * u[vt - 1 ] # corresponds to idx - half
185- u₊ = 2 * rb[1 ] # corresponds to idx + half
186- elseif idx == Operators. left_face_boundary_idx (space)
187- u₋ = 2 * lb[1 ] # corresponds to idx - half
188- u₊ = 2 * u[vt] # corresponds to idx + half
189- else
165+ lg = Geometry. LocalGeometry (space, idx, hidx)
166+ if ! on_boundary (space, op, loc, idx)
190167 u₋ = u[vt - 1 ] # corresponds to idx - half
191168 u₊ = u[vt] # corresponds to idx + half
169+ return u₊ ⊟ u₋
170+ else
171+ bc = Operators. get_boundary (op, loc)
172+ @assert bc isa Operators. SetValue
173+ if on_left_boundary (idx, space)
174+ if bc isa Operators. SetValue
175+ u₋ = 2 * lb[1 ] # corresponds to idx - half
176+ u₊ = 2 * u[vt] # corresponds to idx + half
177+ return u₊ ⊟ u₋
178+ end
179+ else
180+ @assert on_right_boundary (idx, space)
181+ if bc isa Operators. SetValue
182+ u₋ = 2 * u[vt - 1 ] # corresponds to idx - half
183+ u₊ = 2 * rb[1 ] # corresponds to idx + half
184+ return u₊ ⊟ u₋
185+ end
186+ end
192187 end
193- return u₊ ⊟ u₋
194188 end
195189end
0 commit comments