Skip to content

Commit df8ecb9

Browse files
cperivolGoogle-ML-Automation
authored andcommitted
[mgpu] Debug print for mlir vectors.
PiperOrigin-RevId: 700714031
1 parent d449f12 commit df8ecb9

File tree

1 file changed

+32
-14
lines changed

1 file changed

+32
-14
lines changed

jax/experimental/mosaic/gpu/utils.py

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -107,28 +107,46 @@ def c(val: int | float, ty):
107107
raise NotImplementedError(ty)
108108
return arith.constant(ty, attr)
109109

110+
def _debug_scalar_ty_format(arg):
111+
ty_format = None
112+
if ir.IndexType.isinstance(arg.type):
113+
return "%llu"
114+
if ir.IntegerType.isinstance(arg.type):
115+
width = ir.IntegerType(arg.type).width
116+
ty_format = "%llu"
117+
if width < 64:
118+
arg = arith.extui(ir.IntegerType.get_signless(64), arg)
119+
if ir.F32Type.isinstance(arg.type):
120+
ty_format = "%f"
121+
if ir.F16Type.isinstance(arg.type):
122+
ty_format = "%f"
123+
arg = arith.extf(ir.F32Type.get(), arg)
124+
125+
return ty_format, arg
110126

111127
def debug_print(fmt, *args, uniform=True):
112128
type_formats = []
113129
new_args = []
114130
for arg in args:
115-
ty_format = None
116-
if ir.IndexType.isinstance(arg.type):
117-
ty_format = "%llu"
118-
if ir.IntegerType.isinstance(arg.type):
119-
width = ir.IntegerType(arg.type).width
120-
ty_format = "%llu"
121-
if width < 64:
122-
arg = arith.extui(ir.IntegerType.get_signless(64), arg)
123-
if ir.F32Type.isinstance(arg.type):
124-
ty_format = "%f"
125-
if ir.F16Type.isinstance(arg.type):
126-
ty_format = "%f"
127-
arg = arith.extf(ir.F32Type.get(), arg)
131+
if ir.VectorType.isinstance(arg.type):
132+
index = ir.IndexType.get()
133+
vec_ty = ir.VectorType(arg.type)
134+
if len(vec_ty.shape) > 1:
135+
raise NotImplementedError(vec_ty)
136+
vec_args = [
137+
vector.extractelement(arg, position=c(i, index))
138+
for i in range(vec_ty.shape[0])
139+
]
140+
ty_formats, args = zip(*map(_debug_scalar_ty_format,vec_args))
141+
ty_format = f"[{','.join(ty_formats)}]"
142+
new_args += args
143+
else:
144+
ty_format, arg = _debug_scalar_ty_format(arg)
145+
new_args.append(arg)
146+
128147
if ty_format is None:
129148
raise NotImplementedError(arg.type)
130149
type_formats.append(ty_format)
131-
new_args.append(arg)
132150
ctx = (
133151
functools.partial(single_thread, per_block=False)
134152
if uniform

0 commit comments

Comments
 (0)