Skip to content

Commit 0139886

Browse files
Fix pathing (#1337)
* Fix pathing * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent bc20553 commit 0139886

File tree

8 files changed

+219
-13
lines changed

8 files changed

+219
-13
lines changed

src/Compiler.jl

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1378,8 +1378,10 @@ function compile_mlir!(
13781378
seen_args,
13791379
ret,
13801380
linear_args,
1381+
skipped_args,
13811382
in_tys,
13821383
linear_results,
1384+
skipped_results,
13831385
is_sharded,
13841386
) = mlir_fn_res
13851387
compiled_f = mlir_fn_res.f
@@ -2107,8 +2109,10 @@ function compile_mlir!(
21072109
seen_args,
21082110
ret,
21092111
linear_args,
2112+
skipped_args,
21102113
in_tys,
21112114
linear_results2,
2115+
skipped_results,
21122116
mlir_fn_res.num_partitions,
21132117
mlir_fn_res.num_replicas,
21142118
mlir_fn_res.is_sharded,
@@ -2399,7 +2403,6 @@ The _linearized arguments_ do not directly refer to the are the arguments that
23992403
function codegen_flatten!(
24002404
linear_args,
24012405
seen_args,
2402-
result_stores,
24032406
is_sharded::Bool,
24042407
linear_parameter_shardings,
24052408
client,
@@ -2804,14 +2807,16 @@ function codegen_unflatten!(
28042807
if length(path) > 0
28052808
needs_cache_dict = true
28062809
# XXX: we might need to handle sharding here
2807-
unflatcode = :(traced_setfield_buffer!(
2808-
$(runtime),
2809-
$(cache_dict),
2810-
$(concrete_res_name_final),
2811-
$(unflatcode),
2812-
$(Meta.quot(path[end])),
2813-
$(path),
2814-
))
2810+
unflatcode = quote
2811+
traced_setfield_buffer!(
2812+
$(runtime),
2813+
$(cache_dict),
2814+
$(concrete_res_name_final),
2815+
$(unflatcode),
2816+
$(Meta.quot(path[end])),
2817+
$(path),
2818+
)
2819+
end
28152820
else
28162821
unflatcode = :(traced_setfield!(
28172822
$(unflatcode), :data, $(concrete_res_name_final), $(path)
@@ -3280,7 +3285,6 @@ function compile(f, args; sync=false, kwargs...)
32803285
flatten_arg_names, flatten_code, resharded_inputs = codegen_flatten!(
32813286
linear_args,
32823287
seen_args,
3283-
result_stores,
32843288
mlir_fn_res.is_sharded,
32853289
XLA.get_parameter_shardings(exec), # TODO: use the same workflow as output shardings to parse the tensor sharding attributes directly if possible
32863290
client,

src/Ops.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2329,7 +2329,7 @@ end
23292329
end
23302330

23312331
@noinline function call(f, args...)
2332-
seen = Dict()
2332+
seen = Reactant.OrderedIdDict()
23332333
cache_key = []
23342334
Reactant.make_tracer(seen, (f, args...), cache_key, Reactant.TracedToTypes)
23352335
cache = Reactant.Compiler.callcache()

src/TracedRArray.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
module TracedRArrayOverrides
22

33
using Adapt: WrappedArray
4+
using Adapt: Adapt
45
using Base.Broadcast
56
using Base.Broadcast: BroadcastStyle, Broadcasted, AbstractArrayStyle, instantiate
67

@@ -506,6 +507,14 @@ function Base.similar(::TracedRArray, ::Type{T}, dims::Dims{N}) where {T,N}
506507
return Ops.fill(zero(unwrapped_eltype(T)), dims)
507508
end
508509

510+
function Base.show(io::IOty, X::AnyTracedRArray) where {IOty<:Union{IO,IOContext}}
511+
print(io, Core.Typeof(X), "(")
512+
if Adapt.parent(X) !== X
513+
Base.show(io, Adapt.parent(X))
514+
end
515+
return print(io, ")")
516+
end
517+
509518
function Base.show(io::IOty, X::TracedRArray{T,N}) where {T,N,IOty<:Union{IO,IOContext}}
510519
return print(io, "TracedRArray{", T, ",", N, "N}(", X.paths, ", size=", size(X), ")")
511520
# TODO this line segfaults if MLIR IR has not correctly been generated

src/TracedUtils.jl

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,8 +228,10 @@ mutable struct CompiledMlirFnResult{F,TR,Re,Rt,LA,LR,PA,CR,M,MA,RS,GD,DA}
228228
seen_args::OrderedIdDict
229229
ret::Rt
230230
linear_args::Vector{LA}
231+
skipped_args::Vector{LA}
231232
in_tys::Vector{MLIR.IR.Type}
232233
linear_results::Vector{LR}
234+
skipped_results::Vector{LR}
233235
num_partitions::Int
234236
num_replicas::Int
235237
is_sharded::Bool
@@ -333,7 +335,7 @@ function make_mlir_fn(
333335
end
334336
end
335337

336-
(func2, traced_result, ret, linear_args, in_tys, linear_results, num_partitions, is_sharded, unique_meshes, mutated_args, global_device_ids) = finalize_mlir_fn(
338+
(func2, traced_result, ret, linear_args, in_tys, linear_results, skipped_results, num_partitions, is_sharded, unique_meshes, mutated_args, global_device_ids) = finalize_mlir_fn(
337339
result,
338340
traced_args,
339341
linear_args,
@@ -373,8 +375,10 @@ function make_mlir_fn(
373375
seen_args,
374376
ret,
375377
linear_args,
378+
skipped_args,
376379
in_tys,
377380
linear_results,
381+
skipped_results,
378382
num_partitions,
379383
num_replicas,
380384
is_sharded,
@@ -628,9 +632,49 @@ function finalize_mlir_fn(
628632
end
629633

630634
linear_results = Reactant.TracedType[]
635+
skipped_results = Reactant.TracedType[]
631636
for (k, v) in seen_results
632637
v isa Reactant.TracedType || continue
633638
if any(Base.Fix1(===, k), skipped_args)
639+
push!(skipped_results, v)
640+
641+
_, argpath = get_argidx(v, argprefix)
642+
643+
@assert has_idx(v, argprefix)
644+
645+
newpaths = Tuple[]
646+
for path in v.paths
647+
if length(path) == 0
648+
continue
649+
end
650+
if path[1] == argprefix
651+
continue
652+
end
653+
if path[1] == resargprefix
654+
original_arg = args[path[2]]
655+
for p in path[3:end]
656+
original_arg = Reactant.Compiler.traced_getfield(original_arg, p)
657+
end
658+
if !(
659+
original_arg isa Union{
660+
Reactant.ConcreteRNumber,
661+
Reactant.ConcreteRArray,
662+
Reactant.TracedType,
663+
}
664+
)
665+
continue
666+
end
667+
push!(newpaths, path)
668+
end
669+
if path[1] == resprefix
670+
push!(newpaths, path)
671+
end
672+
end
673+
674+
if length(newpaths) != 0
675+
push!(linear_results, Reactant.repath(v, (newpaths...,)))
676+
end
677+
634678
continue
635679
end
636680
if args_in_result != :all
@@ -924,6 +968,7 @@ function finalize_mlir_fn(
924968
linear_args,
925969
in_tys,
926970
linear_results,
971+
skipped_results,
927972
num_partitions,
928973
is_sharded,
929974
unique_meshes,

src/Tracing.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -895,7 +895,7 @@ const traced_type_cache = Dict{Tuple{TraceMode,Type,Any},Dict{Type,Type}}()
895895
Base.@assume_effects :total @inline function traced_type(
896896
T::Type, ::Val{mode}, track_numbers::Type, sharding, runtime
897897
) where {mode}
898-
if mode == TracedSetPath || mode == TracedTrack
898+
if mode == TracedSetPath || mode == TracedTrack || mode == TracedToTypes
899899
return T
900900
end
901901

src/Types.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,10 @@ mutable struct TracedRNumber{T} <: RNumber{T}
5252
end
5353
end
5454

55+
function repath(x::TracedRNumber{T}, paths) where {T}
56+
return TracedRNumber{T}(paths, x.mlir_data)
57+
end
58+
5559
@leaf TracedRNumber
5660

5761
## TracedRArray
@@ -71,6 +75,10 @@ mutable struct TracedRArray{T,N} <: RArray{TracedRNumber{T},N}
7175
end
7276
end
7377

78+
function repath(x::TracedRArray{T,N}, paths) where {T,N}
79+
return TracedRArray{T,N}(paths, x.mlir_data, x.shape)
80+
end
81+
7482
@leaf TracedRArray
7583
Adapt.parent_type(::Type{TracedRArray{T,N}}) where {T,N} = TracedRArray{T,N}
7684

test/constructor.jl

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
using Reactant, Test, Adapt
2+
3+
struct MyGrid{FT,AT} <: AbstractVector{FT}
4+
data::AT
5+
radius::FT
6+
end
7+
8+
Adapt.parent(x::MyGrid) = x.data
9+
10+
Base.getindex(x::MyGrid, args...) = Base.getindex(x.data, args...)
11+
12+
Base.size(x::MyGrid) = Base.size(x.data)
13+
14+
function Base.show(io::IOty, X::MyGrid) where {IOty<:Union{IO,IOContext}}
15+
print(io, Core.Typeof(X), "(")
16+
if Adapt.parent(X) !== X
17+
Base.show(io, Adapt.parent(X))
18+
end
19+
return print(io, ")")
20+
end
21+
22+
Base.@nospecializeinfer function Reactant.traced_type_inner(
23+
@nospecialize(OA::Type{MyGrid{FT,AT}}),
24+
seen,
25+
mode::Reactant.TraceMode,
26+
@nospecialize(track_numbers::Type),
27+
@nospecialize(sharding),
28+
@nospecialize(runtime)
29+
) where {FT,AT}
30+
FT2 = Reactant.traced_type_inner(FT, seen, mode, track_numbers, sharding, runtime)
31+
AT2 = Reactant.traced_type_inner(AT, seen, mode, track_numbers, sharding, runtime)
32+
33+
for NF in (AT2,)
34+
FT2 = Reactant.promote_traced_type(FT2, eltype(NF))
35+
end
36+
37+
res = MyGrid{FT2,AT2}
38+
return res
39+
end
40+
41+
@inline Reactant.make_tracer(seen, @nospecialize(prev::MyGrid), args...; kwargs...) =
42+
Reactant.make_tracer_via_immutable_constructor(seen, prev, args...; kwargs...)
43+
44+
struct MyGrid2{FT,AT} <: AbstractVector{FT}
45+
data::AT
46+
radius::FT
47+
bar::FT
48+
end
49+
50+
Adapt.parent(x::MyGrid2) = x.data
51+
52+
Base.getindex(x::MyGrid2, args...) = Base.getindex(x.data, args...)
53+
54+
Base.size(x::MyGrid2) = Base.size(x.data)
55+
56+
function Base.show(io::IOty, X::MyGrid2) where {IOty<:Union{IO,IOContext}}
57+
print(io, Core.Typeof(X), "(")
58+
if Adapt.parent(X) !== X
59+
Base.show(io, Adapt.parent(X))
60+
end
61+
return print(io, ")")
62+
end
63+
64+
Base.@nospecializeinfer function Reactant.traced_type_inner(
65+
@nospecialize(OA::Type{MyGrid2{FT,AT}}),
66+
seen,
67+
mode::Reactant.TraceMode,
68+
@nospecialize(track_numbers::Type),
69+
@nospecialize(sharding),
70+
@nospecialize(runtime)
71+
) where {FT,AT}
72+
FT2 = Reactant.traced_type_inner(FT, seen, mode, track_numbers, sharding, runtime)
73+
AT2 = Reactant.traced_type_inner(AT, seen, mode, track_numbers, sharding, runtime)
74+
75+
for NF in (AT2,)
76+
FT2 = Reactant.promote_traced_type(FT2, eltype(NF))
77+
end
78+
79+
res = MyGrid2{FT2,AT2}
80+
return res
81+
end
82+
83+
@inline Reactant.make_tracer(seen, @nospecialize(prev::MyGrid2), args...; kwargs...) =
84+
Reactant.make_tracer_via_immutable_constructor(seen, prev, args...; kwargs...)
85+
86+
function update!(g)
87+
@allowscalar g.data[1] = g.radius
88+
return nothing
89+
end
90+
91+
function selfreturn(g)
92+
return g
93+
end
94+
95+
function call_update!(g)
96+
@trace update!(g)
97+
end
98+
99+
function call_selfreturn(g)
100+
@trace selfreturn(g)
101+
end
102+
103+
@testset "Custom construction" begin
104+
g = MyGrid([3.14, 1.59], 2.7)
105+
rg = Reactant.to_rarray(g)
106+
107+
@jit update!(rg)
108+
@test convert(Array, rg.data) == [2.7, 1.59]
109+
110+
rg = Reactant.to_rarray(g)
111+
res = @jit selfreturn(rg)
112+
@test convert(Array, res.data) == [3.14, 1.59]
113+
@test res.radius == 2.7
114+
@show typeof(res)
115+
@test typeof(res.radius) <: ConcreteRNumber
116+
117+
rg = Reactant.to_rarray(g)
118+
119+
@jit call_update!(rg)
120+
@test convert(Array, rg.data) == [2.7, 1.59]
121+
122+
rg = Reactant.to_rarray(g)
123+
res = @jit call_selfreturn(rg)
124+
@test convert(Array, res.data) == [3.14, 1.59]
125+
@test res.radius == 2.7
126+
@show typeof(res)
127+
@test typeof(res.radius) <: ConcreteRNumber
128+
end
129+
130+
@testset "Custom construction2 " begin
131+
g = Ref(MyGrid([3.14, 1.59], 2.7))
132+
g = (g, g)
133+
134+
rg = Reactant.to_rarray(g)
135+
res = @jit selfreturn(rg)
136+
@test convert(Array, res[1][].data) == [3.14, 1.59]
137+
@test convert(Array, res[2][].data) == [3.14, 1.59]
138+
@test res[1][].data == res[2][].data
139+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all"))
1515
@safetestset "Layout" include("layout.jl")
1616
@safetestset "Tracing" include("tracing.jl")
1717
@safetestset "Basic" include("basic.jl")
18+
@safetestset "Constructor" include("constructor.jl")
1819
@safetestset "Autodiff" include("autodiff.jl")
1920
@safetestset "Complex" include("complex.jl")
2021
@safetestset "Broadcast" include("bcast.jl")

0 commit comments

Comments
 (0)