Skip to content

Commit 882b4d7

Browse files
alibaba-mijiLLLLKKKK
authored andcommitted
update - move cutlass_groupgemm position
1 parent 777513b commit 882b4d7

9 files changed

+75
-37
lines changed

rtp_llm/models_py/BUILD

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ requirement(flashinfer)
2525

2626
filegroup(
2727
name = "cutlass_moe_config",
28-
srcs = glob(["configs/cutlass_groupgemm/*"]),
28+
srcs = glob(["kernels/cuda/fp8_kernel/cutlass_groupgemm/*"]),
2929
visibility = ["//visibility:public"],
3030
)
3131

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
from .get_best_config import load_all_configs
2+
3+
# load all configs once at import time
4+
load_all_configs()
5+
from .fp8_kernel import (
6+
cutlass_moe_mm_fp8_scaled,
7+
get_best_config_swap_ab,
8+
scaled_fp8_per_tensor_quant,
9+
scaled_fp8_per_token_quant,
10+
sgl_per_token_group_quant_fp8,
11+
)
12+
13+
__all__ = [
14+
"sgl_per_token_group_quant_fp8",
15+
"scaled_fp8_per_tensor_quant",
16+
"scaled_fp8_per_token_quant",
17+
"cutlass_moe_mm_fp8_scaled",
18+
"get_best_config_swap_ab",
19+
]

rtp_llm/models_py/configs/cutlass_groupgemm/E=20-N=5120-K=6144-device_name=NVIDIA_H20.json renamed to rtp_llm/models_py/kernels/cuda/fp8_kernel/cutlass_groupgemm/E=20-N=5120-K=6144-device_name=NVIDIA_H20.json

File renamed without changes.

rtp_llm/models_py/configs/cutlass_groupgemm/E=20-N=6144-K=2560-device_name=NVIDIA_H20.json renamed to rtp_llm/models_py/kernels/cuda/fp8_kernel/cutlass_groupgemm/E=20-N=6144-K=2560-device_name=NVIDIA_H20.json

File renamed without changes.

rtp_llm/models_py/configs/cutlass_groupgemm/E=32-N=3072-K=4096-device_name=NVIDIA_H20.json renamed to rtp_llm/models_py/kernels/cuda/fp8_kernel/cutlass_groupgemm/E=32-N=3072-K=4096-device_name=NVIDIA_H20.json

File renamed without changes.

rtp_llm/models_py/configs/cutlass_groupgemm/E=32-N=4096-K=1536-device_name=NVIDIA_H20.json renamed to rtp_llm/models_py/kernels/cuda/fp8_kernel/cutlass_groupgemm/E=32-N=4096-K=1536-device_name=NVIDIA_H20.json

File renamed without changes.

rtp_llm/models_py/kernels/cuda/fp8_kernel.py renamed to rtp_llm/models_py/kernels/cuda/fp8_kernel/fp8_kernel.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77

88
import torch
99

10-
from rtp_llm.models_py.configs.get_best_config import get_cutlass_groupgemm_best_config
10+
from rtp_llm.models_py.kernels.cuda.fp8_kernel.get_best_config import (
11+
get_cutlass_groupgemm_best_config,
12+
)
1113
from rtp_llm.models_py.utils.arch import is_cuda
1214
from rtp_llm.models_py.utils.math import align
1315

@@ -20,7 +22,7 @@
2022
per_token_quant_fp8,
2123
)
2224
else:
23-
logging.warning("can't import from rtp_llm_ops, only support cuda!")
25+
logging.info("skip import fp8 quant from rtp_llm_ops for non cuda platform")
2426

2527
logger = logging.getLogger(__name__)
2628

rtp_llm/models_py/configs/get_best_config.py renamed to rtp_llm/models_py/kernels/cuda/fp8_kernel/get_best_config.py

Lines changed: 49 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
_CUTLASS_GROUPGEMM_CONFIG_MAP: Dict[Tuple[int, int, int, str], Dict] = {}
1414

1515

16-
def _load_all_configs():
16+
def load_all_configs():
1717
"""Load all cutlass groupgemm config files into the global map."""
1818
if _CUTLASS_GROUPGEMM_CONFIG_MAP:
1919
# Already loaded
@@ -23,38 +23,58 @@ def _load_all_configs():
2323

