Skip to content

Commit 8c23e6c

Browse files
authored
Add back support for buffer donation (#96)
* Redo call to `make_tracer` with `TracedTrack` to mark donated buffers * Test buffer donation * Fix `i64` to `f64` type conversion
1 parent 2bb871f commit 8c23e6c

File tree

3 files changed

+35
-7
lines changed

3 files changed

+35
-7
lines changed

src/utils.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -125,12 +125,12 @@ function make_mlir_fn(f, args, kwargs, name="main", concretein=true; toscalar=fa
125125
seen_results, result, (:result,), concretein ? TracedTrack : TracedSetPath
126126
)
127127

128-
# retraced_args = ntuple(Val(N)) do i
129-
# Base.@_inline_meta
130-
# return make_tracer(
131-
# seen_results, traced_args[i], concretein ? (:resargs, i) : (), TracedTrack
132-
# )
133-
# end
128+
# marks buffers to be donated
129+
for i in 1:N
130+
make_tracer(
131+
seen_results, traced_args[i], concretein ? (:resargs, i) : (), TracedTrack
132+
)
133+
end
134134

135135
linear_results = TracedRArray[]
136136

test/buffer_donation.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
using Test
2+
using Reactant
3+
4+
# TODO try again with `2` to check automatic conversion from int to float
5+
function donate_fill_x_with_2(x, y)
6+
x .= 2.0
7+
return nothing
8+
end
9+
10+
function donate_inplace_mul(x, y)
11+
x .*= y
12+
return nothing
13+
end
14+
15+
@testset "buffer_donation" begin
16+
a = Reactant.ConcreteRArray(ones(2, 2))
17+
b = Reactant.ConcreteRArray(3 * ones(2, 2))
18+
f = Reactant.compile(donate_fill_x_with_2, (a, b))
19+
f(a, b)
20+
@test convert(Array, a) == 2 * ones(2, 2)
21+
22+
a = Reactant.ConcreteRArray(2 * ones(2, 2))
23+
b = Reactant.ConcreteRArray(3 * ones(2, 2))
24+
f = Reactant.compile(donate_inplace_mul, (a, b))
25+
f(a, b)
26+
@test convert(Array, a) == 6 * ones(2, 2)
27+
end

test/runtests.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,11 @@ include("layout.jl")
4343
include("tracing.jl")
4444
include("basic.jl")
4545
include("bcast.jl")
46-
include("nn.jl")
4746
include("struct.jl")
4847
include("closure.jl")
4948
include("compile.jl")
49+
include("buffer_donation.jl")
50+
include("nn.jl")
5051

5152
if VERSION v"1.10-" # Lux isn't supported on 1.9
5253
include("nn_lux.jl")

0 commit comments

Comments
 (0)