Skip to content

Commit 2ec4e38

Browse files
authored
Simplify obvious choices in gen_cmake_config.py (#3006)
1 parent 81976e0 commit 2ec4e38

File tree

1 file changed

+54
-46
lines changed

1 file changed

+54
-46
lines changed

cmake/gen_cmake_config.py

Lines changed: 54 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from collections import namedtuple
22

3-
Backend = namedtuple("Backend", ["name", "cmake_config_name", "prompt_str"])
3+
Backend = namedtuple("Backend", ["name", "cmake_config_name", "prompt_str", "parent"])
44

55
if __name__ == "__main__":
66
tvm_home = "" # pylint: disable=invalid-name
@@ -13,65 +13,73 @@
1313

1414
cmake_config_str = f"set(TVM_SOURCE_DIR {tvm_home})\n"
1515
cmake_config_str += "set(CMAKE_BUILD_TYPE RelWithDebInfo)\n"
16+
cuda_backend = Backend("CUDA", "USE_CUDA", "Use CUDA? (y/n): ", None)
17+
opencl_backend = Backend("OpenCL", "USE_OPENCL", "Use OpenCL? (y/n) ", None)
1618
backends = [
17-
Backend("CUDA", "USE_CUDA", "Use CUDA? (y/n): "),
18-
Backend("CUTLASS", "USE_CUTLASS", "Use CUTLASS? (y/n): "),
19-
Backend("CUBLAS", "USE_CUBLAS", "Use CUBLAS? (y/n): "),
20-
Backend("ROCm", "USE_ROCM", "Use ROCm? (y/n): "),
21-
Backend("Vulkan", "USE_VULKAN", "Use Vulkan? (y/n): "),
19+
cuda_backend,
20+
Backend("CUTLASS", "USE_CUTLASS", "Use CUTLASS? (y/n): ", cuda_backend),
21+
Backend("CUBLAS", "USE_CUBLAS", "Use CUBLAS? (y/n): ", cuda_backend),
22+
Backend("ROCm", "USE_ROCM", "Use ROCm? (y/n): ", None),
23+
Backend("Vulkan", "USE_VULKAN", "Use Vulkan? (y/n): ", None),
24+
Backend("Metal", "USE_METAL", "Use Metal (Apple M1/M2 GPU) ? (y/n): ", None),
25+
opencl_backend,
2226
Backend(
23-
"Metal",
24-
"USE_METAL",
25-
"Use Metal (Apple M1/M2 GPU) ? (y/n): ",
27+
"OpenCLHostPtr",
28+
"USE_OPENCL_ENABLE_HOST_PTR",
29+
"Use OpenCLHostPtr? (y/n): ",
30+
opencl_backend,
2631
),
27-
Backend(
28-
"OpenCL",
29-
"USE_OPENCL",
30-
"Use OpenCL? (y/n) ",
31-
),
32-
Backend("OpenCLHostPtr", "USE_OPENCL_ENABLE_HOST_PTR", "Use OpenCLHostPtr? (y/n): "),
3332
]
3433

3534
enabled_backends = set()
3635

3736
for backend in backends:
38-
while True:
39-
use_backend = input(backend.prompt_str)
40-
if use_backend in ["yes", "Y", "y"]:
41-
cmake_config_str += f"set({backend.cmake_config_name} ON)\n"
42-
enabled_backends.add(backend.name)
43-
break
44-
elif use_backend in ["no", "N", "n"]:
45-
cmake_config_str += f"set({backend.cmake_config_name} OFF)\n"
46-
break
47-
else:
48-
print(f"Invalid input: {use_backend}. Please input again.")
37+
if backend.parent is not None and backend.parent.name not in enabled_backends:
38+
cmake_config_str += f"set({backend.cmake_config_name} OFF)\n"
39+
else:
40+
while True:
41+
use_backend = input(backend.prompt_str)
42+
if use_backend in ["yes", "Y", "y"]:
43+
cmake_config_str += f"set({backend.cmake_config_name} ON)\n"
44+
enabled_backends.add(backend.name)
45+
break
46+
elif use_backend in ["no", "N", "n"]:
47+
cmake_config_str += f"set({backend.cmake_config_name} OFF)\n"
48+
break
49+
else:
50+
print(f"Invalid input: {use_backend}. Please input again.")
4951

5052
if "CUDA" in enabled_backends:
5153
cmake_config_str += f"set(USE_THRUST ON)\n"
5254

5355
# FlashInfer related
5456
use_flashInfer = False # pylint: disable=invalid-name
55-
while True:
56-
user_input = input("Use FlashInfer? (need CUDA w/ compute capability 80;86;89;90) (y/n): ")
57-
if user_input in ["yes", "Y", "y"]:
58-
cmake_config_str += "set(USE_FLASHINFER ON)\n"
59-
cmake_config_str += "set(FLASHINFER_ENABLE_FP8 OFF)\n"
60-
cmake_config_str += "set(FLASHINFER_ENABLE_BF16 OFF)\n"
61-
cmake_config_str += "set(FLASHINFER_GEN_GROUP_SIZES 1 4 6 8)\n"
62-
cmake_config_str += "set(FLASHINFER_GEN_PAGE_SIZES 16)\n"
63-
cmake_config_str += "set(FLASHINFER_GEN_HEAD_DIMS 128)\n"
64-
cmake_config_str += "set(FLASHINFER_GEN_KV_LAYOUTS 0 1)\n"
65-
cmake_config_str += "set(FLASHINFER_GEN_POS_ENCODING_MODES 0 1)\n"
66-
cmake_config_str += 'set(FLASHINFER_GEN_ALLOW_FP16_QK_REDUCTIONS "false")\n'
67-
cmake_config_str += 'set(FLASHINFER_GEN_CASUALS "false" "true")\n'
68-
use_flashInfer = True # pylint: disable=invalid-name
69-
break
70-
elif user_input in ["no", "N", "n"]:
71-
cmake_config_str += "set(USE_FLASHINFER OFF)\n"
72-
break
73-
else:
74-
print(f"Invalid input: {use_flashInfer}. Please input again.")
57+
if "CUDA" in enabled_backends:
58+
while True:
59+
user_input = input(
60+
"Use FlashInfer? (need CUDA w/ compute capability 80;86;89;90) (y/n): "
61+
)
62+
if user_input in ["yes", "Y", "y"]:
63+
cmake_config_str += "set(USE_FLASHINFER ON)\n"
64+
cmake_config_str += "set(FLASHINFER_ENABLE_FP8 OFF)\n"
65+
cmake_config_str += "set(FLASHINFER_ENABLE_BF16 OFF)\n"
66+
cmake_config_str += "set(FLASHINFER_GEN_GROUP_SIZES 1 4 6 8)\n"
67+
cmake_config_str += "set(FLASHINFER_GEN_PAGE_SIZES 16)\n"
68+
cmake_config_str += "set(FLASHINFER_GEN_HEAD_DIMS 128)\n"
69+
cmake_config_str += "set(FLASHINFER_GEN_KV_LAYOUTS 0 1)\n"
70+
cmake_config_str += "set(FLASHINFER_GEN_POS_ENCODING_MODES 0 1)\n"
71+
cmake_config_str += 'set(FLASHINFER_GEN_ALLOW_FP16_QK_REDUCTIONS "false")\n'
72+
cmake_config_str += 'set(FLASHINFER_GEN_CASUALS "false" "true")\n'
73+
use_flashInfer = True # pylint: disable=invalid-name
74+
break
75+
elif user_input in ["no", "N", "n"]:
76+
cmake_config_str += "set(USE_FLASHINFER OFF)\n"
77+
break
78+
else:
79+
print(f"Invalid input: {use_flashInfer}. Please input again.")
80+
else:
81+
cmake_config_str += "set(USE_FLASHINFER OFF)\n"
82+
7583
if use_flashInfer:
7684
while True:
7785
user_input = input("Enter your CUDA compute capability: ")

0 commit comments

Comments
 (0)