Skip to content

Conversation

@matthias-springer
Copy link
Member

No description provided.

@llvmbot
Copy link
Member

llvmbot commented Nov 19, 2024

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/116725.diff

2 Files Affected:

  • (modified) mlir/python/mlir/extras/types.py (+2)
  • (modified) mlir/test/python/ir/builtin_types.py (+5)
diff --git a/mlir/python/mlir/extras/types.py b/mlir/python/mlir/extras/types.py
index 34eee1edb57ff5..b875d639e9d406 100644
--- a/mlir/python/mlir/extras/types.py
+++ b/mlir/python/mlir/extras/types.py
@@ -21,6 +21,7 @@
     Float8E4M3Type,
     Float8E5M2Type,
     Float8E8M0FNUType,
+    FloatTF32Type,
     FunctionType,
     IndexType,
     IntegerType,
@@ -70,6 +71,7 @@ def ui(width):
 
 f16 = lambda: F16Type.get()
 f32 = lambda: F32Type.get()
+tf32 = lambda: FloatTF32Type.get()
 f64 = lambda: F64Type.get()
 bf16 = lambda: BF16Type.get()
 
diff --git a/mlir/test/python/ir/builtin_types.py b/mlir/test/python/ir/builtin_types.py
index 48ddc8359ca0a1..6ce0fc12d80824 100644
--- a/mlir/test/python/ir/builtin_types.py
+++ b/mlir/test/python/ir/builtin_types.py
@@ -639,6 +639,7 @@ def testTypeIDs():
             (BF16Type, BF16Type.get()),
             (F16Type, F16Type.get()),
             (F32Type, F32Type.get()),
+            (FloatTF32Type, FloatTF32Type.get()),
             (F64Type, F64Type.get()),
             (NoneType, NoneType.get()),
             (ComplexType, ComplexType.get(f32)),
@@ -668,6 +669,7 @@ def testTypeIDs():
         # CHECK: BF16Type(bf16)
         # CHECK: F16Type(f16)
         # CHECK: F32Type(f32)
+        # CHECK: FloatTF32Type(tf32)
         # CHECK: F64Type(f64)
         # CHECK: NoneType(none)
         # CHECK: ComplexType(complex<f32>)
@@ -734,6 +736,9 @@ def print_downcasted(typ):
         # CHECK: F32Type
         # CHECK: F32Type(f32)
         print_downcasted(F32Type.get())
+        # CHECK: FloatTF32Type
+        # CHECK: FloatTF32Type(tf32)
+        print_downcasted(FloatTF32Type.get())
         # CHECK: F64Type
         # CHECK: F64Type(f64)
         print_downcasted(F64Type.get())

@matthias-springer matthias-springer merged commit e17c913 into main Nov 19, 2024
9 of 10 checks passed
@matthias-springer matthias-springer deleted the users/matthias-springer/tf32_python branch November 19, 2024 02:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

mlir:python MLIR Python bindings mlir

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants