Skip to content

Commit eb324ab

Browse files
authored
Merge branch 'main' into ap/updated_no_nan
2 parents 4b18df2 + 7e21eca commit eb324ab

File tree

18 files changed

+395
-85
lines changed

18 files changed

+395
-85
lines changed

Project.toml

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Reactant"
22
uuid = "3c362404-f566-11ee-1572-e11a4b42c853"
33
authors = ["William Moses <[email protected]>", "Valentin Churavy <[email protected]>", "Sergio Sánchez Ramírez <[email protected]>", "Paul Berg <[email protected]>", "Avik Pal <[email protected]>", "Mosè Giordano <[email protected]>"]
4-
version = "0.2.153"
4+
version = "0.2.155"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
@@ -32,6 +32,7 @@ AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
3232
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
3333
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
3434
DLFP8Types = "f4c16678-4a16-415b-82ef-ed337c5d6c7c"
35+
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
3536
Float8s = "81dfefd7-55b0-40c6-a251-db853704e186"
3637
GPUCompiler = "61eb1bfa-7361-4325-ad38-22787b887f55"
3738
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
@@ -54,6 +55,7 @@ ReactantAbstractFFTsExt = "AbstractFFTs"
5455
ReactantArrayInterfaceExt = "ArrayInterface"
5556
ReactantCUDAExt = ["CUDA", "GPUCompiler", "KernelAbstractions", "LLVM"]
5657
ReactantDLFP8TypesExt = "DLFP8Types"
58+
ReactantFillArraysExt = "FillArrays"
5759
ReactantFloat8sExt = "Float8s"
5860
ReactantKernelAbstractionsExt = "KernelAbstractions"
5961
ReactantMPIExt = "MPI"
@@ -77,6 +79,7 @@ Downloads = "1.6"
7779
EnumX = "1"
7880
Enzyme = "0.13.49"
7981
EnzymeCore = "0.8.11"
82+
FillArrays = "1.13"
8083
Float8s = "0.1"
8184
Functors = "0.5"
8285
GPUArraysCore = "0.2"
@@ -97,15 +100,15 @@ PythonCall = "0.9.25"
97100
Random = "1.10"
98101
Random123 = "1.7"
99102
ReactantCore = "0.1.15"
100-
Reactant_jll = "0.0.232"
103+
Reactant_jll = "0.0.233"
101104
ScopedValues = "1.3.0"
102105
Scratch = "1.2"
103106
Sockets = "1.10"
104107
SpecialFunctions = "2.4"
105108
Statistics = "1.10"
106-
unzip_jll = "6"
107109
YaoBlocks = "0.13, 0.14"
108110
julia = "1.10"
111+
unzip_jll = "6"
109112

110113
[extras]
111114
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

deps/ReactantExtra/WORKSPACE

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ NSYNC_COMMIT = "82b118aa7ace3132e517e2c467f8732978cf4023"
44

55
NSYNC_SHA256 = ""
66

7-
ENZYMEXLA_COMMIT = "ad1e0392bcea0afaf75d12445608a08550055ea6"
7+
ENZYMEXLA_COMMIT = "45f1d0d5a47ee706500eff8841b710b9da112ec6"
88

99
ENZYMEXLA_SHA256 = ""
1010

docs/src/api/api.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,13 @@ within_compile
2121
@trace
2222
```
2323

24+
## Reactant data types
25+
26+
```@docs
27+
ConcreteRArray
28+
ConcreteRNumber
29+
```
30+
2431
## Inspect Generated HLO
2532

2633
```@docs

docs/src/introduction/index.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@ Pkg.add("Reactant")
1313

1414
## Quick Start
1515

16-
Reactant provides two new array types at its core, a ConcreteRArray and a TracedRArray. A
17-
ConcreteRArray is an underlying buffer to whatever device data you wish to store and can be
16+
Reactant provides two new array types at its core, a [`ConcreteRArray`](@ref) and a `TracedRArray`. A
17+
`ConcreteRArray` is an underlying buffer to whatever device data you wish to store and can be
1818
created by converting from a regular Julia Array.
1919

