1+ import contextlib
12import os
23import pathlib
34import platform
45import re
56import shutil
67import sys
78import sysconfig
9+ from importlib .util import module_from_spec , spec_from_file_location
810
9- from setuptools import setup
11+ from setuptools import Extension , setup
12+ from setuptools .command .build_ext import build_ext
1013
1114
12- try :
13- from pybind11 .setup_helpers import Pybind11Extension as Extension
14- from pybind11 .setup_helpers import build_ext
15- except ImportError :
16- from setuptools import Extension
17- from setuptools .command .build_ext import build_ext
18-
1915HERE = pathlib .Path (__file__ ).absolute ().parent
20- VERSION_FILE = HERE / 'torchopt' / 'version.py'
21-
22- sys .path .insert (0 , str (VERSION_FILE .parent ))
23- import version # noqa
2416
2517
2618class CMakeExtension (Extension ):
@@ -47,7 +39,6 @@ def build_extension(self, ext):
4739 build_temp .mkdir (parents = True , exist_ok = True )
4840
4941 config = 'Debug' if self .debug else 'Release'
50-
5142 cmake_args = [
5243 f'-DCMAKE_BUILD_TYPE={ config } ' ,
5344 f'-DCMAKE_LIBRARY_OUTPUT_DIRECTORY_{ config .upper ()} ={ ext_path .parent } ' ,
@@ -83,13 +74,53 @@ def build_extension(self, ext):
8374
8475 build_args .extend (['--target' , ext .target , '--' ])
8576
77+ cwd = os .getcwd ()
8678 try :
8779 os .chdir (build_temp )
8880 self .spawn ([cmake , ext .source_dir , * cmake_args ])
8981 if not self .dry_run :
9082 self .spawn ([cmake , '--build' , '.' , * build_args ])
9183 finally :
92- os .chdir (HERE )
84+ os .chdir (cwd )
85+
86+
87+ @contextlib .contextmanager
88+ def vcs_version (name , path ):
89+ path = pathlib .Path (path ).absolute ()
90+ assert path .is_file ()
91+ module_spec = spec_from_file_location (name = name , location = path )
92+ assert module_spec is not None
93+ assert module_spec .loader is not None
94+ module = sys .modules .get (name )
95+ if module is None :
96+ module = module_from_spec (module_spec )
97+ sys .modules [name ] = module
98+ module_spec .loader .exec_module (module )
99+
100+ if module .__release__ :
101+ yield module
102+ return
103+
104+ content = None
105+ try :
106+ try :
107+ content = path .read_text (encoding = 'utf-8' )
108+ path .write_text (
109+ data = re .sub (
110+ r"""__version__\s*=\s*('[^']+'|"[^"]+")""" ,
111+ f'__version__ = { module .__version__ !r} ' ,
112+ string = content ,
113+ ),
114+ encoding = 'utf-8' ,
115+ )
116+ except OSError :
117+ content = None
118+
119+ yield module
120+ finally :
121+ if content is not None :
122+ with path .open (mode = 'wt' , encoding = 'utf-8' , newline = '' ) as file :
123+ file .write (content )
93124
94125
95126CIBUILDWHEEL = os .getenv ('CIBUILDWHEEL' , '0' ) == '1'
@@ -112,29 +143,9 @@ def build_extension(self, ext):
112143 ext_kwargs .clear ()
113144
114145
115- VERSION_CONTENT = None
116-
117- try :
118- if not version .__release__ :
119- try :
120- VERSION_CONTENT = VERSION_FILE .read_text (encoding = 'utf-8' )
121- VERSION_FILE .write_text (
122- data = re .sub (
123- r"""__version__\s*=\s*('[^']+'|"[^"]+")""" ,
124- f'__version__ = { version .__version__ !r} ' ,
125- string = VERSION_CONTENT ,
126- ),
127- encoding = 'utf-8' ,
128- )
129- except OSError :
130- VERSION_CONTENT = None
131-
146+ with vcs_version (name = 'torchopt.version' , path = (HERE / 'torchopt' / 'version.py' )) as version :
132147 setup (
133148 name = 'torchopt' ,
134149 version = version .__version__ ,
135150 ** ext_kwargs ,
136151 )
137- finally :
138- if VERSION_CONTENT is not None :
139- with VERSION_FILE .open (mode = 'wt' , encoding = 'utf-8' , newline = '' ) as file :
140- file .write (VERSION_CONTENT )
0 commit comments