|
1 | 1 | from enum import Enum |
2 | | - |
| 2 | +import triton |
| 3 | +import triton.language as tl |
3 | 4 | import torch |
4 | 5 | import torch.nn.functional as F |
5 | 6 |
|
6 | | -from ki.meta import is_float8_dtype |
7 | | -from ki.safe_import import tl, triton |
8 | | - |
9 | 7 | # ----------------------------------------------------------------------------- |
10 | 8 | # Dequantization / Quantization Utilities |
11 | 9 | # ----------------------------------------------------------------------------- |
@@ -476,7 +474,7 @@ def upcast_from_mxfp(tensor: torch.Tensor, scale: torch.Tensor, dtype: torch.dty |
476 | 474 | assert -ndim <= swizzle_axis < ndim, f"Invalid swizzle axis {swizzle_axis=}" |
477 | 475 | swizzle_axis = swizzle_axis if swizzle_axis >= 0 else swizzle_axis + ndim |
478 | 476 |
|
479 | | - multiplier = 1 if is_float8_dtype(tensor.dtype) else 2 |
| 477 | + multiplier = 1 if "float8" in str(tensor.dtype) else 2 |
480 | 478 | logical_quant_dim_shape = tensor.shape[axis] * multiplier |
481 | 479 | assert tensor.ndim == scale.ndim, (f"Weight and scale must have the same number of dimensions. " |
482 | 480 | f"Got {tensor.ndim=} and {scale.ndim=}") |
@@ -560,7 +558,7 @@ def downcast_to_mxfp_torch(src_tensor: torch.Tensor, out_quant_type: torch.dtype |
560 | 558 | assert -ndim <= swizzle_axis < ndim, f"Invalid swizzle axis {swizzle_axis=}" |
561 | 559 | swizzle_axis = swizzle_axis if swizzle_axis >= 0 else swizzle_axis + ndim |
562 | 560 | is_fp4 = out_quant_type == torch.uint8 |
563 | | - is_fp8 = is_float8_dtype(out_quant_type) |
| 561 | + is_fp8 = "float8" in str(out_quant_type) |
564 | 562 | assert is_fp4 or is_fp8, f"Invalid input tensor dtype {out_quant_type}" |
565 | 563 |
|
566 | 564 | device = src_tensor.device |
|
0 commit comments