Skip to content

Commit b3f3f12

Browse files
committed
Implement findmin, findmax, argmin, argmax
1 parent 65a5928 commit b3f3f12

File tree

2 files changed

+349
-0
lines changed

2 files changed

+349
-0
lines changed

src/NaNMath.jl

Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,4 +351,220 @@ for f in (:min, :max)
351351
@eval ($f)(a, b, c, xs...) = Base.afoldl($f, ($f)(($f)(a, b), c), xs...)
352352
end
353353

354+
# The functions `findmin`, `findmax`, `argmin`, and `argmax` are supported
355+
# to work correctly for the following iterable types:
356+
_valtype(x::AbstractArray{T}) where T<:AbstractFloat = eltype(x)
357+
_valtype(x::Tuple{Vararg{T} where T<:AbstractFloat}) = eltype(x)
358+
_valtype(x::NamedTuple{syms, <:Tuple{Vararg{T} where T<:AbstractFloat}}) where {syms} = eltype(x)
359+
_valtype(x::AbstractDict{K,T}) where {K,T<:AbstractFloat} = valtype(x)
360+
_valtype(x) = error(
361+
"Iterables with value type AbstractFloat or its subtypes are supported.
362+
The provided input type $(typeof(x)) is not.
363+
Consider using the convert function before passing the iterable argument."
364+
365+
)
366+
367+
function _find_extreme(f,compare_op::Function, x)
368+
result_index = 1 # Note: default index value.
369+
result_value = convert(_valtype(x), NaN)
370+
371+
for (k, v) in pairs(x)
372+
if !isnan(v)
373+
if (isnan(result_value) || compare_op(f(v),f(result_value)))
374+
result_index = k
375+
result_value = v
376+
end
377+
end
378+
end
379+
return f(result_value), result_index
380+
end
381+
382+
"""
383+
NaNMath.findmin(f, domain) -> (f(x), index)
384+
385+
NaNMath.findmin(domain) -> (x, index)
386+
387+
##### Args:
388+
* `f`: A function applied to the elements of `domain`;
389+
defaults to `identity` when `domain` is the only argument.
390+
* `domain`: A non-empty collection of floating point numbers such that
391+
`f` is defined on elements of `domain`.
392+
393+
##### Returns:
394+
* Returns a `Tuple` consisting of a value `f(x)` and the index of `x`
395+
in `domain`, ignoring NaN's, such that `f(x)` is minimized.
396+
If there are multiple minimal elements, then the first one will be returned.
397+
398+
If `domain` is a `NamedTuple` or dictionary-like `AbstractDict` L,
399+
the function is applied to its values. The returned index is a key `k`,
400+
such that `f(L[k])` is minimized.
401+
402+
##### Examples:
403+
```julia
404+
julia> NaNMath.findmin([1., 1., 2., 2., NaN])
405+
(1.0, 1)
406+
407+
julia> NaNMath.findmin(-, [1., 1., 2., 2., NaN])
408+
(-2.0, 3)
409+
410+
julia> NaNMath.findmin(abs, Dict(:x => 3.0, :w => -2.2, :y => -3.0, :z => NaN))
411+
(2.2, :w)
412+
```
413+
"""
414+
function findmin end
415+
findmin(f,x) = _find_extreme(f,<,x)
416+
findmin(x) = findmin(identity,x)
417+
418+
"""
419+
NaNMath.findmax(f, domain) -> (f(x), index)
420+
421+
NaNMath.findmax(domain) -> (x, index)
422+
423+
##### Args:
424+
* `f`: A function applied to the elements of `domain`;
425+
defaults to `identity` when `domain` is the only argument.
426+
* `domain`: A non-empty collection of floating point numbers such that
427+
`f` is defined on elements of `domain`.
428+
429+
##### Returns:
430+
* Returns a `Tuple` consisting of a value `f(x)` and the index of `x`
431+
in `domain`, ignoring NaN's, such that `f(x)` is maximized.
432+
If there are multiple maximal elements, then the first one will be returned.
433+
434+
If `domain` is a `NamedTuple` or dictionary-like `AbstractDict` L,
435+
the function `f` is applied to its values. The returned index is a key `k`,
436+
such that `f(L[k])` is maximized.
437+
438+
##### Examples:
439+
```julia
440+
julia> NaNMath.findmax([1., 1., 2., 2., NaN])
441+
(2.0, 3)
442+
443+
julia> NaNMath.findmax(-, [1., 1., 2., 2., NaN])
444+
(-1.0, 1)
445+
446+
julia> NaNMath.findmax(abs, Dict(:x => 3.0, :w => -2.2, :y => -3.0, :z => NaN))
447+
(3.0, :y)
448+
```
449+
"""
450+
function findmax end
451+
findmax(f,x) = _find_extreme(f,>,x)
452+
findmax(x) = findmax(identity,x)
453+
454+
"""
455+
NaNMath.argmin(f, domain) -> x
456+
457+
##### Args:
458+
* `f`: A function applied to the elements of `domain`;
459+
defaults to `identity` when `domain` is the only argument.
460+
* `domain`: A non-empty collection of floating point numbers such that
461+
`f` is defined on elements of `domain`.
462+
463+
##### Returns:
464+
* Returns a value `x` in `domain`, ignoring NaN's, for which `f(x)` is minimized.
465+
If there are multiple minimal values for `f(x)`, then the first one will be returned.
466+
467+
If `domain` is a `NamedTuple` or dictionary-like `AbstractDict` L,
468+
the function is applied to its values. The returned value is `L[k]` for some key `k`
469+
such that `f(L[k])` is minimal.
470+
471+
##### Examples:
472+
```julia
473+
julia> NaNMath.argmin(abs, [1., -1., -2., 2., NaN])
474+
1.0
475+
476+
julia> NaNMath.argmin(identity, [7, 1, 1, NaN])
477+
1.0
478+
```
479+
480+
julia> NaNMath.argmin(exp,Dict("x" => 1.0, "y" => -1.2, "z" => NaN))
481+
-1.2
482+
483+
───────────────────────────────────────────────────────────
484+
485+
NaNMath.argmin(itr) -> key
486+
487+
##### Args:
488+
* `itr`: A non-empty iterable of floating point numbers.
489+
490+
##### Returns:
491+
* Returns the index or key of the minimal element in `itr`, ignoring NaN's.
492+
If there are multiple minimal elements, then the first one will be returned.
493+
494+
If `itr` is a `NamedTuple` or dictionary-like `AbstractDict` L, the returned index is a key `k`,
495+
such that `f(L[k])` is minimal.
496+
497+
##### Examples:
498+
```julia
499+
julia> NaNMath.argmin([7, 1, 1, NaN])
500+
2
501+
502+
julia> NaNMath.argmin([1.0 2; 3 NaN])
503+
CartesianIndex(1, 1)
504+
505+
julia> NaNMath.argmin(Dict("x" => 1.0, "y" => -1.2, "z" => NaN))
506+
"y"
507+
```
508+
"""
509+
function argmin end
510+
argmin(f,x) = getindex(x,findmin(f,x)[2])
511+
argmin(x) = findmin(identity,x)[2]
512+
513+
"""
514+
NaNMath.argmax(f, domain) -> x
515+
516+
##### Args:
517+
* `f`: A function applied to the elements of `domain`;
518+
defaults to `identity` when `domain` is the only argument.
519+
* `domain`: A non-empty collection of floating point numbers such that
520+
`f` is defined on elements of `domain`.
521+
522+
##### Returns:
523+
* Returns a value `x` in `domain`, ignoring NaN's, for which `f(x)` is maximized.
524+
If there are multiple maximal values for `f(x)`, then the first one will be returned.
525+
526+
If `domain` is a `NamedTuple` or dictionary-like `AbstractDict` L,
527+
the function is applied to its values. The returned value is `L[k]` for some key `k`
528+
such that `f(L[k])` is maximal.
529+
530+
##### Examples:
531+
```julia
532+
julia> NaNMath.argmax(abs, [1., -1., -2., NaN])
533+
-2.0
534+
535+
julia> NaNMath.argmax(identity, [7, 1, 1, NaN])
536+
7.0
537+
```
538+
539+
───────────────────────────────────────────────────────────
540+
541+
NaNMath.argmax(itr) -> key
542+
543+
##### Args:
544+
* `itr`: A non-empty iterable of floating point numbers.
545+
546+
##### Returns:
547+
* Returns the index or key of the maximal element in `itr`, ignoring NaN's.
548+
If there are multiple maximal elements, then the first one will be returned.
549+
550+
If `itr` is a `NamedTuple` or dictionary-like `AbstractDict` L, the returned index is a key `k`,
551+
such that `f(L[k])` is maximal.
552+
553+
554+
##### Examples:
555+
```julia
556+
julia> NaNMath.argmax([7, 1, 1, NaN])
557+
1
558+
559+
julia> NaNMath.argmax([1.0 2; 3 NaN])
560+
CartesianIndex(2, 1)
561+
562+
julia> NaNMath.argmax(Dict("x" => 1.0, "y" => -1.2, "z" => NaN))
563+
"x"
564+
```
565+
"""
566+
function argmax end
567+
argmax(x) = findmax(identity,x)[2]
568+
argmax(f,x) = getindex(x,findmax(f,x)[2])
569+
354570
end

