Skip to content

Commit fc9b577

Browse files
committed
feat: getindex of subarrays
1 parent e6a9cd9 commit fc9b577

File tree

4 files changed

+63
-13
lines changed

4 files changed

+63
-13
lines changed

src/Reactant.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
module Reactant
22

3-
using Adapt: Adapt
3+
using Adapt: Adapt, WrappedArray
44

55
# auxiliary types and functions
66
include("OrderedIdDict.jl")

src/TracedRArray.jl

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,17 @@ mutable struct TracedRArray{T,N} <: RArray{T,N}
1616
end
1717
end
1818

19-
function Base.getindex(a::TracedRArray{T,0}) where {T}
20-
return a
21-
end
19+
const AnyTracedRArray{T,N} = Union{
20+
TracedRArray{T,N},WrappedArray{T,N,TracedRArray,TracedRArray{T,N}}
21+
}
22+
const AnyTracedRScalar{T} = AnyTracedRArray{T,0}
23+
const AnyTracedRVector{T} = AnyTracedRArray{T,1}
24+
const AnyTracedRMatrix{T} = AnyTracedRArray{T,2}
25+
const AnyTracedRVecOrMat{T} = Union{AnyTracedRVector{T},AnyTracedRMatrix{T}}
26+
27+
Base.getindex(a::AnyTracedRScalar{T}) where {T} = a
2228

23-
function Base.getindex(a::TracedRArray{T,N}, index::Vararg{Integer,N}) where {T,N}
29+
function Base.getindex(a::TracedRArray{T,N}, index::Vararg{Int,N}) where {T,N}
2430
@warn(
2531
"""Performing scalar indexing on task $(current_task()).
2632
Invocation resulted in scalar indexing of a TracedRArray.
@@ -48,7 +54,7 @@ and require expensive copies and synchronization each time and therefore should
4854
end
4955

5056
function Base.getindex(
51-
a::TracedRArray{T,N}, indices::Vararg{Union{Base.AbstractUnitRange,Colon},N}
57+
a::TracedRArray{T,N}, indices::Vararg{Any,N}
5258
) where {T,N}
5359
indices = [i isa Colon ? (1:size(a, idx)) : i for (idx, i) in enumerate(indices)]
5460
res = MLIR.IR.result(
@@ -62,14 +68,23 @@ function Base.getindex(
6268
),
6369
1,
6470
)
65-
return TracedRArray{T,N}((), res, Tuple(length.(indices)))
71+
x = TracedRArray{T,N}((), res, Tuple(length.(indices)))
72+
ddims = findall(x -> x isa Integer, indices)
73+
!isempty(ddims) && return dropdims(x, dims=Tuple(ddims))
74+
return x
6675
end
6776

68-
function Base.view(
69-
a::TracedRArray{T,N}, indices::Vararg{Union{Base.AbstractUnitRange,Colon},N}
77+
# Prevents ambiguity
78+
function Base.getindex(
79+
a::SubArray{T,N,<:AnyTracedRArray{T,N}}, indices::Int...
80+
) where {T,N}
81+
return getindex(parent(a), Base.reindex(a.indices, indices)...)
82+
end
83+
84+
function Base.getindex(
85+
a::SubArray{T,N,<:AnyTracedRArray{T,N}}, indices...
7086
) where {T,N}
71-
# TODO: Implement before merging the PR
72-
return error("view is not supported yet")
87+
return getindex(parent(a), Base.reindex(a.indices, indices)...)
7388
end
7489

7590
function Base.setindex!(
@@ -101,7 +116,7 @@ function Base.show(io::IOty, X::TracedRArray{T,N}) where {T,N,IOty<:Union{IO,IOC
101116
# return print(io, X.mlir_data, ")")
102117
end
103118

104-
Base.only(A::TracedRArray{T,0}) where {T} = A
119+
Base.only(A::AnyTracedRScalar{T}) where {T} = A
105120

106121
function Base.reshape(A::TracedRArray{T,N}, dims::NTuple{NT,Int}) where {T,N,NT}
107122
prod(dims) == prod(size(A)) || Base._throw_dmrsa(dims, prod(size(A)))
@@ -194,7 +209,7 @@ function promote_to(::Type{TracedRArray{T,N}}, rhs) where {T,N}
194209
)
195210
end
196211

197-
function promote_to(lhs::TracedRArray{T,N}, rhs) where {T,N}
212+
function promote_to(::TracedRArray{T,N}, rhs) where {T,N}
198213
return promote_to(TracedRArray{T,N}, rhs)
199214
end
200215

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ end
4848
@safetestset "Closure" include("closure.jl")
4949
@safetestset "Compile" include("compile.jl")
5050
@safetestset "Buffer Donation" include("buffer_donation.jl")
51+
@safetestset "Wrapped Arrays" include("wrapped_arrays.jl")
5152

5253
@testset "Neural Networks" begin
5354
@safetestset "NNlib Primitives" include("nn/nnlib.jl")

test/wrapped_arrays.jl

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
using Reactant, Test
2+
3+
function view_getindex_1(x)
4+
x = view(x, 2:3, 1:2, :)
5+
return x[2, 1, 1]
6+
end
7+
8+
function view_getindex_2(x)
9+
x = view(x, 2:3, 1:2, :)
10+
return x[1:1, 1, :]
11+
end
12+
13+
function view_getindex_3(x)
14+
x = view(x, 2:3, 1:2, :)
15+
x2 = view(x, 1:1, 2:2, 1:2)
16+
return x2[1, 1, 1:1]
17+
end
18+
19+
@testset "view getindex" begin
20+
x = rand(4, 4, 3)
21+
x_ra = Reactant.to_rarray(x)
22+
23+
view_getindex_1_compiled = @compile view_getindex_1(x_ra)
24+
25+
@test view_getindex_1_compiled(x_ra) view_getindex_1(x)
26+
27+
view_getindex_2_compiled = @compile view_getindex_2(x_ra)
28+
29+
@test view_getindex_2_compiled(x_ra) view_getindex_2(x)
30+
31+
view_getindex_3_compiled = @compile view_getindex_3(x_ra)
32+
33+
@test view_getindex_3_compiled(x_ra) view_getindex_3(x)
34+
end

0 commit comments

Comments
 (0)