2020
```@example quickstart

ext/ReactantFillArraysExt.jl

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
module ReactantFillArraysExt
2+
3+
using Reactant: Reactant, TracedUtils, TracedRNumber, Ops, Sharding, unwrapped_eltype
4+
using ReactantCore: ReactantCore
5+
using FillArrays: FillArrays, AbstractFill, Fill, Ones, Zeros, OneElement
6+
using GPUArraysCore: @allowscalar
7+
8+
# Tracing
9+
Reactant._parent_type(T::Type{<:AbstractFill}) = T
10+
Reactant._parent_type(T::Type{<:OneElement}) = T
11+
12+
for AT in (Fill, Ones, Zeros)
13+
@eval Base.@nospecializeinfer function Reactant.traced_type_inner(
14+
@nospecialize(FA::Type{$(AT){T,N,Axes}}),
15+
seen,
16+
mode::Reactant.TraceMode,
17+
@nospecialize(track_numbers::Type),
18+
@nospecialize(sharding),
19+
@nospecialize(runtime)
20+
) where {T,N,Axes}
21+
# T will be a number so we need to trace it
22+
return $(AT){
23+
Reactant.traced_type_inner(T, seen, mode, Number, sharding, runtime),N,Axes
24+
}
25+
end
26+
end
27+
28+
Base.@nospecializeinfer function Reactant.make_tracer(
29+
seen, @nospecialize(prev::Fill{T,N,Axes}), @nospecialize(path), mode; kwargs...
30+
) where {T,N,Axes}
31+
return Fill(
32+
Reactant.make_tracer(
33+
seen, prev.value, (path..., 1), mode; kwargs..., track_numbers=Number
34+
),
35+
prev.axes,
36+
)
37+
end
38+
39+
Base.@nospecializeinfer function Reactant.make_tracer(
40+
seen,
41+
@nospecialize(prev::Ones{T,N,Axes}),
42+
@nospecialize(path),
43+
mode;
44+
@nospecialize(sharding = Sharding.NoSharding()),
45+
@nospecialize(runtime = nothing),
46+
kwargs...,
47+
) where {T,N,Axes}
48+
return Ones(
49+
Reactant.traced_type_inner(T, seen, mode, Number, sharding, runtime), prev.axes
50+
)
51+
end
52+
53+
Base.@nospecializeinfer function Reactant.make_tracer(
54+
seen,
55+
@nospecialize(prev::Zeros{T,N,Axes}),
56+
@nospecialize(path),
57+
mode;
58+
@nospecialize(sharding = Sharding.NoSharding()),
59+
@nospecialize(runtime = nothing),
60+
kwargs...,
61+
) where {T,N,Axes}
62+
return Zeros(
63+
Reactant.traced_type_inner(T, seen, mode, Number, sharding, runtime), prev.axes
64+
)
65+
end
66+
67+
Base.@nospecializeinfer function Reactant.traced_type_inner(
68+
@nospecialize(FA::Type{OneElement{T,N,I,A}}),
69+
seen,
70+
mode::Reactant.TraceMode,
71+
@nospecialize(track_numbers::Type),
72+
@nospecialize(sharding),
73+
@nospecialize(runtime)
74+
) where {T,N,I,A}
75+
# T will be a number so we need to trace it
76+
return OneElement{
77+
Reactant.traced_type_inner(T, seen, mode, Number, sharding, runtime),N,I,A
78+
}
79+
end
80+
81+
Base.@nospecializeinfer function Reactant.make_tracer(
82+
seen, @nospecialize(prev::OneElement{T,N,I,A}), @nospecialize(path), mode; kwargs...
83+
) where {T,N,I,A}
84+
return OneElement(
85+
Reactant.make_tracer(
86+
seen, prev.val, (path..., 1), mode; kwargs..., track_numbers=Number
87+
),
88+
prev.ind,
89+
prev.axes,
90+
)
91+
end
92+
93+
# Materialize into a dense array
94+
function ReactantCore.materialize_traced_array(x::Fill{T}) where {T}
95+
return TracedUtils.broadcast_to_size(
96+
TracedUtils.promote_to(TracedRNumber{unwrapped_eltype(T)}, x.value), size(x)
97+
)
98+
end
99+
100+
function ReactantCore.materialize_traced_array(x::Ones{T}) where {T}
101+
return TracedUtils.broadcast_to_size(unwrapped_eltype(T)(1), size(x))
102+
end
103+
104+
function ReactantCore.materialize_traced_array(x::Zeros{T}) where {T}
105+
return TracedUtils.broadcast_to_size(unwrapped_eltype(T)(0), size(x))
106+
end
107+
108+
function ReactantCore.materialize_traced_array(x::OneElement{T}) where {T}
109+
y = TracedUtils.broadcast_to_size(unwrapped_eltype(T)(0), size(x))
110+
@allowscalar setindex!(y, x.val, x.ind...)
111+
return y
112+
end
113+
114+
# some functions to avoid bad performance
115+
for AT in (Fill, Ones, Zeros, OneElement)
116+
@eval function Base.similar(x::$AT{<:TracedRNumber}, ::Type{T}, dims::Dims) where {T}
117+
return TracedUtils.broadcast_to_size(unwrapped_eltype(T)(0), dims)
118+
end
119+
end
120+
121+
end

ext/ReactantKernelAbstractionsExt.jl

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@ using Adapt: Adapt
1010

1111
export ReactantBackend
1212

13+
# ToDo: Include XLA client, device and sharding in ReactantBackend struct, to
14+
# support more complex applications? If so, need to adapt implementation of
15+
# `KA.get_backend` and `KA.allocate` accordingly.
1316
struct ReactantBackend <: KA.GPU end
1417

1518
function Base.getproperty(x::ReactantBackend, sym::Symbol)
@@ -22,16 +25,23 @@ function Base.getproperty(x::ReactantBackend, sym::Symbol)
2225
end
2326
end
2427

25-
KA.allocate(n::ReactantBackend, ::Type{T}, dims::Tuple) where {T} = KA.zeros(b, T, dims)
26-
function KA.zeros(::ReactantBackend, ::Type{T}, dims::Tuple) where {T}
27-
return Reactant.to_rarray(zeros(T, dims))
28+
function KA.allocate(::ReactantBackend, ::Type{T}, dims::Tuple) where {T}
29+
return ConcreteRArray{T}(undef, dims)
2830
end
29-
function KA.ones(::ReactantBackend, ::Type{T}, dims::Tuple) where {T}
30-
return Reactant.to_rarray(ones(T, dims))
31+
32+
function KA.zeros(b::ReactantBackend, ::Type{T}, dims::Tuple) where {T}
33+
A = KA.allocate(b, T, dims)
34+
isempty(A) || fill!(A, zero(T))
35+
return A
36+
end
37+
function KA.ones(b::ReactantBackend, ::Type{T}, dims::Tuple) where {T}
38+
A = KA.allocate(b, T, dims)
39+
isempty(A) || fill!(A, one(T))
40+
return A
3141
end
3242

3343
KA.get_backend(::Reactant.AnyTracedRArray) = ReactantBackend()
34-
KA.get_backend(::Reactant.AnyConcretePJRTArray) = ReactantBackend()
44+
KA.get_backend(::Reactant.AnyConcreteRArray) = ReactantBackend()
3545
function KA.synchronize(::ReactantBackend) end
3646

3747
Adapt.adapt_storage(::ReactantBackend, a::Array) = a

src/Compiler.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,7 @@ function create_result(
288288
sym = Symbol("result", var_idx[])
289289
var_idx[] += 1
290290

291-
@assert haskey(result_stores, path)
291+
@assert haskey(result_stores, path) "Expected $(path) in $(keys(result_stores))"
292292
restore = result_stores[path]
293293
delete!(result_stores, path)
294294
if path_to_shard_info !== nothing && haskey(path_to_shard_info, path)

src/ConcreteRArray.jl

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ for T in Base.uniontypes(ReactantPrimitive)
8383
end
8484

8585
function Base.convert(::Type{T}, x::AbstractConcreteNumber) where {T<:Number}
86+
T == typeof(x) && return x
8687
return convert(T, to_number(x))
8788
end
8889

@@ -377,16 +378,8 @@ end
377378
device::Union{Nothing,XLA.PJRT.Device}=nothing,
378379
sharding::Sharding.AbstractSharding=Sharding.NoSharding(),
379380
) where {S}
380-
client = client === nothing ? XLA.default_backend() : client
381-
382-
if idx isa Int && device === nothing
383-
device = XLA.get_device(client, idx)
384-
end
385-
386-
sdata, sharding = sharding(client, device, S, dims)
387-
388-
return ConcretePJRTArray{S,length(dims),length(sdata),typeof(sharding)}(
389-
sdata, dims, sharding
381+
return ConcretePJRTArray{S}(
382+
undef, dims; client=client, idx=idx, device=device, sharding=sharding
390383
)
391384
end
392385

@@ -417,7 +410,7 @@ function Base.similar(a::ConcreteIFRTArray{T}, ::Type{S}=T, dims::Dims=size(a))
417410
end
418411
Base.similar(a::ConcreteIFRTArray, dims::Dims) = similar(a, eltype(a), dims)
419412
function Base.similar(::Type{ConcreteIFRTArray{T}}, dims) where {T}
420-
return ConcreteIFRTArray(similar(Array{T}, dims))
413+
return ConcreteIFRTArray{T}(undef, dims)
421414
end
422415

423416
# Broadcasting interface

src/Reactant.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,20 @@ function ancestor(T::Type{<:AbstractArray})
4242
p_T == T && return T
4343
return ancestor(p_T)
4444
end
45+
if applicable(_parent_type, T)
46+
p_T = _parent_type(T)
47+
p_T == T && return T
48+
return ancestor(p_T)
49+
end
4550
@warn "`Adapt.parent_type` is not implemented for $(T). Assuming $T isn't a wrapped \
4651
array." maxlog = 1
4752
return T
4853
end
4954

55+
# A lot of packages don't define `Adapt.parent_type`. We use `_parent_type` as a way to
56+
# define the parent type of an array without type-piracy.
57+
function _parent_type end
58+
5059
include("accelerators/Accelerators.jl")
5160

5261
using .Accelerators.TPU: has_tpu

src/TracedRArray.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -568,7 +568,7 @@ function overloaded_mapreduce(
568568
original_dims = dims
569569
dims isa Int && (dims = Int64[dims])
570570
dims isa Colon && (dims = collect(Int64, 1:N))
571-
dims isa AbstractVector{<:Integer} || (dims = collect(Int64, dims))
571+
dims isa Vector{Int64} || (dims = collect(Int64, dims))
572572

573573
op_in_T = unwrapped_eltype(Core.Compiler.return_type(f, Tuple{T}))
574574
reduce_init = __default_init(op_in_T, op)

0 commit comments

Comments
 (0)