test/runtests.jl

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,3 +70,136 @@ using Test
7070
@test isnan(NaNMath.max(NaN, NaN))
7171
@test isnan(NaNMath.max(NaN))
7272
@test NaNMath.max(NaN, NaN, 0.0, 1.0) == 1.0
73+
74+
## Based on https://github.com/sethaxen/NaNMath.jl/blob/41b3e7edd9dd4cb6c2873abf6e0d90acf43138ec/test/runtests.jl
75+
@testset "findmin/findmax" begin
76+
if VERSION v"1.7"
77+
xvals = [
78+
[1., 2., 3., 3., 1.],
79+
(1., 2., 3., 3., .1),
80+
(1f0, 2f0, 3f0, -1f0),
81+
(x=1.0, y=3f0, z=-4.0, w=-2f0),
82+
Dict(:a => 1.0, :b => 1.0, :d => 3.0, :c => 2.0),
83+
]
84+
@testset for x in xvals
85+
@test NaNMath.findmin(x) === findmin(x)
86+
@test NaNMath.findmax(x) === findmax(x)
87+
@test NaNMath.findmin(identity, x) === findmin(identity, x)
88+
@test NaNMath.findmax(identity, x) === findmax(identity, x)
89+
@test NaNMath.findmin(sin, x) === findmin(sin, x)
90+
@test NaNMath.findmax(sin, x) === findmax(sin, x)
91+
end
92+
end
93+
94+
x = [7, 7, NaN, 1, 1, NaN]
95+
@test NaNMath.findmin(x) === (1.0, 4)
96+
@test NaNMath.findmax(x) === (7.0, 1)
97+
@test NaNMath.findmin(identity, x) === (1.0, 4)
98+
@test NaNMath.findmax(identity, x) === (7.0, 1)
99+
@test NaNMath.findmin(-, x) === (-7.0, 1)
100+
@test NaNMath.findmax(-, x) === (-1.0, 4)
101+
102+
x = [NaN, NaN]
103+
@test NaNMath.findmin(x) === (NaN, 1)
104+
@test NaNMath.findmax(x) === (NaN, 1)
105+
@test NaNMath.findmin(identity, x) === (NaN, 1)
106+
@test NaNMath.findmax(identity, x) === (NaN, 1)
107+
@test NaNMath.findmin(sin, x) === (NaN, 1)
108+
@test NaNMath.findmax(sin, x) === (NaN, 1)
109+
110+
x = Dict(:a => 1.0, :b => 1 + 2im, :d => 3.0, :c => 2.0)
111+
@test_throws ErrorException NaNMath.findmin(x)
112+
@test_throws ErrorException NaNMath.findmax(x)
113+
114+
x = [3, missing, NaN, -1]
115+
@test_throws ErrorException NaNMath.findmin(x)
116+
117+
x = Dict('a' => 1.0, missing => NaN, 'c' => 2.0)
118+
@test NaNMath.findmin(x) === (1.0, 'a')
119+
@test NaNMath.findmax(x) === (2.0, 'c')
120+
121+
x = Dict(:x => 3.0, :w => 2f0, :y => -1.0, :z => NaN)
122+
@test NaNMath.findmin(x) === (-1.0, :y)
123+
@test NaNMath.findmax(x) === (3.0, :x)
124+
@test NaNMath.findmin(identity, x) === (-1.0, :y)
125+
@test NaNMath.findmax(identity, x) === (3.0, :x)
126+
@test NaNMath.findmin(-, x) === (-3.0, :x)
127+
@test NaNMath.findmax(-, x) === (1.0, :y)
128+
@test NaNMath.findmin(exp, x) === (exp(-1.0), :y)
129+
@test NaNMath.findmax(exp, x) === (exp(3.0), :x)
130+
131+
x = (x=1.0, y=NaN, z=NaN, w=-2.0)
132+
@test NaNMath.findmin(x) === (-2.0, :w)
133+
@test NaNMath.findmax(x) === (1.0, :x)
134+
@test NaNMath.findmin(-,x) === (-1.0, :x)
135+
@test NaNMath.findmax(-,x) === (2.0, :w)
136+
137+
x = [2.0 3.0; 2.0 -1.0]
138+
@test NaNMath.findmin(x) === (-1.0, CartesianIndex(2, 2))
139+
@test NaNMath.findmax(x) === (3.0, CartesianIndex(1, 2))
140+
@test NaNMath.findmin(exp,x) === (exp(-1), CartesianIndex(2, 2))
141+
@test NaNMath.findmax(exp,x) === (exp(3.0), CartesianIndex(1, 2))
142+
end
143+
144+
@testset "argmin/argmax" begin
145+
if VERSION v"1.7"
146+
xvals = [
147+
[1., 2., 4., 3., 1.],
148+
(1., 2., 4., 3., .1),
149+
(1f0, 2f0, 3f0, -1f0),
150+
(x=1.0, y=3f0, z=-4.0, w=-2f0),
151+
Dict(:a => 1.0, :b => 1.0, :d => 3.0, :c => 2.0),
152+
]
153+
@testset for x in xvals
154+
@test NaNMath.argmin(x) === argmin(x)
155+
@test NaNMath.argmax(x) === argmax(x)
156+
x isa Dict || @test NaNMath.argmin(identity, x) === argmin(identity, x)
157+
x isa Dict || @test NaNMath.argmax(identity, x) === argmax(identity, x)
158+
x isa Dict || @test NaNMath.argmin(sin, x) === argmin(sin, x)
159+
x isa Dict || @test NaNMath.argmax(sin, x) === argmax(sin, x)
160+
end
161+
end
162+
x = [7, 7, NaN, 1, 1, NaN]
163+
@test NaNMath.argmin(x) === 4
164+
@test NaNMath.argmax(x) === 1
165+
@test NaNMath.argmin(identity, x) === 1.0
166+
@test NaNMath.argmax(identity, x) === 7.0
167+
@test NaNMath.argmin(-, x) === 7.0
168+
@test NaNMath.argmax(-, x) === 1.0
169+
170+
x = [NaN, NaN]
171+
@test NaNMath.argmin(x) === 1
172+
@test NaNMath.argmax(x) === 1
173+
@test NaNMath.argmin(identity, x) === NaN
174+
@test NaNMath.argmax(identity, x) === NaN
175+
@test NaNMath.argmin(-, x) === NaN
176+
@test NaNMath.argmax(-, x) === NaN
177+
178+
x = [3, missing, NaN, -1]
179+
@test_throws ErrorException NaNMath.argmin(x)
180+
@test_throws ErrorException NaNMath.argmax(x)
181+
182+
x = Dict('a' => 1.0, missing => NaN, 'c' => 2.0)
183+
@test NaNMath.argmin(x) === 'a'
184+
@test NaNMath.argmax(x) === 'c'
185+
186+
x = Dict(:v => NaN, :w => 2.1f0, :x => 3.1, :z => -1.0, :y => NaN)
187+
@test NaNMath.argmin(x) === :z
188+
@test NaNMath.argmax(x) === :x
189+
@test NaNMath.argmin(-, x) === 3.1
190+
@test NaNMath.argmax(-, x) === -1.0
191+
@test NaNMath.argmin(exp, x) === -1.0
192+
@test NaNMath.argmax(exp, x) === 3.1
193+
194+
x = (x=1.1, y=NaN, z=NaN, w=-2.3)
195+
@test NaNMath.argmin(x) === :w
196+
@test NaNMath.argmax(x) === :x
197+
@test NaNMath.argmin(exp, x) === -2.3
198+
@test NaNMath.argmax(exp, x) === 1.1
199+
200+
x = [2.0 3.0; 2.0 -1.0]
201+
@test NaNMath.argmin(x) === CartesianIndex(2, 2)
202+
@test NaNMath.argmax(x) === CartesianIndex(1, 2)
203+
@test NaNMath.argmin(exp,x) === -1.0
204+
@test NaNMath.argmax(exp,x) === 3.0
205+
end

0 commit comments

Comments
 (0)