@@ -60,6 +60,7 @@ defmodule EXLA.Defn do
60
60
compile_options ,
61
61
used_buffers ,
62
62
used_inputs ,
63
+ _stream = true ,
63
64
comp_fun
64
65
)
65
66
@@ -258,7 +259,7 @@ defmodule EXLA.Defn do
258
259
callback = & to_root_computation ( & 1 , & 2 , & 3 , & 4 , Keyword . put ( compile_options , :client , client ) )
259
260
260
261
{ executable , used_inputs , outputs , outfeed , :ok , debug? } =
261
- compile ( client , key , vars , fun , compile_options , 0 , [ ] , callback )
262
+ compile ( client , key , vars , fun , compile_options , 0 , [ ] , _stream = false , callback )
262
263
263
264
fn [ args ] ->
264
265
{ time , lock } =
@@ -357,7 +358,17 @@ defmodule EXLA.Defn do
357
358
358
359
## Compile
359
360
360
- defp compile ( client , key , vars , fun , options , used_buffers , used_inputs , to_computation ) do
361
+ defp compile (
362
+ client ,
363
+ key ,
364
+ vars ,
365
+ fun ,
366
+ options ,
367
+ used_buffers ,
368
+ used_inputs ,
369
+ stream? ,
370
+ to_computation
371
+ ) do
361
372
{ { expr_cache_fun , comp_cache_fun } , options } =
362
373
case Keyword . pop ( options , :cache , true ) do
363
374
{ true , options } ->
@@ -385,7 +396,7 @@ defmodule EXLA.Defn do
385
396
386
397
{ eval_time , { expr , { ref , outputs , { used_inputs , defined_hooks } } } } =
387
398
:timer . tc ( fn ->
388
- expr_cache_fun . ( { key , args_key } , fn ->
399
+ expr_cache_fun . ( { key , args_key , lazy_transfers } , fn ->
389
400
expr = fun . ( vars )
390
401
inputs_and_hooks = Outfeed . used_inputs_and_hooks ( expr , used_inputs , lazy_transfers )
391
402
{ expr , { make_ref ( ) , Nx . to_template ( expr ) , inputs_and_hooks } }
@@ -432,10 +443,16 @@ defmodule EXLA.Defn do
432
443
end )
433
444
434
445
EXLA.MLIR.Module . new ( comp_arg_typespecs , out_typespecs , fn builder ->
446
+ # Only create the token when we know it will actually be
447
+ # used, that is: streaming, lazy transfers or hooks
435
448
outfeed =
436
- outfeed
437
- |> Outfeed . with_token ( Value . create_token ( builder ) )
438
- |> Outfeed . add_infeeds ( builder , reverse_infeeds )
449
+ if stream? or reverse_infeeds != [ ] or hooks != % { } or defined_hooks != % { } do
450
+ outfeed
451
+ |> Outfeed . with_token ( Value . create_token ( builder ) )
452
+ |> Outfeed . add_infeeds ( builder , reverse_infeeds )
453
+ else
454
+ outfeed
455
+ end
439
456
440
457
expr = Nx.Defn.Composite . traverse ( expr || fun . ( vars ) , & Nx . devectorize / 1 )
441
458
@@ -520,19 +537,30 @@ defmodule EXLA.Defn do
520
537
cache
521
538
) do
522
539
[ initial_arg , _arg , pred , body ] = args
523
- initial_with_token = { get_token ( cache ) , initial_arg }
524
540
525
- { initial , cache } = recur_composite ( initial_with_token , state , cache )
541
+ initial =
542
+ if token = get_token ( cache ) do
543
+ { token , initial_arg }
544
+ else
545
+ initial_arg
546
+ end
547
+
548
+ { initial , cache } = recur_composite ( initial , state , cache )
526
549
527
550
{ pred_computation , cache } = mlir_while_computation ( pred , initial , { :pred , 8 } , state , cache )
528
551
{ body_computation , cache } = mlir_while_computation ( body , initial , :with_token , state , cache )
529
552
530
- [ token | results ] =
553
+ results =
531
554
Value . while ( function , pred_computation , body_computation , List . flatten ( initial ) )
532
555
533
- result = wrap_tuple_result ( results , initial_arg )
534
-
535
- { result , update_token ( cache , token ) }
556
+ if get_token ( cache ) do
557
+ [ token | results ] = results
558
+ result = wrap_tuple_result ( results , initial_arg )
559
+ { result , update_token ( cache , token ) }
560
+ else
561
+ result = wrap_tuple_result ( results , initial_arg )
562
+ { result , cache }
563
+ end
536
564
end
537
565
538
566
defp cached_recur_operator ( :cond , % T { data: % Expr { args: args } } = t , state , cache ) do
@@ -688,16 +716,19 @@ defmodule EXLA.Defn do
688
716
{ computation , cache }
689
717
690
718
% { } ->
691
- { computation , cache } = token_computation ( "optional" , call_args , expr , state , cache )
719
+ { computation , cache } = optional_computation ( "optional" , call_args , expr , state , cache )
692
720
{ computation , Map . put ( cache , key , computation ) }
693
721
end
694
722
695
- typespecs = [ Typespec . token ( ) | container_to_typespecs ( expr ) ]
696
-
697
- [ token | result ] =
698
- Value . call ( state . builder , [ get_token ( cache ) | call_args ] , call_body , typespecs )
699
-
700
- { wrap_tuple_result ( result , expr ) , update_token ( cache , token ) }
723
+ if token = get_token ( cache ) do
724
+ typespecs = [ Typespec . token ( ) | container_to_typespecs ( expr ) ]
725
+ [ token | result ] = Value . call ( state . builder , [ token | call_args ] , call_body , typespecs )
726
+ { wrap_tuple_result ( result , expr ) , update_token ( cache , token ) }
727
+ else
728
+ typespecs = container_to_typespecs ( expr )
729
+ result = Value . call ( state . builder , call_args , call_body , typespecs )
730
+ { wrap_tuple_result ( result , expr ) , cache }
731
+ end
701
732
end
702
733
703
734
defp cached_recur_operator ( :attach_token , % T { data: % Expr { args: [ token , expr ] } } , state , cache ) do
@@ -1553,7 +1584,17 @@ defmodule EXLA.Defn do
1553
1584
defp mlir_while_computation ( expr , initial , type , state , cache ) do
1554
1585
arg_typespecs = Enum . map ( List . flatten ( initial ) , & Value . get_typespec / 1 )
1555
1586
1556
- { region , [ arg_token | arg_params ] } = Function . push_region ( state . builder , arg_typespecs )
1587
+ { region , args } = Function . push_region ( state . builder , arg_typespecs )
1588
+
1589
+ outer_token = get_token ( cache )
1590
+
1591
+ { inner_token , arg_params } =
1592
+ if outer_token do
1593
+ [ arg_token | arg_params ] = args
1594
+ { arg_token , arg_params }
1595
+ else
1596
+ { nil , args }
1597
+ end
1557
1598
1558
1599
params = Enum . with_index ( arg_params , & { & 2 , & 1 } )
1559
1600
@@ -1570,11 +1611,15 @@ defmodule EXLA.Defn do
1570
1611
expr
1571
1612
end
1572
1613
1573
- { res , comp_cache } = recur_composite ( expr , & & 1 , state , reset_token ( cache , arg_token ) )
1614
+ { res , comp_cache } = recur_composite ( expr , & & 1 , state , reset_token ( cache , inner_token ) )
1574
1615
1575
1616
res =
1576
1617
if type == :with_token do
1577
- [ get_token ( comp_cache ) | List . flatten ( res ) ]
1618
+ if outer_token do
1619
+ [ get_token ( comp_cache ) | List . flatten ( res ) ]
1620
+ else
1621
+ List . flatten ( res )
1622
+ end
1578
1623
else
1579
1624
Enum . map ( res , & to_type ( & 1 , type ) )
1580
1625
end
@@ -1585,21 +1630,34 @@ defmodule EXLA.Defn do
1585
1630
{ region , merge_outfeed ( cache , comp_cache ) }
1586
1631
end
1587
1632
1588
- defp token_computation ( name , args , expr , % { builder: % Function { } } = state , cache ) do
1633
+ defp optional_computation ( name , args , expr , % { builder: % Function { } } = state , cache ) do
1589
1634
% Function { module: module , name: name } = subbuilder ( state . builder , name )
1590
1635
1591
- token_typespec = Typespec . token ( )
1592
1636
arg_typespecs = Enum . map ( args , & Value . get_typespec / 1 )
1593
1637
out_typespecs = container_to_typespecs ( expr )
1594
1638
1595
- function =
1596
- EXLA.MLIR.Module . add_function ( module , name , [ token_typespec | arg_typespecs ] , [
1597
- token_typespec | out_typespecs
1598
- ] )
1639
+ outer_token = get_token ( cache )
1640
+ token_typespec = Typespec . token ( )
1641
+
1642
+ { arg_typespecs , out_typespecs } =
1643
+ if outer_token do
1644
+ { [ token_typespec | arg_typespecs ] , [ token_typespec | out_typespecs ] }
1645
+ else
1646
+ { arg_typespecs , out_typespecs }
1647
+ end
1599
1648
1600
- [ arg_token | tail ] = EXLA.MLIR.Function . get_arguments ( function )
1649
+ function = EXLA.MLIR.Module . add_function ( module , name , arg_typespecs , out_typespecs )
1650
+ args = EXLA.MLIR.Function . get_arguments ( function )
1601
1651
1602
- params = Enum . with_index ( tail , fn param , i -> { i , param } end )
1652
+ { inner_token , args } =
1653
+ if outer_token do
1654
+ [ arg_token | args ] = args
1655
+ { arg_token , args }
1656
+ else
1657
+ { nil , args }
1658
+ end
1659
+
1660
+ params = Enum . with_index ( args , fn param , i -> { i , param } end )
1603
1661
1604
1662
state = % {
1605
1663
state
@@ -1608,9 +1666,13 @@ defmodule EXLA.Defn do
1608
1666
scope_ids: Tree . scope_ids ( expr )
1609
1667
}
1610
1668
1611
- { res , comp_cache } = recur_composite ( expr , state , reset_token ( cache , arg_token ) )
1669
+ { res , comp_cache } = recur_composite ( expr , state , reset_token ( cache , inner_token ) )
1612
1670
1613
- Value . return ( function , [ get_token ( comp_cache ) | List . flatten ( res ) ] )
1671
+ if outer_token do
1672
+ Value . return ( function , [ get_token ( comp_cache ) | List . flatten ( res ) ] )
1673
+ else
1674
+ Value . return ( function , List . flatten ( res ) )
1675
+ end
1614
1676
1615
1677
{ function , merge_outfeed ( cache , comp_cache ) }
1616
1678
end
@@ -1786,10 +1848,10 @@ defmodule EXLA.Defn do
1786
1848
1787
1849
out_typespecs = container_to_typespecs ( on_true )
1788
1850
1789
- in_token = get_token ( cache )
1851
+ outer_token = get_token ( cache )
1790
1852
1791
1853
result_typespecs =
1792
- if in_token do
1854
+ if outer_token do
1793
1855
[ Typespec . token ( ) | out_typespecs ]
1794
1856
else
1795
1857
out_typespecs
@@ -1799,7 +1861,7 @@ defmodule EXLA.Defn do
1799
1861
{ false_computation , cache } = to_mlir_if_branch ( on_false , false_ids , state , cache )
1800
1862
if_results = Value . if_op ( pred_op , true_computation , false_computation , result_typespecs )
1801
1863
1802
- if in_token do
1864
+ if outer_token do
1803
1865
[ token | results ] = if_results
1804
1866
{ wrap_tuple_result ( results , on_true ) , update_token ( cache , token ) }
1805
1867
else
0 commit comments