@@ -32,8 +32,8 @@ def get_version():
3232 version_txt = os .path .join (cwd , "version.txt" )
3333 with open (version_txt , "r" ) as f :
3434 version = f .readline ().strip ()
35- if os .getenv ("BUILD_VERSION " ):
36- version = os .getenv ("BUILD_VERSION " )
35+ if os .getenv ("TORCHRL_BUILD_VERSION " ):
36+ version = os .getenv ("TORCHRL_BUILD_VERSION " )
3737 elif sha != "Unknown" :
3838 version += "+" + sha [:7 ]
3939 return version
@@ -68,11 +68,13 @@ def write_version_file(version):
6868 f .write ("git_version = {}\n " .format (repr (sha )))
6969
7070
71- def _get_pytorch_version (is_nightly ):
71+ def _get_pytorch_version (is_nightly , is_local ):
7272 # if "PYTORCH_VERSION" in os.environ:
7373 # return f"torch=={os.environ['PYTORCH_VERSION']}"
7474 if is_nightly :
7575 return "torch>=2.4.0.dev"
76+ elif is_local :
77+ return "torch"
7678 return "torch>=2.3.0"
7779
7880
@@ -178,10 +180,12 @@ def _main(argv):
178180 else :
179181 version = get_version ()
180182 write_version_file (version )
183+ TORCHRL_BUILD_VERSION = os .getenv ("TORCHRL_BUILD_VERSION" )
181184 logging .info ("Building wheel {}-{}" .format (package_name , version ))
182- logging .info (f"BUILD_VERSION is { os . getenv ( 'BUILD_VERSION' ) } " )
185+ logging .info (f"TORCHRL_BUILD_VERSION is { TORCHRL_BUILD_VERSION } " )
183186
184- pytorch_package_dep = _get_pytorch_version (is_nightly )
187+ is_local = TORCHRL_BUILD_VERSION is None
188+ pytorch_package_dep = _get_pytorch_version (is_nightly , is_local )
185189 logging .info ("-- PyTorch dependency:" , pytorch_package_dep )
186190 # branch = _run_cmd(["git", "rev-parse", "--abbrev-ref", "HEAD"])
187191 # tag = _run_cmd(["git", "describe", "--tags", "--exact-match", "@"])
0 commit comments