Skip to content

Commit 31259ee

Browse files
committed
Add proper promotion rules for Dimensions and SymbolicDimensions
1 parent 5a9bdd5 commit 31259ee

File tree

3 files changed

+63
-15
lines changed

3 files changed

+63
-15
lines changed

src/math.jl

Lines changed: 46 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,52 @@
11
for (type, base_type, _) in ABSTRACT_QUANTITY_TYPES
22
@eval begin
3-
Base.:*(l::$type, r::$type) = new_quantity(typeof(l), ustrip(l) * ustrip(r), dimension(l) * dimension(r))
4-
Base.:/(l::$type, r::$type) = new_quantity(typeof(l), ustrip(l) / ustrip(r), dimension(l) / dimension(r))
5-
Base.div(x::$type, y::$type, r::RoundingMode=RoundToZero) = new_quantity(typeof(x), div(ustrip(x), ustrip(y), r), dimension(x) / dimension(y))
3+
function Base.:*(l::$type, r::$type)
4+
l, r = promote(l, r)
5+
new_quantity(typeof(l), ustrip(l) * ustrip(r), dimension(l) * dimension(r))
6+
end
7+
function Base.:/(l::$type, r::$type)
8+
l, r = promote(l, r)
9+
new_quantity(typeof(l), ustrip(l) / ustrip(r), dimension(l) / dimension(r))
10+
end
11+
function Base.div(x::$type, y::$type, r::RoundingMode=RoundToZero)
12+
x, y = promote(x, y)
13+
new_quantity(typeof(x), div(ustrip(x), ustrip(y), r), dimension(x) / dimension(y))
14+
end
615

7-
Base.:*(l::$type, r::$base_type) = new_quantity(typeof(l), ustrip(l) * r, dimension(l))
8-
Base.:/(l::$type, r::$base_type) = new_quantity(typeof(l), ustrip(l) / r, dimension(l))
9-
Base.div(x::$type, y::$base_type, r::RoundingMode=RoundToZero) = new_quantity(typeof(x), div(ustrip(x), y, r), dimension(x))
16+
# The rest of the functions are unchanged because they do not operate on two variables of the custom type
17+
function Base.:*(l::$type, r::$base_type)
18+
new_quantity(typeof(l), ustrip(l) * r, dimension(l))
19+
end
20+
function Base.:/(l::$type, r::$base_type)
21+
new_quantity(typeof(l), ustrip(l) / r, dimension(l))
22+
end
23+
function Base.div(x::$type, y::$base_type, r::RoundingMode=RoundToZero)
24+
new_quantity(typeof(x), div(ustrip(x), y, r), dimension(x))
25+
end
1026

11-
Base.:*(l::$base_type, r::$type) = new_quantity(typeof(r), l * ustrip(r), dimension(r))
12-
Base.:/(l::$base_type, r::$type) = new_quantity(typeof(r), l / ustrip(r), inv(dimension(r)))
13-
Base.div(x::$base_type, y::$type, r::RoundingMode=RoundToZero) = new_quantity(typeof(y), div(x, ustrip(y), r), inv(dimension(y)))
27+
function Base.:*(l::$base_type, r::$type)
28+
new_quantity(typeof(r), l * ustrip(r), dimension(r))
29+
end
30+
function Base.:/(l::$base_type, r::$type)
31+
new_quantity(typeof(r), l / ustrip(r), inv(dimension(r)))
32+
end
33+
function Base.div(x::$base_type, y::$type, r::RoundingMode=RoundToZero)
34+
new_quantity(typeof(y), div(x, ustrip(y), r), inv(dimension(y)))
35+
end
1436

15-
Base.:*(l::$type, r::AbstractDimensions) = new_quantity(typeof(l), ustrip(l), dimension(l) * r)
16-
Base.:/(l::$type, r::AbstractDimensions) = new_quantity(typeof(l), ustrip(l), dimension(l) / r)
37+
function Base.:*(l::$type, r::AbstractDimensions)
38+
new_quantity(typeof(l), ustrip(l), dimension(l) * r)
39+
end
40+
function Base.:/(l::$type, r::AbstractDimensions)
41+
new_quantity(typeof(l), ustrip(l), dimension(l) / r)
42+
end
1743

18-
Base.:*(l::AbstractDimensions, r::$type) = new_quantity(typeof(r), ustrip(r), l * dimension(r))
19-
Base.:/(l::AbstractDimensions, r::$type) = new_quantity(typeof(r), inv(ustrip(r)), l / dimension(r))
44+
function Base.:*(l::AbstractDimensions, r::$type)
45+
new_quantity(typeof(r), ustrip(r), l * dimension(r))
46+
end
47+
function Base.:/(l::AbstractDimensions, r::$type)
48+
new_quantity(typeof(r), inv(ustrip(r)), l / dimension(r))
49+
end
2050
end
2151
end
2252

