@@ -98,7 +98,6 @@ def get_python_include_path():
9898 return None
9999
100100
101- # PYTORCH_INSTALL_PATH and LIBTORCH_ROOT
102101def get_torch_root_path ():
103102 try :
104103 import torch
@@ -115,6 +114,12 @@ def get_torch_mlu_root_path():
115114 except ImportError :
116115 return None
117116
117+ def get_nccl_root_path ():
118+ try :
119+ from nvidia import nccl
120+ return str (Path (nccl .__file__ ).parent )
121+ except ImportError :
122+ return None
118123
119124def set_npu_envs ():
120125 PYTORCH_NPU_INSTALL_PATH = os .getenv ("PYTORCH_NPU_INSTALL_PATH" )
@@ -212,7 +217,16 @@ def set_mlu_envs():
212217 os .environ ["LIBTORCH_ROOT" ] = get_torch_root_path ()
213218 os .environ ["PYTORCH_INSTALL_PATH" ] = get_torch_root_path ()
214219 os .environ ["PYTORCH_MLU_INSTALL_PATH" ] = get_torch_mlu_root_path ()
215-
220+
221+ def set_cuda_envs ():
222+ os .environ ["PYTHON_INCLUDE_PATH" ] = get_python_include_path ()
223+ os .environ ["PYTHON_LIB_PATH" ] = get_torch_root_path ()
224+ os .environ ["LIBTORCH_ROOT" ] = get_torch_root_path ()
225+ os .environ ["PYTORCH_INSTALL_PATH" ] = get_torch_root_path ()
226+ os .environ ["CUDA_TOOLKIT_ROOT_DIR" ] = "/usr/local/cuda"
227+ os .environ ["NCCL_ROOT" ] = get_nccl_root_path ()
228+ os .environ ["NCCL_VERSION" ] = "2"
229+
216230class CMakeExtension (Extension ):
217231 def __init__ (self , name : str , path : str , sourcedir : str = "" ) -> None :
218232 super ().__init__ (name , sources = [])
@@ -223,7 +237,7 @@ def __init__(self, name: str, path: str, sourcedir: str = "") -> None:
223237class ExtBuild (build_ext ):
224238 user_options = build_ext .user_options + [
225239 ("base-dir=" , None , "base directory of xLLM project" ),
226- ("device=" , None , "target device type (a3 or a2 or mlu)" ),
240+ ("device=" , None , "target device type (a3 or a2 or mlu or cuda )" ),
227241 ("arch=" , None , "target arch type (x86 or arm)" ),
228242 ("install-xllm-kernels=" , None , "install xllm_kernels RPM package (true/false)" ),
229243 ]
@@ -302,8 +316,14 @@ def build_extension(self, ext: CMakeExtension):
302316 cmake_args += ["-DUSE_MLU=ON" ]
303317 # set mlu environment variables
304318 set_mlu_envs ()
319+ elif self .device == "cuda" :
320+ cuda_architectures = "80;89;90"
321+ cmake_args += ["-DUSE_CUDA=ON" ,
322+ f"-DCMAKE_CUDA_ARCHITECTURES={ cuda_architectures } " ]
323+ # set cuda environment variables
324+ set_cuda_envs ()
305325 else :
306- raise ValueError ("Please set --device to a2 or a3 or mlu." )
326+ raise ValueError ("Please set --device to a2 or a3 or mlu or cuda ." )
307327
308328
309329 # Adding CMake arguments set as environment variable
@@ -353,7 +373,7 @@ def build_extension(self, ext: CMakeExtension):
353373
354374class BuildDistWheel (bdist_wheel ):
355375 user_options = bdist_wheel .user_options + [
356- ("device=" , None , "target device type (a3 or a2 or mlu)" ),
376+ ("device=" , None , "target device type (a3 or a2 or mlu or cuda )" ),
357377 ("arch=" , None , "target arch type (x86 or arm)" ),
358378 ]
359379
@@ -530,7 +550,7 @@ def apply_patch():
530550 idx = sys .argv .index ('--device' )
531551 if idx + 1 < len (sys .argv ):
532552 device = sys .argv [idx + 1 ].lower ()
533- if device not in ('a2' , 'a3' , 'mlu' ):
553+ if device not in ('a2' , 'a3' , 'mlu' , 'cuda' ):
534554 print ("Error: --device must be a2 or a3 or mlu (case-insensitive)" )
535555 sys .exit (1 )
536556 # Remove the arguments so setup() doesn't see them
0 commit comments