diff --git a/docs/src/api/api.md b/docs/src/api/api.md index 412a1bfc2b..7ad1442535 100644 --- a/docs/src/api/api.md +++ b/docs/src/api/api.md @@ -43,9 +43,11 @@ OptimizeCommunicationOptions ShardyPropagationOptions ``` -## Tracing customization +## Tracing ```@docs +Reactant.transmute +Reactant.transmute_type Reactant.@skip_rewrite_func Reactant.@skip_rewrite_type ``` diff --git a/ext/ReactantCUDAExt.jl b/ext/ReactantCUDAExt.jl index 3faf52c949..7b0ce4a299 100644 --- a/ext/ReactantCUDAExt.jl +++ b/ext/ReactantCUDAExt.jl @@ -1018,7 +1018,7 @@ function Reactant.make_tracer( ) x = Base.unsafe_pointer_to_objref(Base.reinterpret(Ptr{Cvoid}, prev.ptr)) x = x::TracedRArray - Reactant.make_tracer(seen, x, path, mode; kwargs...) + Reactant.transmute(seen, x, path, mode; kwargs...) return prev end @@ -1027,7 +1027,7 @@ function Reactant.make_tracer( ) x = Base.unsafe_pointer_to_objref(Base.reinterpret(Ptr{Cvoid}, prev.ptr)) x = x::TracedRNumber - Reactant.make_tracer(seen, x, path, mode; kwargs...) + Reactant.transmute(seen, x, path, mode; kwargs...) return prev end @@ -1093,7 +1093,7 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})( kernelargsym = gensym("kernelarg") for (i, prev) in enumerate(Any[func.f, args...]) - Reactant.make_tracer(seen, prev, (kernelargsym, i), Reactant.NoStopTracedTrack) + Reactant.transmute(seen, prev, (kernelargsym, i), Reactant.NoStopTracedTrack) end wrapper_tys = MLIR.IR.Type[] for arg in values(seen) @@ -1420,7 +1420,7 @@ function Reactant.make_tracer( end error("Unsupported runtime $runtime") end - TT = Reactant.traced_type(eltype(RT), Val(mode), track_numbers, sharding, runtime) + TT = Reactant.transmute_type(eltype(RT), Val(mode), track_numbers, sharding, runtime) if TT === eltype(RT) return prev end @@ -1430,7 +1430,7 @@ function Reactant.make_tracer( for I in eachindex(prev) if isassigned(prev, I) pv = prev[I] - nv = Reactant.make_tracer( + nv = Reactant.transmute( seen, pv, append_path(path, I), diff --git a/src/Compiler.jl b/src/Compiler.jl index 5cff4174aa..8438dd8906 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -1547,7 +1547,7 @@ function compile_mlir!( concrete_seen = OrderedIdDict() - concrete_result = make_tracer( + concrete_result = Reactant.transmute( concrete_seen, traced_result, ("result",), TracedToConcrete; runtime ) diff --git a/src/Ops.jl b/src/Ops.jl index 890cec0433..43019d458a 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -1887,7 +1887,7 @@ end traced_args = Vector{Any}(undef, N) for (i, prev) in enumerate(args) - @inbounds traced_args[i] = Reactant.make_tracer( + @inbounds traced_args[i] = Reactant.transmute( seen_args, prev, (), Reactant.NoStopTracedTrack; track_numbers ) end @@ -1972,14 +1972,14 @@ end tb_traced_args = Vector{Any}(undef, N) fb_traced_args = Vector{Any}(undef, N) for i in 1:N - @inbounds tb_traced_args[i] = Reactant.make_tracer( + @inbounds tb_traced_args[i] = Reactant.transmute( tb_seen_args, args[i], (true_fn_names[1], i), Reactant.TracedSetPath; track_numbers, ) - @inbounds fb_traced_args[i] = Reactant.make_tracer( + @inbounds fb_traced_args[i] = Reactant.transmute( fb_seen_args, args[i], (false_fn_names[1], i), @@ -2043,7 +2043,7 @@ end end seen_true_results = Reactant.OrderedIdDict() - traced_true_results = Reactant.make_tracer( + traced_true_results = Reactant.transmute( seen_true_results, tb_result, (true_fn_names[2],), @@ -2051,7 +2051,7 @@ end track_numbers, ) for (i, arg) in enumerate(tb_traced_args) - Reactant.make_tracer( + Reactant.transmute( seen_true_results, arg, (true_fn_names[3], i), @@ -2108,7 +2108,7 @@ end end seen_false_results = Reactant.OrderedIdDict() - traced_false_results = Reactant.make_tracer( + traced_false_results = Reactant.transmute( seen_false_results, fb_result, (false_fn_names[2],), @@ -2116,7 +2116,7 @@ end track_numbers, ) for (i, arg) in enumerate(fb_traced_args) - Reactant.make_tracer( + Reactant.transmute( seen_false_results, arg, (false_fn_names[3], i), @@ -2372,7 +2372,7 @@ end @noinline function call(f, args...) seen = Reactant.OrderedIdDict() cache_key = [] - Reactant.make_tracer(seen, (f, args...), cache_key, Reactant.TracedToTypes) + Reactant.transmute(seen, (f, args...), cache_key, Reactant.TracedToTypes) cache = Reactant.Compiler.callcache() if haskey(cache, cache_key) # cache lookup: @@ -2414,7 +2414,7 @@ end end seen_cache = Reactant.OrderedIdDict() - Reactant.make_tracer( + Reactant.transmute( seen_cache, args, (), # we have to insert something here, but we remove it immediately below. @@ -2438,7 +2438,7 @@ end ) seen_results = Reactant.OrderedIdDict() - traced_result = Reactant.make_tracer( + traced_result = Reactant.transmute( seen_results, traced_result, (), # we have to insert something here, but we remove it immediately below. diff --git a/src/Reactant.jl b/src/Reactant.jl index 02893f9516..f2b2284ddb 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -98,6 +98,8 @@ unwrapped_eltype(::TracedRNumber{T}) where {T} = T unwrapped_eltype(::Type{<:AbstractArray{T,N}}) where {T,N} = unwrapped_eltype(T) unwrapped_eltype(::AbstractArray{T,N}) where {T,N} = unwrapped_eltype(T) +# TODO replace `promote_trace_type` with `promote_transmute_type` on v0.3 release +promote_transmute_type(a::Type, b::Type) = promote_traced_type(a, b) promote_traced_type(a::Type, b::Type) = Base.promote_type(a, b) aos_to_soa(x::AbstractArray) = x diff --git a/src/TracedUtils.jl b/src/TracedUtils.jl index 1203c02b43..c59703c239 100644 --- a/src/TracedUtils.jl +++ b/src/TracedUtils.jl @@ -431,7 +431,7 @@ function prepare_mlir_fn_args( seen_args0 = OrderedIdDict() try for i in 1:N - @inbounds traced_args[i] = Reactant.make_tracer( + @inbounds traced_args[i] = Reactant.transmute( seen_args0, args[i], (argprefix, i), inmode; toscalar, runtime ) end @@ -623,13 +623,11 @@ function finalize_mlir_fn( seen_results = OrderedIdDict() MLIR.IR.activate!(fnbody) traced_result = try - traced_result = Reactant.make_tracer( - seen_results, result, (resprefix,), outmode; runtime - ) + traced_result = Reactant.transmute(seen_results, result, (resprefix,), outmode; runtime) # marks buffers to be donated for i in 1:N - Reactant.make_tracer( + Reactant.transmute( seen_results, traced_args[i], (resargprefix, i), @@ -1162,7 +1160,7 @@ function elem_apply(f, args::Vararg{Any,Nargs}) where {Nargs} end seen_results = OrderedIdDict() - traced2_result = Reactant.make_tracer( + traced2_result = Reactant.transmute( seen_results, result, (), Reactant.TracedSetPath; tobatch=OutShape ) diff --git a/src/Tracing.jl b/src/Tracing.jl index a256b51308..0dae07a774 100644 --- a/src/Tracing.jl +++ b/src/Tracing.jl @@ -852,6 +852,28 @@ const traced_type_cache = Dict{Tuple{TraceMode,Type,Any},Dict{Type,Type}}() # $(Expr(:meta, :generated, traced_type_generator)) # end +""" + transmute_type(args...; kwargs...) + +Return the adapted typed used for tracing. + +!!! warning + This is the new name for the `traced_type` function, which is deprecated and will be removed on v0.3 release. + If you extend it with new methods, transition to this new function instead on the next breaking release. +""" +transmute_type(args...; kwargs...) = traced_type(args...; kwargs...) + +""" + transmute(args...; kwargs...) + +Adapt the object to be suitable for tracing, returning a new object if needed. + +!!! warning + This is the new name for the `make_tracer` function, which is deprecated and will be removed on v0.3 release. + If you extend it with new methods, transition to this new function instead on the next breaking release. +""" +transmute(args...; kwargs...) = make_tracer(args...; kwargs...) + Base.@assume_effects :total @inline function traced_type( T::Type, ::Val{mode}, track_numbers::Type, sharding, runtime ) where {mode} @@ -945,7 +967,7 @@ Base.@nospecializeinfer function make_tracer_via_immutable_constructor( push!(path, RT) seen[prev] = VisitedObject(length(seen) + 1) end - TT = traced_type(RT, Val(mode), track_numbers, sharding, runtime) + TT = transmute_type(RT, Val(mode), track_numbers, sharding, runtime) @assert !Base.isabstracttype(RT) @assert Base.isconcretetype(RT) nf = fieldcount(RT) @@ -966,7 +988,7 @@ Base.@nospecializeinfer function make_tracer_via_immutable_constructor( if isdefined(prev, i) newpath = mode == TracedToTypes ? path : append_path(path, i) xi = Base.getfield(prev, i) - xi2 = make_tracer( + xi2 = transmute( seen, xi, newpath, @@ -1029,7 +1051,7 @@ Base.@nospecializeinfer function make_tracer_unknown( push!(path, RT) seen[prev] = VisitedObject(length(seen) + 1) end - TT = traced_type(RT, Val(mode), track_numbers, sharding, runtime) + TT = transmute_type(RT, Val(mode), track_numbers, sharding, runtime) @assert !Base.isabstracttype(RT) @assert Base.isconcretetype(RT) nf = fieldcount(RT) @@ -1050,7 +1072,7 @@ Base.@nospecializeinfer function make_tracer_unknown( if isdefined(prev, i) newpath = mode == TracedToTypes ? path : append_path(path, i) xi = Base.getfield(prev, i) - xi2 = make_tracer( + xi2 = transmute( seen, xi, newpath, @@ -1087,7 +1109,7 @@ Base.@nospecializeinfer function make_tracer_unknown( if isdefined(prev, i) newpath = mode == TracedToTypes ? path : append_path(path, i) xi = Base.getfield(prev, i) - xi2 = make_tracer( + xi2 = transmute( seen, xi, newpath, @@ -1505,13 +1527,13 @@ Base.@nospecializeinfer function make_tracer( Sharding.is_sharded(sharding) && error("Cannot specify sharding for Complex") if mode == TracedToTypes push!(path, Core.Typeof(prev)) - make_tracer(seen, prev.re, path, mode; kwargs...) - make_tracer(seen, prev.im, path, mode; kwargs...) + transmute(seen, prev.re, path, mode; kwargs...) + transmute(seen, prev.im, path, mode; kwargs...) return nothing end return Complex( - make_tracer(seen, prev.re, append_path(path, :re), mode; kwargs...), - make_tracer(seen, prev.im, append_path(path, :im), mode; kwargs...), + transmute(seen, prev.re, append_path(path, :re), mode; kwargs...), + transmute(seen, prev.im, append_path(path, :im), mode; kwargs...), ) end @@ -1556,7 +1578,7 @@ Base.@nospecializeinfer function make_tracer( for I in eachindex(prev) if isassigned(prev, I) pv = prev[I] - make_tracer( + transmute( seen, pv, path, @@ -1572,14 +1594,14 @@ Base.@nospecializeinfer function make_tracer( end return nothing end - TT = traced_type(eltype(RT), Val(mode), track_numbers, sharding, runtime) + TT = transmute_type(eltype(RT), Val(mode), track_numbers, sharding, runtime) newa = Array{TT,ndims(RT)}(undef, size(prev)) seen[prev] = newa same = true for I in eachindex(prev) if isassigned(prev, I) pv = prev[I] - nv = make_tracer( + nv = transmute( seen, pv, append_path(path, I), @@ -1641,7 +1663,7 @@ Base.@nospecializeinfer function make_tracer( elseif mode == TracedToTypes push!(path, RT) for (k, v) in prev - make_tracer( + transmute( seen, k, path, @@ -1653,7 +1675,7 @@ Base.@nospecializeinfer function make_tracer( client, kwargs..., ) - make_tracer( + transmute( seen, v, path, @@ -1668,12 +1690,12 @@ Base.@nospecializeinfer function make_tracer( end return nothing end - Value2 = traced_type(Value, Val(mode), track_numbers, sharding, runtime) + Value2 = transmute_type(Value, Val(mode), track_numbers, sharding, runtime) newa = Dict{Key,Value2}() seen[prev] = newa same = true for (k, v) in prev - nv = make_tracer( + nv = transmute( seen, v, append_path(path, k), @@ -1707,7 +1729,7 @@ Base.@nospecializeinfer function make_tracer( if mode == TracedToTypes push!(path, RT) for (i, v) in enumerate(prev) - make_tracer( + transmute( seen, v, path, mode; sharding=Base.getproperty(sharding, i), kwargs... ) end @@ -1715,7 +1737,7 @@ Base.@nospecializeinfer function make_tracer( end return ( ( - make_tracer( + transmute( seen, v, append_path(path, i), @@ -1744,15 +1766,15 @@ Base.@nospecializeinfer function make_tracer( if mode == TracedToTypes push!(path, NT) for i in 1:length(A) - make_tracer( + transmute( seen, Base.getfield(prev, i), path, mode; track_numbers, sharding, kwargs... ) end return nothing end - return NamedTuple{A,traced_type(RT, Val(mode), track_numbers, sharding, runtime)}(( + return NamedTuple{A,transmute_type(RT, Val(mode), track_numbers, sharding, runtime)}(( ( - make_tracer( + transmute( seen, Base.getfield(prev, i), append_path(path, i), @@ -1784,7 +1806,7 @@ Base.@nospecializeinfer function make_tracer( if mode == TracedToTypes push!(path, Core.Box) - return make_tracer(seen, prev2, path, mode; sharding, kwargs...) + return transmute(seen, prev2, path, mode; sharding, kwargs...) end if mode != NoStopTracedTrack && haskey(seen, prev) return seen[prev] @@ -1795,7 +1817,7 @@ Base.@nospecializeinfer function make_tracer( end res = Core.Box(prev2) seen[prev] = res - tr = make_tracer( + tr = transmute( seen, prev2, append_path(path, :contents), @@ -1845,7 +1867,7 @@ end @nospecialize(device), @nospecialize(client) ) - return make_tracer( + return transmute( OrderedIdDict(), x, (), @@ -2012,14 +2034,14 @@ function Reactant.make_tracer( Reactant.Sharding.is_sharded(sharding) && error("Cannot specify sharding for UnitRange") if mode == Reactant.TracedToTypes push!(path, Core.Typeof(prev)) - make_tracer(seen, prev.start, path, mode; kwargs...) - make_tracer(seen, prev.stop, path, mode; kwargs...) + transmute(seen, prev.start, path, mode; kwargs...) + transmute(seen, prev.stop, path, mode; kwargs...) return nothing end - newstart = Reactant.make_tracer( + newstart = Reactant.transmute( seen, prev.start, Reactant.append_path(path, :start), mode; kwargs... ) - newstop = Reactant.make_tracer( + newstop = Reactant.transmute( seen, prev.stop, Reactant.append_path(path, :stop), mode; kwargs... ) if typeof(newstart) == typeof(prev.start) && typeof(newstop) == typeof(prev.stop) @@ -2061,22 +2083,22 @@ function Reactant.make_tracer( error("Cannot specify sharding for StepRangeLen") if mode == Reactant.TracedToTypes push!(path, Core.Typeof(prev)) - make_tracer(seen, prev.ref, path, mode; sharding, kwargs...) - make_tracer(seen, prev.step, path, mode; sharding, kwargs...) - make_tracer(seen, prev.len, path, mode; sharding, kwargs...) - make_tracer(seen, prev.offset, path, mode; sharding, kwargs...) + transmute(seen, prev.ref, path, mode; sharding, kwargs...) + transmute(seen, prev.step, path, mode; sharding, kwargs...) + transmute(seen, prev.len, path, mode; sharding, kwargs...) + transmute(seen, prev.offset, path, mode; sharding, kwargs...) return nothing end - newref = Reactant.make_tracer( + newref = Reactant.transmute( seen, prev.ref, Reactant.append_path(path, :ref), mode; sharding, kwargs... ) - newstep = Reactant.make_tracer( + newstep = Reactant.transmute( seen, prev.step, Reactant.append_path(path, :step), mode; sharding, kwargs... ) - newlen = Reactant.make_tracer( + newlen = Reactant.transmute( seen, prev.len, Reactant.append_path(path, :len), mode; sharding, kwargs... ) - newoffset = Reactant.make_tracer( + newoffset = Reactant.transmute( seen, prev.offset, Reactant.append_path(path, :offset), mode; sharding, kwargs... ) if typeof(newref) == typeof(prev.ref) && diff --git a/test/compile.jl b/test/compile.jl index 73bffad47c..2341254e01 100644 --- a/test/compile.jl +++ b/test/compile.jl @@ -157,7 +157,7 @@ end seen, prev::MockTestCustomPath, path, mode; kwargs... ) custom_path = Reactant.append_path(path, (; custom_id=1)) - traced_x = Reactant.make_tracer(seen, prev.x, custom_path, mode; kwargs...) + traced_x = Reactant.transmute(seen, prev.x, custom_path, mode; kwargs...) return MockTestCustomPath(traced_x) end diff --git a/test/tracing.jl b/test/tracing.jl index 21a4f6559b..3f18aff7ff 100644 --- a/test/tracing.jl +++ b/test/tracing.jl @@ -1,6 +1,6 @@ using Reactant using Reactant: - traced_type, + transmute_type, TracedRArray, TracedRNumber, ConcreteToTraced, @@ -205,7 +205,7 @@ end ), (Wrapper, Wrapper, Wrapper), ] - tracedty = traced_type( + tracedty = transmute_type( origty, Val(ConcreteToTraced), Union{}, @@ -214,7 +214,7 @@ end ) @test tracedty == targetty - tracedty2 = traced_type( + tracedty2 = transmute_type( origty, Val(ConcreteToTraced), ReactantPrimitive, @@ -230,7 +230,7 @@ end TracedRArray{Float64,2}, TracedRArray{Float64,3}, ] - @test_throws Union{ErrorException,String} traced_type( + @test_throws Union{ErrorException,String} transmute_type( type, Val(ConcreteToTraced), Union{}, @@ -239,12 +239,12 @@ end ) end end - @testset "traced_type exceptions" begin + @testset "transmute_type exceptions" begin struct Node x::Vector{Float64} y::Union{Nothing,Node} end - @test_throws NoFieldMatchError traced_type( + @test_throws NoFieldMatchError transmute_type( Node, Val(ArrayToConcrete), Union{},