Skip to content

Commit 014debf

Browse files
committed
Refactor count_nans_state
Remove allow scalar indexing (which wasn't needed before as well). Reduce kernel launches when using gpu by using mapreduce. Credit to @riteshbhirud who created this change in #1215. That PR was innactive for a few months, and only needed the formatter applied. This copies those changes, and applies the formatting.
1 parent 7975335 commit 014debf

File tree

1 file changed

+9
-6
lines changed

1 file changed

+9
-6
lines changed

src/shared_utilities/utils.jl

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -609,12 +609,15 @@ function count_nans_state(
609609
mask = nothing,
610610
verbose = false,
611611
)
612-
# Note: this code uses `parent`; this pattern should not be replicated
613-
num_nans = 0
614-
ClimaComms.allowscalar(ClimaComms.device()) do
615-
num_nans =
616-
isnothing(mask) ? Int(sum(isnan.(parent(state)))) :
617-
Int(sum(isnan.(parent(state)) .* parent(mask)))
612+
if isnothing(mask)
613+
num_nans = count(isnan, parent(state))
614+
else
615+
num_nans = mapreduce(
616+
(s, m) -> m != 0 && isnan(s),
617+
Base.add_sum,
618+
parent(state),
619+
parent(mask),
620+
)
618621
end
619622
if isapprox(num_nans, 0)
620623
verbose && @info "No NaNs found"

0 commit comments

Comments
 (0)