Skip to content

Commit 2bffb5f

Browse files
Simplify promote_operation_fallback (#335)
* Fixes * test: mark test as no longer broken * fix: implement `promote_operation_fallback` for `real` and `imag` * fix: make promotion change non-breaking Co-authored-by: Benoît Legat <[email protected]> * test: update allocation tests * fix: fix `promote_operation` for `Integer / Integer` * test: test `promote_operation` of `real` and `imag` * test: test new non-concrete `promote_operation` methods --------- Co-authored-by: Benoît Legat <[email protected]>
1 parent e470c6b commit 2bffb5f

File tree

5 files changed

+53
-6
lines changed

5 files changed

+53
-6
lines changed

src/interface.jl

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,22 +40,46 @@ function promote_operation_fallback(
4040
::Type{S},
4141
::Type{T},
4242
) where {S,T}
43-
return typeof(op(_instantiate_zero(S), _instantiate_oneunit(T)))
43+
if isconcretetype(S) && isconcretetype(T)
44+
return typeof(op(_instantiate_zero(S), _instantiate_oneunit(T)))
45+
else
46+
return promote_type(S, T)
47+
end
48+
end
49+
50+
function promote_operation_fallback(
51+
op::typeof(/),
52+
::Type{S},
53+
::Type{T},
54+
) where {S<:Integer,T<:Integer}
55+
if isconcretetype(S) && isconcretetype(T)
56+
return typeof(op(_instantiate_zero(S), _instantiate_oneunit(T)))
57+
else
58+
return promote_type(float(S), float(T))
59+
end
4460
end
4561

4662
function promote_operation_fallback(
4763
op::F,
4864
::Type{S},
4965
::Type{T},
5066
) where {F<:Function,S,T}
51-
return typeof(op(_instantiate_zero(S), _instantiate_zero(T)))
67+
if isconcretetype(S) && isconcretetype(T)
68+
return typeof(op(_instantiate_zero(S), _instantiate_zero(T)))
69+
else
70+
return promote_type(S, T)
71+
end
5272
end
5373

5474
function promote_operation_fallback(
5575
op::F,
5676
args::Vararg{Type,N},
5777
) where {F<:Function,N}
58-
return typeof(op(_instantiate_zero.(args)...))
78+
if all(isconcretetype, args)
79+
return typeof(op(_instantiate_zero.(args)...))
80+
else
81+
return promote_type(args...)
82+
end
5983
end
6084

6185
promote_operation_fallback(::typeof(*), ::Type{T}) where {T} = T
@@ -103,6 +127,13 @@ function promote_operation_fallback(
103127
)
104128
end
105129

130+
function promote_operation(
131+
::Union{typeof(real),typeof(imag)},
132+
::Type{Complex{T}},
133+
) where {T}
134+
return T
135+
end
136+
106137
"""
107138
promote_operation(op::Function, ArgsTypes::Type...)
108139

test/int.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,18 @@ import MutableArithmetics as MA
2828
)
2929
@test_throws err MA.promote_operation(op, Int, Vector{Int})
3030
end
31+
for op in [+, -, *, /, div]
32+
@test MA.promote_operation(op, Int, Number) == Number
33+
@test MA.promote_operation(op, Number, Int) == Number
34+
end
35+
@test MA.promote_operation(/, Int, Integer) == Float64
36+
@test MA.promote_operation(/, Integer, Integer) == Float64
37+
@test MA.promote_operation(/, Integer, Int) == Float64
38+
@test MA.promote_operation(gcd, Int, Integer) == Integer
39+
@test MA.promote_operation(gcd, Integer, Integer) == Integer
40+
@test MA.promote_operation(gcd, Integer, Int) == Integer
41+
@test MA.promote_operation(&, Integer, Integer, Integer) == Integer
42+
@test MA.promote_operation(&, Integer, Integer, Int) == Integer
3143
end
3244
@testset "add_to!! / add!!" begin
3345
@test MA.mutability(Int, MA.add_to!!, Int, Int) isa MA.IsNotMutable

test/interface.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,11 @@ Base.@irrational theodorus 1.73205080756887729353 sqrt(big(3))
5757
MathConstants.catalan
5858
@test MA._instantiate(typeof(theodorus)) == theodorus
5959
end
60+
61+
for op in [real, imag]
62+
@test MA.promote_operation(op, ComplexF64) == Float64
63+
@test MA.promote_operation(op, Complex{Real}) == Real
64+
end
6065
end
6166

6267
@testset "Errors" begin

test/rewrite_generic.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,7 @@ function test_rewrite_nonconcrete_vector()
8989
y = Vector{Union{Float64,String}}(x)
9090
@test MA.@rewrite(x' * y, move_factors_into_sums = false) == x' * y
9191
@test MA.@rewrite(x .+ y, move_factors_into_sums = false) == x .+ y
92-
# Reproducing buggy behavior in MA.@rewrite.
93-
@test_broken MA.@rewrite(x + y, move_factors_into_sums = false) == x + x
92+
@test MA.@rewrite(x + y, move_factors_into_sums = false) == x + x
9493
return
9594
end
9695

test/utilities.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
include("dummy.jl")
88

99
# Allocating size for allocating a `BigInt`. Half size on 32-bit.
10-
const BIGINT_ALLOC = @static if VERSION >= v"1.12"
10+
const BIGINT_ALLOC = @static if VERSION >= v"1.12-beta1"
1111
Sys.WORD_SIZE == 64 ? 72 : 36
1212
elseif VERSION >= v"1.11"
1313
Sys.WORD_SIZE == 64 ? 56 : 28

0 commit comments

Comments
 (0)