@@ -107,28 +107,46 @@ def c(val: int | float, ty):
107107 raise NotImplementedError (ty )
108108 return arith .constant (ty , attr )
109109
110+ def _debug_scalar_ty_format (arg ):
111+ ty_format = None
112+ if ir .IndexType .isinstance (arg .type ):
113+ return "%llu"
114+ if ir .IntegerType .isinstance (arg .type ):
115+ width = ir .IntegerType (arg .type ).width
116+ ty_format = "%llu"
117+ if width < 64 :
118+ arg = arith .extui (ir .IntegerType .get_signless (64 ), arg )
119+ if ir .F32Type .isinstance (arg .type ):
120+ ty_format = "%f"
121+ if ir .F16Type .isinstance (arg .type ):
122+ ty_format = "%f"
123+ arg = arith .extf (ir .F32Type .get (), arg )
124+
125+ return ty_format , arg
110126
111127def debug_print (fmt , * args , uniform = True ):
112128 type_formats = []
113129 new_args = []
114130 for arg in args :
115- ty_format = None
116- if ir .IndexType .isinstance (arg .type ):
117- ty_format = "%llu"
118- if ir .IntegerType .isinstance (arg .type ):
119- width = ir .IntegerType (arg .type ).width
120- ty_format = "%llu"
121- if width < 64 :
122- arg = arith .extui (ir .IntegerType .get_signless (64 ), arg )
123- if ir .F32Type .isinstance (arg .type ):
124- ty_format = "%f"
125- if ir .F16Type .isinstance (arg .type ):
126- ty_format = "%f"
127- arg = arith .extf (ir .F32Type .get (), arg )
131+ if ir .VectorType .isinstance (arg .type ):
132+ index = ir .IndexType .get ()
133+ vec_ty = ir .VectorType (arg .type )
134+ if len (vec_ty .shape ) > 1 :
135+ raise NotImplementedError (vec_ty )
136+ vec_args = [
137+ vector .extractelement (arg , position = c (i , index ))
138+ for i in range (vec_ty .shape [0 ])
139+ ]
140+ ty_formats , args = zip (* map (_debug_scalar_ty_format ,vec_args ))
141+ ty_format = f"[{ ',' .join (ty_formats )} ]"
142+ new_args += args
143+ else :
144+ ty_format , arg = _debug_scalar_ty_format (arg )
145+ new_args .append (arg )
146+
128147 if ty_format is None :
129148 raise NotImplementedError (arg .type )
130149 type_formats .append (ty_format )
131- new_args .append (arg )
132150 ctx = (
133151 functools .partial (single_thread , per_block = False )
134152 if uniform
0 commit comments