Skip to content

Commit d1bb8de

Browse files
avik-palAvik Pal
andauthored
fix: expose device and client options in to_rarray (#1432)
* fix: expose device and client options in to_rarray * fix: more places --------- Co-authored-by: Avik Pal <[email protected]>
1 parent bd9f093 commit d1bb8de

File tree

4 files changed

+126
-37
lines changed

4 files changed

+126
-37
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Reactant"
22
uuid = "3c362404-f566-11ee-1572-e11a4b42c853"
33
authors = ["William Moses <[email protected]>", "Valentin Churavy <[email protected]>", "Sergio Sánchez Ramírez <[email protected]>", "Paul Berg <[email protected]>", "Avik Pal <[email protected]>", "Mosè Giordano <[email protected]>"]
4-
version = "0.2.140"
4+
version = "0.2.141"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/Compiler.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3340,7 +3340,7 @@ function __resolve_device_and_client(client, seen_args, linear_args, is_sharded)
33403340
if !isempty(devices_list)
33413341
if !allequal(devices_list)
33423342
msg = "Expected all arguments to be on the same device, got:\n"
3343-
for (i, device) in enumerate(devices_list)
3343+
for (i, device) in enumerate(unique(devices_list))
33443344
msg *= " Device $(i): $(string(device))\n"
33453345
end
33463346
throw(ArgumentError(msg))

src/Tracing.jl

Lines changed: 115 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1154,12 +1154,14 @@ Base.@nospecializeinfer function make_tracer(
11541154
@nospecialize(path),
11551155
mode;
11561156
@nospecialize(sharding = Sharding.NoSharding()),
1157+
@nospecialize(device = nothing),
1158+
@nospecialize(client = nothing),
11571159
kwargs...,
11581160
) where {T,N}
11591161
if mode == TracedToTypes
11601162
throw("Cannot have ConcretePJRTArray as function call argument.")
11611163
end
1162-
mode == ArrayToConcrete && return ConcretePJRTArray(prev; sharding)
1164+
mode == ArrayToConcrete && return ConcretePJRTArray(prev; sharding, device, client)
11631165
mode != ConcreteToTraced && throw("Cannot trace concrete")
11641166
haskey(seen, prev) && return seen[prev]::TracedRArray{T,N}
11651167
res = TracedRArray{T,N}((path,), nothing, size(prev))
@@ -1173,12 +1175,14 @@ Base.@nospecializeinfer function make_tracer(
11731175
@nospecialize(path),
11741176
mode;
11751177
@nospecialize(sharding = Sharding.NoSharding()),
1178+
@nospecialize(device = nothing),
1179+
@nospecialize(client = nothing),
11761180
kwargs...,
11771181
) where {T,N}
11781182
if mode == TracedToTypes
11791183
throw("Cannot have ConcreteIFRTArray as function call argument.")
11801184
end
1181-
mode == ArrayToConcrete && return ConcreteIFRTArray(prev; sharding)
1185+
mode == ArrayToConcrete && return ConcreteIFRTArray(prev; sharding, device, client)
11821186
mode != ConcreteToTraced && throw("Cannot trace concrete")
11831187
haskey(seen, prev) && return seen[prev]::TracedRArray{T,N}
11841188
res = TracedRArray{T,N}((path,), nothing, size(prev))
@@ -1192,12 +1196,14 @@ Base.@nospecializeinfer function make_tracer(
11921196
@nospecialize(path),
11931197
mode;
11941198
@nospecialize(sharding = Sharding.NoSharding()),
1199+
@nospecialize(device = nothing),
1200+
@nospecialize(client = nothing),
11951201
kwargs...,
11961202
) where {T}
11971203
if mode == TracedToTypes
11981204
throw("Cannot have ConcretePJRTNumber as function call argument.")
11991205
end
1200-
mode == ArrayToConcrete && return ConcretePJRTNumber(prev; sharding)
1206+
mode == ArrayToConcrete && return ConcretePJRTNumber(prev; sharding, device, client)
12011207
mode != ConcreteToTraced && throw("Cannot trace existing trace type")
12021208
haskey(seen, prev) && return seen[prev]::TracedRNumber{T}
12031209
res = TracedRNumber{T}((path,), nothing)
@@ -1211,12 +1217,14 @@ Base.@nospecializeinfer function make_tracer(
12111217
@nospecialize(path),
12121218
mode;
12131219
@nospecialize(sharding = Sharding.NoSharding()),
1220+
@nospecialize(device = nothing),
1221+
@nospecialize(client = nothing),
12141222
kwargs...,
12151223
) where {T}
12161224
if mode == TracedToTypes
12171225
throw("Cannot have ConcreteIFRTNumber as function call argument.")
12181226
end
1219-
mode == ArrayToConcrete && return ConcreteIFRTNumber(prev; sharding)
1227+
mode == ArrayToConcrete && return ConcreteIFRTNumber(prev; sharding, device, client)
12201228
mode != ConcreteToTraced && throw("Cannot trace existing trace type")
12211229
haskey(seen, prev) && return seen[prev]::TracedRNumber{T}
12221230
res = TracedRNumber{T}((path,), nothing)
@@ -1425,6 +1433,8 @@ Base.@nospecializeinfer function make_tracer(
14251433
@nospecialize(track_numbers::Type = Union{}),
14261434
@nospecialize(sharding = Sharding.NoSharding()),
14271435
@nospecialize(runtime = nothing),
1436+
@nospecialize(device = nothing),
1437+
@nospecialize(client = nothing),
14281438
kwargs...,
14291439
)
14301440
if mode == TracedToTypes
@@ -1434,8 +1444,10 @@ Base.@nospecializeinfer function make_tracer(
14341444
RT = Core.Typeof(prev)
14351445
if RT <: track_numbers && mode != TracedSetPath && mode != TracedTrack
14361446
if mode == ArrayToConcrete
1437-
runtime isa Val{:PJRT} && return ConcretePJRTNumber(prev; sharding)
1438-
runtime isa Val{:IFRT} && return ConcreteIFRTNumber(prev; sharding)
1447+
runtime isa Val{:PJRT} &&
1448+
return ConcretePJRTNumber(prev; sharding, device, client)
1449+
runtime isa Val{:IFRT} &&
1450+
return ConcreteIFRTNumber(prev; sharding, device, client)
14391451
error("Unsupported runtime $runtime")
14401452
else
14411453
if mode == TracedTrack || mode == NoStopTracedTrack
@@ -1511,6 +1523,8 @@ Base.@nospecializeinfer function make_tracer(
15111523
@nospecialize(track_numbers::Type = Union{}),
15121524
@nospecialize(sharding = Sharding.NoSharding()),
15131525
@nospecialize(runtime = nothing),
1526+
@nospecialize(device = nothing),
1527+
@nospecialize(client = nothing),
15141528
kwargs...,
15151529
)
15161530
RT = Core.Typeof(prev)
@@ -1527,9 +1541,9 @@ Base.@nospecializeinfer function make_tracer(
15271541
if eltype(RT) <: ReactantPrimitive
15281542
if mode == ArrayToConcrete
15291543
runtime isa Val{:PJRT} &&
1530-
(return seen[prev] = ConcretePJRTArray(prev; sharding))
1544+
(return seen[prev] = ConcretePJRTArray(prev; sharding, device, client))
15311545
runtime isa Val{:IFRT} &&
1532-
(return seen[prev] = ConcreteIFRTArray(prev; sharding))
1546+
(return seen[prev] = ConcreteIFRTArray(prev; sharding, device, client))
15331547
error("Unsupported runtime $runtime")
15341548
elseif mode == TracedToTypes
15351549
# Original array can get mutated so we store a copy:
@@ -1543,7 +1557,16 @@ Base.@nospecializeinfer function make_tracer(
15431557
if isassigned(prev, I)
15441558
pv = prev[I]
15451559
make_tracer(
1546-
seen, pv, path, mode; track_numbers, sharding, runtime, kwargs...
1560+
seen,
1561+
pv,
1562+
path,
1563+
mode;
1564+
track_numbers,
1565+
sharding,
1566+
runtime,
1567+
device,
1568+
client,
1569+
kwargs...,
15471570
)
15481571
end
15491572
end
@@ -1564,6 +1587,8 @@ Base.@nospecializeinfer function make_tracer(
15641587
track_numbers,
15651588
sharding=Base.getproperty(sharding, I),
15661589
runtime,
1590+
device,
1591+
client,
15671592
kwargs...,
15681593
)
15691594
if pv !== nv
@@ -1587,6 +1612,8 @@ Base.@nospecializeinfer function make_tracer(
15871612
@nospecialize(track_numbers::Type = Union{}),
15881613
@nospecialize(sharding = Sharding.NoSharding()),
15891614
@nospecialize(runtime = nothing),
1615+
@nospecialize(device = nothing),
1616+
@nospecialize(client = nothing),
15901617
kwargs...,
15911618
) where {Key,Value}
15921619
RT = Core.Typeof(prev)
@@ -1601,9 +1628,9 @@ Base.@nospecializeinfer function make_tracer(
16011628
if eltype(RT) <: ReactantPrimitive
16021629
if mode == ArrayToConcrete
16031630
runtime isa Val{:PJRT} &&
1604-
(return seen[prev] = ConcretePJRTArray(prev; sharding))
1631+
(return seen[prev] = ConcretePJRTArray(prev; sharding, device, client))
16051632
runtime isa Val{:IFRT} &&
1606-
(return seen[prev] = ConcreteIFRTArray(prev; sharding))
1633+
(return seen[prev] = ConcreteIFRTArray(prev; sharding, device, client))
16071634
error("Unsupported runtime $runtime")
16081635
elseif mode == TracedToTypes
16091636
# Original array can get mutated so we store a copy:
@@ -1614,8 +1641,30 @@ Base.@nospecializeinfer function make_tracer(
16141641
elseif mode == TracedToTypes
16151642
push!(path, RT)
16161643
for (k, v) in prev
1617-
make_tracer(seen, k, path, mode; track_numbers, sharding, runtime, kwargs...)
1618-
make_tracer(seen, v, path, mode; track_numbers, sharding, runtime, kwargs...)
1644+
make_tracer(
1645+
seen,
1646+
k,
1647+
path,
1648+
mode;
1649+
track_numbers,
1650+
sharding,
1651+
runtime,
1652+
device,
1653+
client,
1654+
kwargs...,
1655+
)
1656+
make_tracer(
1657+
seen,
1658+
v,
1659+
path,
1660+
mode;
1661+
track_numbers,
1662+
sharding,
1663+
runtime,
1664+
device,
1665+
client,
1666+
kwargs...,
1667+
)
16191668
end
16201669
return nothing
16211670
end
@@ -1780,20 +1829,32 @@ end
17801829
runtime::Union{Nothing,Val{:IFRT},Val{:PJRT}}=nothing,
17811830
track_numbers::Union{Bool,Type}=false,
17821831
sharding=Sharding.Sharding.NoSharding(),
1832+
device=nothing,
1833+
client=nothing,
17831834
)
17841835
runtime === nothing && (runtime = XLA.runtime())
17851836
track_numbers isa Bool && (track_numbers = track_numbers ? Number : Union{})
1786-
return to_rarray_internal(x, track_numbers, sharding, runtime)
1837+
return to_rarray_internal(x, track_numbers, sharding, runtime, device, client)
17871838
end
17881839

17891840
@inline function to_rarray_internal(
17901841
@nospecialize(x),
17911842
@nospecialize(track_numbers::Type),
17921843
@nospecialize(sharding),
1793-
@nospecialize(runtime)
1844+
@nospecialize(runtime),
1845+
@nospecialize(device),
1846+
@nospecialize(client)
17941847
)
17951848
return make_tracer(
1796-
OrderedIdDict(), x, (), ArrayToConcrete; track_numbers, sharding, runtime
1849+
OrderedIdDict(),
1850+
x,
1851+
(),
1852+
ArrayToConcrete;
1853+
track_numbers,
1854+
sharding,
1855+
runtime,
1856+
device,
1857+
client,
17971858
)
17981859
end
17991860

@@ -1802,7 +1863,9 @@ function to_rarray_internal(
18021863
@nospecialize(::TracedRArray),
18031864
@nospecialize(track_numbers::Type),
18041865
@nospecialize(sharding),
1805-
@nospecialize(runtime)
1866+
@nospecialize(runtime),
1867+
@nospecialize(device),
1868+
@nospecialize(client)
18061869
)
18071870
return error("Cannot convert TracedRArray to ConcreteArray")
18081871
end
@@ -1812,27 +1875,33 @@ end
18121875
@nospecialize(track_numbers::Type),
18131876
@nospecialize(sharding),
18141877
::Val{:PJRT},
1878+
@nospecialize(device),
1879+
@nospecialize(client)
18151880
)
1816-
return ConcretePJRTArray(x; sharding)
1881+
return ConcretePJRTArray(x; sharding, device, client)
18171882
end
18181883

18191884
@inline function to_rarray_internal(
18201885
@nospecialize(x::ConcreteIFRTArray),
18211886
@nospecialize(track_numbers::Type),
18221887
@nospecialize(sharding),
18231888
::Val{:IFRT},
1889+
@nospecialize(device),
1890+
@nospecialize(client)
18241891
)
1825-
return ConcreteIFRTArray(x; sharding)
1892+
return ConcreteIFRTArray(x; sharding, device, client)
18261893
end
18271894

18281895
@inline function to_rarray_internal(
18291896
@nospecialize(x::Array{<:ReactantPrimitive}),
18301897
@nospecialize(track_numbers::Type),
18311898
@nospecialize(sharding),
1832-
@nospecialize(runtime)
1899+
@nospecialize(runtime),
1900+
@nospecialize(device),
1901+
@nospecialize(client)
18331902
)
1834-
runtime isa Val{:PJRT} && return ConcretePJRTArray(x; sharding)
1835-
runtime isa Val{:IFRT} && return ConcreteIFRTArray(x; sharding)
1903+
runtime isa Val{:PJRT} && return ConcretePJRTArray(x; sharding, device, client)
1904+
runtime isa Val{:IFRT} && return ConcreteIFRTArray(x; sharding, device, client)
18361905
return error("Unsupported runtime $runtime")
18371906
end
18381907

@@ -1841,45 +1910,55 @@ end
18411910
@nospecialize(track_numbers::Type),
18421911
@nospecialize(sharding),
18431912
runtime,
1913+
@nospecialize(device),
1914+
@nospecialize(client)
18441915
) where {T<:Number}
18451916
if reactant_primitive(T) !== nothing
18461917
if runtime isa Val{:PJRT}
1847-
return ConcretePJRTArray(to_reactant_primitive.(x); sharding)
1918+
return ConcretePJRTArray(to_reactant_primitive.(x); sharding, device, client)
18481919
elseif runtime isa Val{:IFRT}
1849-
return ConcreteIFRTArray(to_reactant_primitive.(x); sharding)
1920+
return ConcreteIFRTArray(to_reactant_primitive.(x); sharding, device, client)
18501921
end
18511922
error("Unsupported runtime $runtime")
18521923
end
1853-
return @invoke to_rarray_internal(x::Any, track_numbers::Type, sharding, runtime)
1924+
return @invoke to_rarray_internal(
1925+
x::Any, track_numbers::Type, sharding, runtime, device, client
1926+
)
18541927
end
18551928

18561929
@inline function to_rarray_internal(
18571930
@nospecialize(x::ConcretePJRTNumber),
18581931
@nospecialize(track_numbers::Type),
18591932
@nospecialize(sharding),
18601933
::Val{:PJRT},
1934+
@nospecialize(device),
1935+
@nospecialize(client)
18611936
)
1862-
return ConcretePJRTNumber(x; sharding)
1937+
return ConcretePJRTNumber(x; sharding, device, client)
18631938
end
18641939

18651940
@inline function to_rarray_internal(
18661941
@nospecialize(x::ConcreteIFRTNumber),
18671942
@nospecialize(track_numbers::Type),
18681943
@nospecialize(sharding),
18691944
::Val{:IFRT},
1945+
@nospecialize(device),
1946+
@nospecialize(client)
18701947
)
1871-
return ConcreteIFRTNumber(x; sharding)
1948+
return ConcreteIFRTNumber(x; sharding, device, client)
18721949
end
18731950

18741951
@inline function to_rarray_internal(
18751952
@nospecialize(x::ReactantPrimitive),
18761953
@nospecialize(track_numbers::Type),
18771954
@nospecialize(sharding),
18781955
runtime,
1956+
@nospecialize(device),
1957+
@nospecialize(client)
18791958
)
18801959
if typeof(x) <: track_numbers
1881-
runtime isa Val{:PJRT} && return ConcretePJRTNumber(x; sharding)
1882-
runtime isa Val{:IFRT} && return ConcreteIFRTNumber(x; sharding)
1960+
runtime isa Val{:PJRT} && return ConcretePJRTNumber(x; sharding, device, client)
1961+
runtime isa Val{:IFRT} && return ConcreteIFRTNumber(x; sharding, device, client)
18831962
error("Unsupported runtime $runtime")
18841963
end
18851964
return x
@@ -1890,15 +1969,19 @@ end
18901969
@nospecialize(track_numbers::Type),
18911970
@nospecialize(sharding),
18921971
runtime,
1972+
@nospecialize(device),
1973+
@nospecialize(client)
18931974
)
18941975
if reactant_primitive(typeof(x)) !== nothing
18951976
runtime isa Val{:PJRT} &&
1896-
return ConcretePJRTArray(to_reactant_primitive(x); sharding)
1977+
return ConcretePJRTArray(to_reactant_primitive(x); sharding, device, client)
18971978
runtime isa Val{:IFRT} &&
1898-
return ConcreteIFRTArray(to_reactant_primitive(x); sharding)
1979+
return ConcreteIFRTArray(to_reactant_primitive(x); sharding, device, client)
18991980
error("Unsupported runtime $runtime")
19001981
end
1901-
return @invoke to_rarray_internal(x::Any, track_numbers::Type, sharding, runtime)
1982+
return @invoke to_rarray_internal(
1983+
x::Any, track_numbers::Type, sharding, runtime, device, client
1984+
)
19021985
end
19031986

19041987
function Reactant.traced_type_inner(

0 commit comments

Comments
 (0)