Skip to content

Commit 425c53d

Browse files
mcognettasimonbyrne
authored andcommitted
adding fast max and min (#31866)
1 parent 8cc2f12 commit 425c53d

File tree

2 files changed

+21
-9
lines changed

2 files changed

+21
-9
lines changed

base/fastmath.jl

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,15 @@ ComplexTypes = Union{ComplexF32, ComplexF64}
233233
(a==real(y)) & (T(0)==imag(y))
234234

235235
ne_fast(x::T, y::T) where {T<:ComplexTypes} = !(x==y)
236+
237+
# Note: we use the same comparison for min, max, and minmax, so
238+
# that the compiler can convert between them
239+
max_fast(x::T, y::T) where {T<:FloatTypes} = ifelse(y > x, y, x)
240+
min_fast(x::T, y::T) where {T<:FloatTypes} = ifelse(y > x, x, y)
241+
minmax_fast(x::T, y::T) where {T<:FloatTypes} = ifelse(y > x, (x,y), (y,x))
242+
243+
max_fast(x::T, y::T, z::T...) where {T<:FloatTypes} = max_fast(max_fast(x, y), z...)
244+
min_fast(x::T, y::T, z::T...) where {T<:FloatTypes} = min_fast(min_fast(x, y), z...)
236245
end
237246

238247
# fall-back implementations and type promotion
@@ -245,7 +254,7 @@ for op in (:abs, :abs2, :conj, :inv, :sign)
245254
end
246255
end
247256

248-
for op in (:+, :-, :*, :/, :(==), :!=, :<, :<=, :cmp, :rem)
257+
for op in (:+, :-, :*, :/, :(==), :!=, :<, :<=, :cmp, :rem, :min, :max, :minmax)
249258
op_fast = fast_op[op]
250259
@eval begin
251260
# fall-back implementation for non-numeric types
@@ -304,12 +313,6 @@ sincos_fast(v) = (sin_fast(v), cos_fast(v))
304313
@fastmath begin
305314
hypot_fast(x::T, y::T) where {T<:FloatTypes} = sqrt(x*x + y*y)
306315

307-
# Note: we use the same comparison for min, max, and minmax, so
308-
# that the compiler can convert between them
309-
max_fast(x::T, y::T) where {T<:FloatTypes} = ifelse(y > x, y, x)
310-
min_fast(x::T, y::T) where {T<:FloatTypes} = ifelse(y > x, x, y)
311-
minmax_fast(x::T, y::T) where {T<:FloatTypes} = ifelse(y > x, (x,y), (y,x))
312-
313316
# complex numbers
314317

315318
function cis_fast(x::T) where {T<:FloatTypes}
@@ -362,7 +365,7 @@ for f in (:acos, :acosh, :angle, :asin, :asinh, :atan, :atanh, :cbrt,
362365
end
363366
end
364367

365-
for f in (:^, :atan, :hypot, :max, :min, :minmax, :log)
368+
for f in (:^, :atan, :hypot, :log)
366369
f_fast = fast_op[f]
367370
@eval begin
368371
# fall-back implementation for non-numeric types

test/fastmath.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,8 @@ end
109109
for f in (:+, :-, :abs, :abs2, :conj, :inv, :sign,
110110
:acos, :asin, :asinh, :atan, :atanh, :cbrt, :cos, :cosh,
111111
:exp10, :exp2, :exp, :log10, :log1p,
112-
:log2, :log, :sin, :sinh, :sqrt, :tan, :tanh)
112+
:log2, :log, :sin, :sinh, :sqrt, :tan, :tanh,
113+
:min, :max)
113114
@eval begin
114115
@test @fastmath($f($half)) $f($half)
115116
@test @fastmath($f($third)) $f($third)
@@ -142,6 +143,14 @@ end
142143
@test @fastmath($f($third, $half)) $f($third, $half)
143144
end
144145
end
146+
147+
# issue 31795
148+
for f in (:min, :max)
149+
@eval begin
150+
@test @fastmath($f($half, $third, 1+$half)) $f($half, $third, 1+$half)
151+
end
152+
end
153+
145154
for f in (:minmax,)
146155
@eval begin
147156
@test @fastmath($f($half, $third)[1]) $f($half, $third)[1]

0 commit comments

Comments
 (0)