99from .. import language
1010from .._C .libtriton import ir
1111from ..language import constexpr , tensor , str_to_ty
12- from ..language .core import _unwrap_if_constexpr , nv_tma_desc_type
12+ from ..language .core import _unwrap_if_constexpr , nv_tma_desc_type , _value
1313from ..runtime .jit import _normalize_ty , get_jit_fn_file_line
1414# ideally we wouldn't need any runtime component
1515from ..runtime import JITFunction
@@ -47,6 +47,10 @@ def mangle_fn(name, arg_tys, constants):
4747 return ret
4848
4949
50+ def _is_triton_value (o : Any ) -> bool :
51+ return isinstance (o , _value )
52+
53+
5054def _is_triton_tensor (o : Any ) -> bool :
5155 return isinstance (o , tensor )
5256
@@ -501,7 +505,7 @@ def visit_Assign(self, node):
501505 # by default, constexpr are assigned into python variable
502506 value = _unwrap_if_constexpr (value )
503507 if value is not None and \
504- not _is_triton_tensor (value ) and \
508+ not _is_triton_value (value ) and \
505509 not isinstance (value , native_nontensor_types ):
506510 value = language .semantic .to_tensor (value , self .builder )
507511 self .set_value (name , value )
@@ -802,6 +806,15 @@ def visit_UnaryOp(self, node):
802806 ast .USub : '__neg__' , ast .UAdd : '__pos__' , ast .Not : '__not__' , ast .Invert : '__invert__'
803807 }
804808
809+ def _verify_loop_carried_variable (self , name , loop_val , live_val ):
810+ assert _is_triton_value (loop_val ), f'cannot reassign constxpr { name } in the loop'
811+ assert _is_triton_value (live_val ), f'cannot reasign constexpr { name } in the loop'
812+ assert type (loop_val ) == type (live_val ), f'Loop carried variable { name } changed type'
813+ assert not _is_triton_tensor (loop_val ) or loop_val .type == live_val .type , \
814+ f'Loop-carried variable { name } has initial type { live_val .type } ' \
815+ f'but is re-assigned to { loop_val .type } in loop! ' \
816+ f'Please make sure that the type stays consistent.'
817+
805818 def visit_While (self , node ):
806819 with enter_sub_region (self ) as sr :
807820 liveins , insert_block = sr
@@ -824,17 +837,14 @@ def visit_While(self, node):
824837 for name in loop_defs :
825838 if name in liveins :
826839 # We should not def new constexpr
827- assert _is_triton_tensor (loop_defs [name ]), f'cannot reassign constxpr { name } in the loop'
828- assert _is_triton_tensor (liveins [name ]), f'cannot reasign constexpr { name } in the loop'
829- assert loop_defs [name ].type == liveins [name ].type , \
830- f'Loop-carried variable { name } has initial type { liveins [name ].type } ' \
831- f'but is re-assigned to { loop_defs [name ].type } in loop! ' \
832- f'Please make sure that the type stays consistent.'
840+ loop_val = loop_defs [name ]
841+ live_val = liveins [name ]
842+ self ._verify_loop_carried_variable (name , loop_val , live_val )
833843
834844 # these are loop-carried values
835845 names .append (name )
836- ret_types .append (loop_defs [ name ] .type )
837- init_args .append (liveins [ name ] )
846+ ret_types .append (loop_val .type )
847+ init_args .append (live_val )
838848
839849 self ._set_insertion_point_and_loc (ip , last_loc )
840850 while_op = self .builder .create_while_op ([ty .to_ir (self .builder ) for ty in ret_types ],
@@ -972,16 +982,13 @@ def visit_For(self, node):
972982 names = []
973983 for name in self .local_defs :
974984 if name in liveins :
975- assert _is_triton_tensor (self .local_defs [name ]), f'cannot reassign constxpr { name } in the loop'
976- assert _is_triton_tensor (liveins [name ]), f'cannot reassign constxpr { name } in the loop'
977- assert self .local_defs [name ].type == liveins [name ].type , \
978- f'Loop-carried variable { name } has initial type { liveins [name ].type } ' \
979- f'but is re-assigned to { self .local_defs [name ].type } in loop! ' \
980- f'Please make sure that the type stays consistent.'
985+ loop_val = self .local_defs [name ]
986+ live_val = liveins [name ]
987+ self ._verify_loop_carried_variable (name , loop_val , live_val )
981988
982989 names .append (name )
983- init_args .append (language . semantic . to_tensor ( liveins [ name ], self . builder ) )
984- yields .append (language . semantic . to_tensor ( self . local_defs [ name ], self . builder ) )
990+ init_args .append (live_val )
991+ yields .append (loop_val )
985992
986993 # create ForOp
987994 self ._set_insertion_point_and_loc (ip , last_loc )
@@ -1051,7 +1058,7 @@ def visit_Assert(self, node) -> Any:
10511058 def call_JitFunction (self , fn : JITFunction , args , kwargs ):
10521059 args = inspect .getcallargs (fn .fn , * args , ** kwargs )
10531060 args = [args [name ] for name in fn .arg_names ]
1054- args = [arg if _is_triton_tensor (arg ) else constexpr (arg ) for arg in args ]
1061+ args = [arg if _is_triton_value (arg ) else constexpr (arg ) for arg in args ]
10551062 # generate function def
10561063 attributes = {}
10571064 constexprs = [i for i , arg in enumerate (args ) if _is_constexpr (arg )]
@@ -1110,7 +1117,7 @@ def visit_Call(self, node):
11101117 if isinstance (fn , JITFunction ):
11111118 _check_fn_args (node , fn , args )
11121119 return self .call_JitFunction (fn , args , kws )
1113- if (hasattr (fn , '__self__' ) and _is_triton_tensor (fn .__self__ )) or language .core .is_builtin (fn ):
1120+ if (hasattr (fn , '__self__' ) and _is_triton_value (fn .__self__ )) or language .core .is_builtin (fn ):
11141121 extra_kwargs = {"_builder" : self .builder }
11151122 sig = inspect .signature (fn )
11161123 if '_generator' in sig .parameters :
0 commit comments