2424
# Load open source config directory
2525
opensource_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), op_name)
26-
if os.path.exists(opensource_dir):
27-
pattern = os.path.join(opensource_dir, "E=*-N=*-K=*-device_name=*.json")
28-
for config_file in glob.glob(pattern):
29-
filename = os.path.basename(config_file)
30-
try:
31-
# Parse filename: E={E}-N={N}-K={K}-device_name={device_name}.json
32-
parts = filename.replace(".json", "").split("-")
33-
E = int(parts[0].split("=")[1])
34-
N = int(parts[1].split("=")[1])
35-
K = int(parts[2].split("=")[1])
36-
device_name = parts[3].split("=")[1]
37-
38-
# Load config
39-
with open(config_file) as f:
40-
config_data = json.load(f)
41-
# Convert string keys to int
42-
config_data = {int(key): val for key, val in config_data.items()}
43-
44-
# Store in global map
45-
key = (E, N, K, device_name)
46-
_CUTLASS_GROUPGEMM_CONFIG_MAP[key] = config_data
47-
logging.debug(f"Loaded config from {config_file}")
48-
except Exception as e:
49-
logging.warning(f"Failed to load config from {config_file}: {e}")
26+
27+
# Try to get internal source config directory
28+
# Collect all config directories to load
29+
config_dirs = [opensource_dir]
30+
try:
31+
import internal_source.rtp_llm.models_py.kernels.cuda.fp8_kernel
32+
33+
internalsource_dir = os.path.join(
34+
os.path.dirname(
35+
os.path.realpath(
36+
internal_source.rtp_llm.models_py.kernels.cuda.fp8_kernel.__file__
37+
)
38+
),
39+
op_name,
40+
)
41+
config_dirs.append(internalsource_dir)
42+
except ImportError:
43+
logging.info("internal_source not found")
44+
45+
# Load configs from all directories
46+
for config_dir in config_dirs:
47+
if os.path.exists(config_dir):
48+
logging.info(f"Loading configs from {config_dir}")
49+
pattern = os.path.join(config_dir, "E=*-N=*-K=*-device_name=*.json")
50+
for config_file in glob.glob(pattern):
51+
filename = os.path.basename(config_file)
52+
try:
53+
# Parse filename: E={E}-N={N}-K={K}-device_name={device_name}.json
54+
parts = filename.replace(".json", "").split("-")
55+
E = int(parts[0].split("=")[1])
56+
N = int(parts[1].split("=")[1])
57+
K = int(parts[2].split("=")[1])
58+
device_name = parts[3].split("=")[1]
59+
60+
# Load config
61+
with open(config_file) as f:
62+
config_data = json.load(f)
63+
# Convert string keys to int
64+
config_data = {
65+
int(key): val for key, val in config_data.items()
66+
}
67+
68+
# Store in global map
69+
key = (E, N, K, device_name)
70+
_CUTLASS_GROUPGEMM_CONFIG_MAP[key] = config_data
71+
logging.debug(f"Loaded config from {config_file}")
72+
except Exception as e:
73+
logging.warning(f"Failed to load config from {config_file}: {e}")
5074

5175
logging.info(
5276
f"Loaded {len(_CUTLASS_GROUPGEMM_CONFIG_MAP)} cutlass groupgemm configurations"
5377
)
54-
try:
55-
import internal_source.rtp_llm.utils.register_cutlass_configs
56-
except:
57-
logging.info("internal_source not found")
5878

5979

6080
def register_cutlass_groupgemm_config(
@@ -90,9 +110,6 @@ def get_cutlass_groupgemm_best_config(E: int, N: int, K: int) -> Optional[Dict]:
90110
Configuration dictionary mapping batch sizes to tile configurations,
91111
or None if no configuration is found.
92112
"""
93-
# Load all configs if not already loaded
94-
_load_all_configs()
95-
96113
device_name = torch.cuda.get_device_name().replace("-", "_").replace(" ", "_")
97114
key = (E, N, K, device_name)
98115

rtp_llm/models_py/kernels/cuda/test/per_token_group_quant_8bit_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,11 @@
88
from torch import dtype as _dtype
99
from torch.profiler import ProfilerActivity, profile, record_function
1010

11-
from rtp_llm.models_py.kernels.cuda.fp8_kernel import (
11+
from rtp_llm.models_py.utils.arch import is_hip
12+
from rtp_llm.ops.compute_ops import (
1213
per_token_group_quant_fp8,
1314
per_token_group_quant_int8,
1415
)
15-
from rtp_llm.models_py.utils.arch import is_hip
1616

1717
_is_hip = is_hip()
1818

0 commit comments

Comments
 (0)