Skip to content

Commit 7d3511b

Browse files
authored
feat: add Base.Complex overloads (#1385)
1 parent 5dcf9bf commit 7d3511b

File tree

3 files changed

+32
-1
lines changed

3 files changed

+32
-1
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Reactant"
22
uuid = "3c362404-f566-11ee-1572-e11a4b42c853"
33
authors = ["William Moses <[email protected]>", "Valentin Churavy <[email protected]>", "Sergio Sánchez Ramírez <[email protected]>", "Paul Berg <[email protected]>", "Avik Pal <[email protected]>", "Mosè Giordano <[email protected]>"]
4-
version = "0.2.129"
4+
version = "0.2.130"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/TracedRNumber.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,22 @@ function TracedUtils.promote_to(::TracedRNumber{T}, rhs) where {T}
119119
return TracedUtils.promote_to(TracedRNumber{T}, rhs)
120120
end
121121

122+
for (aT, bT) in (
123+
(TracedRNumber{<:Real}, Real),
124+
(Real, TracedRNumber{<:Real}),
125+
(TracedRNumber{<:Real}, TracedRNumber{<:Real}),
126+
)
127+
@eval function Base.Complex(a::$aT, b::$bT)
128+
T = promote_type(unwrapped_eltype(a), unwrapped_eltype(b))
129+
a = TracedUtils.promote_to(TracedRNumber{T}, a)
130+
b = TracedUtils.promote_to(TracedRNumber{T}, b)
131+
return Ops.complex(a, b)
132+
end
133+
end
134+
135+
Base.Complex(x::TracedRNumber{<:Real}) = Ops.complex(x, zero(x))
136+
Base.Complex(x::TracedRNumber{<:Complex}) = x
137+
122138
for (jlop, hloop) in (
123139
(:(Base.min), :minimum),
124140
(:(Base.max), :maximum),

test/complex.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,3 +103,18 @@ end
103103
x_ra = Reactant.to_rarray(x)
104104
@test @jit(sum(abs2, x_ra)) sum(abs2, x)
105105
end
106+
107+
@testset "create complex numbers" begin
108+
x = randn(ComplexF32)
109+
x_ra = Reactant.to_rarray(x; track_numbers=true)
110+
@test @jit(Complex(x_ra)) == x_ra
111+
112+
x = randn(Float32)
113+
y = randn(Float64)
114+
x_ra = Reactant.to_rarray(x; track_numbers=true)
115+
y_ra = Reactant.to_rarray(y; track_numbers=true)
116+
@test @jit(Complex(x_ra, y_ra)) == Complex(x, y)
117+
@test @jit(Complex(x_ra, y)) == Complex(x, y)
118+
@test @jit(Complex(x, y_ra)) == Complex(x, y)
119+
@test @jit(Complex(x_ra)) == Complex(x) == @jit(Complex(x_ra, 0))
120+
end

0 commit comments

Comments
 (0)