Skip to content

Commit 11875a0

Browse files
committed
Extend level and field_values to work with lazy Fields
1 parent 0d06d18 commit 11875a0

File tree

6 files changed

+95
-15
lines changed

6 files changed

+95
-15
lines changed

src/DataLayouts/DataLayouts.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ import Adapt
7171

7272
import ..Utilities: PlusHalf, unionall_type
7373
import ..DebugOnly: call_post_op_callback, post_op_callback
74-
import ..slab, ..slab_args, ..column, ..column_args, ..level
74+
import ..slab, ..slab_args, ..column, ..column_args, ..level, ..level_args
7575
export slab,
7676
column,
7777
level,
@@ -316,6 +316,9 @@ Base.parent(data::AbstractData) = getfield(data, :array)
316316

317317
Base.similar(data::AbstractData{S}) where {S} = similar(data, S)
318318

319+
@inline Base.:(==)(data1::D, data2::D) where {D <: AbstractData} =
320+
parent(data1) == parent(data2)
321+
319322
@inline function ncomponents(data::AbstractData{S}) where {S}
320323
typesize(eltype(parent(data)), S)
321324
end

src/DataLayouts/broadcast.jl

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ DataStyle(::Type{VF{S, Nv, A}}) where {S, Nv, A} =
2323
DataColumnStyle(::Type{VFStyle{Nv, A}}) where {Nv, A} = VFStyle{Nv, A}
2424
Data0DStyle(::Type{VFStyle{Nv, A}}) where {Nv, A} = DataFStyle{A}
2525

26-
abstract type Data1DStyle{Ni} <: DataStyle end
26+
abstract type DataLevelStyle <: DataStyle end
27+
abstract type Data1DStyle{Ni} <: DataLevelStyle end
2728
struct IFHStyle{Ni, A} <: Data1DStyle{Ni} end
2829
DataStyle(::Type{IFH{S, Ni, A}}) where {S, Ni, A} =
2930
IFHStyle{Ni, parent_array_type(A)}()
@@ -33,7 +34,7 @@ DataStyle(::Type{IHF{S, Ni, A}}) where {S, Ni, A} =
3334
IHFStyle{Ni, parent_array_type(A)}()
3435
Data0DStyle(::Type{IHFStyle{Ni, A}}) where {Ni, A} = DataFStyle{A}
3536

36-
abstract type DataSlab1DStyle{Ni} <: DataStyle end
37+
abstract type DataSlab1DStyle{Ni} <: DataLevelStyle end
3738
DataSlab1DStyle(::Type{IFHStyle{Ni, A}}) where {Ni, A} = IFStyle{Ni, A}
3839
DataSlab1DStyle(::Type{IHFStyle{Ni, A}}) where {Ni, A} = IFStyle{Ni, A}
3940

@@ -42,13 +43,13 @@ DataStyle(::Type{IF{S, Ni, A}}) where {S, Ni, A} =
4243
IFStyle{Ni, parent_array_type(A)}()
4344
Data0DStyle(::Type{IFStyle{Ni, A}}) where {Ni, A} = DataFStyle{A}
4445

45-
abstract type DataSlab2DStyle{Nij} <: DataStyle end
46+
abstract type DataSlab2DStyle{Nij} <: DataLevelStyle end
4647
struct IJFStyle{Nij, A} <: DataSlab2DStyle{Nij} end
4748
DataStyle(::Type{IJF{S, Nij, A}}) where {S, Nij, A} =
4849
IJFStyle{Nij, parent_array_type(A)}()
4950
Data0DStyle(::Type{IJFStyle{Nij, A}}) where {Nij, A} = DataFStyle{A}
5051

