Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
eff6232
Faster NaN checks
MilesCranmer May 12, 2023
53b937b
Include empty array check
MilesCranmer May 12, 2023
23dd552
Fix behavior of tails in `is_good_array`
MilesCranmer May 12, 2023
414aefb
Enable `is_good_array` for complex numbers
MilesCranmer May 12, 2023
10dfdb4
More detailed NaN detection testing
MilesCranmer May 13, 2023
a748ebf
Fix NaN detector for no tail
MilesCranmer May 13, 2023
d94339b
Refactor benchmark
MilesCranmer May 13, 2023
7ba498d
Include benchmark of `is_bad_array`
MilesCranmer May 13, 2023
bd0eb4f
Fix benchmark definition
MilesCranmer May 13, 2023
5bd0e64
Test multiple sizes in NaN check benchmark
MilesCranmer May 13, 2023
1664691
Use `muladd` over `fma`
MilesCranmer May 15, 2023
c084e17
Prefer to use `Val(unroll)` over passing previous Val
MilesCranmer May 15, 2023
c696f9f
Tune unroll based on array length
MilesCranmer May 17, 2023
5030ad7
Reduce `is_good_array` with LoopVectorization.jl
MilesCranmer May 19, 2023
f1aad67
Turn off fastmath mode in turbo
MilesCranmer May 19, 2023
fa05fd0
Add `is_good_array` for non-Float32/64
MilesCranmer May 19, 2023
82beba4
Fix NaN detection test
MilesCranmer May 19, 2023
884ad20
vmapreduce version of NaN detector
MilesCranmer May 19, 2023
528b5ae
Fix `is_good_array` for empty input
MilesCranmer May 20, 2023
5ae822c
Avoid `vmapreduce` on Windows
MilesCranmer May 20, 2023
9177e69
Avoid LoopVectorization within `@generated`
MilesCranmer May 20, 2023
643e7fd
Fix `is_good_array` for empty input
MilesCranmer May 20, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 90 additions & 39 deletions benchmark/benchmarks.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using DynamicExpressions, BenchmarkTools, Random
using DynamicExpressions.EquationUtilsModule: is_constant
using DynamicExpressions.UtilsModule: is_bad_array

include("benchmark_utils.jl")

Expand All @@ -10,55 +11,88 @@ function benchmark_evaluation()
operators = OperatorEnum(;
binary_operators=[+, -, /, *], unary_operators=[cos, exp], enable_autodiff=true
)
for T in (ComplexF32, ComplexF64, Float32, Float64)
if !(T <: Real) && PACKAGE_VERSION < v"0.5.0" && PACKAGE_VERSION != v"0.0.0"
continue
end
suite[T] = BenchmarkGroup()

n = 1_000
config_options = [
[
(turbo=turbo, T=T, n=n, derivative=derivative) for turbo in (false, true) for
T in (ComplexF32, ComplexF64, Float32, Float64) for n in (100, 1_000, 10_000)
for derivative in (false, true)
]...,
]

#! format: off
for turbo in (false, true)
if turbo && !(T in (Float32, Float64))
continue
end
extra_key = turbo ? "_turbo" : ""
config_options = filter!(config_options) do config
!(config.T <: Real) &&
PACKAGE_VERSION < v"0.5.0" &&
PACKAGE_VERSION != v"0.0.0" &&
return false

config.turbo && !(config.T in (Float32, Float64)) && return false

config.T != Float32 && config.n != 1_000 && return false

config.T != Float32 && config.derivative && return false

return true
end

for config in config_options
T = config.T
turbo = config.turbo
n = config.n
derivative = config.derivative

derivative_s = derivative ? "derivative" : "evaluation"
turbo_s = turbo ? "turbo" : "standard"

haskey(suite, derivative_s) || (suite[derivative_s] = BenchmarkGroup())
haskey(suite[derivative_s], T) || (suite[derivative_s][T] = BenchmarkGroup())
haskey(suite[derivative_s][T], n) || (suite[derivative_s][T][n] = BenchmarkGroup())
haskey(suite[derivative_s][T][n], turbo_s) ||
(suite[derivative_s][T][n][turbo_s] = BenchmarkGroup())

