@@ -16,21 +16,21 @@ mutable struct TracedRArray{T,N} <: RArray{T,N}
16
16
end
17
17
end
18
18
19
- const AnyTracedRArray{T,N} = Union{
20
- TracedRArray{T,N},WrappedArray{T,N,TracedRArray,TracedRArray{T,N}}
21
- }
19
+ const WrappedTracedRArray{T,N} = WrappedArray{T,N,TracedRArray,TracedRArray{T,N}}
20
+ const AnyTracedRArray{T,N} = Union{TracedRArray{T,N},WrappedTracedRArray{T,N}}
22
21
const AnyTracedRScalar{T} = AnyTracedRArray{T,0 }
23
22
const AnyTracedRVector{T} = AnyTracedRArray{T,1 }
24
23
const AnyTracedRMatrix{T} = AnyTracedRArray{T,2 }
25
24
const AnyTracedRVecOrMat{T} = Union{AnyTracedRVector{T},AnyTracedRMatrix{T}}
26
25
26
+ materialize_traced_array (x:: TracedRArray ) = x
27
+ materialize_traced_array (x:: WrappedTracedRArray ) = x[axes (x)... ]
28
+
27
29
get_mlir_data (x:: TracedRArray ) = x. mlir_data
28
- get_mlir_data (x:: AnyTracedRArray ) = get_mlir_data (x[ axes (x)... ] )
30
+ get_mlir_data (x:: AnyTracedRArray ) = get_mlir_data (materialize_traced_array (x))
29
31
30
32
ancestor (x:: TracedRArray ) = x
31
- function ancestor (x:: WrappedArray{T,N,TracedRArray,TracedRArray{T,N}} ) where {T,N}
32
- return ancestor (parent (x))
33
- end
33
+ ancestor (x:: WrappedTracedRArray ) = ancestor (parent (x))
34
34
35
35
get_ancestor_indices (:: TracedRArray , indices... ) = indices
36
36
function get_ancestor_indices (
@@ -88,15 +88,11 @@ function Base.getindex(a::TracedRArray{T,N}, indices::Vararg{Any,N}) where {T,N}
88
88
end
89
89
90
90
# Prevent ambiguity
91
- function Base. getindex (
92
- a:: WrappedArray{T,N,TracedRArray,<:TracedRArray{T,N}} , index:: Int...
93
- ) where {T,N}
91
+ function Base. getindex (a:: WrappedTracedRArray , index:: Int... )
94
92
return getindex (ancestor (a), get_ancestor_indices (a, index... )... )
95
93
end
96
94
97
- function Base. getindex (
98
- a:: WrappedArray{T,N,TracedRArray,<:TracedRArray{T,N}} , indices...
99
- ) where {T,N}
95
+ function Base. getindex (a:: WrappedTracedRArray , indices... )
100
96
return getindex (ancestor (a), get_ancestor_indices (a, indices... )... )
101
97
end
102
98
@@ -783,14 +779,8 @@ function broadcast_to_size(arg::AbstractArray, rsize)
783
779
return arg
784
780
end
785
781
786
- function broadcast_to_size (
787
- arg:: WrappedArray{T,N,TracedRArray,<:TracedRArray{T,N}} , rsize
788
- ) where {T,N}
789
- return broadcast_to_size (arg[axes (arg)... ], rsize)
790
- end
791
-
792
- function broadcast_to_size (arg:: TracedRArray , rsize)
793
- return arg
782
+ function broadcast_to_size (arg:: AnyTracedRArray , rsize)
783
+ return materialize_traced_array (arg)
794
784
end
795
785
796
786
function broadcast_to_size (arg:: Base.RefValue , rsize)
0 commit comments