Skip to content

Commit 10d364c

Browse files
committed
feat: TracedRScalar
1 parent 72e23c8 commit 10d364c

File tree

3 files changed

+66
-15
lines changed

3 files changed

+66
-15
lines changed

src/Reactant.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ include("OrderedIdDict.jl")
88
using Enzyme
99

1010
abstract type RArray{T,N} <: AbstractArray{T,N} end
11+
abstract type RScalar{T} <: Number end
1112

1213
function Base.reshape(A::RArray, dims::Tuple{Vararg{Union{Int,Colon}}})
1314
return reshape(A, Base._reshape_uncolon(A, dims))

src/TracedRArray.jl

Lines changed: 59 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,22 @@ end
1919

2020
TracedRArray{T,N}(x::TracedRArray{T,N}) where {T,N} = x
2121

22+
mutable struct TracedRScalar{T} <: RScalar{T}
23+
paths::Tuple
24+
mlir_data::Union{Nothing,MLIR.IR.Value}
25+
26+
function TracedRScalar{T}(
27+
paths::Tuple, mlir_data::Union{Nothing,MLIR.IR.Value}
28+
) where {T}
29+
if !isnothing(mlir_data)
30+
@assert size(MLIR.IR.type(mlir_data)) == ()
31+
end
32+
return new{T}(paths, mlir_data)
33+
end
34+
end
35+
2236
const WrappedTracedRArray{T,N} = WrappedArray{T,N,TracedRArray,TracedRArray{T,N}}
2337
const AnyTracedRArray{T,N} = Union{TracedRArray{T,N},WrappedTracedRArray{T,N}}
24-
const AnyTracedRScalar{T} = AnyTracedRArray{T,0}
2538
const AnyTracedRVector{T} = AnyTracedRArray{T,1}
2639
const AnyTracedRMatrix{T} = AnyTracedRArray{T,2}
2740
const AnyTracedRVecOrMat{T} = Union{AnyTracedRVector{T},AnyTracedRMatrix{T}}
@@ -40,12 +53,12 @@ function get_ancestor_indices(x::WrappedTracedRArray, indices...)
4053
return get_ancestor_indices(parent(x), Base.reindex(parentindices(x), indices)...)
4154
end
4255

43-
Base.getindex(a::AnyTracedRScalar{T}) where {T} = a
56+
Base.getindex(a::TracedRScalar{T}) where {T} = a
4457

45-
Base.zero(::AnyTracedRScalar{T}) where {T} = promote_to(TracedRArray{T,0}, zero(T))
46-
Base.one(::AnyTracedRScalar{T}) where {T} = promote_to(TracedRArray{T,0}, one(T))
58+
Base.zero(::TracedRScalar{T}) where {T} = promote_to(TracedRScalar{T}, zero(T))
59+
Base.one(::TracedRScalar{T}) where {T} = promote_to(TracedRScalar{T}, one(T))
4760

48-
function Base.convert(::Type{<:AnyTracedRScalar{T}}, x::Number) where {T}
61+
function Base.convert(::Type{<:TracedRScalar{T}}, x::Number) where {T}
4962
return promote_to(TracedRArray{T,0}, T(x))
5063
end
5164

@@ -73,7 +86,7 @@ and require expensive copies and synchronization each time and therefore should
7386
),
7487
1,
7588
)
76-
return TracedRArray{T,0}((), res2, ())
89+
return TracedRScalar{T}((), res2)
7790
end
7891

