File tree Expand file tree Collapse file tree 1 file changed +6
-9
lines changed
jax/experimental/mosaic/gpu Expand file tree Collapse file tree 1 file changed +6
-9
lines changed Original file line number Diff line number Diff line change @@ -108,21 +108,18 @@ def c(val: int | float, ty):
108108 return arith .constant (ty , attr )
109109
110110def _debug_scalar_ty_format (arg ):
111- ty_format = None
112111 if ir .IndexType .isinstance (arg .type ):
113- return "%llu"
112+ return "%llu" , arg
114113 if ir .IntegerType .isinstance (arg .type ):
115- width = ir .IntegerType (arg .type ).width
116- ty_format = "%llu"
117- if width < 64 :
114+ if ir .IntegerType (arg .type ).width < 64 :
118115 arg = arith .extui (ir .IntegerType .get_signless (64 ), arg )
116+ return "%llu" , arg
119117 if ir .F32Type .isinstance (arg .type ):
120- ty_format = "%f"
118+ return "%f" , arg
121119 if ir .F16Type .isinstance (arg .type ):
122- ty_format = "%f"
123120 arg = arith .extf (ir .F32Type .get (), arg )
124-
125- return ty_format , arg
121+ return "%f" , arg
122+ raise NotImplementedError ( f"Can't print the type { arg . type } " )
126123
127124def debug_print (fmt , * args , uniform = True ):
128125 type_formats = []
You can’t perform that action at this time.
0 commit comments