Skip to content

Introduce transmute, transmute_type functions as future names for make_tracer, traced_type #1441

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion docs/src/api/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,11 @@ OptimizeCommunicationOptions
ShardyPropagationOptions
```

## Tracing customization
## Tracing

```@docs
Reactant.transmute
Reactant.transmute_type
Reactant.@skip_rewrite_func
Reactant.@skip_rewrite_type
```
Expand Down
10 changes: 5 additions & 5 deletions ext/ReactantCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1018,7 +1018,7 @@ function Reactant.make_tracer(
)
x = Base.unsafe_pointer_to_objref(Base.reinterpret(Ptr{Cvoid}, prev.ptr))
x = x::TracedRArray
Reactant.make_tracer(seen, x, path, mode; kwargs...)
Reactant.transmute(seen, x, path, mode; kwargs...)
return prev
end

Expand All @@ -1027,7 +1027,7 @@ function Reactant.make_tracer(
)
x = Base.unsafe_pointer_to_objref(Base.reinterpret(Ptr{Cvoid}, prev.ptr))
x = x::TracedRNumber
Reactant.make_tracer(seen, x, path, mode; kwargs...)
Reactant.transmute(seen, x, path, mode; kwargs...)
return prev
end

Expand Down Expand Up @@ -1093,7 +1093,7 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
kernelargsym = gensym("kernelarg")

for (i, prev) in enumerate(Any[func.f, args...])
Reactant.make_tracer(seen, prev, (kernelargsym, i), Reactant.NoStopTracedTrack)
Reactant.transmute(seen, prev, (kernelargsym, i), Reactant.NoStopTracedTrack)
end
wrapper_tys = MLIR.IR.Type[]
for arg in values(seen)
Expand Down Expand Up @@ -1420,7 +1420,7 @@ function Reactant.make_tracer(
end
error("Unsupported runtime $runtime")
end
TT = Reactant.traced_type(eltype(RT), Val(mode), track_numbers, sharding, runtime)
TT = Reactant.transmute_type(eltype(RT), Val(mode), track_numbers, sharding, runtime)
if TT === eltype(RT)
return prev
end
Expand All @@ -1430,7 +1430,7 @@ function Reactant.make_tracer(
for I in eachindex(prev)
if isassigned(prev, I)
pv = prev[I]
nv = Reactant.make_tracer(
nv = Reactant.transmute(
seen,
pv,
append_path(path, I),
Expand Down
2 changes: 1 addition & 1 deletion src/Compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1547,7 +1547,7 @@ function compile_mlir!(

concrete_seen = OrderedIdDict()

concrete_result = make_tracer(
concrete_result = Reactant.transmute(
concrete_seen, traced_result, ("result",), TracedToConcrete; runtime
)

Expand Down
20 changes: 10 additions & 10 deletions src/Ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1887,7 +1887,7 @@ end
traced_args = Vector{Any}(undef, N)

for (i, prev) in enumerate(args)
@inbounds traced_args[i] = Reactant.make_tracer(
@inbounds traced_args[i] = Reactant.transmute(
seen_args, prev, (), Reactant.NoStopTracedTrack; track_numbers
)
end
Expand Down Expand Up @@ -1972,14 +1972,14 @@ end
tb_traced_args = Vector{Any}(undef, N)
fb_traced_args = Vector{Any}(undef, N)
for i in 1:N
@inbounds tb_traced_args[i] = Reactant.make_tracer(
@inbounds tb_traced_args[i] = Reactant.transmute(
tb_seen_args,
args[i],
(true_fn_names[1], i),
Reactant.TracedSetPath;
track_numbers,
)
@inbounds fb_traced_args[i] = Reactant.make_tracer(
@inbounds fb_traced_args[i] = Reactant.transmute(
fb_seen_args,
args[i],
(false_fn_names[1], i),
Expand Down Expand Up @@ -2043,15 +2043,15 @@ end
end

seen_true_results = Reactant.OrderedIdDict()
traced_true_results = Reactant.make_tracer(
traced_true_results = Reactant.transmute(
seen_true_results,
tb_result,
(true_fn_names[2],),
Reactant.NoStopTracedTrack;
track_numbers,
)
for (i, arg) in enumerate(tb_traced_args)
Reactant.make_tracer(
Reactant.transmute(
seen_true_results,
arg,
(true_fn_names[3], i),
Expand Down Expand Up @@ -2108,15 +2108,15 @@ end
end

seen_false_results = Reactant.OrderedIdDict()
traced_false_results = Reactant.make_tracer(
traced_false_results = Reactant.transmute(
seen_false_results,
fb_result,
(false_fn_names[2],),
Reactant.NoStopTracedTrack;
track_numbers,
)
for (i, arg) in enumerate(fb_traced_args)
Reactant.make_tracer(
Reactant.transmute(
seen_false_results,
arg,
(false_fn_names[3], i),
Expand Down Expand Up @@ -2372,7 +2372,7 @@ end
@noinline function call(f, args...)
seen = Reactant.OrderedIdDict()
cache_key = []
Reactant.make_tracer(seen, (f, args...), cache_key, Reactant.TracedToTypes)
Reactant.transmute(seen, (f, args...), cache_key, Reactant.TracedToTypes)
cache = Reactant.Compiler.callcache()
if haskey(cache, cache_key)
# cache lookup:
Expand Down Expand Up @@ -2414,7 +2414,7 @@ end
end

seen_cache = Reactant.OrderedIdDict()
Reactant.make_tracer(
Reactant.transmute(
seen_cache,
args,
(), # we have to insert something here, but we remove it immediately below.
Expand All @@ -2438,7 +2438,7 @@ end
)

seen_results = Reactant.OrderedIdDict()
traced_result = Reactant.make_tracer(
traced_result = Reactant.transmute(
seen_results,
traced_result,
(), # we have to insert something here, but we remove it immediately below.
Expand Down
2 changes: 2 additions & 0 deletions src/Reactant.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ unwrapped_eltype(::TracedRNumber{T}) where {T} = T
unwrapped_eltype(::Type{<:AbstractArray{T,N}}) where {T,N} = unwrapped_eltype(T)
unwrapped_eltype(::AbstractArray{T,N}) where {T,N} = unwrapped_eltype(T)

# TODO replace `promote_trace_type` with `promote_transmute_type` on v0.3 release
promote_transmute_type(a::Type, b::Type) = promote_traced_type(a, b)
promote_traced_type(a::Type, b::Type) = Base.promote_type(a, b)

aos_to_soa(x::AbstractArray) = x
Expand Down
10 changes: 4 additions & 6 deletions src/TracedUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,7 @@ function prepare_mlir_fn_args(
seen_args0 = OrderedIdDict()
try
for i in 1:N
@inbounds traced_args[i] = Reactant.make_tracer(
@inbounds traced_args[i] = Reactant.transmute(
seen_args0, args[i], (argprefix, i), inmode; toscalar, runtime
)
end
Expand Down Expand Up @@ -623,13 +623,11 @@ function finalize_mlir_fn(
seen_results = OrderedIdDict()
MLIR.IR.activate!(fnbody)
traced_result = try
traced_result = Reactant.make_tracer(
seen_results, result, (resprefix,), outmode; runtime
)
traced_result = Reactant.transmute(seen_results, result, (resprefix,), outmode; runtime)

# marks buffers to be donated
for i in 1:N
Reactant.make_tracer(
Reactant.transmute(
seen_results,
traced_args[i],
(resargprefix, i),
Expand Down Expand Up @@ -1162,7 +1160,7 @@ function elem_apply(f, args::Vararg{Any,Nargs}) where {Nargs}
end

seen_results = OrderedIdDict()
traced2_result = Reactant.make_tracer(
traced2_result = Reactant.transmute(
seen_results, result, (), Reactant.TracedSetPath; tobatch=OutShape
)

Expand Down
Loading
Loading