Skip to content

Commit 7e44d41

Browse files
Merge pull request #62 from jd-foster/jdf/alt_findminmax
Implement `findmin`, `findmax`, `argmin`, `argmax`
2 parents 1fde86f + 7a1f4d7 commit 7e44d41

File tree

2 files changed

+349
-0
lines changed

2 files changed

+349
-0
lines changed

src/NaNMath.jl

Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,4 +369,220 @@ for f in (:min, :max)
369369
@eval ($f)(a, b, c, xs...) = Base.afoldl($f, ($f)(($f)(a, b), c), xs...)
370370
end
371371

372+
# The functions `findmin`, `findmax`, `argmin`, and `argmax` are supported
373+
# to work correctly for the following iterable types:
374+
_valtype(x::AbstractArray{T}) where T<:AbstractFloat = eltype(x)
375+
_valtype(x::Tuple{Vararg{T} where T<:AbstractFloat}) = eltype(x)
376+
_valtype(x::NamedTuple{syms, <:Tuple{Vararg{T} where T<:AbstractFloat}}) where {syms} = eltype(x)
377+
_valtype(x::AbstractDict{K,T}) where {K,T<:AbstractFloat} = valtype(x)
378+
_valtype(x) = error(
379+
"Iterables with value type AbstractFloat or its subtypes are supported.
380+
The provided input type $(typeof(x)) is not.
381+
Consider using the convert function before passing the iterable argument."
382+
383+
)
384+
385+
function _find_extreme(f,compare_op::Function, x)
386+
result_index = 1 # Note: default index value.
387+
result_value = convert(_valtype(x), NaN)
388+
389+
for (k, v) in pairs(x)
390+
if !isnan(v)
391+
if (isnan(result_value) || compare_op(f(v),f(result_value)))
392+
result_index = k
393+
result_value = v
394+
end
395+
end
396+
end
397+
return f(result_value), result_index
398+
end
399+
400+
"""
401+
NaNMath.findmin(f, domain) -> (f(x), index)
402+
403+
NaNMath.findmin(domain) -> (x, index)
404+
405+
##### Args:
406+
* `f`: A function applied to the elements of `domain`;
407+
defaults to `identity` when `domain` is the only argument.
408+
* `domain`: A non-empty collection of floating point numbers such that
409+
`f` is defined on elements of `domain`.
410+
411+
##### Returns:
412+
* Returns a `Tuple` consisting of a value `f(x)` and the index of `x`
413+
in `domain`, ignoring NaN's, such that `f(x)` is minimized.
414+
If there are multiple minimal elements, then the first one will be returned.
415+
416+
If `domain` is a `NamedTuple` or dictionary-like `AbstractDict` L,
417+
the function is applied to its values. The returned index is a key `k`,
418+
such that `f(L[k])` is minimized.
419+
420+
##### Examples:
421+
```julia
422+
julia> NaNMath.findmin([1., 1., 2., 2., NaN])
423+
(1.0, 1)
424+
425+
julia> NaNMath.findmin(-, [1., 1., 2., 2., NaN])
426+
(-2.0, 3)
427+
428+
julia> NaNMath.findmin(abs, Dict(:x => 3.0, :w => -2.2, :y => -3.0, :z => NaN))
429+
(2.2, :w)
430+
```
431+
"""
432+
function findmin end
433+
findmin(f,x) = _find_extreme(f,<,x)
434+
findmin(x) = findmin(identity,x)
435+
436+
"""
437+
NaNMath.findmax(f, domain) -> (f(x), index)
438+
439+
NaNMath.findmax(domain) -> (x, index)
440+
441+
##### Args:
442+
* `f`: A function applied to the elements of `domain`;
443+
defaults to `identity` when `domain` is the only argument.
444+
* `domain`: A non-empty collection of floating point numbers such that
445+
`f` is defined on elements of `domain`.
446+
447+
##### Returns:
448+
* Returns a `Tuple` consisting of a value `f(x)` and the index of `x`
449+
in `domain`, ignoring NaN's, such that `f(x)` is maximized.
450+
If there are multiple maximal elements, then the first one will be returned.
451+
452+
If `domain` is a `NamedTuple` or dictionary-like `AbstractDict` L,
453+
the function `f` is applied to its values. The returned index is a key `k`,
454+
such that `f(L[k])` is maximized.
455+
456+
##### Examples:
457+
```julia
458+
julia> NaNMath.findmax([1., 1., 2., 2., NaN])
459+
(2.0, 3)
460+
461+
julia> NaNMath.findmax(-, [1., 1., 2., 2., NaN])
462+
(-1.0, 1)
463+
464+
julia> NaNMath.findmax(abs, Dict(:x => 3.0, :w => -2.2, :y => -3.0, :z => NaN))
465+
(3.0, :y)
466+
```
467+
"""
468+
function findmax end
469+
findmax(f,x) = _find_extreme(f,>,x)
470+
findmax(x) = findmax(identity,x)
471+
472+
"""
473+
NaNMath.argmin(f, domain) -> x
474+
475+
##### Args:
476+
* `f`: A function applied to the elements of `domain`;
477+
defaults to `identity` when `domain` is the only argument.
478+
* `domain`: A non-empty collection of floating point numbers such that
479+
`f` is defined on elements of `domain`.
480+
481+
##### Returns:
482+
* Returns a value `x` in `domain`, ignoring NaN's, for which `f(x)` is minimized.
483+
If there are multiple minimal values for `f(x)`, then the first one will be returned.
484+
485+
If `domain` is a `NamedTuple` or dictionary-like `AbstractDict` L,
486+
the function is applied to its values. The returned value is `L[k]` for some key `k`
487+
such that `f(L[k])` is minimal.
488+
489+
##### Examples:
490+
```julia
491+
julia> NaNMath.argmin(abs, [1., -1., -2., 2., NaN])
492+
1.0
493+
494+
julia> NaNMath.argmin(identity, [7, 1, 1, NaN])
495+
1.0
496+
```
497+
498+
julia> NaNMath.argmin(exp,Dict("x" => 1.0, "y" => -1.2, "z" => NaN))
499+
-1.2
500+
501+
───────────────────────────────────────────────────────────
502+
503+
NaNMath.argmin(itr) -> key
504+
505+
##### Args:
506+
* `itr`: A non-empty iterable of floating point numbers.
507+
508+
##### Returns:
509+
* Returns the index or key of the minimal element in `itr`, ignoring NaN's.
510+
If there are multiple minimal elements, then the first one will be returned.
511+
512+
If `itr` is a `NamedTuple` or dictionary-like `AbstractDict` L, the returned index is a key `k`,
513+
such that `f(L[k])` is minimal.
514+
515+
##### Examples:
516+
```julia
517+
julia> NaNMath.argmin([7, 1, 1, NaN])
518+
2
519+
520+
julia> NaNMath.argmin([1.0 2; 3 NaN])
521+
CartesianIndex(1, 1)
522+
523+
julia> NaNMath.argmin(Dict("x" => 1.0, "y" => -1.2, "z" => NaN))
524+
"y"
525+
```
526+
"""
527+
function argmin end
528+
argmin(f,x) = getindex(x,findmin(f,x)[2])
529+
argmin(x) = findmin(identity,x)[2]
530+
531+
"""
532+
NaNMath.argmax(f, domain) -> x
533+
534+
##### Args:
535+
* `f`: A function applied to the elements of `domain`;
536+
defaults to `identity` when `domain` is the only argument.
537+
* `domain`: A non-empty collection of floating point numbers such that
538+
`f` is defined on elements of `domain`.
539+
540+
##### Returns:
541+
* Returns a value `x` in `domain`, ignoring NaN's, for which `f(x)` is maximized.
542+
If there are multiple maximal values for `f(x)`, then the first one will be returned.
543+
544+
If `domain` is a `NamedTuple` or dictionary-like `AbstractDict` L,
545+
the function is applied to its values. The returned value is `L[k]` for some key `k`
546+
such that `f(L[k])` is maximal.
547+
548+
##### Examples:
549+
```julia
550+
julia> NaNMath.argmax(abs, [1., -1., -2., NaN])
551+
-2.0
552+
553+
julia> NaNMath.argmax(identity, [7, 1, 1, NaN])
554+
7.0
555+
```
556+
557+
───────────────────────────────────────────────────────────
558+
559+
NaNMath.argmax(itr) -> key
560+
561+
##### Args:
562+
* `itr`: A non-empty iterable of floating point numbers.
563+
564+
##### Returns:
565+
* Returns the index or key of the maximal element in `itr`, ignoring NaN's.
566+
If there are multiple maximal elements, then the first one will be returned.
567+
568+
If `itr` is a `NamedTuple` or dictionary-like `AbstractDict` L, the returned index is a key `k`,
569+
such that `f(L[k])` is maximal.
570+
571+
572+
##### Examples:
573+
```julia
574+
julia> NaNMath.argmax([7, 1, 1, NaN])
575+
1
576+
577+
julia> NaNMath.argmax([1.0 2; 3 NaN])
578+
CartesianIndex(2, 1)
579+
580+
julia> NaNMath.argmax(Dict("x" => 1.0, "y" => -1.2, "z" => NaN))
581+
"x"
582+
```
583+
"""
584+
function argmax end
585+
argmax(x) = findmax(identity,x)[2]
586+
argmax(f,x) = getindex(x,findmax(f,x)[2])
587+
372588
end

