Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,12 +465,12 @@ def __init__(self, name: str, sourcedir: str) -> None:
print(name, sourcedir)
self.sourcedir = sourcedir

def get_cmake_abi_args(cmake_args):
def get_compile_abi_args(compile_args):
if torch.compiled_with_cxx11_abi():
cmake_args.append("-D_GLIBCXX_USE_CXX11_ABI=1")
compile_args.append("-D_GLIBCXX_USE_CXX11_ABI=1")
else:
cmake_args.append("-D_GLIBCXX_USE_CXX11_ABI=0")
return cmake_args
compile_args.append("-D_GLIBCXX_USE_CXX11_ABI=0")
return compile_args

class CMakeBuild(BuildExtension):

Expand Down Expand Up @@ -512,7 +512,7 @@ def build_extension(self, ext) -> None:
else:
raise ValueError("Unsupported backend: CUDA_HOME, MUSA_HOME, and ROCM_HOME are not set and XPU is not available.")

cmake_args = get_cmake_abi_args(cmake_args)
cmake_args = get_compile_abi_args(cmake_args)
# log cmake_args
print("CMake args:", cmake_args)

Expand Down Expand Up @@ -603,12 +603,12 @@ def build_extension(self, ext) -> None:
],
extra_compile_args={
'cxx': ['-O3', '-DKTRANSFORMERS_USE_CUDA'],
'nvcc': [
'nvcc': get_compile_abi_args([
'-O3',
# '--use_fast_math',
'-Xcompiler', '-fPIC',
'-DKTRANSFORMERS_USE_CUDA',
]
])
}
)
elif MUSA_HOME is not None:
Expand All @@ -627,11 +627,11 @@ def build_extension(self, ext) -> None:
],
extra_compile_args={
'cxx': ['force_mcc'],
'mcc': [
'mcc': get_compile_abi_args([
'-O3',
'-DKTRANSFORMERS_USE_MUSA',
'-DTHRUST_IGNORE_CUB_VERSION_CHECK',
]
])
}
)
elif torch.xpu.is_available(): #XPUExtension is not available now.
Expand Down