Skip to content

Commit 831fbdc

Browse files
authored
Merge pull request #133 from EnzymeAD/ap/view
feat: robust handling of wrapped arrays of reactant arrays
2 parents e2ca620 + e8fe0c8 commit 831fbdc

File tree

10 files changed

+213
-51
lines changed

10 files changed

+213
-51
lines changed

Project.toml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ authors = ["William Moses <[email protected]>", "Valentin Churavy <[email protected]
44
version = "0.2.1"
55

66
[deps]
7+
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
78
CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82"
89
Downloads = "f43a241f-c20a-4ad4-852c-f6b1247861c6"
910
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
@@ -14,13 +15,11 @@ Reactant_jll = "0192cb87-2b54-54ad-80e0-3be72ad8a3c0"
1415
Scratch = "6c6a2e73-6563-6170-7368-637461726353"
1516

1617
[weakdeps]
17-
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
1818
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
1919
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
2020
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
2121

2222
[extensions]
23-
ReactantAdaptExt = "Adapt"
2423
ReactantArrayInterfaceExt = "ArrayInterface"
2524
ReactantNNlibExt = "NNlib"
2625
ReactantStatisticsExt = "Statistics"

ext/ReactantAdaptExt.jl

Lines changed: 0 additions & 8 deletions
This file was deleted.

ext/ReactantNNlibExt.jl

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
module ReactantNNlibExt
22

33
using NNlib
4-
using Reactant
4+
using Reactant: Reactant, TracedRArray, AnyTracedRArray, materialize_traced_array
55

66
for (jlop, hloop) in (
77
(:(NNlib.tanh_fast), :tanh),
88
(:(NNlib.sigmoid_fast), :logistic),
99
(:(NNlib.sigmoid), :logistic),
1010
)
11-
@eval function $(jlop)(x::Reactant.TracedRArray{T,0}) where {T}
12-
return Reactant.TracedRArray{T,0}(
11+
@eval function $(jlop)(x::TracedRArray{T,0}) where {T}
12+
return TracedRArray{T,0}(
1313
(),
1414
Reactant.MLIR.IR.result(
1515
Reactant.MLIR.Dialects.stablehlo.$(hloop)(x.mlir_data), 1
@@ -19,18 +19,16 @@ for (jlop, hloop) in (
1919
end
2020
end
2121

22-
NNlib.relu(x::Reactant.TracedRArray{T,0}) where {T} = max(x, zero(T))
22+
NNlib.relu(x::TracedRArray{T,0}) where {T} = max(x, zero(T))
2323

24-
function NNlib.gelu(x::Reactant.TracedRArray{T,0}) where {T}
24+
function NNlib.gelu(x::TracedRArray{T,0}) where {T}
2525
α = T(0.044715)
2626
λλ = T((8 / π))
2727
return x * sigmoid(λλ * x * muladd(x^2, α, one(T)))
2828
end
2929

3030
# TODO handle non finite cases
31-
function NNlib.softmax!(
32-
out::Reactant.TracedRArray{T,N}, x::AbstractArray; dims=1
33-
) where {T,N}
31+
function NNlib.softmax!(out::TracedRArray{T,N}, x::AbstractArray; dims=1) where {T,N}
3432
max_ = NNlib.fast_maximum(x; dims)
3533
#if all(isfinite, max_)
3634
@fastmath out .= exp.(x .- max_)
@@ -43,8 +41,11 @@ function NNlib.softmax!(
4341
end
4442

4543
function NNlib.conv(
46-
x::Reactant.TracedRArray{T,N}, W::Reactant.TracedRArray{T}, cdims::DenseConvDims
44+
x::AnyTracedRArray{T,N}, W::AnyTracedRArray{T}, cdims::DenseConvDims
4745
) where {T,N}
46+
x = materialize_traced_array(x)
47+
W = materialize_traced_array(W)
48+
4849
kernel_size = NNlib.kernel_size(cdims)
4950
padding = NNlib.padding(cdims)
5051
stride = NNlib.stride(cdims)
@@ -119,10 +120,12 @@ function NNlib.conv(
119120
batch_group_count=1,
120121
)
121122

122-
return Reactant.TracedRArray{T,N}((), Reactant.MLIR.IR.result(conv), output_shape)
123+
return TracedRArray{T,N}((), Reactant.MLIR.IR.result(conv), output_shape)
123124
end
124125

125-
function reduce_window(f, x::Reactant.TracedRArray{T,N}, pdims; init) where {T,N}
126+
function reduce_window(f, x::AnyTracedRArray{T,N}, pdims; init) where {T,N}
127+
x = materialize_traced_array(x)
128+
126129
num_spatial_dims = N - 2
127130
input_spatial_dims = 1:num_spatial_dims
128131

@@ -185,21 +188,22 @@ function reduce_window(f, x::Reactant.TracedRArray{T,N}, pdims; init) where {T,N
185188
body,
186189
)
187190

188-
return Reactant.TracedRArray{T,N}(
189-
(), Reactant.MLIR.IR.result(reduction), size(result_type)
190-
)
191+
return TracedRArray{T,N}((), Reactant.MLIR.IR.result(reduction), size(result_type))
191192
end
192193

193-
function NNlib.maxpool(x::Reactant.TracedRArray{T}, pdims::NNlib.PoolDims) where {T}
194+
function NNlib.maxpool(x::AnyTracedRArray{T}, pdims::NNlib.PoolDims) where {T}
194195
return reduce_window(
195196
Reactant.MLIR.Dialects.stablehlo.maximum, x, pdims; init=typemin(T)
196197
)
197198
end
198199

199-
function NNlib.meanpool(x::Reactant.TracedRArray{T}, pdims::NNlib.PoolDims) where {T}
200+
function NNlib.meanpool(x::AnyTracedRArray{T}, pdims::NNlib.PoolDims) where {T}
200201
numel = prod(NNlib.kernel_size(pdims))
201202
return reduce_window(Reactant.MLIR.Dialects.stablehlo.add, x, pdims; init=zero(T)) ./
202203
T(numel)
203204
end
204205

206+
NNlib.batched_transpose(x::AnyTracedRArray{T,3}) where {T} = permutedims(x, (2, 1, 3))
207+
NNlib.batched_adjoint(x::AnyTracedRArray{<:Real,3}) = NNlib.batched_transpose(x)
208+
205209
end # module ReactantNNlibExt

ext/ReactantStatisticsExt.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,18 @@
11
module ReactantStatisticsExt
22

3-
using Reactant: TracedRArray
3+
using Reactant: AnyTracedRArray, materialize_traced_array
44
using Statistics: Statistics
55

6-
function Statistics.mean(A::TracedRArray{T,N}; dims=:) where {T,N}
6+
function Statistics.mean(A::AnyTracedRArray{T,N}; dims=:) where {T,N}
7+
A = materialize_traced_array(A)
78
denom = dims isa Colon ? length(A) : prod(Base.Fix1(size, A), dims)
89
return mapreduce(identity, +, A; dims) / denom
910
end
1011

1112
function Statistics.var(
12-
A::TracedRArray{T,N}; dims=:, mean=nothing, corrected=true
13+
A::AnyTracedRArray{T,N}; dims=:, mean=nothing, corrected=true
1314
) where {T,N}
15+
A = materialize_traced_array(A)
1416
mean === nothing && (mean = Statistics.mean(A; dims))
1517
denom = (dims isa Colon ? length(A) : prod(Base.Fix1(size, A), dims)) - corrected
1618
return mapreduce(abs2, +, A .- mean; dims) / denom

src/ConcreteRArray.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ end
1010

1111
ConcreteRArray(data::T) where {T<:Number} = ConcreteRArray{T,0}(data, ())
1212

13+
Adapt.adapt_storage(::Type{T}, x::AbstractArray) where {T<:ConcreteRArray} = T(x)
14+
1315
function ConcreteRArray(
1416
data::Array{T,N}; client=XLA.default_backend[], idx=XLA.default_device_idx[]
1517
) where {T,N}

src/Reactant.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
module Reactant
22

3+
using Adapt: Adapt, WrappedArray
4+
35
# auxiliary types and functions
46
include("OrderedIdDict.jl")
57

src/TracedRArray.jl

Lines changed: 55 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,32 @@ mutable struct TracedRArray{T,N} <: RArray{T,N}
1616
end
1717
end
1818

19-
function Base.getindex(a::TracedRArray{T,0}) where {T}
20-
return a
19+
const WrappedTracedRArray{T,N} = WrappedArray{T,N,TracedRArray,TracedRArray{T,N}}
20+
const AnyTracedRArray{T,N} = Union{TracedRArray{T,N},WrappedTracedRArray{T,N}}
21+
const AnyTracedRScalar{T} = AnyTracedRArray{T,0}
22+
const AnyTracedRVector{T} = AnyTracedRArray{T,1}
23+
const AnyTracedRMatrix{T} = AnyTracedRArray{T,2}
24+
const AnyTracedRVecOrMat{T} = Union{AnyTracedRVector{T},AnyTracedRMatrix{T}}
25+
26+
materialize_traced_array(x::TracedRArray) = x
27+
materialize_traced_array(x::WrappedTracedRArray) = x[axes(x)...]
28+
29+
get_mlir_data(x::TracedRArray) = x.mlir_data
30+
get_mlir_data(x::AnyTracedRArray) = get_mlir_data(materialize_traced_array(x))
31+
32+
ancestor(x::TracedRArray) = x
33+
ancestor(x::WrappedTracedRArray) = ancestor(parent(x))
34+
35+
get_ancestor_indices(::TracedRArray, indices...) = indices
36+
function get_ancestor_indices(
37+
x::SubArray{T,N,<:AnyTracedRArray{T,N}}, indices...
38+
) where {T,N}
39+
return get_ancestor_indices(parent(x), Base.reindex(x.indices, indices)...)
2140
end
2241

23-
function Base.getindex(a::TracedRArray{T,N}, index::Vararg{Integer,N}) where {T,N}
42+
Base.getindex(a::AnyTracedRScalar{T}) where {T} = a
43+
44+
function Base.getindex(a::TracedRArray{T,N}, index::Vararg{Int,N}) where {T,N}
2445
@warn(
2546
"""Performing scalar indexing on task $(current_task()).
2647
Invocation resulted in scalar indexing of a TracedRArray.
@@ -47,9 +68,7 @@ and require expensive copies and synchronization each time and therefore should
4768
return TracedRArray{T,0}((), res2, ())
4869
end
4970

50-
function Base.getindex(
51-
a::TracedRArray{T,N}, indices::Vararg{Union{Base.AbstractUnitRange,Colon},N}
52-
) where {T,N}
71+
function Base.getindex(a::TracedRArray{T,N}, indices::Vararg{Any,N}) where {T,N}
5372
indices = [i isa Colon ? (1:size(a, idx)) : i for (idx, i) in enumerate(indices)]
5473
res = MLIR.IR.result(
5574
MLIR.Dialects.stablehlo.slice(
@@ -62,14 +81,19 @@ function Base.getindex(
6281
),
6382
1,
6483
)
65-
return TracedRArray{T,N}((), res, Tuple(length.(indices)))
84+
x = TracedRArray{T,N}((), res, Tuple(length.(indices)))
85+
ddims = findall(x -> x isa Integer, indices)
86+
!isempty(ddims) && return dropdims(x; dims=Tuple(ddims))
87+
return x
6688
end
6789

68-
function Base.view(
69-
a::TracedRArray{T,N}, indices::Vararg{Union{Base.AbstractUnitRange,Colon},N}
70-
) where {T,N}
71-
# TODO: Implement before merging the PR
72-
return error("view is not supported yet")
90+
# Prevent ambiguity
91+
function Base.getindex(a::WrappedTracedRArray, index::Int...)
92+
return getindex(ancestor(a), get_ancestor_indices(a, index...)...)
93+
end
94+
95+
function Base.getindex(a::WrappedTracedRArray, indices...)
96+
return getindex(ancestor(a), get_ancestor_indices(a, indices...)...)
7397
end
7498

7599
function Base.setindex!(
@@ -101,15 +125,15 @@ function Base.show(io::IOty, X::TracedRArray{T,N}) where {T,N,IOty<:Union{IO,IOC
101125
# return print(io, X.mlir_data, ")")
102126
end
103127

104-
Base.only(A::TracedRArray{T,0}) where {T} = A
128+
Base.only(A::AnyTracedRScalar{T}) where {T} = A
105129

106-
function Base.reshape(A::TracedRArray{T,N}, dims::NTuple{NT,Int}) where {T,N,NT}
130+
function Base.reshape(A::AnyTracedRArray{T,N}, dims::NTuple{NT,Int}) where {T,N,NT}
107131
prod(dims) == prod(size(A)) || Base._throw_dmrsa(dims, prod(size(A)))
108132

109133
# HLO reshape semantics collapse the opposite way
110134
res1 = MLIR.IR.result(
111135
MLIR.Dialects.stablehlo.transpose(
112-
A.mlir_data;
136+
get_mlir_data(A);
113137
permutation=MLIR.IR.DenseArrayAttribute([Int64(N - 1 - i) for i in 0:(N - 1)]),
114138
),
115139
1,
@@ -137,12 +161,12 @@ function Base.reshape(A::TracedRArray{T,N}, dims::NTuple{NT,Int}) where {T,N,NT}
137161
return TracedRArray{T,NT}((), res3, dims)
138162
end
139163

140-
function Base.permutedims(A::TracedRArray{T,N}, perm) where {T,N}
164+
function Base.permutedims(A::AnyTracedRArray{T,N}, perm) where {T,N}
141165
return TracedRArray{T,N}(
142166
(),
143167
MLIR.IR.result(
144168
MLIR.Dialects.stablehlo.transpose(
145-
A.mlir_data;
169+
get_mlir_data(A);
146170
permutation=MLIR.IR.DenseArrayAttribute([Int64(i - 1) for i in perm]),
147171
),
148172
1,
@@ -151,13 +175,19 @@ function Base.permutedims(A::TracedRArray{T,N}, perm) where {T,N}
151175
)
152176
end
153177

178+
function Base.transpose(A::AnyTracedRVecOrMat)
179+
A = ndims(A) == 1 ? reshape(A, :, 1) : A
180+
return permutedims(A, (2, 1))
181+
end
182+
Base.adjoint(A::AnyTracedRVecOrMat{<:Real}) = transpose(A)
183+
154184
function Base.promote_rule(
155185
::Type{TracedRArray{T,N}}, ::Type{TracedRArray{S,N}}
156186
) where {T,S,N}
157187
return TracedRArray{Base.promote_type(T, S),N}
158188
end
159189

160-
function Base.promote_rule(A::Type{T}, B::Type{TracedRArray{S,N}}) where {T,S,N}
190+
function Base.promote_rule(::Type{T}, ::Type{TracedRArray{S,N}}) where {T,S,N}
161191
return TracedRArray{Base.promote_type(T, S),N}
162192
end
163193

@@ -194,7 +224,7 @@ function promote_to(::Type{TracedRArray{T,N}}, rhs) where {T,N}
194224
)
195225
end
196226

197-
function promote_to(lhs::TracedRArray{T,N}, rhs) where {T,N}
227+
function promote_to(::TracedRArray{T,N}, rhs) where {T,N}
198228
return promote_to(TracedRArray{T,N}, rhs)
199229
end
200230

@@ -668,6 +698,7 @@ function Base.mapreducedim!(
668698
end
669699

670700
struct AbstractReactantArrayStyle{N} <: Base.Broadcast.AbstractArrayStyle{N} end
701+
671702
AbstractReactantArrayStyle(::Val{N}) where {N} = AbstractReactantArrayStyle{N}()
672703
AbstractReactantArrayStyle{M}(::Val{N}) where {N,M} = AbstractReactantArrayStyle{N}()
673704

@@ -678,7 +709,9 @@ AbstractReactantArrayStyle{M}(::Val{N}) where {N,M} = AbstractReactantArrayStyle
678709
# copy(inst)
679710
# end
680711

681-
BroadcastStyle(::Type{T}) where {T<:TracedRArray} = AbstractReactantArrayStyle{ndims(T)}()
712+
function BroadcastStyle(::Type{<:AnyTracedRArray{T,N}}) where {T,N}
713+
return AbstractReactantArrayStyle{N}()
714+
end
682715

683716
function Base.similar(
684717
bc::Broadcasted{AbstractReactantArrayStyle{N}}, ::Type{T}, dims
@@ -746,8 +779,8 @@ function broadcast_to_size(arg::AbstractArray, rsize)
746779
return arg
747780
end
748781

749-
function broadcast_to_size(arg::TracedRArray, rsize)
750-
return arg
782+
function broadcast_to_size(arg::AnyTracedRArray, rsize)
783+
return materialize_traced_array(arg)
751784
end
752785

753786
function broadcast_to_size(arg::Base.RefValue, rsize)

test/basic.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,3 +261,15 @@ tuple_byref2(x) = abs2.(x), tuple_byref2(x)
261261
# @test r2[2].a.b.data === x.data
262262
# @test r2[1] == abs2.([1.0 -2.0; -3.0 4.0])
263263
end
264+
265+
sum_xxᵀ(x) = sum(x .* x')
266+
267+
@testset "sum(x .* x')" begin
268+
@testset "size(x): $(size(x))" for x in (rand(4, 4), rand(4))
269+
x_ca = Reactant.to_rarray(x)
270+
271+
sum_xxᵀ_compiled = @compile sum_xxᵀ(x_ca)
272+
273+
@test sum_xxᵀ_compiled(x_ca) sum_xxᵀ(x)
274+
end
275+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ end
4848
@safetestset "Closure" include("closure.jl")
4949
@safetestset "Compile" include("compile.jl")
5050
@safetestset "Buffer Donation" include("buffer_donation.jl")
51+
@safetestset "Wrapped Arrays" include("wrapped_arrays.jl")
5152

5253
@testset "Neural Networks" begin
5354
@safetestset "NNlib Primitives" include("nn/nnlib.jl")

0 commit comments

Comments
 (0)