51-
abstract type Data2DStyle{Nij} <: DataStyle end
52+
abstract type Data2DStyle{Nij} <: DataLevelStyle end
5253
struct IJFHStyle{Nij, A} <: Data2DStyle{Nij} end
5354
DataStyle(::Type{IJFH{S, Nij, A}}) where {S, Nij, A} =
5455
IJFHStyle{Nij, parent_array_type(A)}()
@@ -67,6 +68,7 @@ DataStyle(::Type{VIFH{S, Nv, Ni, A}}) where {S, Nv, Ni, A} =
6768
VIFHStyle{Nv, Ni, parent_array_type(A)}()
6869
Data1DXStyle(::Type{VIFHStyle{Nv, Ni, A}}) where {Ni, Nv, A} =
6970
VIFHStyle{Nv, Ni, A}
71+
DataLevelStyle(::Type{VIFHStyle{Nv, Ni, A}}) where {Ni, Nv, A} = IFHStyle{Ni, A}
7072
DataColumnStyle(::Type{VIFHStyle{Nv, Ni, A}}) where {Ni, Nv, A} = VFStyle{Nv, A}
7173
DataSlab1DStyle(::Type{VIFHStyle{Nv, Ni, A}}) where {Ni, Nv, A} = IFStyle{Ni, A}
7274
Data0DStyle(::Type{VIFHStyle{Nv, Ni, A}}) where {Nv, Ni, A} = DataFStyle{A}
@@ -76,6 +78,7 @@ DataStyle(::Type{VIHF{S, Nv, Ni, A}}) where {S, Nv, Ni, A} =
7678
VIHFStyle{Nv, Ni, parent_array_type(A)}()
7779
Data1DXStyle(::Type{VIHFStyle{Nv, Ni, A}}) where {Ni, Nv, A} =
7880
VIHFStyle{Nv, Ni, A}
81+
DataLevelStyle(::Type{VIHFStyle{Nv, Ni, A}}) where {Ni, Nv, A} = IHFStyle{Ni, A}
7982
DataColumnStyle(::Type{VIHFStyle{Nv, Ni, A}}) where {Ni, Nv, A} = VFStyle{Nv, A}
8083
DataSlab1DStyle(::Type{VIHFStyle{Nv, Ni, A}}) where {Ni, Nv, A} = IFStyle{Ni, A}
8184
Data0DStyle(::Type{VIHFStyle{Nv, Ni, A}}) where {Nv, Ni, A} = DataFStyle{A}
@@ -86,6 +89,8 @@ DataStyle(::Type{VIJFH{S, Nv, Nij, A}}) where {S, Nv, Nij, A} =
8689
VIJFHStyle{Nv, Nij, parent_array_type(A)}()
8790
Data2DXStyle(::Type{VIJFHStyle{Nv, Nij, A}}) where {Nv, Nij, A} =
8891
VIJFHStyle{Nv, Nij, A}
92+
DataLevelStyle(::Type{VIJFHStyle{Nv, Nij, A}}) where {Nv, Nij, A} =
93+
IJFHStyle{Nij, A}
8994
DataColumnStyle(::Type{VIJFHStyle{Nv, Nij, A}}) where {Nv, Nij, A} =
9095
VFStyle{Nv, A}
9196
DataSlab2DStyle(::Type{VIJFHStyle{Nv, Nij, A}}) where {Nv, Nij, A} =
@@ -97,21 +102,18 @@ DataStyle(::Type{VIJHF{S, Nv, Nij, A}}) where {S, Nv, Nij, A} =
97102
VIJHFStyle{Nv, Nij, parent_array_type(A)}()
98103
Data2DXStyle(::Type{VIJHFStyle{Nv, Nij, A}}) where {Nv, Nij, A} =
99104
VIJHFStyle{Nv, Nij, A}
105+
DataLevelStyle(::Type{VIJHFStyle{Nv, Nij, A}}) where {Nv, Nij, A} =
106+
IJHFStyle{Nij, A}
100107
DataColumnStyle(::Type{VIJHFStyle{Nv, Nij, A}}) where {Nv, Nij, A} =
101108
VFStyle{Nv, A}
102109
DataSlab2DStyle(::Type{VIJHFStyle{Nv, Nij, A}}) where {Nv, Nij, A} =
103110
IJFStyle{Nij, A}
104111
Data0DStyle(::Type{VIJHFStyle{Nv, Nij, A}}) where {Nv, Nij, A} = DataFStyle{A}
105112

