Skip to content
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 5 additions & 7 deletions src/shared_utilities/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -593,13 +593,12 @@ function count_nans_state(
mask = nothing,
verbose = false,
)
# Note: this code uses `parent`; this pattern should not be replicated
num_nans = 0
ClimaComms.allowscalar(ClimaComms.device()) do
num_nans =
isnothing(mask) ? Int(sum(isnan.(parent(state)))) :
Int(sum(isnan.(parent(state)) .* parent(mask)))
if isnothing(mask)
num_nans = sum(@. ifelse(isnan(state), 1, 0))
else
num_nans = sum(@. ifelse(isnan(state), mask, 0))
end
num_nans = Int(num_nans)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can further optimize this and reduce some of the allocations. For the no mask case, we can go from 2 kernel launches to just one if we use mapreduce or count. I believe @kmdeck is correct, and line 597 will still allocate an intermediate. For the case with a mask, we can also get away with only one kernel launch and no allocations with mapreduce.

That could look something like

Suggested change
if isnothing(mask)
num_nans = sum(@. ifelse(isnan(state), 1, 0))
else
num_nans = sum(@. ifelse(isnan(state), mask, 0))
end
num_nans = Int(num_nans)
if isnothing(mask)
num_nans = count(isnan, parent(state)
else
num_nans = mapreduce((s, m) -> m != 0 && isnan(s), Base.add_sum, parent(state), parent(mask))
end

That would also avoid the type conversion.

I'm not sure if masks are guaranteed to be boolean valued, so I'm not sure if line 599 will always be valid

Copy link
Member

@kmdeck kmdeck Jul 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!

The mask is not boolean, because we create it using this:

apply_threshold(field, value) =

I think we did this because you cant create a field of bools by broadcasting over a field of floats. But we can do something like:

apply_threshold(field, value) =
    field > value ? 0 : 1

and then later down in the landsea_mask function, change thishttps://github.com/CliMA/ClimaLand.jl/blob/7ac4cd452663559f20df22d2a60e8f1aaa92d90f/src/shared_utilities/Domains.jl#L1088

to

binary_mask = ClimaCore.Fields.Field(Bool, surface_space)
fill!(ClimaCore.Fields.field_values(binary_mask), 0)
@. binary_mask = apply_threshold(mask, threshold)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One other thought is that we call this function for each state variable in Y, but with one exception, if one is NaN, the entire state is NaN at that same location. so maybe we can also speed it up by calling it for just one field.

I'm not sure how to do that in a nice automated way though, without making an assumption about what is in Y

if isapprox(num_nans, 0)
verbose && @info "No NaNs found"
else
Expand All @@ -608,7 +607,6 @@ function count_nans_state(
return nothing
end


"""
NaNCheckCallback(nancheck_frequency::Union{AbstractFloat, Dates.Period},
start_date, dt)
Expand Down
Loading