Skip to content

Commit babeb7c

Browse files
authored
feat: allow type-casting numbers to tracednumbers (#209)
* feat: allow type-casting numbers to tracednumbers * chore: apply formatting suggestion
1 parent f570fcc commit babeb7c

File tree

2 files changed

+15
-0
lines changed

2 files changed

+15
-0
lines changed

src/TracedRNumber.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,11 @@ function Base.convert(::Type{TracedRNumber{T}}, x::Number) where {T}
5252
return promote_to(TracedRNumber{T}, x)
5353
end
5454

55+
TracedRNumber{T}(x::TracedRNumber{T}) where {T} = x
56+
function TracedRNumber{T}(x::Number) where {T}
57+
return promote_to(TracedRNumber{T}, x)
58+
end
59+
5560
function promote_to(::Type{TracedRNumber{T}}, rhs) where {T}
5661
if isa(rhs, TracedRNumber)
5762
rhs isa TracedRNumber{T} && return rhs

test/basic.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -488,3 +488,13 @@ end
488488
@test res3 isa ConcreteRArray
489489
end
490490
end
491+
492+
relu(x::T) where {T<:Number} = max(T(0), x)
493+
relu(x) = relu.(x)
494+
495+
@testset "type casting" begin
496+
x = randn(2, 10)
497+
x_ra = Reactant.to_rarray(x)
498+
499+
@test @jit(relu(x_ra)) relu(x)
500+
end

0 commit comments

Comments
 (0)