Skip to content

Commit 94dd538

Browse files
authored
[BENCH] fixed dependencies (#6436)
1 parent 3728fdf commit 94dd538

File tree

6 files changed

+13
-21
lines changed

6 files changed

+13
-21
lines changed

bench/triton_bench/matmul_ogs_details/_common.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from ki.safe_import import tl, triton
1+
import triton
2+
import triton.language as tl
23

34
# -----------------------------------------------------------------------------
45
# Utilities

bench/triton_bench/matmul_ogs_details/opt_flags.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
from dataclasses import dataclass
2+
import triton
23

34
import torch
45

5-
from ki.meta import cuda_capability_geq
6-
from ki.safe_import import triton
6+
from triton_bench.meta import cuda_capability_geq
77

88
from . import opt_flags_amd, opt_flags_nvidia
99

bench/triton_bench/matmul_ogs_details/opt_flags_amd.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import torch
2-
3-
from ki.safe_import import triton
2+
import triton
43

54

65
def compute_block_nk(n, block_m, grid_m, num_xcds, lhs_dtype, rhs_dtype, microscaling_ctx):

bench/triton_bench/matmul_ogs_details/opt_flags_nvidia.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import torch
2-
3-
from ki.meta import cuda_capability_geq
4-
from ki.safe_import import triton
2+
import triton
3+
from triton_bench.meta import cuda_capability_geq
54

65

76
def compute_grid_size(routing_data, m, n, block_m, block_n):

bench/triton_bench/mxfp.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
from enum import Enum
2-
2+
import triton
3+
import triton.language as tl
34
import torch
45
import torch.nn.functional as F
56

6-
from ki.meta import is_float8_dtype
7-
from ki.safe_import import tl, triton
8-
97
# -----------------------------------------------------------------------------
108
# Dequantization / Quantization Utilities
119
# -----------------------------------------------------------------------------
@@ -476,7 +474,7 @@ def upcast_from_mxfp(tensor: torch.Tensor, scale: torch.Tensor, dtype: torch.dty
476474
assert -ndim <= swizzle_axis < ndim, f"Invalid swizzle axis {swizzle_axis=}"
477475
swizzle_axis = swizzle_axis if swizzle_axis >= 0 else swizzle_axis + ndim
478476

479-
multiplier = 1 if is_float8_dtype(tensor.dtype) else 2
477+
multiplier = 1 if "float8" in str(tensor.dtype) else 2
480478
logical_quant_dim_shape = tensor.shape[axis] * multiplier
481479
assert tensor.ndim == scale.ndim, (f"Weight and scale must have the same number of dimensions. "
482480
f"Got {tensor.ndim=} and {scale.ndim=}")
@@ -560,7 +558,7 @@ def downcast_to_mxfp_torch(src_tensor: torch.Tensor, out_quant_type: torch.dtype
560558
assert -ndim <= swizzle_axis < ndim, f"Invalid swizzle axis {swizzle_axis=}"
561559
swizzle_axis = swizzle_axis if swizzle_axis >= 0 else swizzle_axis + ndim
562560
is_fp4 = out_quant_type == torch.uint8
563-
is_fp8 = is_float8_dtype(out_quant_type)
561+
is_fp8 = "float8" in str(out_quant_type)
564562
assert is_fp4 or is_fp8, f"Invalid input tensor dtype {out_quant_type}"
565563

566564
device = src_tensor.device

bench/triton_bench/testing.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import torch
88

9-
from ki.meta import MAX_FINITE_FLOAT8E4B8, MAX_FINITE_FLOAT8E4NV, MAX_FINITE_FLOAT8E5
9+
from triton_bench.meta import MAX_FINITE_FLOAT8E4B8, MAX_FINITE_FLOAT8E4NV, MAX_FINITE_FLOAT8E5
1010

1111

1212
def assert_equal(ref, tri):
@@ -18,12 +18,7 @@ def assert_equal(ref, tri):
1818

1919
def assert_close(ref, tri, maxtol=None, rmstol=None, description="--", verbose=True):
2020
if tri.dtype.itemsize == 1:
21-
# TODO:
22-
# switch to ref.to(tri.dtype) when Triton does
23-
# RTNE on A100
24-
from ki.tritium import type
25-
26-
ref_as_type = type(ref, tri.dtype)
21+
ref_as_type = ref.to(tri.dtype)
2722
if ref.dtype == tri.dtype:
2823
assert torch.all(ref_as_type == tri)
2924
return

0 commit comments

Comments
 (0)