|
1 | 1 | from collections import namedtuple
|
2 | 2 |
|
3 |
| -Backend = namedtuple("Backend", ["name", "cmake_config_name", "prompt_str"]) |
| 3 | +Backend = namedtuple("Backend", ["name", "cmake_config_name", "prompt_str", "parent"]) |
4 | 4 |
|
5 | 5 | if __name__ == "__main__":
|
6 | 6 | tvm_home = "" # pylint: disable=invalid-name
|
|
13 | 13 |
|
14 | 14 | cmake_config_str = f"set(TVM_SOURCE_DIR {tvm_home})\n"
|
15 | 15 | cmake_config_str += "set(CMAKE_BUILD_TYPE RelWithDebInfo)\n"
|
| 16 | + cuda_backend = Backend("CUDA", "USE_CUDA", "Use CUDA? (y/n): ", None) |
| 17 | + opencl_backend = Backend("OpenCL", "USE_OPENCL", "Use OpenCL? (y/n) ", None) |
16 | 18 | backends = [
|
17 |
| - Backend("CUDA", "USE_CUDA", "Use CUDA? (y/n): "), |
18 |
| - Backend("CUTLASS", "USE_CUTLASS", "Use CUTLASS? (y/n): "), |
19 |
| - Backend("CUBLAS", "USE_CUBLAS", "Use CUBLAS? (y/n): "), |
20 |
| - Backend("ROCm", "USE_ROCM", "Use ROCm? (y/n): "), |
21 |
| - Backend("Vulkan", "USE_VULKAN", "Use Vulkan? (y/n): "), |
| 19 | + cuda_backend, |
| 20 | + Backend("CUTLASS", "USE_CUTLASS", "Use CUTLASS? (y/n): ", cuda_backend), |
| 21 | + Backend("CUBLAS", "USE_CUBLAS", "Use CUBLAS? (y/n): ", cuda_backend), |
| 22 | + Backend("ROCm", "USE_ROCM", "Use ROCm? (y/n): ", None), |
| 23 | + Backend("Vulkan", "USE_VULKAN", "Use Vulkan? (y/n): ", None), |
| 24 | + Backend("Metal", "USE_METAL", "Use Metal (Apple M1/M2 GPU) ? (y/n): ", None), |
| 25 | + opencl_backend, |
22 | 26 | Backend(
|
23 |
| - "Metal", |
24 |
| - "USE_METAL", |
25 |
| - "Use Metal (Apple M1/M2 GPU) ? (y/n): ", |
| 27 | + "OpenCLHostPtr", |
| 28 | + "USE_OPENCL_ENABLE_HOST_PTR", |
| 29 | + "Use OpenCLHostPtr? (y/n): ", |
| 30 | + opencl_backend, |
26 | 31 | ),
|
27 |
| - Backend( |
28 |
| - "OpenCL", |
29 |
| - "USE_OPENCL", |
30 |
| - "Use OpenCL? (y/n) ", |
31 |
| - ), |
32 |
| - Backend("OpenCLHostPtr", "USE_OPENCL_ENABLE_HOST_PTR", "Use OpenCLHostPtr? (y/n): "), |
33 | 32 | ]
|
34 | 33 |
|
35 | 34 | enabled_backends = set()
|
36 | 35 |
|
37 | 36 | for backend in backends:
|
38 |
| - while True: |
39 |
| - use_backend = input(backend.prompt_str) |
40 |
| - if use_backend in ["yes", "Y", "y"]: |
41 |
| - cmake_config_str += f"set({backend.cmake_config_name} ON)\n" |
42 |
| - enabled_backends.add(backend.name) |
43 |
| - break |
44 |
| - elif use_backend in ["no", "N", "n"]: |
45 |
| - cmake_config_str += f"set({backend.cmake_config_name} OFF)\n" |
46 |
| - break |
47 |
| - else: |
48 |
| - print(f"Invalid input: {use_backend}. Please input again.") |
| 37 | + if backend.parent is not None and backend.parent.name not in enabled_backends: |
| 38 | + cmake_config_str += f"set({backend.cmake_config_name} OFF)\n" |
| 39 | + else: |
| 40 | + while True: |
| 41 | + use_backend = input(backend.prompt_str) |
| 42 | + if use_backend in ["yes", "Y", "y"]: |
| 43 | + cmake_config_str += f"set({backend.cmake_config_name} ON)\n" |
| 44 | + enabled_backends.add(backend.name) |
| 45 | + break |
| 46 | + elif use_backend in ["no", "N", "n"]: |
| 47 | + cmake_config_str += f"set({backend.cmake_config_name} OFF)\n" |
| 48 | + break |
| 49 | + else: |
| 50 | + print(f"Invalid input: {use_backend}. Please input again.") |
49 | 51 |
|
50 | 52 | if "CUDA" in enabled_backends:
|
51 | 53 | cmake_config_str += f"set(USE_THRUST ON)\n"
|
52 | 54 |
|
53 | 55 | # FlashInfer related
|
54 | 56 | use_flashInfer = False # pylint: disable=invalid-name
|
55 |
| - while True: |
56 |
| - user_input = input("Use FlashInfer? (need CUDA w/ compute capability 80;86;89;90) (y/n): ") |
57 |
| - if user_input in ["yes", "Y", "y"]: |
58 |
| - cmake_config_str += "set(USE_FLASHINFER ON)\n" |
59 |
| - cmake_config_str += "set(FLASHINFER_ENABLE_FP8 OFF)\n" |
60 |
| - cmake_config_str += "set(FLASHINFER_ENABLE_BF16 OFF)\n" |
61 |
| - cmake_config_str += "set(FLASHINFER_GEN_GROUP_SIZES 1 4 6 8)\n" |
62 |
| - cmake_config_str += "set(FLASHINFER_GEN_PAGE_SIZES 16)\n" |
63 |
| - cmake_config_str += "set(FLASHINFER_GEN_HEAD_DIMS 128)\n" |
64 |
| - cmake_config_str += "set(FLASHINFER_GEN_KV_LAYOUTS 0 1)\n" |
65 |
| - cmake_config_str += "set(FLASHINFER_GEN_POS_ENCODING_MODES 0 1)\n" |
66 |
| - cmake_config_str += 'set(FLASHINFER_GEN_ALLOW_FP16_QK_REDUCTIONS "false")\n' |
67 |
| - cmake_config_str += 'set(FLASHINFER_GEN_CASUALS "false" "true")\n' |
68 |
| - use_flashInfer = True # pylint: disable=invalid-name |
69 |
| - break |
70 |
| - elif user_input in ["no", "N", "n"]: |
71 |
| - cmake_config_str += "set(USE_FLASHINFER OFF)\n" |
72 |
| - break |
73 |
| - else: |
74 |
| - print(f"Invalid input: {use_flashInfer}. Please input again.") |
| 57 | + if "CUDA" in enabled_backends: |
| 58 | + while True: |
| 59 | + user_input = input( |
| 60 | + "Use FlashInfer? (need CUDA w/ compute capability 80;86;89;90) (y/n): " |
| 61 | + ) |
| 62 | + if user_input in ["yes", "Y", "y"]: |
| 63 | + cmake_config_str += "set(USE_FLASHINFER ON)\n" |
| 64 | + cmake_config_str += "set(FLASHINFER_ENABLE_FP8 OFF)\n" |
| 65 | + cmake_config_str += "set(FLASHINFER_ENABLE_BF16 OFF)\n" |
| 66 | + cmake_config_str += "set(FLASHINFER_GEN_GROUP_SIZES 1 4 6 8)\n" |
| 67 | + cmake_config_str += "set(FLASHINFER_GEN_PAGE_SIZES 16)\n" |
| 68 | + cmake_config_str += "set(FLASHINFER_GEN_HEAD_DIMS 128)\n" |
| 69 | + cmake_config_str += "set(FLASHINFER_GEN_KV_LAYOUTS 0 1)\n" |
| 70 | + cmake_config_str += "set(FLASHINFER_GEN_POS_ENCODING_MODES 0 1)\n" |
| 71 | + cmake_config_str += 'set(FLASHINFER_GEN_ALLOW_FP16_QK_REDUCTIONS "false")\n' |
| 72 | + cmake_config_str += 'set(FLASHINFER_GEN_CASUALS "false" "true")\n' |
| 73 | + use_flashInfer = True # pylint: disable=invalid-name |
| 74 | + break |
| 75 | + elif user_input in ["no", "N", "n"]: |
| 76 | + cmake_config_str += "set(USE_FLASHINFER OFF)\n" |
| 77 | + break |
| 78 | + else: |
| 79 | + print(f"Invalid input: {use_flashInfer}. Please input again.") |
| 80 | + else: |
| 81 | + cmake_config_str += "set(USE_FLASHINFER OFF)\n" |
| 82 | + |
75 | 83 | if use_flashInfer:
|
76 | 84 | while True:
|
77 | 85 | user_input = input("Enter your CUDA compute capability: ")
|
|
0 commit comments