Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 135 additions & 0 deletions src/Tracing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -598,6 +598,54 @@ Base.@nospecializeinfer function traced_type_inner(
}
end

@static if isdefined(Core, :Memory)
Base.@nospecializeinfer function traced_type_inner(
@nospecialize(A::Type{<:Core.Memory}),
seen,
@nospecialize(mode::TraceMode),
@nospecialize(track_numbers::Type),
@nospecialize(ndevices),
@nospecialize(runtime)
)
T = eltype(A)
if A isa UnionAll
A´ = Base.unwrap_unionall(A)
traced_T = traced_type_inner(T, seen, mode, track_numbers, ndevices, runtime)

A_wrapper = A´.name.wrapper
if mode == ArrayToConcrete && T <: ReactantPrimitive
if runtime isa Val{:PJRT}
A_wrapper = ConcretePJRTArray{T,1} where {T}
elseif runtime isa Val{:IFRT}
A_wrapper = ConcreteIFRTArray{T,1} where {T}
else
error("Unsupported runtime $runtime")
end
end

A´´ = A_wrapper{traced_T}
return T isa Core.TypeVar ? UnionAll(traced_T, A´´) : A´´
else
if mode == ArrayToConcrete && T <: ReactantPrimitive
runtime isa Val{:PJRT} && return ConcretePJRTArray{T,1,_unwrap_val(ndevices)}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
runtime isa Val{:PJRT} && return ConcretePJRTArray{T,1,_unwrap_val(ndevices)}
runtime isa Val{:PJRT} &&
return ConcretePJRTArray{T,1,_unwrap_val(ndevices)}

if runtime isa Val{:IFRT}
# For IFRT, when ndevices is 1, it's not sharded
if ndevices isa Val{1}
return ConcreteIFRTArray{T,1,Nothing}
else
return ConcreteIFRTArray{T,1}
end
end
error("Unsupported runtime $runtime")
else
return Core.Memory{
traced_type_inner(T, seen, mode, track_numbers, ndevices, runtime),N
}
end
end
end
end

Base.@nospecializeinfer function traced_type_inner(
@nospecialize(PT::Type{ReactantRNG{S}}),
seen,
Expand Down Expand Up @@ -1849,6 +1897,93 @@ Base.@nospecializeinfer function make_tracer(
end
end

Base.@nospecializeinfer function make_tracer(
seen,
@nospecialize(prev::Memory),
@nospecialize(path),
mode;
@nospecialize(track_numbers::Type = Union{}),
@nospecialize(sharding = Sharding.NoSharding()),
@nospecialize(runtime = nothing),
@nospecialize(device = nothing),
@nospecialize(client = nothing),
kwargs...,
)
RT = Core.Typeof(prev)
if mode != NoStopTracedTrack && haskey(seen, prev)
if mode == TracedToTypes
visited = seen[prev]
push!(path, visited)
return nothing
end
return seen[prev]
end
if eltype(RT) <: ReactantPrimitive
if mode == ArrayToConcrete
runtime isa Val{:PJRT} &&
(return seen[prev] = ConcretePJRTArray(prev; sharding, device, client))
runtime isa Val{:IFRT} &&
(return seen[prev] = ConcreteIFRTArray(prev; sharding, device, client))
error("Unsupported runtime $runtime")
elseif mode == TracedToTypes
# Original array can get mutated so we store a copy:
push!(path, copy(prev))
seen[prev] = VisitedObject(length(seen) + 1)
return nothing
end
elseif mode == TracedToTypes
push!(path, RT)
for I in eachindex(prev)
if isassigned(prev, I)
pv = prev[I]
make_tracer(
seen,
pv,
path,
mode;
track_numbers,
sharding,
runtime,
device,
client,
kwargs...,
)
end
end
return nothing
end
TT = traced_type(eltype(RT), Val(mode), track_numbers, sharding, runtime)
newa = Memory{TT}(undef, length(prev))
seen[prev] = newa
same = true
for I in eachindex(prev)
if isassigned(prev, I)
pv = prev[I]
nv = make_tracer(
seen,
pv,
append_path(path, I),
mode;
track_numbers,
sharding=Base.getproperty(sharding, I),
runtime,
device,
client,
kwargs...,
)
if pv !== nv
same = false
end
@inbounds newa[I] = nv
end
end
if same
seen[prev] = prev
return prev
end
return newa
end

Base.@nospecializeinfer function make_tracer(
seen,
@nospecialize(prev::Dict{Key,Value}),
Expand Down
28 changes: 27 additions & 1 deletion test/core/tracing.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using Reactant, Test
using Reactant:
traced_type,
make_tracer,
TracedRArray,
TracedRNumber,
ConcreteToTraced,
Expand Down Expand Up @@ -37,7 +38,7 @@ struct MyFix{N,FT,XT} <: Base.Function
end

@testset "trace_type (mode = ConcreteToTraced)" begin
@testset "$origty" for (origty, targetty, targettynum) in [
testsuite = [
(Any, Any, Any),
(Real, Real, Real),
(Module, Module, Module),
Expand Down Expand Up @@ -272,6 +273,14 @@ end
),
(Wrapper, Wrapper, Wrapper),
]

if isdefined(Core, :Memory)
push!(testsuite, (Memory{UInt8}, Memory{UInt8}, Memory{TracedRNumber{UInt8}}))
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wsmoses in the other PR, I think you said that converting here to TracedRNumber needs to be discussed because it seems a too edgy case. If so, I agree, but right now it feels awkward that we treat Memory the same as Array except for this point.

push!(testsuite, (Memory{ConcreteRArray{Float64,1}}, Memory{TracedRArray{Float64,1}}, Memory{TracedRArray{Float64,1}}))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
push!(testsuite, (Memory{ConcreteRArray{Float64,1}}, Memory{TracedRArray{Float64,1}}, Memory{TracedRArray{Float64,1}}))
push!(
testsuite,
(
Memory{ConcreteRArray{Float64,1}},
Memory{TracedRArray{Float64,1}},
Memory{TracedRArray{Float64,1}},
),
)

push!(testsuite, (Memory, Memory, Memory))
end

@testset "$origty" for (origty, targetty, targettynum) in testsuite
tracedty = traced_type(
origty,
Val(ConcreteToTraced),
Expand Down Expand Up @@ -423,3 +432,20 @@ end
@test iszero(t.compile_time)
end
end

@testset "make_tracer" begin
sharding = Sharding.NoSharding()
rt = Reactant.XLA.runtime()

m = Memory{Int64}(undef, 1)

# Memory maps to ConcreteRArray in ArrayToConcrete mode
mt = make_tracer(IdDict(), m, Val(ArrayToConcrete), Union{}, sharding, rt)
@test mt isa ConcreteRArray{Int64,1}
@test mt[1] == m[1]

# Memory should map to itself in other modes
mt = make_tracer(IdDict(), m, Val(ConcreteToTraced), Union{}, sharding, rt)
@test mt isa Memory{Int64}
@test mt === m
end
Loading