Skip to content

Commit 9aa8703

Browse files
authored
Fix ustrip broadcasting when range step has different unit than eltype (#715)
1 parent 6bfc193 commit 9aa8703

File tree

2 files changed

+90
-7
lines changed

2 files changed

+90
-7
lines changed

src/range.jl

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,14 +154,54 @@ broadcasted(::DefaultArrayStyle{1}, ::typeof(*), r::AbstractRange, x::AbstractQu
154154
broadcasted(::DefaultArrayStyle{1}, ::typeof(*), x::AbstractQuantity, r::AbstractRange) =
155155
broadcasted(DefaultArrayStyle{1}(), *, ustrip(x), r) * unit(x)
156156

157-
const BCAST_PROPAGATE_CALLS = Union{typeof(upreferred), typeof(ustrip), Units}
157+
const BCAST_PROPAGATE_CALLS = Union{typeof(upreferred), Units}
158158
broadcasted(::DefaultArrayStyle{1}, ::typeof(*), r::AbstractRange, x::Ref{<:Units}) = r * x[]
159159
broadcasted(::DefaultArrayStyle{1}, ::typeof(*), x::Ref{<:Units}, r::AbstractRange) = x[] * r
160160
broadcasted(::DefaultArrayStyle{1}, x::BCAST_PROPAGATE_CALLS, r::StepRangeLen) = StepRangeLen{typeof(x(zero(eltype(r))))}(x(r.ref), x(r.step), r.len, r.offset)
161-
broadcasted(::DefaultArrayStyle{1}, x::BCAST_PROPAGATE_CALLS, r::StepRange) = StepRange(x(r.start), x(r.step), x(r.stop))
161+
function broadcasted(::DefaultArrayStyle{1}, x::BCAST_PROPAGATE_CALLS, r::StepRange)
162+
start = x(r.start)
163+
au_to = absoluteunit(unit(start))
164+
step = uconvert(au_to, r.step)
165+
if Base.ArithmeticStyle(start) == Base.ArithmeticRounds() || Base.ArithmeticStyle(step) == Base.ArithmeticRounds()
166+
au_from = absoluteunit(unit(r.start))
167+
astart = ustrip(au_from, r.start)
168+
astop = ustrip(au_from, r.stop)
169+
len = length(r)
170+
offset = _offset_for_steprangelen(astart, astop, len)
171+
nb = ndigits(max(offset-1, len-offset), base=2, pad=0)
172+
T = promote_type(typeof(start/unit(start)), typeof(step/unit(step)))
173+
unitless_range = Base.steprangelen_hp(T, ustrip(au_to, r[offset]), ustrip(au_to, step), nb, len, offset)
174+
return unitless_range * unit(start)
175+
else
176+
return StepRange(start, step, x(r.stop))
177+
end
178+
end
162179
broadcasted(::DefaultArrayStyle{1}, x::BCAST_PROPAGATE_CALLS, r::LinRange) = LinRange(x(r.start), x(r.stop), r.len)
163180
broadcasted(::DefaultArrayStyle{1}, ::typeof(|>), r::AbstractRange, x::Ref{<:BCAST_PROPAGATE_CALLS}) = broadcasted(DefaultArrayStyle{1}(), x[], r)
164181

182+
function _offset_for_steprangelen(start, stop, len)
183+
if iszero(start)
184+
return oneunit(len)
185+
elseif iszero(stop)
186+
return len
187+
elseif signbit(start) == signbit(stop)
188+
return abs(start) < abs(stop) ? oneunit(len) : len
189+
else
190+
fstart = Float64(start)
191+
fstop = Float64(stop)
192+
return round(typeof(len), (fstop-len*fstart)/(fstop-fstart))
193+
end
194+
end
195+
196+
broadcasted(::DefaultArrayStyle{1}, ::typeof(ustrip), r::StepRangeLen) =
197+
StepRangeLen{typeof(ustrip(zero(eltype(r))))}(ustrip(unit(eltype(r)), r.ref), ustrip(unit(eltype(r)), r.step), r.len, r.offset)
198+
broadcasted(::DefaultArrayStyle{1}, ::typeof(ustrip), r::StepRange) =
199+
ustrip(unit(eltype(r)), r.start):ustrip(unit(eltype(r)), r.step):ustrip(unit(eltype(r)), r.stop)
200+
broadcasted(::DefaultArrayStyle{1}, ::typeof(ustrip), r::LinRange) =
201+
LinRange(ustrip(unit(eltype(r)), r.start), ustrip(unit(eltype(r)), r.stop), r.len)
202+
broadcasted(::DefaultArrayStyle{1}, ::typeof(|>), r::AbstractRange, ::Ref{typeof(ustrip)}) =
203+
broadcasted(DefaultArrayStyle{1}(), ustrip, r)
204+
165205
# for ambiguity resolution
166206
broadcasted(::DefaultArrayStyle{1}, ::typeof(*), r::StepRangeLen{T}, x::AbstractQuantity) where T =
167207
broadcasted(DefaultArrayStyle{1}(), *, r, ustrip(x)) * unit(x)

test/runtests.jl

Lines changed: 48 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ import Unitful:
1111
Ra, °F, °C, K,
1212
rad, mrad, °,
1313
ms, s, minute, hr, d, yr, Hz,
14-
J, A, N, mol, V,
14+
J, A, N, mol, V, mJ, eV,
1515
mW, W,
1616
dB, dB_rp, dB_p, dBm, dBV, dBSPL, Decibel,
1717
Np, Np_rp, Np_p, Neper,
@@ -1325,14 +1325,61 @@ end
13251325

13261326
@test @inferred((1:2:5) .* cm .|> mm) === 10mm:20mm:50mm
13271327
@test mm.((1:2:5) .* cm) === 10mm:20mm:50mm
1328+
@test @inferred(StepRange(1cm,1mm,2cm) .|> km) === (1//100_000)km:(1//1_000_000)km:(2//100_000)km
1329+
13281330
@test @inferred((1:2:5) .* km .|> upreferred) === 1000m:2000m:5000m
13291331
@test @inferred((1:2:5)km .|> upreferred) === 1000m:2000m:5000m
13301332
@test @inferred((1:2:5) .|> upreferred) === 1:2:5
13311333
@test @inferred((1.0:2.0:5.0) .* km .|> upreferred) === 1000.0m:2000.0m:5000.0m
13321334
@test @inferred((1.0:2.0:5.0)km .|> upreferred) === 1000.0m:2000.0m:5000.0m
13331335
@test @inferred((1.0:2.0:5.0) .|> upreferred) === 1.0:2.0:5.0
1336+
@test @inferred(StepRange(1cm,1mm,2cm) .|> upreferred) === (1//100)m:(1//1000)m:(2//100)m
1337+
1338+
# float conversion, dimensionful
1339+
for r = [1eV:1eV:5eV, 1eV:1eV:5_000_000eV, 5_000_000eV:-1eV:-1eV, -123_456_789eV:2eV:987_654_321eV, (-11//12)eV:(1//3)eV:(11//4)eV]
1340+
for f = (mJ, upreferred)
1341+
rf = @inferred(r .|> f)
1342+
test_indices = length(r) 10_000 ? eachindex(r) : rand(eachindex(r), 10_000)
1343+
@test eltype(rf) === typeof(f(zero(eltype(r))))
1344+
@test all((rf[i], f(r[i]); rtol=eps()) for i = test_indices)
1345+
end
1346+
end
1347+
1348+
# float conversion from unitless
1349+
r = 1:1:360
1350+
rf = °.(r)
1351+
@test all((rf[i], °(r[i]); rtol=eps()) for i = eachindex(r))
1352+
1353+
# float conversion to unitless
1354+
r = (1:1:360
1355+
for f = (mrad, NoUnits, upreferred)
1356+
rf = f.(r)
1357+
@test eltype(rf) === typeof(f(zero(eltype(r))))
1358+
@test all((rf[i], f(r[i]); rtol=eps()) for i = eachindex(r))
1359+
end
1360+
1361+
# exact conversion from and to unitless
1362+
@test rad.(1:1:360) === (1:1:360)rad
1363+
@test mrad.(1:1:360) === (1_000:1_000:360_000)mrad
1364+
@test upreferred.(1:1:360) === 1:1:360
1365+
@test NoUnits.((1:1:360)rad) === 1:1:360
1366+
@test upreferred.((1:1:360)rad) === 1:1:360
1367+
@test NoUnits.((1:2:5)mrad) === 1//1000:1//500:1//200
1368+
@test upreferred.((1:2:5)mrad) === 1//1000:1//500:1//200
1369+
13341370
@test @inferred((1:2:5) .* cm .|> mm .|> ustrip) === 10:20:50
13351371
@test @inferred((1f0:2f0:5f0) .* cm .|> mm .|> ustrip) === 10f0:20f0:50f0
1372+
@test @inferred(StepRange{typeof(1m),typeof(1cm)}(1m,1cm,2m) .|> ustrip) === 1:1//100:2
1373+
@test @inferred(StepRangeLen{typeof(1f0m)}(1.0m, 1.0cm, 101) .|> ustrip) === StepRangeLen{Float32}(1.0, 0.01, 101)
1374+
@test @inferred(StepRangeLen{typeof(1.0m)}(Base.TwicePrecision(1.0m), Base.TwicePrecision(1.0cm), 101) .|> ustrip) === StepRangeLen{Float64}(Base.TwicePrecision(1.0), Base.TwicePrecision(0.01), 101)
1375+
@test @inferred((1:0.1:1.0) .|> ustrip) == 1:0.1:1.0
1376+
@test @inferred((1m:0.1m:1.0m) .|> ustrip) == 1:0.1:1.0
1377+
@test @inferred(StepRange{typeof(0m),typeof(1cm)}(1m,1cm,2m) .|> ustrip) === 1:1//100:2
1378+
@test @inferred(StepRangeLen{typeof(1f0m)}(1.0m, 1.0cm, 101) .|> ustrip) === StepRangeLen{Float32}(1.0, 0.01, 101)
1379+
@test @inferred(StepRangeLen{typeof(1.0m)}(Base.TwicePrecision(1.0m), Base.TwicePrecision(1.0cm), 101) .|> ustrip) === StepRangeLen{Float64}(Base.TwicePrecision(1.0), Base.TwicePrecision(0.01), 101)
1380+
@test @inferred(StepRangeLen{typeof(1.0mm)}(Base.TwicePrecision(1.0m), Base.TwicePrecision(1.0cm), 101) .|> ustrip) === 1000.0:10.0:2000.0
1381+
@test ustrip.(1:0.1:1.0) == 1:0.1:1.0
1382+
@test ustrip.(1m:0.1m:1.0m) == 1:0.1:1.0
13361383
end
13371384
@testset ">> quantities and non-quantities" begin
13381385
@test range(1, step=1m/mm, length=5) == 1:1000:4001
@@ -1400,10 +1447,6 @@ end
14001447
@test_throws ArgumentError range(step=1m, length=5)
14011448
end
14021449
end
1403-
@testset ">> broadcast ustrip" begin
1404-
@test ustrip.(1:0.1:1.0) == 1:0.1:1.0
1405-
@test ustrip.(1m:0.1m:1.0m) == 1:0.1:1.0
1406-
end
14071450
end
14081451
@testset "> Arrays" begin
14091452
@testset ">> Array multiplication" begin

0 commit comments

Comments
 (0)