1616# Copyright (c) 2025 DeepSeek
1717# Licensed under the MIT License - https://github.com/deepseek-ai/DeepEP/blob/main/LICENSE
1818
19- import functools
2019import hashlib
20+ import functools
2121import os
2222import re
2323import subprocess
@@ -75,9 +75,7 @@ def get_nvcc_compiler() -> Tuple[str, str]:
7575 match = version_pattern .search (os .popen (f"{ path } --version" ).read ())
7676 version = match .group (1 )
7777 assert match , f"Cannot get the version of NVCC compiler { path } "
78- assert (
79- version >= least_version_required
80- ), f"NVCC { path } version { version } is lower than { least_version_required } "
78+ assert version >= least_version_required , f"NVCC { path } version { version } is lower than { least_version_required } "
8179 return path , version
8280 raise RuntimeError ("Cannot find any available NVCC compiler" )
8381
@@ -117,18 +115,13 @@ def put(path, data, is_binary=False):
117115
118116def build (name : str , arg_defs : tuple , code : str ) -> Runtime :
119117 # Compiler flags
120- nvcc_flags = [
121- "-std=c++17" ,
122- "-shared" ,
123- "-O3" ,
124- "--expt-relaxed-constexpr" ,
125- "--expt-extended-lambda" ,
126- "-gencode=arch=compute_90a,code=sm_90a" ,
127- "--ptxas-options=--register-usage-level=10" + (",--verbose" if "DG_PTXAS_VERBOSE" in os .environ else "" ),
128- # Suppress some unnecessary warnings, such as unused variables for certain `constexpr` branch cases
129- "--diag-suppress=177,174,940" ,
130- ]
131- cxx_flags = ["-fPIC" , "-O3" , "-Wno-deprecated-declarations" , "-Wno-abi" ]
118+ cpp_standard = int (os .getenv ("DG_NVCC_OVERRIDE_CPP_STANDARD" , 20 ))
119+ nvcc_flags = [f"-std=c++{ cpp_standard } " , "-shared" , "-O3" , "--expt-relaxed-constexpr" , "--expt-extended-lambda" ,
120+ "-gencode=arch=compute_90a,code=sm_90a" ,
121+ "--ptxas-options=--register-usage-level=10" + (",--verbose" if "DG_PTXAS_VERBOSE" in os .environ else "" ),
122+ # Suppress some unnecessary warnings, such as unused variables for certain `constexpr` branch cases
123+ "--diag-suppress=39,174,177,940" ]
124+ cxx_flags = ["-fPIC" , "-O3" , "-Wno-deprecated-declarations" , "-Wno-abi" , "-fconcepts" ]
132125 flags = [* nvcc_flags , f'--compiler-options={ "," .join (cxx_flags )} ' ]
133126 include_dirs = [get_jit_include_dir ()]
134127
@@ -155,8 +148,12 @@ def build(name: str, arg_defs: tuple, code: str) -> Runtime:
155148 # Compile into a temporary SO file
156149 so_path = f"{ path } /kernel.so"
157150 tmp_so_path = f"{ make_tmp_dir ()} /nvcc.tmp.{ str (uuid .uuid4 ())} .{ hash_to_hex (so_path )} .so"
151+
158152 # Compile
159- command = [get_nvcc_compiler ()[0 ], src_path , "-o" , tmp_so_path , * flags , * [f"-I{ d } " for d in include_dirs ]]
153+ command = [get_nvcc_compiler ()[0 ],
154+ src_path , "-o" , tmp_so_path ,
155+ * flags ,
156+ * [f"-I{ d } " for d in include_dirs ]]
160157 if os .getenv ("DG_JIT_DEBUG" , None ) or os .getenv ("DG_JIT_PRINT_NVCC_COMMAND" , False ):
161158 print (f"Compiling JIT runtime { name } with command { command } " )
162159 return_code = subprocess .check_call (command )
0 commit comments