246
246
macro code_hlo (options, maybe_call= nothing )
247
247
call = something (maybe_call, options)
248
248
options = isnothing (maybe_call) ? :(optimize = true ) : options
249
- Meta. isexpr (call, :call ) || error (" @code_mlir : expected call, got $call " )
249
+ Meta. isexpr (call, :call ) || error (" @code_hlo : expected call, got $call " )
250
250
if ! Meta. isexpr (options, :(= )) || options. args[1 ] != :optimize
251
- error (" @code_mlir : expected options in format optimize=value, got $options " )
251
+ error (" @code_hlo : expected options in format optimize=value, got $options " )
252
252
end
253
253
254
254
options = Expr (:tuple , Expr (:parameters , Expr (:kw , options. args... )))
@@ -269,6 +269,26 @@ macro code_hlo(options, maybe_call=nothing)
269
269
end
270
270
end
271
271
272
+ """
273
+ @compile f(args...)
274
+ """
275
+ macro compile (options, maybe_call= nothing )
276
+ call = something (maybe_call, options)
277
+ options = isnothing (maybe_call) ? :(optimize = true ) : options
278
+ Meta. isexpr (call, :call ) || error (" @compile: expected call, got $call " )
279
+ if ! Meta. isexpr (options, :(= )) || options. args[1 ] != :optimize
280
+ error (" @compile: expected options in format optimize=value, got $options " )
281
+ end
282
+
283
+ options = Expr (:tuple , Expr (:parameters , Expr (:kw , options. args... )))
284
+
285
+ quote
286
+ f = $ (esc (call. args[1 ]))
287
+ args = $ (esc (Expr (:tuple , call. args[2 : end ]. .. )))
288
+ compile (f, args)
289
+ end
290
+ end
291
+
272
292
traced_getfield (obj, field) = Base. getfield (obj, field)
273
293
274
294
function create_result (tocopy:: T , path, result_stores) where {T}
@@ -287,7 +307,9 @@ function create_result(tocopy::T, path, result_stores) where {T}
287
307
end
288
308
289
309
function create_result (tocopy:: ConcreteRArray{T,N} , path, result_stores) where {T,N}
290
- return :(ConcreteRArray {$T,$N} ($ (result_stores[path]), $ (tocopy. shape)))
310
+ restore = result_stores[path]
311
+ delete! (result_stores, path)
312
+ return :(ConcreteRArray {$T,$N} ($ restore, $ (tocopy. shape)))
291
313
end
292
314
293
315
function create_result (tocopy:: Array{T,N} , path, result_stores) where {T,N}
@@ -353,9 +375,19 @@ function compile(f, args; pipeline_options="", client=nothing)
353
375
closure_ty = typeof (fnwrap)
354
376
355
377
arg_syncs = Expr[]
378
+ resarg_syncs = Expr[]
356
379
topres = Symbol[]
357
380
linearized_args = Union{Symbol,Expr}[]
358
381
382
+ concretize = Expr[]
383
+ for (idx, _) in enumerate (linear_results)
384
+ push! (concretize, :($ (Symbol (:concrete_res_ , idx)) = linearized_results[$ idx]))
385
+ end
386
+
387
+ delinearized_results = Expr[]
388
+
389
+ result_stores = Dict {Tuple,Symbol} ()
390
+
359
391
for (i, arg) in enumerate (linear_args)
360
392
paths = ((p for p in arg. paths if p[1 ] == :args ). .. ,)
361
393
path = if length (paths) == 1
@@ -367,25 +399,48 @@ function compile(f, args; pipeline_options="", client=nothing)
367
399
for p in path[3 : end ]
368
400
res = :(traced_getfield ($ res, $ (Meta. quot (p))))
369
401
end
402
+ usym = Symbol (" usbuf_$i " )
403
+ usbuf = :($ usym = $ res. data)
370
404
sym = Symbol (" sbuf_$i " )
371
- sbuf = :($ sym = XLA. synced_buffer ($ res. data))
405
+ sbuf = :($ sym = XLA. synced_buffer ($ usym))
406
+ push! (arg_syncs, usbuf)
372
407
push! (arg_syncs, sbuf)
373
408
374
409
push! (topres, sym)
375
410
376
411
res = :($ sym. buffer)
377
412
push! (linearized_args, res)
413
+
414
+ respaths = ((p for p in arg. paths if p[1 ] != :args ). .. ,)
415
+
416
+ resarg = false
417
+ for respath in respaths
418
+ if respath[1 ] == :result
419
+ res = Symbol (" result" )
420
+ respath = respath[2 : end ]
421
+ result_stores[respath] = usym
422
+ resarg = true
423
+ continue
424
+ else
425
+ @assert respath[1 ] == :resargs
426
+ if respath[2 ] == path[2 ]
427
+ continue
428
+ end
429
+ res = :(args[$ (respath[2 ])])
430
+ path = path[3 : end ]
431
+ end
432
+ for p in path
433
+ res = :(traced_getfield ($ res, $ (Meta. quot (p))))
434
+ end
435
+ resarg = true
436
+ res = :($ res. data = $ usym)
437
+ push! (delinearized_results, res)
438
+ end
439
+ if resarg
440
+ push! (resarg_syncs, usbuf)
441
+ end
378
442
end
379
-
380
- concretize = Expr[]
381
- for (idx, _) in enumerate (linear_results)
382
- push! (concretize, :($ (Symbol (:concrete_res_ , idx)) = linearized_results[$ idx]))
383
- end
384
-
385
- delinearized_results = Expr[]
386
-
387
- result_stores = Dict {Tuple,Symbol} ()
388
-
443
+
389
444
for (idx, result) in enumerate (linear_results)
390
445
paths = ((p for p in result. paths if p[1 ] != :args ). .. ,)
391
446
for path in paths
@@ -412,6 +467,38 @@ function compile(f, args; pipeline_options="", client=nothing)
412
467
end
413
468
end
414
469
470
+ donated_args_set = zeros (UInt8, length (linearized_args))
471
+ preserved_argnums = [i for (_, i) in preserved_args]
472
+ for (i, _) in enumerate (linear_args)
473
+ if ! in (i, preserved_argnums)
474
+ donated_args_set[i] = 1
475
+ end
476
+ end
477
+ donated_args_set = (donated_args_set... ,)
478
+
479
+ exec_call = if length (linear_results) == 0
480
+ quote
481
+ $ (resarg_syncs... )
482
+ end
483
+ else
484
+ quote
485
+ $ (arg_syncs... )
486
+ GC. @preserve $ (topres... ) begin
487
+ linearized_results = XLA. ExecutableCall (
488
+ $ exec, # thunk.exec,
489
+ ($ (linearized_args... ),),
490
+ $ donated_args_set,
491
+ Val ($ (length (linear_results))),
492
+ )
493
+ end
494
+ end
495
+ end
496
+
497
+ prevkeys = collect (keys (result_stores))
498
+ resexpr = create_result (concrete_result, (), result_stores)
499
+ postkeys = collect (keys (result_stores))
500
+ used = [t for t in prevkeys if ! in (t, postkeys)]
501
+
415
502
for (result, arg_idx) in preserved_args
416
503
for path in result. paths
417
504
arg = linear_args[arg_idx + 1 ]
@@ -420,6 +507,9 @@ function compile(f, args; pipeline_options="", client=nothing)
420
507
if path[1 ] == :result
421
508
res = Symbol (" result" )
422
509
path = path[2 : end ]
510
+ if in (path, used)
511
+ continue
512
+ end
423
513
else
424
514
@assert path[1 ] == :resargs || path[1 ] == :args
425
515
# We can optimize cases where we set the arg to itself
@@ -433,7 +523,7 @@ function compile(f, args; pipeline_options="", client=nothing)
433
523
res = :(traced_getfield ($ res, $ (Meta. quot (p))))
434
524
end
435
525
436
- argres = :(args[argpath[2 ]])
526
+ argres = :(args[$ ( argpath[2 ]) ])
437
527
for p in argpath[3 : end ]
438
528
argres = :(traced_getfield ($ argres, $ (Meta. quot (p))))
439
529
end
@@ -443,33 +533,6 @@ function compile(f, args; pipeline_options="", client=nothing)
443
533
end
444
534
end
445
535
446
- donated_args_set = zeros (UInt8, length (linearized_args))
447
- preserved_argnums = [i for (_, i) in preserved_args]
448
- for (i, _) in enumerate (linear_args)
449
- if ! in (i, preserved_argnums)
450
- donated_args_set[i] = 1
451
- end
452
- end
453
- donated_args_set = (donated_args_set... ,)
454
-
455
- exec_call = if length (linear_results) == 0
456
- :()
457
- else
458
- quote
459
- $ (arg_syncs... )
460
- GC. @preserve $ (topres... ) begin
461
- linearized_results = XLA. ExecutableCall (
462
- $ exec, # thunk.exec,
463
- ($ (linearized_args... ),),
464
- $ donated_args_set,
465
- Val ($ (length (linear_results))),
466
- )
467
- end
468
- end
469
- end
470
-
471
- resexpr = create_result (concrete_result, (), result_stores)
472
-
473
536
fname = gensym (Symbol (Symbol (f), :_reactant ))
474
537
475
538
expr = :(function $fname (args... )
0 commit comments