Skip to content

Commit 660af62

Browse files
Danil328irexyc
andauthored
Added the ability to build a project with PyTorch 2.0. (open-mmlab#2553)
* Added the ability to build a project with PyTorch 2.0. Namely, I added the flag -std=c++17 to extra_compile_args depending on the version of Torch. * Lost the condition for the presence of nvcc * Lost the condition for the presence of nvcc * Add parse_version * fix lint --------- Co-authored-by: Xin Chen <[email protected]>
1 parent db73d55 commit 660af62

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

setup.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import os
22

3+
from pkg_resources import parse_version
34
from setuptools import find_packages, setup
45

56
EXT_TYPE = ''
67
try:
8+
import torch
79
from torch.utils.cpp_extension import BuildExtension
810
cmd_class = {'build_ext': BuildExtension}
911
EXT_TYPE = 'torch'
@@ -139,7 +141,10 @@ def get_extensions():
139141
# to compile those cpp files, so there is no need to add the
140142
# argument
141143
if platform.system() != 'Windows':
142-
extra_compile_args['cxx'] = ['-std=c++14']
144+
if parse_version(torch.__version__) <= parse_version('1.12.1'):
145+
extra_compile_args['cxx'] = ['-std=c++14']
146+
else:
147+
extra_compile_args['cxx'] = ['-std=c++17']
143148

144149
include_dirs = []
145150

@@ -159,7 +164,10 @@ def get_extensions():
159164
# to compile those cpp files, so there is no need to add the
160165
# argument
161166
if 'nvcc' in extra_compile_args and platform.system() != 'Windows':
162-
extra_compile_args['nvcc'] += ['-std=c++14']
167+
if parse_version(torch.__version__) <= parse_version('1.12.1'):
168+
extra_compile_args['nvcc'] += ['-std=c++14']
169+
else:
170+
extra_compile_args['nvcc'] += ['-std=c++17']
163171

164172
ext_ops = extension(
165173
name=ext_name,

0 commit comments

Comments
 (0)