@@ -1154,12 +1154,14 @@ Base.@nospecializeinfer function make_tracer(
1154
1154
@nospecialize (path),
1155
1155
mode;
1156
1156
@nospecialize (sharding = Sharding. NoSharding ()),
1157
+ @nospecialize (device = nothing ),
1158
+ @nospecialize (client = nothing ),
1157
1159
kwargs... ,
1158
1160
) where {T,N}
1159
1161
if mode == TracedToTypes
1160
1162
throw (" Cannot have ConcretePJRTArray as function call argument." )
1161
1163
end
1162
- mode == ArrayToConcrete && return ConcretePJRTArray (prev; sharding)
1164
+ mode == ArrayToConcrete && return ConcretePJRTArray (prev; sharding, device, client )
1163
1165
mode != ConcreteToTraced && throw (" Cannot trace concrete" )
1164
1166
haskey (seen, prev) && return seen[prev]:: TracedRArray{T,N}
1165
1167
res = TracedRArray {T,N} ((path,), nothing , size (prev))
@@ -1173,12 +1175,14 @@ Base.@nospecializeinfer function make_tracer(
1173
1175
@nospecialize (path),
1174
1176
mode;
1175
1177
@nospecialize (sharding = Sharding. NoSharding ()),
1178
+ @nospecialize (device = nothing ),
1179
+ @nospecialize (client = nothing ),
1176
1180
kwargs... ,
1177
1181
) where {T,N}
1178
1182
if mode == TracedToTypes
1179
1183
throw (" Cannot have ConcreteIFRTArray as function call argument." )
1180
1184
end
1181
- mode == ArrayToConcrete && return ConcreteIFRTArray (prev; sharding)
1185
+ mode == ArrayToConcrete && return ConcreteIFRTArray (prev; sharding, device, client )
1182
1186
mode != ConcreteToTraced && throw (" Cannot trace concrete" )
1183
1187
haskey (seen, prev) && return seen[prev]:: TracedRArray{T,N}
1184
1188
res = TracedRArray {T,N} ((path,), nothing , size (prev))
@@ -1192,12 +1196,14 @@ Base.@nospecializeinfer function make_tracer(
1192
1196
@nospecialize (path),
1193
1197
mode;
1194
1198
@nospecialize (sharding = Sharding. NoSharding ()),
1199
+ @nospecialize (device = nothing ),
1200
+ @nospecialize (client = nothing ),
1195
1201
kwargs... ,
1196
1202
) where {T}
1197
1203
if mode == TracedToTypes
1198
1204
throw (" Cannot have ConcretePJRTNumber as function call argument." )
1199
1205
end
1200
- mode == ArrayToConcrete && return ConcretePJRTNumber (prev; sharding)
1206
+ mode == ArrayToConcrete && return ConcretePJRTNumber (prev; sharding, device, client )
1201
1207
mode != ConcreteToTraced && throw (" Cannot trace existing trace type" )
1202
1208
haskey (seen, prev) && return seen[prev]:: TracedRNumber{T}
1203
1209
res = TracedRNumber {T} ((path,), nothing )
@@ -1211,12 +1217,14 @@ Base.@nospecializeinfer function make_tracer(
1211
1217
@nospecialize (path),
1212
1218
mode;
1213
1219
@nospecialize (sharding = Sharding. NoSharding ()),
1220
+ @nospecialize (device = nothing ),
1221
+ @nospecialize (client = nothing ),
1214
1222
kwargs... ,
1215
1223
) where {T}
1216
1224
if mode == TracedToTypes
1217
1225
throw (" Cannot have ConcreteIFRTNumber as function call argument." )
1218
1226
end
1219
- mode == ArrayToConcrete && return ConcreteIFRTNumber (prev; sharding)
1227
+ mode == ArrayToConcrete && return ConcreteIFRTNumber (prev; sharding, device, client )
1220
1228
mode != ConcreteToTraced && throw (" Cannot trace existing trace type" )
1221
1229
haskey (seen, prev) && return seen[prev]:: TracedRNumber{T}
1222
1230
res = TracedRNumber {T} ((path,), nothing )
@@ -1425,6 +1433,8 @@ Base.@nospecializeinfer function make_tracer(
1425
1433
@nospecialize (track_numbers:: Type = Union{}),
1426
1434
@nospecialize (sharding = Sharding. NoSharding ()),
1427
1435
@nospecialize (runtime = nothing ),
1436
+ @nospecialize (device = nothing ),
1437
+ @nospecialize (client = nothing ),
1428
1438
kwargs... ,
1429
1439
)
1430
1440
if mode == TracedToTypes
@@ -1434,8 +1444,10 @@ Base.@nospecializeinfer function make_tracer(
1434
1444
RT = Core. Typeof (prev)
1435
1445
if RT <: track_numbers && mode != TracedSetPath && mode != TracedTrack
1436
1446
if mode == ArrayToConcrete
1437
- runtime isa Val{:PJRT } && return ConcretePJRTNumber (prev; sharding)
1438
- runtime isa Val{:IFRT } && return ConcreteIFRTNumber (prev; sharding)
1447
+ runtime isa Val{:PJRT } &&
1448
+ return ConcretePJRTNumber (prev; sharding, device, client)
1449
+ runtime isa Val{:IFRT } &&
1450
+ return ConcreteIFRTNumber (prev; sharding, device, client)
1439
1451
error (" Unsupported runtime $runtime " )
1440
1452
else
1441
1453
if mode == TracedTrack || mode == NoStopTracedTrack
@@ -1511,6 +1523,8 @@ Base.@nospecializeinfer function make_tracer(
1511
1523
@nospecialize (track_numbers:: Type = Union{}),
1512
1524
@nospecialize (sharding = Sharding. NoSharding ()),
1513
1525
@nospecialize (runtime = nothing ),
1526
+ @nospecialize (device = nothing ),
1527
+ @nospecialize (client = nothing ),
1514
1528
kwargs... ,
1515
1529
)
1516
1530
RT = Core. Typeof (prev)
@@ -1527,9 +1541,9 @@ Base.@nospecializeinfer function make_tracer(
1527
1541
if eltype (RT) <: ReactantPrimitive
1528
1542
if mode == ArrayToConcrete
1529
1543
runtime isa Val{:PJRT } &&
1530
- (return seen[prev] = ConcretePJRTArray (prev; sharding))
1544
+ (return seen[prev] = ConcretePJRTArray (prev; sharding, device, client ))
1531
1545
runtime isa Val{:IFRT } &&
1532
- (return seen[prev] = ConcreteIFRTArray (prev; sharding))
1546
+ (return seen[prev] = ConcreteIFRTArray (prev; sharding, device, client ))
1533
1547
error (" Unsupported runtime $runtime " )
1534
1548
elseif mode == TracedToTypes
1535
1549
# Original array can get mutated so we store a copy:
@@ -1543,7 +1557,16 @@ Base.@nospecializeinfer function make_tracer(
1543
1557
if isassigned (prev, I)
1544
1558
pv = prev[I]
1545
1559
make_tracer (
1546
- seen, pv, path, mode; track_numbers, sharding, runtime, kwargs...
1560
+ seen,
1561
+ pv,
1562
+ path,
1563
+ mode;
1564
+ track_numbers,
1565
+ sharding,
1566
+ runtime,
1567
+ device,
1568
+ client,
1569
+ kwargs... ,
1547
1570
)
1548
1571
end
1549
1572
end
@@ -1564,6 +1587,8 @@ Base.@nospecializeinfer function make_tracer(
1564
1587
track_numbers,
1565
1588
sharding= Base. getproperty (sharding, I),
1566
1589
runtime,
1590
+ device,
1591
+ client,
1567
1592
kwargs... ,
1568
1593
)
1569
1594
if pv != = nv
@@ -1587,6 +1612,8 @@ Base.@nospecializeinfer function make_tracer(
1587
1612
@nospecialize (track_numbers:: Type = Union{}),
1588
1613
@nospecialize (sharding = Sharding. NoSharding ()),
1589
1614
@nospecialize (runtime = nothing ),
1615
+ @nospecialize (device = nothing ),
1616
+ @nospecialize (client = nothing ),
1590
1617
kwargs... ,
1591
1618
) where {Key,Value}
1592
1619
RT = Core. Typeof (prev)
@@ -1601,9 +1628,9 @@ Base.@nospecializeinfer function make_tracer(
1601
1628
if eltype (RT) <: ReactantPrimitive
1602
1629
if mode == ArrayToConcrete
1603
1630
runtime isa Val{:PJRT } &&
1604
- (return seen[prev] = ConcretePJRTArray (prev; sharding))
1631
+ (return seen[prev] = ConcretePJRTArray (prev; sharding, device, client ))
1605
1632
runtime isa Val{:IFRT } &&
1606
- (return seen[prev] = ConcreteIFRTArray (prev; sharding))
1633
+ (return seen[prev] = ConcreteIFRTArray (prev; sharding, device, client ))
1607
1634
error (" Unsupported runtime $runtime " )
1608
1635
elseif mode == TracedToTypes
1609
1636
# Original array can get mutated so we store a copy:
@@ -1614,8 +1641,30 @@ Base.@nospecializeinfer function make_tracer(
1614
1641
elseif mode == TracedToTypes
1615
1642
push! (path, RT)
1616
1643
for (k, v) in prev
1617
- make_tracer (seen, k, path, mode; track_numbers, sharding, runtime, kwargs... )
1618
- make_tracer (seen, v, path, mode; track_numbers, sharding, runtime, kwargs... )
1644
+ make_tracer (
1645
+ seen,
1646
+ k,
1647
+ path,
1648
+ mode;
1649
+ track_numbers,
1650
+ sharding,
1651
+ runtime,
1652
+ device,
1653
+ client,
1654
+ kwargs... ,
1655
+ )
1656
+ make_tracer (
1657
+ seen,
1658
+ v,
1659
+ path,
1660
+ mode;
1661
+ track_numbers,
1662
+ sharding,
1663
+ runtime,
1664
+ device,
1665
+ client,
1666
+ kwargs... ,
1667
+ )
1619
1668
end
1620
1669
return nothing
1621
1670
end
@@ -1780,20 +1829,32 @@ end
1780
1829
runtime:: Union{Nothing,Val{:IFRT},Val{:PJRT}} = nothing ,
1781
1830
track_numbers:: Union{Bool,Type} = false ,
1782
1831
sharding= Sharding. Sharding. NoSharding (),
1832
+ device= nothing ,
1833
+ client= nothing ,
1783
1834
)
1784
1835
runtime === nothing && (runtime = XLA. runtime ())
1785
1836
track_numbers isa Bool && (track_numbers = track_numbers ? Number : Union{})
1786
- return to_rarray_internal (x, track_numbers, sharding, runtime)
1837
+ return to_rarray_internal (x, track_numbers, sharding, runtime, device, client )
1787
1838
end
1788
1839
1789
1840
@inline function to_rarray_internal (
1790
1841
@nospecialize (x),
1791
1842
@nospecialize (track_numbers:: Type ),
1792
1843
@nospecialize (sharding),
1793
- @nospecialize (runtime)
1844
+ @nospecialize (runtime),
1845
+ @nospecialize (device),
1846
+ @nospecialize (client)
1794
1847
)
1795
1848
return make_tracer (
1796
- OrderedIdDict (), x, (), ArrayToConcrete; track_numbers, sharding, runtime
1849
+ OrderedIdDict (),
1850
+ x,
1851
+ (),
1852
+ ArrayToConcrete;
1853
+ track_numbers,
1854
+ sharding,
1855
+ runtime,
1856
+ device,
1857
+ client,
1797
1858
)
1798
1859
end
1799
1860
@@ -1802,7 +1863,9 @@ function to_rarray_internal(
1802
1863
@nospecialize (:: TracedRArray ),
1803
1864
@nospecialize (track_numbers:: Type ),
1804
1865
@nospecialize (sharding),
1805
- @nospecialize (runtime)
1866
+ @nospecialize (runtime),
1867
+ @nospecialize (device),
1868
+ @nospecialize (client)
1806
1869
)
1807
1870
return error (" Cannot convert TracedRArray to ConcreteArray" )
1808
1871
end
@@ -1812,27 +1875,33 @@ end
1812
1875
@nospecialize (track_numbers:: Type ),
1813
1876
@nospecialize (sharding),
1814
1877
:: Val{:PJRT} ,
1878
+ @nospecialize (device),
1879
+ @nospecialize (client)
1815
1880
)
1816
- return ConcretePJRTArray (x; sharding)
1881
+ return ConcretePJRTArray (x; sharding, device, client )
1817
1882
end
1818
1883
1819
1884
@inline function to_rarray_internal (
1820
1885
@nospecialize (x:: ConcreteIFRTArray ),
1821
1886
@nospecialize (track_numbers:: Type ),
1822
1887
@nospecialize (sharding),
1823
1888
:: Val{:IFRT} ,
1889
+ @nospecialize (device),
1890
+ @nospecialize (client)
1824
1891
)
1825
- return ConcreteIFRTArray (x; sharding)
1892
+ return ConcreteIFRTArray (x; sharding, device, client )
1826
1893
end
1827
1894
1828
1895
@inline function to_rarray_internal (
1829
1896
@nospecialize (x:: Array{<:ReactantPrimitive} ),
1830
1897
@nospecialize (track_numbers:: Type ),
1831
1898
@nospecialize (sharding),
1832
- @nospecialize (runtime)
1899
+ @nospecialize (runtime),
1900
+ @nospecialize (device),
1901
+ @nospecialize (client)
1833
1902
)
1834
- runtime isa Val{:PJRT } && return ConcretePJRTArray (x; sharding)
1835
- runtime isa Val{:IFRT } && return ConcreteIFRTArray (x; sharding)
1903
+ runtime isa Val{:PJRT } && return ConcretePJRTArray (x; sharding, device, client )
1904
+ runtime isa Val{:IFRT } && return ConcreteIFRTArray (x; sharding, device, client )
1836
1905
return error (" Unsupported runtime $runtime " )
1837
1906
end
1838
1907
@@ -1841,45 +1910,55 @@ end
1841
1910
@nospecialize (track_numbers:: Type ),
1842
1911
@nospecialize (sharding),
1843
1912
runtime,
1913
+ @nospecialize (device),
1914
+ @nospecialize (client)
1844
1915
) where {T<: Number }
1845
1916
if reactant_primitive (T) != = nothing
1846
1917
if runtime isa Val{:PJRT }
1847
- return ConcretePJRTArray (to_reactant_primitive .(x); sharding)
1918
+ return ConcretePJRTArray (to_reactant_primitive .(x); sharding, device, client )
1848
1919
elseif runtime isa Val{:IFRT }
1849
- return ConcreteIFRTArray (to_reactant_primitive .(x); sharding)
1920
+ return ConcreteIFRTArray (to_reactant_primitive .(x); sharding, device, client )
1850
1921
end
1851
1922
error (" Unsupported runtime $runtime " )
1852
1923
end
1853
- return @invoke to_rarray_internal (x:: Any , track_numbers:: Type , sharding, runtime)
1924
+ return @invoke to_rarray_internal (
1925
+ x:: Any , track_numbers:: Type , sharding, runtime, device, client
1926
+ )
1854
1927
end
1855
1928
1856
1929
@inline function to_rarray_internal (
1857
1930
@nospecialize (x:: ConcretePJRTNumber ),
1858
1931
@nospecialize (track_numbers:: Type ),
1859
1932
@nospecialize (sharding),
1860
1933
:: Val{:PJRT} ,
1934
+ @nospecialize (device),
1935
+ @nospecialize (client)
1861
1936
)
1862
- return ConcretePJRTNumber (x; sharding)
1937
+ return ConcretePJRTNumber (x; sharding, device, client )
1863
1938
end
1864
1939
1865
1940
@inline function to_rarray_internal (
1866
1941
@nospecialize (x:: ConcreteIFRTNumber ),
1867
1942
@nospecialize (track_numbers:: Type ),
1868
1943
@nospecialize (sharding),
1869
1944
:: Val{:IFRT} ,
1945
+ @nospecialize (device),
1946
+ @nospecialize (client)
1870
1947
)
1871
- return ConcreteIFRTNumber (x; sharding)
1948
+ return ConcreteIFRTNumber (x; sharding, device, client )
1872
1949
end
1873
1950
1874
1951
@inline function to_rarray_internal (
1875
1952
@nospecialize (x:: ReactantPrimitive ),
1876
1953
@nospecialize (track_numbers:: Type ),
1877
1954
@nospecialize (sharding),
1878
1955
runtime,
1956
+ @nospecialize (device),
1957
+ @nospecialize (client)
1879
1958
)
1880
1959
if typeof (x) <: track_numbers
1881
- runtime isa Val{:PJRT } && return ConcretePJRTNumber (x; sharding)
1882
- runtime isa Val{:IFRT } && return ConcreteIFRTNumber (x; sharding)
1960
+ runtime isa Val{:PJRT } && return ConcretePJRTNumber (x; sharding, device, client )
1961
+ runtime isa Val{:IFRT } && return ConcreteIFRTNumber (x; sharding, device, client )
1883
1962
error (" Unsupported runtime $runtime " )
1884
1963
end
1885
1964
return x
@@ -1890,15 +1969,19 @@ end
1890
1969
@nospecialize (track_numbers:: Type ),
1891
1970
@nospecialize (sharding),
1892
1971
runtime,
1972
+ @nospecialize (device),
1973
+ @nospecialize (client)
1893
1974
)
1894
1975
if reactant_primitive (typeof (x)) != = nothing
1895
1976
runtime isa Val{:PJRT } &&
1896
- return ConcretePJRTArray (to_reactant_primitive (x); sharding)
1977
+ return ConcretePJRTArray (to_reactant_primitive (x); sharding, device, client )
1897
1978
runtime isa Val{:IFRT } &&
1898
- return ConcreteIFRTArray (to_reactant_primitive (x); sharding)
1979
+ return ConcreteIFRTArray (to_reactant_primitive (x); sharding, device, client )
1899
1980
error (" Unsupported runtime $runtime " )
1900
1981
end
1901
- return @invoke to_rarray_internal (x:: Any , track_numbers:: Type , sharding, runtime)
1982
+ return @invoke to_rarray_internal (
1983
+ x:: Any , track_numbers:: Type , sharding, runtime, device, client
1984
+ )
1902
1985
end
1903
1986
1904
1987
function Reactant. traced_type_inner (
0 commit comments