Skip to content

Commit d5bfafb

Browse files
cperivolGoogle-ML-Automation
authored andcommitted
[mgpu] Added a missed case for debug_print types and raise a proper error if a type is unexpected.
PiperOrigin-RevId: 701003002
1 parent 14ddb81 commit d5bfafb

File tree

1 file changed

+6
-9
lines changed

1 file changed

+6
-9
lines changed

jax/experimental/mosaic/gpu/utils.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -108,21 +108,18 @@ def c(val: int | float, ty):
108108
return arith.constant(ty, attr)
109109

110110
def _debug_scalar_ty_format(arg):
111-
ty_format = None
112111
if ir.IndexType.isinstance(arg.type):
113-
return "%llu"
112+
return "%llu", arg
114113
if ir.IntegerType.isinstance(arg.type):
115-
width = ir.IntegerType(arg.type).width
116-
ty_format = "%llu"
117-
if width < 64:
114+
if ir.IntegerType(arg.type).width < 64:
118115
arg = arith.extui(ir.IntegerType.get_signless(64), arg)
116+
return "%llu", arg
119117
if ir.F32Type.isinstance(arg.type):
120-
ty_format = "%f"
118+
return "%f", arg
121119
if ir.F16Type.isinstance(arg.type):
122-
ty_format = "%f"
123120
arg = arith.extf(ir.F32Type.get(), arg)
124-
125-
return ty_format, arg
121+
return "%f", arg
122+
raise NotImplementedError(f"Can't print the type {arg.type}")
126123

127124
def debug_print(fmt, *args, uniform=True):
128125
type_formats = []

0 commit comments

Comments
 (0)