@@ -16,11 +16,17 @@ mutable struct TracedRArray{T,N} <: RArray{T,N}
16
16
end
17
17
end
18
18
19
- function Base. getindex (a:: TracedRArray{T,0} ) where {T}
20
- return a
21
- end
19
+ const AnyTracedRArray{T,N} = Union{
20
+ TracedRArray{T,N},WrappedArray{T,N,TracedRArray,TracedRArray{T,N}}
21
+ }
22
+ const AnyTracedRScalar{T} = AnyTracedRArray{T,0 }
23
+ const AnyTracedRVector{T} = AnyTracedRArray{T,1 }
24
+ const AnyTracedRMatrix{T} = AnyTracedRArray{T,2 }
25
+ const AnyTracedRVecOrMat{T} = Union{AnyTracedRVector{T},AnyTracedRMatrix{T}}
26
+
27
+ Base. getindex (a:: AnyTracedRScalar{T} ) where {T} = a
22
28
23
- function Base. getindex (a:: TracedRArray{T,N} , index:: Vararg{Integer ,N} ) where {T,N}
29
+ function Base. getindex (a:: TracedRArray{T,N} , index:: Vararg{Int ,N} ) where {T,N}
24
30
@warn (
25
31
""" Performing scalar indexing on task $(current_task ()) .
26
32
Invocation resulted in scalar indexing of a TracedRArray.
@@ -48,7 +54,7 @@ and require expensive copies and synchronization each time and therefore should
48
54
end
49
55
50
56
function Base. getindex (
51
- a:: TracedRArray{T,N} , indices:: Vararg{Union{Base.AbstractUnitRange,Colon} ,N}
57
+ a:: TracedRArray{T,N} , indices:: Vararg{Any ,N}
52
58
) where {T,N}
53
59
indices = [i isa Colon ? (1 : size (a, idx)) : i for (idx, i) in enumerate (indices)]
54
60
res = MLIR. IR. result (
@@ -62,14 +68,23 @@ function Base.getindex(
62
68
),
63
69
1 ,
64
70
)
65
- return TracedRArray {T,N} ((), res, Tuple (length .(indices)))
71
+ x = TracedRArray {T,N} ((), res, Tuple (length .(indices)))
72
+ ddims = findall (x -> x isa Integer, indices)
73
+ ! isempty (ddims) && return dropdims (x, dims= Tuple (ddims))
74
+ return x
66
75
end
67
76
68
- function Base. view (
69
- a:: TracedRArray{T,N} , indices:: Vararg{Union{Base.AbstractUnitRange,Colon},N}
77
+ # Prevents ambiguity
78
+ function Base. getindex (
79
+ a:: SubArray{T,N,<:AnyTracedRArray{T,N}} , indices:: Int...
80
+ ) where {T,N}
81
+ return getindex (parent (a), Base. reindex (a. indices, indices)... )
82
+ end
83
+
84
+ function Base. getindex (
85
+ a:: SubArray{T,N,<:AnyTracedRArray{T,N}} , indices...
70
86
) where {T,N}
71
- # TODO : Implement before merging the PR
72
- return error (" view is not supported yet" )
87
+ return getindex (parent (a), Base. reindex (a. indices, indices)... )
73
88
end
74
89
75
90
function Base. setindex! (
@@ -101,7 +116,7 @@ function Base.show(io::IOty, X::TracedRArray{T,N}) where {T,N,IOty<:Union{IO,IOC
101
116
# return print(io, X.mlir_data, ")")
102
117
end
103
118
104
- Base. only (A:: TracedRArray{T,0 } ) where {T} = A
119
+ Base. only (A:: AnyTracedRScalar{T } ) where {T} = A
105
120
106
121
function Base. reshape (A:: TracedRArray{T,N} , dims:: NTuple{NT,Int} ) where {T,N,NT}
107
122
prod (dims) == prod (size (A)) || Base. _throw_dmrsa (dims, prod (size (A)))
@@ -194,7 +209,7 @@ function promote_to(::Type{TracedRArray{T,N}}, rhs) where {T,N}
194
209
)
195
210
end
196
211
197
- function promote_to (lhs :: TracedRArray{T,N} , rhs) where {T,N}
212
+ function promote_to (:: TracedRArray{T,N} , rhs) where {T,N}
198
213
return promote_to (TracedRArray{T,N}, rhs)
199
214
end
200
215
0 commit comments