if derivative
eval_grad_tree_array(
gen_random_tree_fixed_size(20, operators, 5, T),
randn(MersenneTwister(0), T, 5, n),
operators;
variable=true,
turbo=turbo,
)
suite[derivative_s][T][n][turbo_s] = @benchmarkable(
[
eval_grad_tree_array(tree, X, $operators; variable=true, turbo=$turbo)
for tree in trees
],
setup = (
X = randn(MersenneTwister(0), $T, 5, $n);
treesize = 20;
ntrees = 100;
trees = [
gen_random_tree_fixed_size(treesize, $operators, 5, $T) for
_ in 1:ntrees
]
)
)
else
eval_tree_array(
gen_random_tree_fixed_size(20, operators, 5, T),
randn(MersenneTwister(0), T, 5, n),
operators;
turbo=turbo
turbo=turbo,
)
suite[T]["evaluation$(extra_key)"] = @benchmarkable(
suite[derivative_s][T][n][turbo_s] = @benchmarkable(
[eval_tree_array(tree, X, $operators; turbo=$turbo) for tree in trees],
setup=(
X=randn(MersenneTwister(0), $T, 5, $n);
treesize=20;
ntrees=100;
trees=[gen_random_tree_fixed_size(treesize, $operators, 5, $T) for _ in 1:ntrees]
setup = (
X = randn(MersenneTwister(0), $T, 5, $n);
treesize = 20;
ntrees = 100;
trees = [
gen_random_tree_fixed_size(treesize, $operators, 5, $T) for
_ in 1:ntrees
]
)
)
if T <: Real
eval_grad_tree_array(
gen_random_tree_fixed_size(20, operators, 5, T),
randn(MersenneTwister(0), T, 5, n),
operators;
variable=true,
turbo=turbo
)
suite[T]["derivative$(extra_key)"] = @benchmarkable(
[eval_grad_tree_array(tree, X, $operators; variable=true, turbo=$turbo) for tree in trees],
setup=(
X=randn(MersenneTwister(0), $T, 5, $n);
treesize=20;
ntrees=100;
trees=[gen_random_tree_fixed_size(treesize, $operators, 5, $T) for _ in 1:ntrees]
)
)
end
end
#! format: on
end
return suite
end
Expand All @@ -82,7 +116,7 @@ end
f_tree_op(f::F, tree, operators) where {F} = f(tree, operators)
f_tree_op(f::F, tree) where {F} = f(tree)

function benchmark_utilities()
function tree_utilities()
suite = BenchmarkGroup()

all_funcs = (
Expand Down Expand Up @@ -138,6 +172,23 @@ function benchmark_utilities()
s
end
end
return suite
end

function benchmark_utilities()
suite = BenchmarkGroup()

suite["trees"] = tree_utilities()
suite["extra"] = let s = BenchmarkGroup()
s["is_bad_array_x16"] = BenchmarkGroup()
f(Xs) = any(is_bad_array, Xs)
for m in [50, 500, 5000]
s["is_bad_array_x16"][m] = @benchmarkable(
$(f)(Xs), setup = (Xs = ntuple(n -> randn(Float64, $m + n), 16))
)
end
s
end

return suite
end
Expand Down
13 changes: 10 additions & 3 deletions src/Utils.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Useful functions to be used throughout the library."""
module UtilsModule

using LoopVectorization: @turbo
using LoopVectorization: @turbo, vmapreduce
using MacroTools: postwalk, @capture, splitdef, combinedef

"""Remove all type assertions in an expression."""
Expand Down Expand Up @@ -74,8 +74,15 @@ macro return_on_false2(flag, retval, retval2)
end

# Fastest way to check for NaN in an array.
# (due to optimizations in sum())
is_bad_array(array) = !(isempty(array) || isfinite(sum(array)))
# Thanks @mikmore https://discourse.julialang.org/t/fastest-way-to-check-for-inf-or-nan-in-an-array/76954/33?u=milescranmer
is_bad_array(x) = !is_good_array(x)
function is_good_array(x::AbstractArray{T}) where {T}
isempty(x) && return true
IS_WINDOWS && return sum(xi -> xi * zero(xi), x) == zero(T)
return vmapreduce(xi -> xi * zero(xi), +, x) == zero(T)
end
const IS_WINDOWS = Sys.iswindows()

isgood(x::T) where {T<:Number} = !(isnan(x) || !isfinite(x))
isgood(x) = true
isbad(x) = !isgood(x)
Expand Down
33 changes: 29 additions & 4 deletions test/test_nan_detection.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
println("Testing NaN detection.")
using DynamicExpressions
using Test

Expand Down Expand Up @@ -29,8 +28,34 @@ function run_nan_detection_test(T)
@test !flag
end

for T in [Float16, Float32, Float64]
run_nan_detection_test(T)
@testset "Simple NaN detections" begin
for T in [Float16, Float32, Float64]
@testset "NaN detection with $T" begin
run_nan_detection_test(T)
end
end
end

println("Passed.")
using DynamicExpressions.UtilsModule: is_bad_array
using StaticArrays

function manual_nan_test(
::Type{T}, array_size, nan_location, ::Val{static_array}
) where {T,static_array}
x = ones(T, array_size)
x = static_array ? MVector{array_size}(x) : x
@test !is_bad_array(x)
x[nan_location] = T(NaN)
@test is_bad_array(x)
end

@testset "Manual NaN tests" begin
unroll_size = 32
for T in [Float16, Float32, Float64, ComplexF16, ComplexF32, ComplexF64],
array_size in 1:(2 * unroll_size + 1),
nan_location in 1:array_size,
static_array in [false, true]

manual_nan_test(T, array_size, nan_location, static_array ? Val(true) : Val(false))
end
end