19
19
20
20
TracedRArray {T,N} (x:: TracedRArray{T,N} ) where {T,N} = x
21
21
22
+ mutable struct TracedRScalar{T} <: RScalar{T}
23
+ paths:: Tuple
24
+ mlir_data:: Union{Nothing,MLIR.IR.Value}
25
+
26
+ function TracedRScalar {T} (
27
+ paths:: Tuple , mlir_data:: Union{Nothing,MLIR.IR.Value}
28
+ ) where {T}
29
+ if ! isnothing (mlir_data)
30
+ @assert size (MLIR. IR. type (mlir_data)) == ()
31
+ end
32
+ return new {T} (paths, mlir_data)
33
+ end
34
+ end
35
+
22
36
const WrappedTracedRArray{T,N} = WrappedArray{T,N,TracedRArray,TracedRArray{T,N}}
23
37
const AnyTracedRArray{T,N} = Union{TracedRArray{T,N},WrappedTracedRArray{T,N}}
24
- const AnyTracedRScalar{T} = AnyTracedRArray{T,0 }
25
38
const AnyTracedRVector{T} = AnyTracedRArray{T,1 }
26
39
const AnyTracedRMatrix{T} = AnyTracedRArray{T,2 }
27
40
const AnyTracedRVecOrMat{T} = Union{AnyTracedRVector{T},AnyTracedRMatrix{T}}
@@ -40,12 +53,12 @@ function get_ancestor_indices(x::WrappedTracedRArray, indices...)
40
53
return get_ancestor_indices (parent (x), Base. reindex (parentindices (x), indices)... )
41
54
end
42
55
43
- Base. getindex (a:: AnyTracedRScalar {T} ) where {T} = a
56
+ Base. getindex (a:: TracedRScalar {T} ) where {T} = a
44
57
45
- Base. zero (:: AnyTracedRScalar {T} ) where {T} = promote_to (TracedRArray{T, 0 }, zero (T))
46
- Base. one (:: AnyTracedRScalar {T} ) where {T} = promote_to (TracedRArray{T, 0 }, one (T))
58
+ Base. zero (:: TracedRScalar {T} ) where {T} = promote_to (TracedRScalar{T }, zero (T))
59
+ Base. one (:: TracedRScalar {T} ) where {T} = promote_to (TracedRScalar{T }, one (T))
47
60
48
- function Base. convert (:: Type{<:AnyTracedRScalar {T}} , x:: Number ) where {T}
61
+ function Base. convert (:: Type{<:TracedRScalar {T}} , x:: Number ) where {T}
49
62
return promote_to (TracedRArray{T,0 }, T (x))
50
63
end
51
64
@@ -73,7 +86,7 @@ and require expensive copies and synchronization each time and therefore should
73
86
),
74
87
1 ,
75
88
)
76
- return TracedRArray {T,0 } ((), res2, () )
89
+ return TracedRScalar {T } ((), res2)
77
90
end
78
91
79
92
function Base. getindex (a:: TracedRArray{T,N} , indices:: Vararg{Any,N} ) where {T,N}
@@ -133,7 +146,11 @@ function Base.show(io::IOty, X::TracedRArray{T,N}) where {T,N,IOty<:Union{IO,IOC
133
146
# return print(io, X.mlir_data, ")")
134
147
end
135
148
136
- Base. only (A:: AnyTracedRScalar{T} ) where {T} = A
149
+ function Base. show (io:: IOty , X:: TracedRScalar{T} ) where {T,IOty<: Union{IO,IOContext} }
150
+ return print (io, " TracedRScalar{" , T, " }(" , X. paths, " )" )
151
+ end
152
+
153
+ Base. only (A:: TracedRScalar{T} ) where {T} = A
137
154
138
155
function Base. reshape (A:: AnyTracedRArray{T,N} , dims:: NTuple{NT,Int} ) where {T,N,NT}
139
156
if prod (dims) != prod (size (A))
207
224
208
225
function promote_to (:: Type{TracedRArray{T,N}} , rhs) where {T,N}
209
226
if isa (rhs, TracedRArray)
210
- if typeof (rhs) == TracedRArray{T,N}
211
- return rhs
212
- end
227
+ rhs isa TracedRArray{T,N} && return rhs
213
228
return TracedRArray {T,N} (
214
229
(),
215
230
MLIR. IR. result (
@@ -222,11 +237,8 @@ function promote_to(::Type{TracedRArray{T,N}}, rhs) where {T,N}
222
237
)
223
238
end
224
239
if isa (rhs, Number)
225
- attr = fill (MLIR. IR. Attribute (T (rhs)), mlir_type (TracedRArray{T,N}, size (rhs)))
226
- ta = TracedRArray {T,N} (
227
- (), MLIR. IR. result (MLIR. Dialects. stablehlo. constant (; value= attr), 1 ), size (rhs)
228
- )
229
- return ta
240
+ throw (ArgumentError (" Cannot promote number to `TracedRArray`. Use \
241
+ `TracedRScalar` instead." ))
230
242
end
231
243
T0 = eltype (rhs)
232
244
attr = MLIR. IR. DenseElementsAttribute (collect (rhs))
@@ -238,9 +250,41 @@ function promote_to(::Type{TracedRArray{T,N}}, rhs) where {T,N}
238
250
)
239
251
end
240
252
253
+ function promote_to (:: Type{TracedRScalar{T}} , rhs) where {T}
254
+ if isa (rhs, TracedRScalar)
255
+ rhs isa TracedRScalar{T} && return rhs
256
+ return TracedRScalar {T} (
257
+ (),
258
+ MLIR. IR. result (
259
+ MLIR. Dialects. stablehlo. convert (
260
+ rhs. mlir_data; result= mlir_type (TracedRScalar{T})
261
+ ),
262
+ 1 ,
263
+ ),
264
+ )
265
+ end
266
+ if isa (rhs, Number)
267
+ attr = fill (MLIR. IR. Attribute (T (rhs)), mlir_type (TracedRScalar{T}))
268
+ return TracedRScalar {T} (
269
+ (), MLIR. IR. result (MLIR. Dialects. stablehlo. constant (; value= attr), 1 )
270
+ )
271
+ end
272
+ T0 = eltype (rhs)
273
+ attr = MLIR. IR. DenseElementsAttribute (collect (rhs))
274
+ return promote_to (
275
+ TracedRScalar{T},
276
+ TracedRScalar {T0} (
277
+ (), MLIR. IR. result (MLIR. Dialects. stablehlo. constant (; value= attr), 1 )
278
+ ),
279
+ )
280
+ end
281
+
241
282
function promote_to (:: TracedRArray{T,N} , rhs) where {T,N}
242
283
return promote_to (TracedRArray{T,N}, rhs)
243
284
end
285
+ function promote_to (:: TracedRScalar{T} , rhs) where {T}
286
+ return promote_to (TracedRScalar{T}, rhs)
287
+ end
244
288
245
289
for (jlop, hloop) in (
246
290
(:(Base. min), :minimum ),
0 commit comments