Skip to content

Commit cdaba7c

Browse files
authored
Namedtuple (#123)
1 parent 157d230 commit cdaba7c

File tree

2 files changed

+121
-42
lines changed

2 files changed

+121
-42
lines changed

src/Compiler.jl

Lines changed: 105 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -246,9 +246,9 @@ end
246246
macro code_hlo(options, maybe_call=nothing)
247247
call = something(maybe_call, options)
248248
options = isnothing(maybe_call) ? :(optimize = true) : options
249-
Meta.isexpr(call, :call) || error("@code_mlir: expected call, got $call")
249+
Meta.isexpr(call, :call) || error("@code_hlo: expected call, got $call")
250250
if !Meta.isexpr(options, :(=)) || options.args[1] != :optimize
251-
error("@code_mlir: expected options in format optimize=value, got $options")
251+
error("@code_hlo: expected options in format optimize=value, got $options")
252252
end
253253

254254
options = Expr(:tuple, Expr(:parameters, Expr(:kw, options.args...)))
@@ -269,6 +269,26 @@ macro code_hlo(options, maybe_call=nothing)
269269
end
270270
end
271271

272+
"""
273+
@compile f(args...)
274+
"""
275+
macro compile(options, maybe_call=nothing)
276+
call = something(maybe_call, options)
277+
options = isnothing(maybe_call) ? :(optimize = true) : options
278+
Meta.isexpr(call, :call) || error("@compile: expected call, got $call")
279+
if !Meta.isexpr(options, :(=)) || options.args[1] != :optimize
280+
error("@compile: expected options in format optimize=value, got $options")
281+
end
282+
283+
options = Expr(:tuple, Expr(:parameters, Expr(:kw, options.args...)))
284+
285+
quote
286+
f = $(esc(call.args[1]))
287+
args = $(esc(Expr(:tuple, call.args[2:end]...)))
288+
compile(f, args)
289+
end
290+
end
291+
272292
traced_getfield(obj, field) = Base.getfield(obj, field)
273293

274294
function create_result(tocopy::T, path, result_stores) where {T}
@@ -287,7 +307,9 @@ function create_result(tocopy::T, path, result_stores) where {T}
287307
end
288308

289309
function create_result(tocopy::ConcreteRArray{T,N}, path, result_stores) where {T,N}
290-
return :(ConcreteRArray{$T,$N}($(result_stores[path]), $(tocopy.shape)))
310+
restore = result_stores[path]
311+
delete!(result_stores, path)
312+
return :(ConcreteRArray{$T,$N}($restore, $(tocopy.shape)))
291313
end
292314

293315
function create_result(tocopy::Array{T,N}, path, result_stores) where {T,N}
@@ -353,9 +375,19 @@ function compile(f, args; pipeline_options="", client=nothing)
353375
closure_ty = typeof(fnwrap)
354376

355377
arg_syncs = Expr[]
378+
resarg_syncs = Expr[]
356379
topres = Symbol[]
357380
linearized_args = Union{Symbol,Expr}[]
358381

382+
concretize = Expr[]
383+
for (idx, _) in enumerate(linear_results)
384+
push!(concretize, :($(Symbol(:concrete_res_, idx)) = linearized_results[$idx]))
385+
end
386+
387+
delinearized_results = Expr[]
388+
389+
result_stores = Dict{Tuple,Symbol}()
390+
359391
for (i, arg) in enumerate(linear_args)
360392
paths = ((p for p in arg.paths if p[1] == :args)...,)
361393
path = if length(paths) == 1
@@ -367,25 +399,48 @@ function compile(f, args; pipeline_options="", client=nothing)
367399
for p in path[3:end]
368400
res = :(traced_getfield($res, $(Meta.quot(p))))
369401
end
402+
usym = Symbol("usbuf_$i")
403+
usbuf = :($usym = $res.data)
370404
sym = Symbol("sbuf_$i")
371-
sbuf = :($sym = XLA.synced_buffer($res.data))
405+
sbuf = :($sym = XLA.synced_buffer($usym))
406+
push!(arg_syncs, usbuf)
372407
push!(arg_syncs, sbuf)
373408

374409
push!(topres, sym)
375410

376411
res = :($sym.buffer)
377412
push!(linearized_args, res)
413+
414+
respaths = ((p for p in arg.paths if p[1] != :args)...,)
415+
416+
resarg = false
417+
for respath in respaths
418+
if respath[1] == :result
419+
res = Symbol("result")
420+
respath = respath[2:end]
421+
result_stores[respath] = usym
422+
resarg = true
423+
continue
424+
else
425+
@assert respath[1] == :resargs
426+
if respath[2] == path[2]
427+
continue
428+
end
429+
res = :(args[$(respath[2])])
430+
path = path[3:end]
431+
end
432+
for p in path
433+
res = :(traced_getfield($res, $(Meta.quot(p))))
434+
end
435+
resarg = true
436+
res = :($res.data = $usym)
437+
push!(delinearized_results, res)
438+
end
439+
if resarg
440+
push!(resarg_syncs, usbuf)
441+
end
378442
end
379-
380-
concretize = Expr[]
381-
for (idx, _) in enumerate(linear_results)
382-
push!(concretize, :($(Symbol(:concrete_res_, idx)) = linearized_results[$idx]))
383-
end
384-
385-
delinearized_results = Expr[]
386-
387-
result_stores = Dict{Tuple,Symbol}()
388-
443+
389444
for (idx, result) in enumerate(linear_results)
390445
paths = ((p for p in result.paths if p[1] != :args)...,)
391446
for path in paths
@@ -412,6 +467,38 @@ function compile(f, args; pipeline_options="", client=nothing)
412467
end
413468
end
414469

470+
donated_args_set = zeros(UInt8, length(linearized_args))
471+
preserved_argnums = [i for (_, i) in preserved_args]
472+
for (i, _) in enumerate(linear_args)
473+
if !in(i, preserved_argnums)
474+
donated_args_set[i] = 1
475+
end
476+
end
477+
donated_args_set = (donated_args_set...,)
478+
479+
exec_call = if length(linear_results) == 0
480+
quote
481+
$(resarg_syncs...)
482+
end
483+
else
484+
quote
485+
$(arg_syncs...)
486+
GC.@preserve $(topres...) begin
487+
linearized_results = XLA.ExecutableCall(
488+
$exec, # thunk.exec,
489+
($(linearized_args...),),
490+
$donated_args_set,
491+
Val($(length(linear_results))),
492+
)
493+
end
494+
end
495+
end
496+
497+
prevkeys = collect(keys(result_stores))
498+
resexpr = create_result(concrete_result, (), result_stores)
499+
postkeys = collect(keys(result_stores))
500+
used = [t for t in prevkeys if !in(t, postkeys)]
501+
415502
for (result, arg_idx) in preserved_args
416503
for path in result.paths
417504
arg = linear_args[arg_idx + 1]
@@ -420,6 +507,9 @@ function compile(f, args; pipeline_options="", client=nothing)
420507
if path[1] == :result
421508
res = Symbol("result")
422509
path = path[2:end]
510+
if in(path, used)
511+
continue
512+
end
423513
else
424514
@assert path[1] == :resargs || path[1] == :args
425515
# We can optimize cases where we set the arg to itself
@@ -433,7 +523,7 @@ function compile(f, args; pipeline_options="", client=nothing)
433523
res = :(traced_getfield($res, $(Meta.quot(p))))
434524
end
435525

436-
argres = :(args[argpath[2]])
526+
argres = :(args[$(argpath[2])])
437527
for p in argpath[3:end]
438528
argres = :(traced_getfield($argres, $(Meta.quot(p))))
439529
end
@@ -443,33 +533,6 @@ function compile(f, args; pipeline_options="", client=nothing)
443533
end
444534
end
445535

446-
donated_args_set = zeros(UInt8, length(linearized_args))
447-
preserved_argnums = [i for (_, i) in preserved_args]
448-
for (i, _) in enumerate(linear_args)
449-
if !in(i, preserved_argnums)
450-
donated_args_set[i] = 1
451-
end
452-
end
453-
donated_args_set = (donated_args_set...,)
454-
455-
exec_call = if length(linear_results) == 0
456-
:()
457-
else
458-
quote
459-
$(arg_syncs...)
460-
GC.@preserve $(topres...) begin
461-
linearized_results = XLA.ExecutableCall(
462-
$exec, # thunk.exec,
463-
($(linearized_args...),),
464-
$donated_args_set,
465-
Val($(length(linear_results))),
466-
)
467-
end
468-
end
469-
end
470-
471-
resexpr = create_result(concrete_result, (), result_stores)
472-
473536
fname = gensym(Symbol(Symbol(f), :_reactant))
474537

475538
expr = :(function $fname(args...)

test/basic.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,3 +245,19 @@ end
245245
# get_indices_compiled = Reactant.compile(get_indices, (x_concrete,))
246246
# get_view_compiled = Reactant.compile(get_view, (x_concrete,))
247247
end
248+
249+
tuple_byref(x) = (; a =(; b=x))
250+
tuple_byref2(x) = abs2.(x), tuple_byref2(x)
251+
252+
@testset "Tuple byref" begin
253+
x = Reactant.to_rarray([1.0 -2.0; -3.0 4.0])
254+
f1 = Reactant.compile(tuple_byref, (x,))
255+
r1 = f1(x)
256+
@test r1.a.b.data === x.data
257+
258+
# TODO this seems to hang during compile
259+
# f2 = Reactant.compile(tuple_byref2, (x,))
260+
# r2 = f2(x)
261+
# @test r2[2].a.b.data === x.data
262+
# @test r2[1] == abs2.([1.0 -2.0; -3.0 4.0])
263+
end

0 commit comments

Comments
 (0)