Skip to content

Commit a17315c

Browse files
feat: specialize dispatches for faster concrete array generation (#213)
* feat: specialize dispatches for faster concrete array generation * chore: apply formatting suggestion Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent b6ee968 commit a17315c

File tree

2 files changed

+37
-1
lines changed

2 files changed

+37
-1
lines changed

src/Tracing.jl

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -563,7 +563,29 @@ end
563563

564564
@inline function to_rarray(@nospecialize(x); track_numbers::Union{Bool,Tuple}=())
565565
track_numbers isa Bool && (track_numbers = track_numbers ? (Number,) : ())
566+
return to_rarray_internal(x, track_numbers)
567+
end
568+
569+
@inline function to_rarray_internal(@nospecialize(x), track_numbers::Tuple)
566570
return make_tracer(OrderedIdDict(), x, (), Reactant.ArrayToConcrete; track_numbers)
567571
end
568572

569-
to_rarray(x::ReactantPrimitive) = ConcreteRArray(x)
573+
function to_rarray_internal(@nospecialize(::TracedRArray), ::Tuple)
574+
return error("Cannot convert TracedRArray to ConcreteRArray")
575+
end
576+
@inline to_rarray_internal(@nospecialize(x::ConcreteRArray), ::Tuple) = x
577+
@inline function to_rarray_internal(
578+
@nospecialize(x::AbstractArray{<:ReactantPrimitive}), ::Tuple
579+
)
580+
return ConcreteRArray(x)
581+
end
582+
583+
@inline to_rarray_internal(@nospecialize(x::ConcreteRNumber), ::Tuple) = x
584+
@inline function to_rarray_internal(
585+
@nospecialize(x::ReactantPrimitive), track_numbers::Tuple
586+
)
587+
for T in track_numbers
588+
typeof(x) <: T && return ConcreteRNumber(x)
589+
end
590+
return x
591+
end

test/tracing.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,4 +100,18 @@ using Test
100100
end
101101
end
102102
end
103+
104+
@testset "specialized dispatches" begin
105+
@test @inferred Union{Float64,ConcreteRArray{Float64}} Reactant.to_rarray(
106+
1.0; track_numbers=(Number,)
107+
) isa ConcreteRNumber
108+
@test @inferred Reactant.to_rarray(1.0) isa Float64
109+
@test @inferred Reactant.to_rarray(rand(3)) isa ConcreteRArray
110+
111+
x_ra = Reactant.to_rarray(rand(3))
112+
@test @inferred Reactant.to_rarray(x_ra) isa ConcreteRArray
113+
114+
x_ra = Reactant.to_rarray(1.0; track_numbers=(Number,))
115+
@test @inferred Reactant.to_rarray(x_ra) isa ConcreteRNumber
116+
end
103117
end

0 commit comments

Comments
 (0)