We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent f18df8f commit 1bc9df4Copy full SHA for 1bc9df4
jax/experimental/mosaic/gpu/utils.py
@@ -296,6 +296,12 @@ def globaltimer(kind: Literal["low", "high"] | None = None):
296
297
298
def bytewidth(ty: ir.Type):
299
+ # The actual width of TF32 is 19 bits. However, sinc we need to treat it as
300
+ # 32 bits for compatibility reasons. TF32 used to be 32 bits wide in upstream
301
+ # MLIR, but it changed in
302
+ # https://github.com/llvm/llvm-project/commit/67a1fdb014790a38a205d28e1748634de34471dd.
303
+ if ir.FloatTF32Type.isinstance(ty):
304
+ return 4
305
if ir.IntegerType.isinstance(ty):
306
return ir.IntegerType(ty).width // 8
307
if ir.FloatType.isinstance(ty):
0 commit comments