Skip to content

Commit cc8f0cd

Browse files
committed
Don't try and convert to FloatX except if Integer or AbstractFloat
1 parent 424a0b7 commit cc8f0cd

File tree

3 files changed

+59
-5
lines changed

3 files changed

+59
-5
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRulesCore"
22
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
3-
version = "1.0.1"
3+
version = "1.0.2"
44

55
[deps]
66
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"

src/projection.jl

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -135,13 +135,44 @@ ProjectTo(::Real) = ProjectTo{Real}()
135135
ProjectTo(::Complex) = ProjectTo{Complex}()
136136
ProjectTo(::Number) = ProjectTo{Number}()
137137
for T in (Float16, Float32, Float64, ComplexF16, ComplexF32, ComplexF64)
138-
# Preserve low-precision floats as accidental promotion is a common perforance bug
138+
# Preserve low-precision floats as accidental promotion is a common perforance bug
139139
@eval ProjectTo(::$T) = ProjectTo{$T}()
140140
end
141141
ProjectTo(x::Integer) = ProjectTo(float(x))
142142
ProjectTo(x::Complex{<:Integer}) = ProjectTo(float(x))
143-
(::ProjectTo{T})(dx::Number) where {T<:Number} = convert(T, dx)
144-
(::ProjectTo{T})(dx::Number) where {T<:Real} = convert(T, real(dx))
143+
144+
# Preserve low-precision floats as accidental promotion is a common perforance bug
145+
(::ProjectTo{T})(dx::AbstractFloat) where T<:AbstractFloat = convert(T, dx)
146+
(::ProjectTo{T})(dx::Integer) where T<:AbstractFloat = convert(T, dx)
147+
148+
149+
# We asked for a number/real and they gave use one. We did ask for a particular concrete
150+
# type, but that is just for the preserving low precision floats, which is handled above.
151+
# Any Number/Real actually occupies the same subspace, so we can trust them.
152+
# In particular, this makes weirder Real subtypes that are not simply the values like
153+
# ForwardDiff.Dual and Symbolics.Sym work, because we stay out of their way.
154+
(::ProjectTo{<:Number})(dx::Number) where {T<:Number} = dx
155+
(::ProjectTo{<:Real})(dx::Real) = dx
156+
157+
(::ProjectTo{T})(dx::Complex) where T<:Real = ProjectTo(zero(T))(real(dx))
158+
159+
# Complex
160+
function (proj::ProjectTo{<:Complex{<:AbstractFloat}})(
161+
dx::Complex{<:Union{AbstractFloat,Integer}}
162+
)
163+
# in this case we can just convert as we know we are dealing with
164+
# boring floating point types or integers
165+
return convert(project_type(proj), dx)
166+
end
167+
# Pass though non-AbstractFloat to project each component
168+
function (::ProjectTo{<:Complex{T}})(dx::Complex) where T
169+
project = ProjectTo(zero(T))
170+
return Complex(project(real(dx)), project(imag(dx)))
171+
end
172+
function (::ProjectTo{<:Complex{T}})(dx::Real) where T
173+
project = ProjectTo(zero(T))
174+
return Complex(project(dx), project(zero(dx)))
175+
end
145176

146177
# Arrays
147178
# If we don't have a more specialized `ProjectTo` rule, we just assume that there is

test/projection.jl

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,15 @@ using ChainRulesCore, Test
22
using LinearAlgebra, SparseArrays
33
using OffsetArrays, BenchmarkTools
44

5+
# Like ForwardDiff.jl's Dual
6+
struct Dual{T<:Real} <: Real
7+
value::T
8+
partial::T
9+
end
10+
Base.real(x::Dual) = x
11+
Base.float(x::Dual) = Dual(float(x.value), float(x.partial))
12+
Base.zero(x::Dual) = Dual(zero(x.value), zero(x.partial))
13+
514
@testset "projection" begin
615

716
#####
@@ -12,14 +21,28 @@ using OffsetArrays, BenchmarkTools
1221
# real / complex
1322
@test ProjectTo(1.0)(2.0 + 3im) === 2.0
1423
@test ProjectTo(1.0 + 2.0im)(3.0) === 3.0 + 0.0im
24+
@test ProjectTo(2.0+3.0im)(1+1im) === 1.0+1.0im
25+
@test ProjectTo(2.0)(1+1im) === 1.0
26+
1527

1628
# storage
17-
@test ProjectTo(1)(pi) === Float64(pi)
29+
@test ProjectTo(1)(pi) === pi
1830
@test ProjectTo(1 + im)(pi) === ComplexF64(pi)
1931
@test ProjectTo(1//2)(3//4) === 3//4
2032
@test ProjectTo(1.0f0)(1 / 2) === 0.5f0
2133
@test ProjectTo(1.0f0 + 2im)(3) === 3.0f0 + 0im
2234
@test ProjectTo(big(1.0))(2) === 2
35+
@test ProjectTo(1.0)(2) === 2.0
36+
end
37+
38+
@testset "Dual" begin # some weird Real subtype that we should basically leave alone
39+
@test ProjectTo(1.0)(Dual(1.0, 2.0)) isa Dual
40+
@test ProjectTo(1.0)(Dual(1, 2)) isa Dual
41+
@test ProjectTo(1.0 + 1im)(Dual(1.0, 2.0)) isa Complex{<:Dual}
42+
@test ProjectTo(1.0 + 1im)(
43+
Complex(Dual(1.0, 2.0), Dual(1.0, 2.0))
44+
) isa Complex{<:Dual}
45+
@test ProjectTo(1.0)(Complex(Dual(1.0, 2.0), Dual(1.0, 2.0))) isa Dual
2346
end
2447

2548
@testset "Base: arrays of numbers" begin

0 commit comments

Comments
 (0)