Skip to content

Commit 0b803e9

Browse files
committed
refactor: cleaner handling of wrapper indices
1 parent a90efff commit 0b803e9

File tree

1 file changed

+21
-5
lines changed

1 file changed

+21
-5
lines changed

src/TracedRArray.jl

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,18 @@ const AnyTracedRVecOrMat{T} = Union{AnyTracedRVector{T},AnyTracedRMatrix{T}}
2727
get_mlir_data(x::TracedRArray) = x.mlir_data
2828
get_mlir_data(x::AnyTracedRArray) = get_mlir_data(x[axes(x)...])
2929

30+
ancestor(x::TracedRArray) = x
31+
function ancestor(x::WrappedArray{T,N,TracedRArray,TracedRArray{T,N}}) where {T,N}
32+
return ancestor(parent(x))
33+
end
34+
35+
get_ancestor_indices(::TracedRArray, indices...) = indices
36+
function get_ancestor_indices(
37+
x::SubArray{T,N,<:AnyTracedRArray{T,N}}, indices...
38+
) where {T,N}
39+
return get_ancestor_indices(parent(x), Base.reindex(x.indices, indices)...)
40+
end
41+
3042
Base.getindex(a::AnyTracedRScalar{T}) where {T} = a
3143

3244
function Base.getindex(a::TracedRArray{T,N}, index::Vararg{Int,N}) where {T,N}
@@ -75,13 +87,17 @@ function Base.getindex(a::TracedRArray{T,N}, indices::Vararg{Any,N}) where {T,N}
7587
return x
7688
end
7789

78-
# Prevents ambiguity
79-
function Base.getindex(a::SubArray{T,N,<:AnyTracedRArray{T,N}}, indices::Int...) where {T,N}
80-
return getindex(parent(a), Base.reindex(a.indices, indices)...)
90+
# Prevent ambiguity
91+
function Base.getindex(
92+
a::WrappedArray{T,N,TracedRArray,<:TracedRArray{T,N}}, index::Int...
93+
) where {T,N}
94+
return getindex(ancestor(a), get_ancestor_indices(a, index...)...)
8195
end
8296

83-
function Base.getindex(a::SubArray{T,N,<:AnyTracedRArray{T,N}}, indices...) where {T,N}
84-
return getindex(parent(a), Base.reindex(a.indices, indices)...)
97+
function Base.getindex(
98+
a::WrappedArray{T,N,TracedRArray,<:TracedRArray{T,N}}, indices...
99+
) where {T,N}
100+
return getindex(ancestor(a), get_ancestor_indices(a, indices...)...)
85101
end
86102

87103
function Base.setindex!(

0 commit comments

Comments
 (0)