Skip to content

Commit fce399c

Browse files
authored
feat: support tracing scalars (#205)
* feat: support tracing scalars * test: add scalar tests * fix: return concrete scalars * refactor: rename union type to ConcreteRScalar
1 parent 6866f05 commit fce399c

File tree

7 files changed

+161
-43
lines changed

7 files changed

+161
-43
lines changed

deps/ReactantExtra/make-bindings.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,8 @@ function build_file(output_path)
66
dir=@__DIR__,
77
),
88
)
9-
Base.Filesystem.cp(
10-
joinpath(@__DIR__, "bazel-bin", file),
11-
output_path;
12-
force=true,
9+
return Base.Filesystem.cp(
10+
joinpath(@__DIR__, "bazel-bin", file), output_path; force=true
1311
)
1412
end
1513

src/Compiler.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import ..Reactant:
55
MLIR,
66
XLA,
77
ConcreteRArray,
8+
ConcreteRNumber,
89
TracedRArray,
910
TracedRNumber,
1011
OrderedIdDict,
@@ -30,6 +31,16 @@ function create_result(tocopy::T, path, result_stores) where {T}
3031
return Expr(:new, T, elems...)
3132
end
3233

34+
function create_result(tocopy::ConcreteRNumber{T}, path, result_stores) where {T}
35+
if haskey(result_stores, path)
36+
restore = result_stores[path]
37+
delete!(result_stores, path)
38+
return :(ConcreteRNumber{$T}($restore))
39+
end
40+
# We will set the data for this later
41+
return :(ConcreteRNumber{$T}($(tocopy.data)))
42+
end
43+
3344
function create_result(tocopy::ConcreteRArray{T,N}, path, result_stores) where {T,N}
3445
if haskey(result_stores, path)
3546
restore = result_stores[path]

src/ConcreteRArray.jl

Lines changed: 63 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,34 @@ end
44

55
mutable struct ConcreteRArray{T,N} <: RArray{T,N}
66
data::XLA.AsyncBuffer
7-
# data::XLAArray{T, N}
7+
# data::XLAArray{T, N}
88
shape::NTuple{N,Int}
99
end
1010

11-
ConcreteRArray(data::T) where {T<:Number} = ConcreteRArray(fill(data))
11+
mutable struct ConcreteRNumber{T} <: RNumber{T}
12+
data::XLA.AsyncBuffer
13+
end
14+
15+
function ConcreteRNumber(
16+
data::T; client=XLA.default_backend[], idx=XLA.default_device_idx[]
17+
) where {T<:Number}
18+
crarray = ConcreteRArray(fill(data); client, idx)
19+
return ConcreteRNumber{T}(crarray.data)
20+
end
21+
22+
Base.size(::ConcreteRNumber) = ()
23+
24+
function ConcreteRArray(
25+
data::T; client=XLA.default_backend[], idx=XLA.default_device_idx[]
26+
) where {T<:Number}
27+
Base.depwarn(
28+
"ConcreteRArray(data::Number) is deprecated, use ConcreteRNumber(data) instead",
29+
:ConcreteRArray,
30+
)
31+
return ConcreteRArray(fill(data); client, idx)
32+
end
33+
34+
const ConcreteRScalar{T} = Union{ConcreteRArray{T,0},ConcreteRNumber{T}}
1235

1336
Adapt.adapt_storage(::Type{T}, x::AbstractArray) where {T<:ConcreteRArray} = T(x)
1437

@@ -48,7 +71,7 @@ function Base.convert(::Type{T}, X::ConcreteRArray{ElType,N}) where {T<:Array,El
4871
# XLA.from_row_major(data)
4972
end
5073

51-
function synchronize(x::ConcreteRArray)
74+
function synchronize(x::Union{ConcreteRArray,ConcreteRNumber})
5275
XLA.synced_buffer(x.data)
5376
return nothing
5477
end
@@ -60,7 +83,7 @@ end
6083
# return ConcreteRArray{T,N}(x.data)
6184
# end
6285

63-
function to_float(X::ConcreteRArray{T,0}) where {T}
86+
function to_number(X::ConcreteRScalar{T}) where {T}
6487
data = Ref{T}()
6588
XLA.await(X.data)
6689
buf = X.data.buffer
@@ -70,36 +93,49 @@ function to_float(X::ConcreteRArray{T,0}) where {T}
7093
return data[]
7194
end
7295

73-
function Base.convert(::Type{T}, x::ConcreteRArray{T,0}) where {T}
74-
return to_float(x)
96+
Base.convert(::Type{T}, x::ConcreteRScalar{T}) where {T} = to_number(x)
97+
98+
for jlop in (
99+
:(Base.isless),
100+
:(Base.:+),
101+
:(Base.:-),
102+
:(Base.:*),
103+
:(Base.:/),
104+
:(Base.:^),
105+
:(Base.:(==)),
106+
),
107+
T in (ConcreteRNumber, ConcreteRArray{<:Any,0})
108+
109+
@eval begin
110+
$(jlop)(x::$(T), y::$(T)) = $(jlop)(to_number(x), to_number(y))
111+
$(jlop)(x::$(T), y::Number) = $(jlop)(to_number(x), y)
112+
$(jlop)(x::Number, y::$(T)) = $(jlop)(x, to_number(y))
113+
end
75114
end
76115

77-
for jlop in (:(Base.isless), :(Base.:+), :(Base.:-), :(Base.:*), :(Base.:/), :(Base.:^))
116+
for T in (ConcreteRNumber, ConcreteRArray{<:Any,0})
78117
@eval begin
79-
function $jlop(x::ConcreteRArray{T,0}, y::ConcreteRArray{U,0}) where {T,U}
80-
return $jlop(to_float(x), to_float(y))
118+
function Base.isapprox(x::$(T), y::Number; kwargs...)
119+
return Base.isapprox(to_number(x), y; kwargs...)
81120
end
82-
function $jlop(x::ConcreteRArray{T,0}, y) where {T}
83-
return $jlop(to_float(x), y)
121+
122+
function Base.isapprox(x::Number, y::$(T); kwargs...)
123+
return Base.isapprox(x, to_number(y); kwargs...)
84124
end
85-
function $jlop(x, y::ConcreteRArray{U,0}) where {U}
86-
return $jlop(x, to_float(y))
125+
126+
function Base.isapprox(x::$(T), y::$(T); kwargs...)
127+
return Base.isapprox(to_number(x), to_number(y); kwargs...)
87128
end
88129
end
89130
end
90131

91-
function Base.isapprox(x::ConcreteRArray{T,0}, y; kwargs...) where {T}
92-
return Base.isapprox(to_float(x), y; kwargs...)
93-
end
94-
95-
function Base.isapprox(x, y::ConcreteRArray{T,0}; kwargs...) where {T}
96-
return Base.isapprox(x, to_float(y); kwargs...)
97-
end
98-
99-
function Base.isapprox(
100-
x::ConcreteRArray{T,0}, y::ConcreteRArray{T2,0}; kwargs...
101-
) where {T,T2}
102-
return Base.isapprox(to_float(x), to_float(y); kwargs...)
132+
function Base.show(io::IO, X::ConcreteRScalar{T}) where {T}
133+
if X.data == XLA.AsyncEmptyBuffer
134+
println(io, "<Empty buffer>")
135+
return nothing
136+
end
137+
str = sprint(show, to_number(X))
138+
return print(io, "$(typeof(X))($(str))")
103139
end
104140

105141
function Base.print_array(io::IO, X::ConcreteRArray)
@@ -115,7 +151,8 @@ function Base.show(io::IO, X::ConcreteRArray)
115151
println(io, "<Empty buffer>")
116152
return nothing
117153
end
118-
return Base.show(io, convert(Array, X))
154+
str = sprint(show, convert(Array, X))
155+
return print(io, "$(typeof(X))($(str))")
119156
end
120157

121158
const getindex_warned = Ref(false)

src/Reactant.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ include("Tracing.jl")
9494
include("Compiler.jl")
9595

9696
using .Compiler: @compile, @code_hlo, @jit, traced_getfield, create_result, compile
97-
export ConcreteRArray, @compile, @code_hlo, @jit
97+
export ConcreteRArray, ConcreteRNumber, @compile, @code_hlo, @jit
9898

9999
const registry = Ref{MLIR.IR.DialectRegistry}()
100100
function __init__()

src/Tracing.jl

Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,8 @@ function traced_type(::Type{T}, seen::ST, ::Val{mode}) where {ST,T<:TracedType,m
189189
elseif mode == TracedToConcrete
190190
@inline base_typec(TV::TT) where {TT<:UnionAll} =
191191
UnionAll(TV.var, base_typec(TV.body))
192-
@inline base_typec(TV::TT) where {TT<:DataType} = ConcreteRArray{TV.parameters...}
192+
@inline base_typec(TV::TT) where {TT<:DataType} =
193+
(T <: TracedRArray ? ConcreteRArray : ConcreteRNumber){TV.parameters...}
193194
return base_typec(T)
194195
elseif mode == TracedTrack || mode == TracedSetPath
195196
return T
@@ -232,6 +233,7 @@ function make_tracer(
232233
mode;
233234
toscalar=false,
234235
tobatch=nothing,
236+
kwargs...,
235237
) where {RT}
236238
if haskey(seen, prev)
237239
return seen[prev]
@@ -308,13 +310,29 @@ function make_tracer(
308310
return res
309311
end
310312

313+
function make_tracer(seen, prev::ConcreteRNumber{T}, path, mode; kwargs...) where {T}
314+
if mode == ArrayToConcrete
315+
return prev
316+
end
317+
if mode != ConcreteToTraced
318+
throw("Cannot trace existing trace type")
319+
end
320+
if haskey(seen, prev)
321+
return seen[prev]::TracedRNumber{T}
322+
end
323+
res = TracedRNumber{T}((path,), nothing)
324+
seen[prev] = res
325+
return res
326+
end
327+
311328
function make_tracer(
312329
seen,
313330
@nospecialize(prev::TracedRArray{T,N}),
314331
@nospecialize(path),
315332
mode;
316333
toscalar=false,
317334
tobatch=nothing,
335+
kwargs...,
318336
) where {T,N}
319337
if mode == ConcreteToTraced
320338
throw("Cannot trace existing trace type")
@@ -389,9 +407,9 @@ function make_tracer(
389407

390408
if mode == TracedToConcrete
391409
if haskey(seen, prev)
392-
return seen[prev]::ConcreteRArray{T,0}
410+
return seen[prev]::ConcreteRNumber{T}
393411
end
394-
res = ConcreteRArray{T,0}(XLA.AsyncEmptyBuffer, size(prev))
412+
res = ConcreteRNumber{T}(XLA.AsyncEmptyBuffer)
395413
seen[prev] = res
396414
return res
397415
end
@@ -400,8 +418,13 @@ function make_tracer(
400418
end
401419

402420
function make_tracer(
403-
seen, @nospecialize(prev::RT), @nospecialize(path), mode; kwargs...
404-
) where {RT<:AbstractFloat}
421+
seen, @nospecialize(prev::RT), @nospecialize(path), mode; track_numbers=(), kwargs...
422+
) where {RT<:Number}
423+
if mode == ArrayToConcrete
424+
length(track_numbers) == 0 && return prev
425+
should_convert = any(Base.Fix1(<:, RT), track_numbers)
426+
return should_convert ? ConcreteRNumber(prev) : prev
427+
end
405428
return prev
406429
end
407430

@@ -414,10 +437,15 @@ function make_tracer(
414437
mode;
415438
toscalar=false,
416439
tobatch=nothing,
440+
kwargs...,
417441
) where {RT}
418442
return Complex(
419-
make_tracer(seen, prev.re, append_path(path, :re), mode; toscalar, tobatch),
420-
make_tracer(seen, prev.im, append_path(path, :im), mode; toscalar, tobatch),
443+
make_tracer(
444+
seen, prev.re, append_path(path, :re), mode; toscalar, tobatch, kwargs...
445+
),
446+
make_tracer(
447+
seen, prev.im, append_path(path, :im), mode; toscalar, tobatch, kwargs...
448+
),
421449
)
422450
end
423451

@@ -489,8 +517,9 @@ function make_tracer(seen, prev::Core.Box, @nospecialize(path), mode; kwargs...)
489517
return res
490518
end
491519

492-
@inline function to_rarray(@nospecialize(x))
493-
return make_tracer(OrderedIdDict(), x, (), Reactant.ArrayToConcrete)
520+
@inline function to_rarray(@nospecialize(x); track_numbers::Union{Bool,Tuple}=())
521+
track_numbers isa Bool && (track_numbers = track_numbers ? (Number,) : ())
522+
return make_tracer(OrderedIdDict(), x, (), Reactant.ArrayToConcrete; track_numbers)
494523
end
495524

496525
to_rarray(x::ReactantPrimitive) = ConcreteRArray(x)

test/basic.jl

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,7 @@ sum_compare(x) = sum(x) > 0
8080
# Ensure we are tracing as scalars. Else this will fail due to > not being defined on
8181
# arrays
8282
f = @compile sum_compare(a)
83-
# We need to use [] to unwrap the scalar. We will fix this in the future.
84-
@test f(a)[] == sum_compare(x)
83+
@test f(a) == sum_compare(x)
8584
end
8685

8786
function mysoftmax!(x)
@@ -445,3 +444,47 @@ end
445444
c = Reactant.compile(+, (a, b))(a, b)
446445
@test c == ones(CT, 2) + ones(CT, 2)
447446
end
447+
448+
@testset "Scalars" begin
449+
@testset "Only Scalars" begin
450+
x = (3, 3.14)
451+
452+
f1(x) = x[1] * x[2]
453+
454+
x_ra = Reactant.to_rarray(x; track_numbers=(Number,))
455+
f2 = @compile f1(x_ra)
456+
@test f2(Reactant.to_rarray((5, 5.2); track_numbers=(Number,))) 5 * 5.2
457+
@test f2(Reactant.to_rarray((5, 5.2); track_numbers=(Number,))) isa ConcreteRNumber
458+
459+
x_ra = Reactant.to_rarray(x)
460+
f3 = @compile f1(x_ra)
461+
@test f3(Reactant.to_rarray((5, 5.2))) f1(x)
462+
@test !(f3(Reactant.to_rarray((5, 5.2))) isa ConcreteRNumber)
463+
@test f3(Reactant.to_rarray((5, 5.2))) isa Number
464+
465+
x_ra = Reactant.to_rarray(x; track_numbers=(Int,))
466+
f4 = @compile f1(x_ra)
467+
@test f4(Reactant.to_rarray((5, 5.2); track_numbers=(Int,))) 5 * 3.14
468+
@test f4(Reactant.to_rarray((5, 5.2); track_numbers=(Int,))) isa ConcreteRNumber
469+
end
470+
471+
@testset "Mixed" begin
472+
x = (3, [3.14])
473+
474+
f1(x) = x[1] * x[2]
475+
476+
x_ra = Reactant.to_rarray(x; track_numbers=(Number,))
477+
478+
f2 = @compile f1(x_ra)
479+
res2 = f2(Reactant.to_rarray((5, [3.14]); track_numbers=(Number,)))
480+
@test only(res2) 5 * 3.14
481+
@test res2 isa ConcreteRArray
482+
483+
x_ra = Reactant.to_rarray(x)
484+
485+
f3 = @compile f1(x_ra)
486+
res3 = f3(Reactant.to_rarray((5, [3.14])))
487+
@test only(res3) only(f1(x))
488+
@test res3 isa ConcreteRArray
489+
end
490+
end

test/compile.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ Base.sum(x::NamedTuple{(:a,),Tuple{T}}) where {T<:Reactant.TracedRArray} = (; a=
1111

1212
f = @compile sum(x2)
1313

14-
@test f(x2) isa @NamedTuple{a::Reactant.ConcreteRArray{Float64,0}}
14+
@test f(x2) isa @NamedTuple{a::Reactant.ConcreteRNumber{Float64}}
1515
@test isapprox(f(x2).a, sum(x.a))
1616
end
1717
end

0 commit comments

Comments
 (0)