You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
It's slightly inconvenient to go through fp16/bf16 when we want to
(de)quantize mxfp from/to fp32
# New contributor declaration
- [x] I am not making a trivial change, such as fixing a typo in a
comment.
- [x] I have written a PR description following these
[rules](https://cbea.ms/git-commit/#why-not-how).
- [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`.
- Select one of the following.
- [x] I have added tests.
- `/test` for `lit` tests
- `/unittest` for C++ tests
- `/python/test` for end-to-end tests -> Added more test cases in
```pytest -xs python/triton_kernels/tests/test_mxfp.py```
- [ ] This PR does not need a test because `FILL THIS IN`.
- Select one of the following.
- [x] I have not added any `lit` tests.
- [ ] The `lit` tests I have added follow these [best
practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices),
including the "tests should be minimal" section. (Usually running Python
code
and using the instructions it generates is not minimal.)
tl.static_assert(mx_scale_ptr.dtype.element_ty==tl.uint8, f"{mx_scale_ptr.dtype.element_ty=} must be uint8")
110
-
tl.static_assert((src_dtype==tl.bfloat16) or (src_dtype==tl.float16), f"{src_dtype=} must be bfloat16 or float16")
110
+
tl.static_assert((src_dtype==tl.bfloat16) or (src_dtype==tl.float16)or (src_dtype==tl.float32), f"{src_dtype=} must be bfloat16 or float16 or float32")
0 commit comments