Skip to content

Commit 4d21152

Browse files
committed
feat: specialize dispatches for faster concrete array generation
1 parent babeb7c commit 4d21152

File tree

2 files changed

+37
-1
lines changed

2 files changed

+37
-1
lines changed

src/Tracing.jl

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -519,7 +519,29 @@ end
519519

520520
@inline function to_rarray(@nospecialize(x); track_numbers::Union{Bool,Tuple}=())
521521
track_numbers isa Bool && (track_numbers = track_numbers ? (Number,) : ())
522+
return to_rarray_internal(x, track_numbers)
523+
end
524+
525+
@inline function to_rarray_internal(@nospecialize(x), track_numbers::Tuple)
522526
return make_tracer(OrderedIdDict(), x, (), Reactant.ArrayToConcrete; track_numbers)
523527
end
524528

525-
to_rarray(x::ReactantPrimitive) = ConcreteRArray(x)
529+
function to_rarray_internal(@nospecialize(::TracedRArray), ::Tuple)
530+
error("Cannot convert TracedRArray to ConcreteRArray")
531+
end
532+
@inline to_rarray_internal(@nospecialize(x::ConcreteRArray), ::Tuple) = x
533+
@inline function to_rarray_internal(
534+
@nospecialize(x::AbstractArray{<:ReactantPrimitive}), ::Tuple
535+
)
536+
return ConcreteRArray(x)
537+
end
538+
539+
@inline to_rarray_internal(@nospecialize(x::ConcreteRNumber), ::Tuple) = x
540+
@inline function to_rarray_internal(
541+
@nospecialize(x::ReactantPrimitive), track_numbers::Tuple
542+
)
543+
for T in track_numbers
544+
typeof(x) <: T && return ConcreteRNumber(x)
545+
end
546+
return x
547+
end

test/tracing.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,4 +100,18 @@ using Test
100100
end
101101
end
102102
end
103+
104+
@testset "specialized dispatches" begin
105+
@test @inferred Union{Float64,ConcreteRArray{Float64}} Reactant.to_rarray(
106+
1.0; track_numbers=(Number,)
107+
) isa ConcreteRNumber
108+
@test @inferred Reactant.to_rarray(1.0) isa Float64
109+
@test @inferred Reactant.to_rarray(rand(3)) isa ConcreteRArray
110+
111+
x_ra = Reactant.to_rarray(rand(3))
112+
@test @inferred Reactant.to_rarray(x_ra) isa ConcreteRArray
113+
114+
x_ra = Reactant.to_rarray(1.0; track_numbers=(Number,))
115+
@test @inferred Reactant.to_rarray(x_ra) isa ConcreteRNumber
116+
end
103117
end

0 commit comments

Comments
 (0)