@@ -27,6 +27,18 @@ const AnyTracedRVecOrMat{T} = Union{AnyTracedRVector{T},AnyTracedRMatrix{T}}
27
27
get_mlir_data (x:: TracedRArray ) = x. mlir_data
28
28
get_mlir_data (x:: AnyTracedRArray ) = get_mlir_data (x[axes (x)... ])
29
29
30
+ ancestor (x:: TracedRArray ) = x
31
+ function ancestor (x:: WrappedArray{T,N,TracedRArray,TracedRArray{T,N}} ) where {T,N}
32
+ return ancestor (parent (x))
33
+ end
34
+
35
+ get_ancestor_indices (:: TracedRArray , indices... ) = indices
36
+ function get_ancestor_indices (
37
+ x:: SubArray{T,N,<:AnyTracedRArray{T,N}} , indices...
38
+ ) where {T,N}
39
+ return get_ancestor_indices (parent (x), Base. reindex (x. indices, indices)... )
40
+ end
41
+
30
42
Base. getindex (a:: AnyTracedRScalar{T} ) where {T} = a
31
43
32
44
function Base. getindex (a:: TracedRArray{T,N} , index:: Vararg{Int,N} ) where {T,N}
@@ -75,13 +87,17 @@ function Base.getindex(a::TracedRArray{T,N}, indices::Vararg{Any,N}) where {T,N}
75
87
return x
76
88
end
77
89
78
- # Prevents ambiguity
79
- function Base. getindex (a:: SubArray{T,N,<:AnyTracedRArray{T,N}} , indices:: Int... ) where {T,N}
80
- return getindex (parent (a), Base. reindex (a. indices, indices)... )
90
+ # Prevent ambiguity
91
+ function Base. getindex (
92
+ a:: WrappedArray{T,N,TracedRArray,<:TracedRArray{T,N}} , index:: Int...
93
+ ) where {T,N}
94
+ return getindex (ancestor (a), get_ancestor_indices (a, index... )... )
81
95
end
82
96
83
- function Base. getindex (a:: SubArray{T,N,<:AnyTracedRArray{T,N}} , indices... ) where {T,N}
84
- return getindex (parent (a), Base. reindex (a. indices, indices)... )
97
+ function Base. getindex (
98
+ a:: WrappedArray{T,N,TracedRArray,<:TracedRArray{T,N}} , indices...
99
+ ) where {T,N}
100
+ return getindex (ancestor (a), get_ancestor_indices (a, indices... )... )
85
101
end
86
102
87
103
function Base. setindex! (
0 commit comments