7992
function Base.getindex(a::TracedRArray{T,N}, indices::Vararg{Any,N}) where {T,N}
@@ -133,7 +146,11 @@ function Base.show(io::IOty, X::TracedRArray{T,N}) where {T,N,IOty<:Union{IO,IOC
133146
# return print(io, X.mlir_data, ")")
134147
end
135148

136-
Base.only(A::AnyTracedRScalar{T}) where {T} = A
149+
function Base.show(io::IOty, X::TracedRScalar{T}) where {T,IOty<:Union{IO,IOContext}}
150+
return print(io, "TracedRScalar{", T, "}(", X.paths, ")")
151+
end
152+
153+
Base.only(A::TracedRScalar{T}) where {T} = A
137154

138155
function Base.reshape(A::AnyTracedRArray{T,N}, dims::NTuple{NT,Int}) where {T,N,NT}
139156
if prod(dims) != prod(size(A))
@@ -207,9 +224,7 @@ end
207224

208225
function promote_to(::Type{TracedRArray{T,N}}, rhs) where {T,N}
209226
if isa(rhs, TracedRArray)
210-
if typeof(rhs) == TracedRArray{T,N}
211-
return rhs
212-
end
227+
rhs isa TracedRArray{T,N} && return rhs
213228
return TracedRArray{T,N}(
214229
(),
215230
MLIR.IR.result(
@@ -222,11 +237,8 @@ function promote_to(::Type{TracedRArray{T,N}}, rhs) where {T,N}
222237
)
223238
end
224239
if isa(rhs, Number)
225-
attr = fill(MLIR.IR.Attribute(T(rhs)), mlir_type(TracedRArray{T,N}, size(rhs)))
226-
ta = TracedRArray{T,N}(
227-
(), MLIR.IR.result(MLIR.Dialects.stablehlo.constant(; value=attr), 1), size(rhs)
228-
)
229-
return ta
240+
throw(ArgumentError("Cannot promote number to `TracedRArray`. Use \
241+
`TracedRScalar` instead."))
230242
end
231243
T0 = eltype(rhs)
232244
attr = MLIR.IR.DenseElementsAttribute(collect(rhs))
@@ -238,9 +250,41 @@ function promote_to(::Type{TracedRArray{T,N}}, rhs) where {T,N}
238250
)
239251
end
240252

253+
function promote_to(::Type{TracedRScalar{T}}, rhs) where {T}
254+
if isa(rhs, TracedRScalar)
255+
rhs isa TracedRScalar{T} && return rhs
256+
return TracedRScalar{T}(
257+
(),
258+
MLIR.IR.result(
259+
MLIR.Dialects.stablehlo.convert(
260+
rhs.mlir_data; result=mlir_type(TracedRScalar{T})
261+
),
262+
1,
263+
),
264+
)
265+
end
266+
if isa(rhs, Number)
267+
attr = fill(MLIR.IR.Attribute(T(rhs)), mlir_type(TracedRScalar{T}))
268+
return TracedRScalar{T}(
269+
(), MLIR.IR.result(MLIR.Dialects.stablehlo.constant(; value=attr), 1)
270+
)
271+
end
272+
T0 = eltype(rhs)
273+
attr = MLIR.IR.DenseElementsAttribute(collect(rhs))
274+
return promote_to(
275+
TracedRScalar{T},
276+
TracedRScalar{T0}(
277+
(), MLIR.IR.result(MLIR.Dialects.stablehlo.constant(; value=attr), 1)
278+
),
279+
)
280+
end
281+
241282
function promote_to(::TracedRArray{T,N}, rhs) where {T,N}
242283
return promote_to(TracedRArray{T,N}, rhs)
243284
end
285+
function promote_to(::TracedRScalar{T}, rhs) where {T}
286+
return promote_to(TracedRScalar{T}, rhs)
287+
end
244288

245289
for (jlop, hloop) in (
246290
(:(Base.min), :minimum),

src/utils.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,17 @@ function mlir_type(x::RArray{T,N}) where {T,N}
22
return MLIR.IR.TensorType(size(x), MLIR.IR.Type(T))
33
end
44

5+
mlir_type(::RScalar{T}) where {T} = MLIR.IR.TensorType((), MLIR.IR.Type(T))
6+
57
function mlir_type(::Type{<:RArray{T,N}}, shape) where {T,N}
68
@assert length(shape) == N
79
return MLIR.IR.TensorType(shape, MLIR.IR.Type(T))
810
end
911

12+
function mlir_type(::Type{<:RScalar{T}}) where {T}
13+
return MLIR.IR.TensorType((), MLIR.IR.Type(T))
14+
end
15+
1016
function transpose_ty(mlirty)
1117
return MLIR.IR.TensorType([reverse(size(mlirty))...], eltype(mlirty))
1218
end

0 commit comments

Comments
 (0)