test/runtests.jl

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,139 @@ using Test
8484
@test isnan(NaNMath.max(NaN))
8585
@test NaNMath.max(NaN, NaN, 0.0, 1.0) == 1.0
8686

87+
## Based on https://github.com/sethaxen/NaNMath.jl/blob/41b3e7edd9dd4cb6c2873abf6e0d90acf43138ec/test/runtests.jl
88+
@testset "findmin/findmax" begin
89+
if VERSION v"1.7"
90+
xvals = [
91+
[1., 2., 3., 3., 1.],
92+
(1., 2., 3., 3., .1),
93+
(1f0, 2f0, 3f0, -1f0),
94+
(x=1.0, y=3f0, z=-4.0, w=-2f0),
95+
Dict(:a => 1.0, :b => 1.0, :d => 3.0, :c => 2.0),
96+
]
97+
@testset for x in xvals
98+
@test NaNMath.findmin(x) === findmin(x)
99+
@test NaNMath.findmax(x) === findmax(x)
100+
@test NaNMath.findmin(identity, x) === findmin(identity, x)
101+
@test NaNMath.findmax(identity, x) === findmax(identity, x)
102+
@test NaNMath.findmin(sin, x) === findmin(sin, x)
103+
@test NaNMath.findmax(sin, x) === findmax(sin, x)
104+
end
105+
end
106+
107+
x = [7, 7, NaN, 1, 1, NaN]
108+
@test NaNMath.findmin(x) === (1.0, 4)
109+
@test NaNMath.findmax(x) === (7.0, 1)
110+
@test NaNMath.findmin(identity, x) === (1.0, 4)
111+
@test NaNMath.findmax(identity, x) === (7.0, 1)
112+
@test NaNMath.findmin(-, x) === (-7.0, 1)
113+
@test NaNMath.findmax(-, x) === (-1.0, 4)
114+
115+
x = [NaN, NaN]
116+
@test NaNMath.findmin(x) === (NaN, 1)
117+
@test NaNMath.findmax(x) === (NaN, 1)
118+
@test NaNMath.findmin(identity, x) === (NaN, 1)
119+
@test NaNMath.findmax(identity, x) === (NaN, 1)
120+
@test NaNMath.findmin(sin, x) === (NaN, 1)
121+
@test NaNMath.findmax(sin, x) === (NaN, 1)
122+
123+
x = Dict(:a => 1.0, :b => 1 + 2im, :d => 3.0, :c => 2.0)
124+
@test_throws ErrorException NaNMath.findmin(x)
125+
@test_throws ErrorException NaNMath.findmax(x)
126+
127+
x = [3, missing, NaN, -1]
128+
@test_throws ErrorException NaNMath.findmin(x)
129+
130+
x = Dict('a' => 1.0, missing => NaN, 'c' => 2.0)
131+
@test NaNMath.findmin(x) === (1.0, 'a')
132+
@test NaNMath.findmax(x) === (2.0, 'c')
133+
134+
x = Dict(:x => 3.0, :w => 2f0, :y => -1.0, :z => NaN)
135+
@test NaNMath.findmin(x) === (-1.0, :y)
136+
@test NaNMath.findmax(x) === (3.0, :x)
137+
@test NaNMath.findmin(identity, x) === (-1.0, :y)
138+
@test NaNMath.findmax(identity, x) === (3.0, :x)
139+
@test NaNMath.findmin(-, x) === (-3.0, :x)
140+
@test NaNMath.findmax(-, x) === (1.0, :y)
141+
@test NaNMath.findmin(exp, x) === (exp(-1.0), :y)
142+
@test NaNMath.findmax(exp, x) === (exp(3.0), :x)
143+
144+
x = (x=1.0, y=NaN, z=NaN, w=-2.0)
145+
@test NaNMath.findmin(x) === (-2.0, :w)
146+
@test NaNMath.findmax(x) === (1.0, :x)
147+
@test NaNMath.findmin(-,x) === (-1.0, :x)
148+
@test NaNMath.findmax(-,x) === (2.0, :w)
149+
150+
x = [2.0 3.0; 2.0 -1.0]
151+
@test NaNMath.findmin(x) === (-1.0, CartesianIndex(2, 2))
152+
@test NaNMath.findmax(x) === (3.0, CartesianIndex(1, 2))
153+
@test NaNMath.findmin(exp,x) === (exp(-1), CartesianIndex(2, 2))
154+
@test NaNMath.findmax(exp,x) === (exp(3.0), CartesianIndex(1, 2))
155+
end
156+
157+
@testset "argmin/argmax" begin
158+
if VERSION v"1.7"
159+
xvals = [
160+
[1., 2., 4., 3., 1.],
161+
(1., 2., 4., 3., .1),
162+
(1f0, 2f0, 3f0, -1f0),
163+
(x=1.0, y=3f0, z=-4.0, w=-2f0),
164+
Dict(:a => 1.0, :b => 1.0, :d => 3.0, :c => 2.0),
165+
]
166+
@testset for x in xvals
167+
@test NaNMath.argmin(x) === argmin(x)
168+
@test NaNMath.argmax(x) === argmax(x)
169+
x isa Dict || @test NaNMath.argmin(identity, x) === argmin(identity, x)
170+
x isa Dict || @test NaNMath.argmax(identity, x) === argmax(identity, x)
171+
x isa Dict || @test NaNMath.argmin(sin, x) === argmin(sin, x)
172+
x isa Dict || @test NaNMath.argmax(sin, x) === argmax(sin, x)
173+
end
174+
end
175+
x = [7, 7, NaN, 1, 1, NaN]
176+
@test NaNMath.argmin(x) === 4
177+
@test NaNMath.argmax(x) === 1
178+
@test NaNMath.argmin(identity, x) === 1.0
179+
@test NaNMath.argmax(identity, x) === 7.0
180+
@test NaNMath.argmin(-, x) === 7.0
181+
@test NaNMath.argmax(-, x) === 1.0
182+
183+
x = [NaN, NaN]
184+
@test NaNMath.argmin(x) === 1
185+
@test NaNMath.argmax(x) === 1
186+
@test NaNMath.argmin(identity, x) === NaN
187+
@test NaNMath.argmax(identity, x) === NaN
188+
@test NaNMath.argmin(-, x) === NaN
189+
@test NaNMath.argmax(-, x) === NaN
190+
191+
x = [3, missing, NaN, -1]
192+
@test_throws ErrorException NaNMath.argmin(x)
193+
@test_throws ErrorException NaNMath.argmax(x)
194+
195+
x = Dict('a' => 1.0, missing => NaN, 'c' => 2.0)
196+
@test NaNMath.argmin(x) === 'a'
197+
@test NaNMath.argmax(x) === 'c'
198+
199+
x = Dict(:v => NaN, :w => 2.1f0, :x => 3.1, :z => -1.0, :y => NaN)
200+
@test NaNMath.argmin(x) === :z
201+
@test NaNMath.argmax(x) === :x
202+
@test NaNMath.argmin(-, x) === 3.1
203+
@test NaNMath.argmax(-, x) === -1.0
204+
@test NaNMath.argmin(exp, x) === -1.0
205+
@test NaNMath.argmax(exp, x) === 3.1
206+
207+
x = (x=1.1, y=NaN, z=NaN, w=-2.3)
208+
@test NaNMath.argmin(x) === :w
209+
@test NaNMath.argmax(x) === :x
210+
@test NaNMath.argmin(exp, x) === -2.3
211+
@test NaNMath.argmax(exp, x) === 1.1
212+
213+
x = [2.0 3.0; 2.0 -1.0]
214+
@test NaNMath.argmin(x) === CartesianIndex(2, 2)
215+
@test NaNMath.argmax(x) === CartesianIndex(1, 2)
216+
@test NaNMath.argmin(exp,x) === -1.0
217+
@test NaNMath.argmax(exp,x) === 3.0
218+
end
219+
87220
# Test forwarding
88221
x = 1 + 2im
89222
for f in (:sin, :cos, :tan, :asin, :acos, :acosh, :atanh, :log, :log2, :log10,

0 commit comments

Comments
 (0)