@@ -27,6 +57,7 @@ Base.:/(l::AbstractDimensions, r::AbstractDimensions) = map_dimensions(-, l, r)
2757
for (type, base_type, _) in ABSTRACT_QUANTITY_TYPES, op in (:+, :-)
2858
@eval begin
2959
function Base.$op(l::$type, r::$type)
60+
l, r = promote(l, r)
3061
dimension(l) == dimension(r) || throw(DimensionError(l, r))
3162
return new_quantity(typeof(l), $op(ustrip(l), ustrip(r)), dimension(l))
3263
end
@@ -125,6 +156,7 @@ for (type, base_type, _) in ABSTRACT_QUANTITY_TYPES, f in (:atan, :atand)
125156
return $f(ustrip(x))
126157
end
127158
function Base.$f(y::$type, x::$type)
159+
y, x = promote(y, x)
128160
dimension(y) == dimension(x) || throw(DimensionError(y, x))
129161
return $f(ustrip(y), ustrip(x))
130162
end
@@ -154,6 +186,7 @@ for (type, base_type, _) in ABSTRACT_QUANTITY_TYPES, f in (:copysign, :flipsign,
154186
# and ignore any dimensions on y, since those will cancel out.
155187
@eval begin
156188
function Base.$f(x::$type, y::$type)
189+
x, y = promote(x, y)
157190
return new_quantity(typeof(x), $f(ustrip(x), ustrip(y)), dimension(x))
158191
end
159192
function Base.$f(x::$type, y::$base_type)

src/utils.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ Base.keys(q::UnionAbstractQuantity) = keys(ustrip(q))
7777

7878
# Numeric checks
7979
function Base.isapprox(l::UnionAbstractQuantity, r::UnionAbstractQuantity; kws...)
80+
l, r = promote(l, r)
8081
return isapprox(ustrip(l), ustrip(r); kws...) && dimension(l) == dimension(r)
8182
end
8283
function Base.isapprox(l::Number, r::UnionAbstractQuantity; kws...)
@@ -88,11 +89,15 @@ function Base.isapprox(l::UnionAbstractQuantity, r::Number; kws...)
8889
return isapprox(ustrip(l), r; kws...)
8990
end
9091
Base.iszero(d::AbstractDimensions) = all_dimensions(iszero, d)
91-
Base.:(==)(l::AbstractDimensions, r::AbstractDimensions) = all_dimensions(==, l, r)
92-
Base.:(==)(l::UnionAbstractQuantity, r::UnionAbstractQuantity) = ustrip(l) == ustrip(r) && dimension(l) == dimension(r)
92+
function Base.:(==)(l::UnionAbstractQuantity, r::UnionAbstractQuantity)
93+
l, r = promote(l, r)
94+
ustrip(l) == ustrip(r) && dimension(l) == dimension(r)
95+
end
9396
Base.:(==)(l::Number, r::UnionAbstractQuantity) = ustrip(l) == ustrip(r) && iszero(dimension(r))
9497
Base.:(==)(l::UnionAbstractQuantity, r::Number) = ustrip(l) == ustrip(r) && iszero(dimension(l))
98+
Base.:(==)(l::AbstractDimensions, r::AbstractDimensions) = all_dimensions(==, l, r)
9599
function Base.isless(l::UnionAbstractQuantity, r::UnionAbstractQuantity)
100+
l, r = promote(l, r)
96101
dimension(l) == dimension(r) || throw(DimensionError(l, r))
97102
return isless(ustrip(l), ustrip(r))
98103
end

test/unittests.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -633,6 +633,16 @@ end
633633
qa = [x, y]
634634
@test qa isa Vector{Quantity{Float64,SymbolicDimensions{Rational{Int}}}}
635635
DynamicQuantities.with_type_parameters(SymbolicDimensions{Float64}, Rational{Int}) == SymbolicDimensions{Rational{Int}}
636+
637+
@testset "Promotion with Dimensions" begin
638+
x = 0.5u"cm"
639+
y = -0.03u"m"
640+
x_s = 0.5us"cm"
641+
for op in (+, -, *, /, atan, atand, copysign, flipsign, mod)
642+
@test op(x, y) == op(x_s, y)
643+
@test op(y, x) == op(y, x_s)
644+
end
645+
end
636646
end
637647

638648
@testset "uconvert" begin

0 commit comments

Comments
 (0)