@@ -435,6 +435,47 @@ def _set_insertion_point_and_loc(self, ip, loc):
435
435
self .builder .restore_insertion_point (ip )
436
436
self .builder .set_loc (loc )
437
437
438
+ def _find_carries (self , node , liveins ):
439
+ # We must extract the handles before the value is editted in the loop
440
+ livehandles = {name : flatten_values_to_ir ([v ]) for name , v in liveins .items () if _is_triton_value (v )}
441
+ # create loop body block
442
+ block = self .builder .create_block ()
443
+ self .builder .set_insertion_point_to_start (block )
444
+ # dry visit loop body
445
+ self .scf_stack .append (node )
446
+ self .visit_compound_statement (node .body )
447
+ self .scf_stack .pop ()
448
+ block .erase ()
449
+
450
+ # If a variable (name) has changed value within the loop, then it's
451
+ # a loop-carried variable. (The new and old value must be of the
452
+ # same type)
453
+ init_tys = []
454
+ init_handles = []
455
+ names = []
456
+
457
+ for name , live_val in liveins .items ():
458
+ if _is_triton_value (live_val ):
459
+ loop_val = self .lscope [name ]
460
+ assert type (live_val ) is type (loop_val ), f'Loop carried variable { name } changed type'
461
+
462
+ live_handles = livehandles [name ]
463
+ loop_handles = flatten_values_to_ir ([loop_val ])
464
+ if live_handles != loop_handles :
465
+ self ._verify_loop_carried_variable (name , loop_val , live_val )
466
+
467
+ names .append (name )
468
+ init_tys .append (live_val .type )
469
+ init_handles .extend (live_handles )
470
+ else :
471
+ assert name not in self .local_defs , f'Loop carried variable { name } is not a triton value'
472
+
473
+ # reset local scope to not pick up local defs from the dry run.
474
+ self .lscope = liveins .copy ()
475
+ self .local_defs = {}
476
+
477
+ return names , init_handles , init_tys
478
+
438
479
#
439
480
# AST visitor
440
481
#
@@ -918,8 +959,8 @@ def visit_UnaryOp(self, node):
918
959
}
919
960
920
961
def _verify_loop_carried_variable (self , name , loop_val , live_val ):
921
- assert _is_triton_value (loop_val ), f'cannot reassign constxpr { name } in the loop'
922
- assert _is_triton_value (live_val ), f'cannot reasign constexpr { name } in the loop'
962
+ assert _is_triton_value (loop_val ), f'cannot reassign constexpr { name } in the loop'
963
+ assert _is_triton_value (live_val ), f'cannot reassign constexpr { name } in the loop'
923
964
assert type (loop_val ) is type (live_val ), f'Loop carried variable { name } changed type'
924
965
assert not _is_triton_tensor (loop_val ) or loop_val .type == live_val .type , \
925
966
f'Loop-carried variable { name } has initial type { live_val .type } ' \
@@ -931,33 +972,9 @@ def visit_While(self, node):
931
972
liveins , insert_block = sr
932
973
ip , last_loc = self ._get_insertion_point_and_loc ()
933
974
934
- # loop body (the after region)
935
- # loop_block = self.builder.create_block()
936
- dummy = self .builder .create_block ()
937
- self .builder .set_insertion_point_to_start (dummy )
938
- self .scf_stack .append (node )
939
- self .visit_compound_statement (node .body )
940
- self .scf_stack .pop ()
941
- loop_defs = self .local_defs
942
- dummy .erase ()
943
-
944
- # collect loop-carried values
945
- names = []
946
- init_args = []
947
- for name in loop_defs :
948
- if name in liveins :
949
- # We should not def new constexpr
950
- loop_val = loop_defs [name ]
951
- live_val = liveins [name ]
952
- self ._verify_loop_carried_variable (name , loop_val , live_val )
953
-
954
- # these are loop-carried values
955
- names .append (name )
956
- init_args .append (live_val )
975
+ names , init_handles , init_fe_tys = self ._find_carries (node , liveins )
957
976
958
- init_handles = flatten_values_to_ir (init_args )
959
977
init_tys = [h .get_type () for h in init_handles ]
960
- init_fe_tys = [a .type for a in init_args ]
961
978
self ._set_insertion_point_and_loc (ip , last_loc )
962
979
while_op = self .builder .create_while_op (init_tys , init_handles )
963
980
# merge the condition region
@@ -985,13 +1002,9 @@ def visit_While(self, node):
985
1002
self .scf_stack .append (node )
986
1003
self .visit_compound_statement (node .body )
987
1004
self .scf_stack .pop ()
988
- loop_defs = self .local_defs
989
- yields = []
990
- for name in loop_defs :
991
- if name in liveins :
992
- loop_defs [name ]._flatten_ir (yields )
993
1005
994
- self .builder .create_yield_op (yields )
1006
+ yield_handles = flatten_values_to_ir (self .lscope [name ] for name in names )
1007
+ self .builder .create_yield_op (yield_handles )
995
1008
996
1009
# WhileOp defines new values, update the symbol table (lscope, local_defs)
997
1010
result_handles = [while_op .get_result (i ) for i in range (len (init_handles ))]
@@ -1097,34 +1110,10 @@ def visit_For(self, node):
1097
1110
liveins , insert_block = sr
1098
1111
ip , last_loc = self ._get_insertion_point_and_loc ()
1099
1112
1100
- # create loop body block
1101
- block = self .builder .create_block ()
1102
- self .builder .set_insertion_point_to_start (block )
1103
- # dry visit loop body
1104
- self .scf_stack .append (node )
1105
- self .visit_compound_statement (node .body )
1106
- self .scf_stack .pop ()
1107
- block .erase ()
1108
-
1109
- # If a variable (name) is defined in both its parent & itself, then it's
1110
- # a loop-carried variable. (They must be of the same type)
1111
- init_args = []
1112
- yields = []
1113
- names = []
1114
- for name in self .local_defs :
1115
- if name in liveins :
1116
- loop_val = self .local_defs [name ]
1117
- live_val = liveins [name ]
1118
- self ._verify_loop_carried_variable (name , loop_val , live_val )
1119
-
1120
- names .append (name )
1121
- init_args .append (live_val )
1122
- yields .append (loop_val )
1113
+ names , init_handles , init_tys = self ._find_carries (node , liveins )
1123
1114
1124
1115
# create ForOp
1125
1116
self ._set_insertion_point_and_loc (ip , last_loc )
1126
- init_handles = flatten_values_to_ir (init_args )
1127
- init_tys = [v .type for v in init_args ]
1128
1117
for_op = self .builder .create_for_op (lb , ub , step , init_handles )
1129
1118
if _unwrap_if_constexpr (num_stages ) is not None :
1130
1119
for_op .set_attr ("tt.num_stages" , self .builder .get_int32_attr (num_stages ))
@@ -1140,26 +1129,16 @@ def visit_For(self, node):
1140
1129
self .scf_stack .append (node )
1141
1130
for_op_body = for_op .get_body (0 )
1142
1131
self .builder .set_insertion_point_to_start (for_op_body )
1143
- # reset local scope to not pick up local defs from the previous dry run.
1144
- self .lscope = liveins .copy ()
1145
- self .local_defs = {}
1146
1132
block_handles = [for_op_body .arg (i + 1 ) for i in range (len (init_handles ))]
1147
1133
block_args = unflatten_ir_values (block_handles , init_tys )
1148
1134
for name , val in zip (names , block_args ):
1149
1135
self .set_value (name , val )
1150
1136
self .visit_compound_statement (node .body )
1151
1137
self .scf_stack .pop ()
1152
- yields = []
1153
- for name in self .local_defs :
1154
- if name in liveins :
1155
- local = self .local_defs [name ]
1156
- if isinstance (local , constexpr ):
1157
- local = self .semantic .to_tensor (local )
1158
- yields .append (local )
1138
+ yield_handles = flatten_values_to_ir (self .lscope [name ] for name in names )
1159
1139
1160
1140
# create YieldOp
1161
- if len (yields ) > 0 :
1162
- yield_handles = flatten_values_to_ir (yields )
1141
+ if len (yield_handles ) > 0 :
1163
1142
self .builder .create_yield_op (yield_handles )
1164
1143
for_op_region = for_op_body .get_parent ()
1165
1144
assert for_op_region .size () == 1 , "We use SCF, so the loop body should only have one block"
0 commit comments