106-
const HorizontalDataStyle = Union{
107-
Data1DStyle,
108-
Data2DStyle,
109-
DataSlab1DStyle,
110-
DataSlab2DStyle,
111-
Data1DXStyle,
112-
Data2DXStyle,
113-
}
114-
DataColumnStyle(::Type{Style}) where {Style <: HorizontalDataStyle} =
113+
DataLevelStyle(::Type{Style}) where {Style <: DataLevelStyle} = Style
114+
DataLevelStyle(::Type{Style}) where {Style <: DataColumnStyle} =
115+
Data0DStyle(Style)
116+
DataColumnStyle(::Type{Style}) where {Style <: DataLevelStyle} =
115117
Data0DStyle(Style)
116118
DataSlabStyle(::Type{Style}) where {Style <: Union{Data1DStyle, Data1DXStyle}} =
117119
DataSlab1DStyle(Style)
@@ -374,6 +376,23 @@ Base.@propagate_inbounds function slab(
374376
Base.Broadcast.Broadcasted{DataSlab2DStyle(DS)}(bc.f, _args, _axes)
375377
end
376378

379+
Base.@propagate_inbounds function level(
380+
bc::Base.Broadcast.Broadcasted{DS},
381+
inds...,
382+
) where {DS <: DataStyle}
383+
_args = level_args(bc.args, inds...)
384+
_axes = nothing
385+
bcc = Base.Broadcast.Broadcasted{DataLevelStyle(DS)}(bc.f, _args, _axes)
386+
Base.Broadcast.instantiate(bcc)
387+
end
388+
389+
@inline function level(
390+
bc::Base.Broadcast.Broadcasted{DS},
391+
inds...,
392+
) where {DS <: DataLevelStyle}
393+
bc
394+
end
395+
377396
Base.@propagate_inbounds function column(
378397
bc::Base.Broadcast.Broadcasted{DS},
379398
inds...,

src/Fields/Fields.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ module Fields
22

33
import ClimaComms
44
import MultiBroadcastFusion as MBF
5-
import ..slab, ..slab_args, ..column, ..column_args, ..level
5+
import ..slab, ..slab_args, ..column, ..column_args, ..level, ..level_args
66
import ..DebugOnly: call_post_op_callback, post_op_callback
77
import ..DataLayouts:
88
DataLayouts,

src/Fields/broadcast.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ struct FieldStyle{DS <: DataStyle} <: AbstractFieldStyle end
1717
FieldStyle(::DS) where {DS <: DataStyle} = FieldStyle{DS}()
1818
FieldStyle(x::Base.Broadcast.Unknown) = x
1919

20+
FieldLevelStyle(::Type{S}) where {DS, S <: FieldStyle{DS}} =
21+
FieldStyle{DataLayouts.DataLevelStyle(DS)}
2022
FieldColumnStyle(::Type{S}) where {DS, S <: FieldStyle{DS}} =
2123
FieldStyle{DataLayouts.DataColumnStyle(DS)}
2224
FieldSlabStyle(::Type{S}) where {DS, S <: FieldStyle{DS}} =
@@ -136,6 +138,26 @@ Base.@propagate_inbounds function slab(
136138
DataLayouts.NonExtrudedBroadcasted{_Style}(bc.f, _args, _axes)
137139
end
138140

141+
Base.@propagate_inbounds function level(
142+
bc::Base.Broadcast.Broadcasted{Style},
143+
inds...,
144+
) where {Style <: AbstractFieldStyle}
145+
_Style = FieldLevelStyle(Style)
146+
_args = level_args(bc.args, inds...)
147+
_axes = level(axes(bc), inds...)
148+
Base.Broadcast.Broadcasted{_Style}(bc.f, _args, _axes)
149+
end
150+
151+
Base.@propagate_inbounds function level(
152+
bc::DataLayouts.NonExtrudedBroadcasted{Style},
153+
inds...,
154+
) where {Style <: AbstractFieldStyle}
155+
_Style = FieldLevelStyle(Style)
156+
_args = level_args(bc.args, inds...)
157+
_axes = level(axes(bc), inds...)
158+
DataLayouts.NonExtrudedBroadcasted{_Style}(bc.f, _args, _axes)
159+
end
160+
139161
Base.@propagate_inbounds function column(
140162
bc::Base.Broadcast.Broadcasted{Style},
141163
inds...,
@@ -183,6 +205,8 @@ function todata(bc::DataLayouts.NonExtrudedBroadcasted{Style}) where {Style}
183205
DataLayouts.NonExtrudedBroadcasted{Style}(bc.f, _args)
184206
end
185207

208+
field_values(bc::Base.AbstractBroadcasted) = todata(bc)
209+
186210
# same logic as Base.Broadcast.Broadcasted (which only defines it for Tuples)
187211
Base.axes(bc::Base.Broadcast.Broadcasted{<:AbstractFieldStyle}) =
188212
_axes(bc, bc.axes)

src/interface.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,3 +43,10 @@ Base.@propagate_inbounds column_args(args::NamedTuple, inds...) =
4343
NamedTuple{keys(args)}(column_args(values(args), inds...))
4444

4545
function level end
46+
47+
Base.@propagate_inbounds level(x, inds...) = x
48+
Base.@propagate_inbounds level_args(args::Tuple, inds...) =
49+
(level(args[1], inds...), level_args(Base.tail(args), inds...)...)
50+
Base.@propagate_inbounds level_args(args::Tuple{Any}, inds...) =
51+
(level(args[1], inds...),)
52+
Base.@propagate_inbounds level_args(args::Tuple{}, inds...) = ()

test/Fields/unit_field.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -662,6 +662,33 @@ end
662662
end
663663
end
664664

665+
@testset "Lazy Field broadcasts" begin
666+
FT = Float64
667+
for space in TU.all_spaces(FT)
668+
field = fill((; x = FT(1)), space)
669+
@test field == Base.materialize(lazy.(identity.(field)))
670+
@test field .+ 1 == Base.materialize(lazy.(field .+ 1))
671+
@test Fields.field_values(field .+ 1) ==
672+
Base.materialize(Fields.field_values(lazy.(field .+ 1)))
673+
end
674+
end
675+
676+
@testset "Levels of Fields and Field broadcasts" begin
677+
FT = Float64
678+
for space in TU.all_spaces(FT)
679+
TU.levelable(space) || continue
680+
field = fill((; x = FT(1)), space)
681+
level_of_field = Fields.Field(
682+
Spaces.level(Fields.field_values(field), 1),
683+
Spaces.level(space, TU.fc_index(1, space)),
684+
)
685+
@test level_of_field == Spaces.level(field, TU.fc_index(1, space))
686+
@test level_of_field == Base.materialize(
687+
Spaces.level(lazy.(identity.(field)), TU.fc_index(1, space)),
688+
)
689+
end
690+
end
691+
665692
@testset "Columns of Fields and Field broadcasts" begin
666693
FT = Float64
667694
for space in TU.all_spaces(FT)

0 commit comments

Comments
 (0)