Skip to content

Commit 198c602

Browse files
committed
intercept column spectral broadcast
1 parent 2411834 commit 198c602

File tree

2 files changed

+12
-11
lines changed

2 files changed

+12
-11
lines changed

ext/cuda/operators_spectral_element.jl

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -35,17 +35,6 @@ function Base.copyto!(
3535
mask = DataLayouts.NoMask(),
3636
)
3737
space = axes(out)
38-
# Column spaces don't have horizontal spectral element structure,
39-
# so operators with empty axes just return zero without launching a kernel
40-
if space isa Spaces.FiniteDifferenceSpace
41-
fill!(
42-
parent(Fields.field_values(out)),
43-
zero(eltype(parent(Fields.field_values(out)))),
44-
)
45-
Operators.call_post_op_callback() &&
46-
Operators.post_op_callback(out, out, sbc)
47-
return out
48-
end
4938
us = UniversalSize(Fields.field_values(out))
5039
# executed
5140
p = spectral_partition(us)

src/Operators/spectralelement.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,12 @@ function Base.Broadcast.instantiate(sbc::SpectralBroadcasted)
130130
Base.Broadcast.check_broadcast_axes(axes, args...)
131131
end
132132
end
133+
# For FiniteDifferenceSpace, fall back to pointwise broadcasting
134+
if axes isa Spaces.FiniteDifferenceSpace
135+
return Base.Broadcast.instantiate(
136+
Base.Broadcast.Broadcasted{Fields.FieldStyle}(op, args, axes)
137+
)
138+
end
133139
# If we've already instantiated, then we need to strip the type parameters,
134140
# for example, `Divergence{()}(axes)`.
135141
op = unionall_type(typeof(op)){()}(axes)
@@ -149,6 +155,12 @@ function Base.Broadcast.instantiate(
149155
axes = bc.axes
150156
Base.Broadcast.check_broadcast_axes(axes, args...)
151157
end
158+
# For FiniteDifferenceSpace, fall back to pointwise broadcasting
159+
if axes isa Spaces.FiniteDifferenceSpace
160+
return Base.Broadcast.instantiate(
161+
Base.Broadcast.Broadcasted{Fields.FieldStyle}(bc.f, args, axes)
162+
)
163+
end
152164
Style = AbstractSpectralStyle(ClimaComms.device(axes))
153165
return Base.Broadcast.Broadcasted{Style}(bc.f, args, axes)
154166
end

0 commit comments

Comments
 (0)