@@ -16,11 +16,32 @@ mutable struct TracedRArray{T,N} <: RArray{T,N}
16
16
end
17
17
end
18
18
19
- function Base. getindex (a:: TracedRArray{T,0} ) where {T}
20
- return a
19
+ const WrappedTracedRArray{T,N} = WrappedArray{T,N,TracedRArray,TracedRArray{T,N}}
20
+ const AnyTracedRArray{T,N} = Union{TracedRArray{T,N},WrappedTracedRArray{T,N}}
21
+ const AnyTracedRScalar{T} = AnyTracedRArray{T,0 }
22
+ const AnyTracedRVector{T} = AnyTracedRArray{T,1 }
23
+ const AnyTracedRMatrix{T} = AnyTracedRArray{T,2 }
24
+ const AnyTracedRVecOrMat{T} = Union{AnyTracedRVector{T},AnyTracedRMatrix{T}}
25
+
26
+ materialize_traced_array (x:: TracedRArray ) = x
27
+ materialize_traced_array (x:: WrappedTracedRArray ) = x[axes (x)... ]
28
+
29
+ get_mlir_data (x:: TracedRArray ) = x. mlir_data
30
+ get_mlir_data (x:: AnyTracedRArray ) = get_mlir_data (materialize_traced_array (x))
31
+
32
+ ancestor (x:: TracedRArray ) = x
33
+ ancestor (x:: WrappedTracedRArray ) = ancestor (parent (x))
34
+
35
+ get_ancestor_indices (:: TracedRArray , indices... ) = indices
36
+ function get_ancestor_indices (
37
+ x:: SubArray{T,N,<:AnyTracedRArray{T,N}} , indices...
38
+ ) where {T,N}
39
+ return get_ancestor_indices (parent (x), Base. reindex (x. indices, indices)... )
21
40
end
22
41
23
- function Base. getindex (a:: TracedRArray{T,N} , index:: Vararg{Integer,N} ) where {T,N}
42
+ Base. getindex (a:: AnyTracedRScalar{T} ) where {T} = a
43
+
44
+ function Base. getindex (a:: TracedRArray{T,N} , index:: Vararg{Int,N} ) where {T,N}
24
45
@warn (
25
46
""" Performing scalar indexing on task $(current_task ()) .
26
47
Invocation resulted in scalar indexing of a TracedRArray.
@@ -47,9 +68,7 @@ and require expensive copies and synchronization each time and therefore should
47
68
return TracedRArray {T,0} ((), res2, ())
48
69
end
49
70
50
- function Base. getindex (
51
- a:: TracedRArray{T,N} , indices:: Vararg{Union{Base.AbstractUnitRange,Colon},N}
52
- ) where {T,N}
71
+ function Base. getindex (a:: TracedRArray{T,N} , indices:: Vararg{Any,N} ) where {T,N}
53
72
indices = [i isa Colon ? (1 : size (a, idx)) : i for (idx, i) in enumerate (indices)]
54
73
res = MLIR. IR. result (
55
74
MLIR. Dialects. stablehlo. slice (
@@ -62,14 +81,19 @@ function Base.getindex(
62
81
),
63
82
1 ,
64
83
)
65
- return TracedRArray {T,N} ((), res, Tuple (length .(indices)))
84
+ x = TracedRArray {T,N} ((), res, Tuple (length .(indices)))
85
+ ddims = findall (x -> x isa Integer, indices)
86
+ ! isempty (ddims) && return dropdims (x; dims= Tuple (ddims))
87
+ return x
66
88
end
67
89
68
- function Base. view (
69
- a:: TracedRArray{T,N} , indices:: Vararg{Union{Base.AbstractUnitRange,Colon},N}
70
- ) where {T,N}
71
- # TODO : Implement before merging the PR
72
- return error (" view is not supported yet" )
90
+ # Prevent ambiguity
91
+ function Base. getindex (a:: WrappedTracedRArray , index:: Int... )
92
+ return getindex (ancestor (a), get_ancestor_indices (a, index... )... )
93
+ end
94
+
95
+ function Base. getindex (a:: WrappedTracedRArray , indices... )
96
+ return getindex (ancestor (a), get_ancestor_indices (a, indices... )... )
73
97
end
74
98
75
99
function Base. setindex! (
@@ -101,15 +125,15 @@ function Base.show(io::IOty, X::TracedRArray{T,N}) where {T,N,IOty<:Union{IO,IOC
101
125
# return print(io, X.mlir_data, ")")
102
126
end
103
127
104
- Base. only (A:: TracedRArray{T,0 } ) where {T} = A
128
+ Base. only (A:: AnyTracedRScalar{T } ) where {T} = A
105
129
106
- function Base. reshape (A:: TracedRArray {T,N} , dims:: NTuple{NT,Int} ) where {T,N,NT}
130
+ function Base. reshape (A:: AnyTracedRArray {T,N} , dims:: NTuple{NT,Int} ) where {T,N,NT}
107
131
prod (dims) == prod (size (A)) || Base. _throw_dmrsa (dims, prod (size (A)))
108
132
109
133
# HLO reshape semantics collapse the opposite way
110
134
res1 = MLIR. IR. result (
111
135
MLIR. Dialects. stablehlo. transpose (
112
- A . mlir_data ;
136
+ get_mlir_data (A) ;
113
137
permutation= MLIR. IR. DenseArrayAttribute ([Int64 (N - 1 - i) for i in 0 : (N - 1 )]),
114
138
),
115
139
1 ,
@@ -137,12 +161,12 @@ function Base.reshape(A::TracedRArray{T,N}, dims::NTuple{NT,Int}) where {T,N,NT}
137
161
return TracedRArray {T,NT} ((), res3, dims)
138
162
end
139
163
140
- function Base. permutedims (A:: TracedRArray {T,N} , perm) where {T,N}
164
+ function Base. permutedims (A:: AnyTracedRArray {T,N} , perm) where {T,N}
141
165
return TracedRArray {T,N} (
142
166
(),
143
167
MLIR. IR. result (
144
168
MLIR. Dialects. stablehlo. transpose (
145
- A . mlir_data ;
169
+ get_mlir_data (A) ;
146
170
permutation= MLIR. IR. DenseArrayAttribute ([Int64 (i - 1 ) for i in perm]),
147
171
),
148
172
1 ,
@@ -151,13 +175,19 @@ function Base.permutedims(A::TracedRArray{T,N}, perm) where {T,N}
151
175
)
152
176
end
153
177
178
+ function Base. transpose (A:: AnyTracedRVecOrMat )
179
+ A = ndims (A) == 1 ? reshape (A, :, 1 ) : A
180
+ return permutedims (A, (2 , 1 ))
181
+ end
182
+ Base. adjoint (A:: AnyTracedRVecOrMat{<:Real} ) = transpose (A)
183
+
154
184
function Base. promote_rule (
155
185
:: Type{TracedRArray{T,N}} , :: Type{TracedRArray{S,N}}
156
186
) where {T,S,N}
157
187
return TracedRArray{Base. promote_type (T, S),N}
158
188
end
159
189
160
- function Base. promote_rule (A :: Type{T} , B :: Type{TracedRArray{S,N}} ) where {T,S,N}
190
+ function Base. promote_rule (:: Type{T} , :: Type{TracedRArray{S,N}} ) where {T,S,N}
161
191
return TracedRArray{Base. promote_type (T, S),N}
162
192
end
163
193
@@ -194,7 +224,7 @@ function promote_to(::Type{TracedRArray{T,N}}, rhs) where {T,N}
194
224
)
195
225
end
196
226
197
- function promote_to (lhs :: TracedRArray{T,N} , rhs) where {T,N}
227
+ function promote_to (:: TracedRArray{T,N} , rhs) where {T,N}
198
228
return promote_to (TracedRArray{T,N}, rhs)
199
229
end
200
230
@@ -668,6 +698,7 @@ function Base.mapreducedim!(
668
698
end
669
699
670
700
struct AbstractReactantArrayStyle{N} <: Base.Broadcast.AbstractArrayStyle{N} end
701
+
671
702
AbstractReactantArrayStyle (:: Val{N} ) where {N} = AbstractReactantArrayStyle {N} ()
672
703
AbstractReactantArrayStyle {M} (:: Val{N} ) where {N,M} = AbstractReactantArrayStyle {N} ()
673
704
@@ -678,7 +709,9 @@ AbstractReactantArrayStyle{M}(::Val{N}) where {N,M} = AbstractReactantArrayStyle
678
709
# copy(inst)
679
710
# end
680
711
681
- BroadcastStyle (:: Type{T} ) where {T<: TracedRArray } = AbstractReactantArrayStyle {ndims(T)} ()
712
+ function BroadcastStyle (:: Type{<:AnyTracedRArray{T,N}} ) where {T,N}
713
+ return AbstractReactantArrayStyle {N} ()
714
+ end
682
715
683
716
function Base. similar (
684
717
bc:: Broadcasted{AbstractReactantArrayStyle{N}} , :: Type{T} , dims
@@ -746,8 +779,8 @@ function broadcast_to_size(arg::AbstractArray, rsize)
746
779
return arg
747
780
end
748
781
749
- function broadcast_to_size (arg:: TracedRArray , rsize)
750
- return arg
782
+ function broadcast_to_size (arg:: AnyTracedRArray , rsize)
783
+ return materialize_traced_array ( arg)
751
784
end
752
785
753
786
function broadcast_to_size (arg:: Base.RefValue , rsize)
0 commit comments