Skip to content
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/AbstractOperations/AbstractOperations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ using Oceananigans.Fields
using Oceananigans.Utils

using Oceananigans: location
using Oceananigans.Fields: instantiated_location
using Oceananigans.Operators: interpolation_operator

import Adapt
Expand Down
32 changes: 16 additions & 16 deletions src/AbstractOperations/binary_operations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,21 +35,21 @@ indices(β::BinaryOperation) = construct_regionally(intersect_indices, location(

"""Create a binary operation for `op` acting on `a` and `b` at `Lc`, where
`a` and `b` have location `La` and `Lb`."""
function _binary_operation(Lc, op, a, b, La, Lb, grid)
function _binary_operation(Lc::Tuple{LX, LY, LZ}, op, a, b, La, Lb, grid) where {LX, LY, LZ}
▶a = interpolation_operator(La, Lc)
▶b = interpolation_operator(Lb, Lc)

return BinaryOperation{Lc[1], Lc[2], Lc[3]}(op, a, b, ▶a, ▶b, grid)
return BinaryOperation{LX, LY, LZ}(op, a, b, ▶a, ▶b, grid)
end

const ConcreteLocationType = Union{Type{Face}, Type{Center}}
const ConcreteLocationType = Union{Face, Center}

# Precedence rules for choosing operation location:
choose_location(La, Lb, Lc) = Lc # Fallback to the specification Lc, but also...
choose_location(::Type{Face}, ::Type{Face}, Lc) = Face # keep common locations; and
choose_location(::Type{Center}, ::Type{Center}, Lc) = Center #
choose_location(La::ConcreteLocationType, ::Type{Nothing}, Lc) = La # don't interpolate unspecified locations.
choose_location(::Type{Nothing}, Lb::ConcreteLocationType, Lc) = Lb #
choose_location(::Face, ::Face, Lc) = Face # keep common locations; and
choose_location(::Center, ::Center, Lc) = Center #
choose_location(La::ConcreteLocationType, ::Nothing, Lc) = La # don't interpolate unspecified locations.
choose_location(::Nothing, Lb::ConcreteLocationType, Lc) = Lb #

# Apply the function if the inputs are scalars, otherwise broadcast it over the inputs
# This can occur in the binary operator code if we index into with an array, e.g. array[1:10]
Expand Down Expand Up @@ -107,8 +107,8 @@ function define_binary_operator(op)
if that is also Nothing, `Lc`.
"""
function $op(Lc::Tuple, a, b)
La = location(a)
Lb = location(b)
La = instantiated_location(a)
Lb = instantiated_location(b)
Lab = choose_location.(La, Lb, Lc)

grid = Oceananigans.AbstractOperations.validate_grid(a, b)
Expand All @@ -127,15 +127,15 @@ function define_binary_operator(op)
$op(Lc::Tuple, a::AbstractField, m::GridMetric) = $op(Lc, a, grid_metric_operation(location(a), m, a.grid))

# Sugary versions with default locations
$op(a::AF, b::AF) = $op(location(a), a, b)
$op(a::AF, b) = $op(location(a), a, b)
$op(a, b::AF) = $op(location(b), a, b)
$op(a::AF, b::AF) = $op(instantiated_location(a), a, b)
$op(a::AF, b) = $op(instantiated_location(a), a, b)
$op(a, b::AF) = $op(instantiated_location(b), a, b)

$op(a::AF, b::Number) = $op(location(a), a, b)
$op(a::Number, b::AF) = $op(location(b), a, b)
$op(a::AF, b::Number) = $op(instantiated_location(a), a, b)
$op(a::Number, b::AF) = $op(instantiated_location(b), a, b)

$op(a::AF, b::ConstantField) = $op(location(a), a, b.constant)
$op(a::ConstantField, b::AF) = $op(location(b), a.constant, b)
$op(a::AF, b::ConstantField) = $op(instantiated_location(a), a, b.constant)
$op(a::ConstantField, b::AF) = $op(instantiated_location(b), a.constant, b)

$op(a::Number, b::ConstantField) = ConstantField($op(a, b.constant))
$op(a::ConstantField, b::Number) = ConstantField($op(a.constant, b))
Expand Down
22 changes: 9 additions & 13 deletions src/AbstractOperations/conditional_operations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ struct ConditionalOperation{LX, LY, LZ, F, C, O, G, M, T} <: AbstractOperation{L
end

# Some special cases
const NoFuncCO = ConditionalOperation{<:Any, <:Any, <:Any, Nothing}
const NoConditionCO = ConditionalOperation{<:Any, <:Any, <:Any, <:Any, Nothing}
const NoFuncNoConditionCO = ConditionalOperation{<:Any, <:Any, <:Any, Nothing, Nothing}
const NoFuncCO{LX, LY, LZ} = ConditionalOperation{LX, LY, LZ, Nothing}
const NoConditionCO{LX, LY, LZ} = ConditionalOperation{LX, LY, LZ, <:Any, Nothing}
const NoFuncNoConditionCO{LX, LY, LZ} = ConditionalOperation{LX, LY, LZ, Nothing, Nothing}

"""
ConditionalOperation(operand::AbstractField;
Expand Down Expand Up @@ -91,12 +91,11 @@ julia> d[2, 1, 1]
10.0
```
"""
function ConditionalOperation(operand::AbstractField;
function ConditionalOperation(operand::AbstractField{LX, LY, LZ};
func = nothing,
condition = nothing,
mask = zero(eltype(operand)))
mask = zero(eltype(operand))) where {LX, LY, LZ}
condition = validate_condition(condition, operand)
LX, LY, LZ = location(operand)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this has an effect (but tell me if its not true). I think this compiles identically. We use the function since its the recommendation of YASGuide (eg type parameters should be used for dispatch, not to access type info). It's not that important, but wanted to mention.

Copy link
Collaborator Author

@simone-silvestri simone-silvestri Oct 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is a matter of passing arguments directly as parameteric types, which is type unstable, rather than inferring a type. The constructor cannot infer the values of an argument, but can infer the type of the arguments that are passed. Check the difference between these two constructors (which basically build the same object with just a different name)

struct InstableType{T1, T2, T3} end
struct StableType{T1, T2, T3} end

StableType(t::Tuple{T1, T2, T3}) where {T1, T2, T3} = StableType{T1, T2, T3}()
InstableType(T::Tuple) = InstableType{T[1], T[2], T[3]}()

@code_warntype StableType((1.0, 1.f0, Float16(1.0))) # This is type stable
@code_warntype InstableType((Float64, Float32, Float16)) # This is type - unstable

Copy link
Collaborator Author

@simone-silvestri simone-silvestri Oct 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, you are right. In the above case, there is no issue because LX, LY, LZ should be inferred in the function call. I will revert it. The problem arises only for constructors like this

function _binary_operation(Lc, op, a, b, La, Lb, grid)
▶a = interpolation_operator(La, Lc)
▶b = interpolation_operator(Lb, Lc)
return BinaryOperation{Lc[1], Lc[2], Lc[3]}(op, a, b, ▶a, ▶b, grid)
end

which should actually be

 function _binary_operation(Lc::Tuple{LX, LY, LZ}, op, a, b, La, Lb, grid) where {LX, LY, LZ}
      ▶a = interpolation_operator(La, Lc) 
      ▶b = interpolation_operator(Lb, Lc) 
  
     return BinaryOperation{LX, LY, LZ}(op, a, b, ▶a, ▶b, grid) 
 end 

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got it!

return ConditionalOperation{LX, LY, LZ}(operand, func, operand.grid, condition, mask)
end

Expand All @@ -111,23 +110,20 @@ function validate_condition(cond::AbstractArray, operand::AbstractField)
return cond
end

function ConditionalOperation(c::ConditionalOperation;
function ConditionalOperation(c::ConditionalOperation{LX, LY, LZ};
func = c.func,
condition = c.condition,
mask = c.mask)
mask = c.mask) where {LX, LY, LZ}
condition = validate_condition(condition, operand)
LX, LY, LZ = location(c)
compined_func = func ∘ c.func

return ConditionalOperation{LX, LY, LZ}(c.operand, compined_func, c.grid, condition, mask)
end

function ConditionalOperation(c::NoFuncCO;
function ConditionalOperation(c::NoFuncCO{LX, LY, LZ};
func = c.func,
condition = c.condition,
mask = c.mask)
mask = c.mask) where {LX, LY, LZ}
condition = validate_condition(condition, operand)
LX, LY, LZ = location(c)
return ConditionalOperation{LX, LY, LZ}(c.operand, func, c.grid, condition, mask)
end

Expand Down
18 changes: 10 additions & 8 deletions src/AbstractOperations/derivatives.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ end

"""Create a derivative operator `∂` acting on `arg` at `L∂`, followed by
interpolation to `L` on `grid`."""
function _derivative(L, ∂, arg, L∂, abstract_∂, grid)
function _derivative(L::Tuple{LX, LY, LZ}, ∂, arg, L∂, abstract_∂, grid) where {LX, LY, LZ}
▶ = interpolation_operator(L∂, L)
return Derivative{L[1], L[2], L[3]}(∂, arg, ▶, abstract_∂, grid)
return Derivative{LX, LY, LZ}(∂, arg, ▶, abstract_∂, grid)
end

indices(d::Derivative) = indices(d.arg)
Expand All @@ -41,6 +41,8 @@ indices(d::Derivative) = indices(d.arg)
"""Return `Center` if given `Face` or `Face` if given `Center`."""
flip(::Type{Face}) = Center
flip(::Type{Center}) = Face
flip(::Face) = Center()
flip(::Center) = Face()

const LocationType = Union{Type{Face}, Type{Center}, Type{Nothing}}

Expand All @@ -63,7 +65,7 @@ Return an abstract representation of an ``x``-derivative acting on field `arg` f
by interpolation to `L`, where `L` is a 3-tuple of `Face`s and `Center`s.
"""
∂x(L::Tuple, arg::AF{LX, LY, LZ}) where {LX, LY, LZ} =
_derivative(L, ∂x(LX, LY, LZ), arg, (flip(LX), LY, LZ), ∂x, arg.grid)
_derivative(L, ∂x(LX(), LY(), LZ()), arg, (flip(LX()), LY(), LZ()), ∂x, arg.grid)

"""
∂y(L::Tuple, arg::AbstractField)
Expand All @@ -72,7 +74,7 @@ Return an abstract representation of a ``y``-derivative acting on field `arg` fo
by interpolation to `L`, where `L` is a 3-tuple of `Face`s and `Center`s.
"""
∂y(L::Tuple, arg::AF{LX, LY, LZ}) where {LX, LY, LZ} =
_derivative(L, ∂y(LX, LY, LZ), arg, (LX, flip(LY), LZ), ∂y, arg.grid)
_derivative(L, ∂y(LX(), LY(), LZ()), arg, (LX(), flip(LY()), LZ()), ∂y, arg.grid)

"""
∂z(L::Tuple, arg::AbstractField)
Expand All @@ -81,7 +83,7 @@ Return an abstract representation of a ``z``-derivative acting on field `arg` fo
by interpolation to `L`, where `L` is a 3-tuple of `Face`s and `Center`s.
"""
∂z(L::Tuple, arg::AF{LX, LY, LZ}) where {LX, LY, LZ} =
_derivative(L, ∂z(LX, LY, LZ), arg, (LX, LY, flip(LZ)), ∂z, arg.grid)
_derivative(L, ∂z(LX(), LY(), LZ()), arg, (LX(), LY(), flip(LZ())), ∂z, arg.grid)

# Defaults

Expand All @@ -90,21 +92,21 @@ by interpolation to `L`, where `L` is a 3-tuple of `Face`s and `Center`s.

Return an abstract representation of a ``x``-derivative acting on field `arg`.
"""
∂x(arg::AF{LX, LY, LZ}) where {LX, LY, LZ} = ∂x((flip(LX), LY, LZ), arg)
∂x(arg::AF{LX, LY, LZ}) where {LX, LY, LZ} = ∂x((flip(LX()), LY(), LZ()), arg)

"""
∂y(arg::AbstractField)

Return an abstract representation of a ``y``-derivative acting on field `arg`.
"""
∂y(arg::AF{LX, LY, LZ}) where {LX, LY, LZ} = ∂y((LX, flip(LY), LZ), arg)
∂y(arg::AF{LX, LY, LZ}) where {LX, LY, LZ} = ∂y((LX(), flip(LY()), LZ()), arg)

"""
∂z(arg::AbstractField)

Return an abstract representation of a ``z``-derivative acting on field `arg`.
"""
∂z(arg::AF{LX, LY, LZ}) where {LX, LY, LZ} = ∂z((LX, LY, flip(LZ)), arg)
∂z(arg::AF{LX, LY, LZ}) where {LX, LY, LZ} = ∂z((LX(), LY(), flip(LZ())), arg)

#####
##### Nested computations
Expand Down
9 changes: 4 additions & 5 deletions src/AbstractOperations/grid_metrics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,14 +79,13 @@ julia> c_dz[1, 1, 1]
3.0
```
"""
grid_metric_operation(loc, metric, grid) =
KernelFunctionOperation{loc[1], loc[2], loc[3]}(metric_function(loc, metric), grid)
grid_metric_operation(loc::Tuple{LX, LY, LZ}, metric, grid) where {LX, LY, LZ} =
KernelFunctionOperation{LX, LY, LZ}(metric_function(loc, metric), grid)

const NodeMetric = Union{XNode, YNode, ZNode, ΛNode, ΦNode, RNode}

function grid_metric_operation(loc, metric::NodeMetric, grid)
LX, LY, LZ = loc
ℓx, ℓy, ℓz = LX(), LY(), LZ()
function grid_metric_operation(loc::Tuple{LX, LY, LZ}, metric::NodeMetric, grid) where {LX, LY, LZ}
ℓx, ℓy, ℓz = loc
ξnode = metric_function(loc, metric)
return KernelFunctionOperation{LX, LY, LZ}(ξnode, grid, ℓx, ℓy, ℓz)
end
Expand Down
6 changes: 3 additions & 3 deletions src/AbstractOperations/multiary_operations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ end

indices(Π::MultiaryOperation) = construct_regionally(intersect_indices, location(Π), Π.args...)

function _multiary_operation(L, op, args, Largs, grid)
function _multiary_operation(L::Tuple{LX, LY, LZ}, op, args, Largs, grid) where {LX, LY, LZ}
▶ = Tuple(interpolation_operator(La, L) for La in Largs)
return MultiaryOperation{L[1], L[2], L[3]}(op, Tuple(a for a in args), ▶, grid)
return MultiaryOperation{LX, LY, LZ}(op, Tuple(a for a in args), ▶, grid)
end

# Recompute location of multiary operation
Expand Down Expand Up @@ -52,7 +52,7 @@ function define_multiary_operator(op)
$op(a::Oceananigans.Fields.AbstractField,
b::Union{Function, Oceananigans.Fields.AbstractField},
c::Union{Function, Oceananigans.Fields.AbstractField},
d::Union{Function, Oceananigans.Fields.AbstractField}...) = $op(Oceananigans.Fields.location(a), a, b, c, d...)
d::Union{Function, Oceananigans.Fields.AbstractField}...) = $op(Oceananigans.Fields.instantiated_location(a), a, b, c, d...)
end
end

Expand Down
4 changes: 2 additions & 2 deletions src/AbstractOperations/unary_operations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ indices(υ::UnaryOperation) = indices(υ.arg)

"""Create a unary operation for `operator` acting on `arg` which interpolates the
result from `Larg` to `L`."""
function _unary_operation(L, operator, arg, Larg, grid)
function _unary_operation(L::Tuple{LX, LY, LZ}, operator, arg, Larg, grid) where {LX, LY, LZ}
▶ = interpolation_operator(Larg, L)
return UnaryOperation{L[1], L[2], L[3]}(operator, arg, ▶, grid)
return UnaryOperation{LX, LY, LZ}(operator, arg, ▶, grid)
end

# Recompute location of unary operation
Expand Down
10 changes: 7 additions & 3 deletions src/Fields/scans.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,12 @@ Base.summary(::Accumulating) = "Accumulating"
const Reduction = Scan{<:AbstractReducing}
const Accumulation = Scan{<:AbstractAccumulating}

scan_indices(::AbstractReducing, indices; dims) = Tuple(i ∈ dims ? Colon() : indices[i] for i in 1:3)
scan_indices(::AbstractAccumulating, indices; dims) = indices
@inline location(s::Scan) = location(s.operand)
@inline instantiated_location(s::Scan) = instantiated_location(s.operand)

scan_indices(::AbstractReducing, indices, dims) = Tuple(i ∈ dims ? Colon() : indices[i] for i in 1:3)
scan_indices(::AbstractAccumulating, indices, dims) = indices
scan_indices(::AbstractReducing, ::Tuple{Colon, Colon, Colon}, dims) = (:, :, :)

Base.summary(s::Scan) = string(summary(s.type), " ",
s.scan!,
Expand All @@ -52,7 +56,7 @@ function Field(scan::Scan;
grid = operand.grid
LX, LY, LZ = loc = instantiated_location(scan)
dims = filter_nothing_dims(scan.dims, loc)
indices = scan_indices(scan.type, indices; dims)
indices = scan_indices(scan.type, indices, dims)

if isnothing(data)
data = new_data(grid, loc, indices)
Expand Down
Loading