Skip to content

Commit 986ba9b

Browse files
committed
Remove os.path in configdefaults.py
1 parent 910e0cf commit 986ba9b

File tree

1 file changed

+67
-70
lines changed

1 file changed

+67
-70
lines changed

pytensor/configdefaults.py

Lines changed: 67 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import socket
77
import sys
88
import textwrap
9+
from pathlib import Path
910

1011
import numpy as np
1112
from setuptools._distutils.spawn import find_executable
@@ -33,64 +34,60 @@
3334
_logger = logging.getLogger("pytensor.configdefaults")
3435

3536

36-
def get_cuda_root():
37+
def get_cuda_root() -> Path | None:
3738
# We look for the cuda path since we need headers from there
38-
v = os.getenv("CUDA_ROOT", "")
39-
if v:
40-
return v
41-
v = os.getenv("CUDA_PATH", "")
42-
if v:
43-
return v
44-
s = os.getenv("PATH")
45-
if not s:
46-
return ""
47-
for dir in s.split(os.path.pathsep):
48-
if os.path.exists(os.path.join(dir, "nvcc")):
49-
return os.path.dirname(os.path.abspath(dir))
50-
return ""
51-
52-
53-
def default_cuda_include():
39+
if (v := os.getenv("CUDA_ROOT")) is not None:
40+
return Path(v)
41+
if (v := os.getenv("CUDA_PATH")) is not None:
42+
return Path(v)
43+
if (s := os.getenv("PATH")) is None:
44+
return Path()
45+
for dir in s.split(os.pathsep):
46+
if (Path(dir) / "nvcc").exists():
47+
return Path(dir).absolute().parent
48+
return None
49+
50+
51+
def default_cuda_include() -> Path | None:
5452
if config.cuda__root:
55-
return os.path.join(config.cuda__root, "include")
56-
return ""
53+
return Path(config.cuda__root) / "include"
54+
return None
5755

5856

59-
def default_dnn_base_path():
57+
def default_dnn_base_path() -> Path | None:
6058
# We want to default to the cuda root if cudnn is installed there
61-
root = config.cuda__root
59+
root = Path(config.cuda__root)
6260
# The include doesn't change location between OS.
63-
if root and os.path.exists(os.path.join(root, "include", "cudnn.h")):
61+
if (root / "include/cudnn.h").exists():
6462
return root
65-
return ""
63+
return None
6664

6765

68-
def default_dnn_inc_path():
66+
def default_dnn_inc_path() -> Path | None:
6967
if config.dnn__base_path != "":
70-
return os.path.join(config.dnn__base_path, "include")
71-
return ""
68+
return Path(config.dnn__base_path) / "include"
69+
return None
7270

7371

74-
def default_dnn_lib_path():
72+
def default_dnn_lib_path() -> Path | None:
7573
if config.dnn__base_path != "":
7674
if sys.platform == "win32":
77-
path = os.path.join(config.dnn__base_path, "lib", "x64")
75+
path = Path(config.dnn__base__path) / "lib/x64"
7876
elif sys.platform == "darwin":
79-
path = os.path.join(config.dnn__base_path, "lib")
77+
path = Path(config.dnn__base__path) / "lib"
8078
else:
8179
# This is linux
82-
path = os.path.join(config.dnn__base_path, "lib64")
80+
path = Path(config.dnn__base__path) / "lib64"
8381
return path
84-
return ""
82+
return None
8583

8684

87-
def default_dnn_bin_path():
85+
def default_dnn_bin_path() -> Path | None:
8886
if config.dnn__base_path != "":
8987
if sys.platform == "win32":
90-
return os.path.join(config.dnn__base_path, "bin")
91-
else:
92-
return config.dnn__library_path
93-
return ""
88+
return Path(config.dnn__base_path) / "bin"
89+
return Path(config.dnn__library_path)
90+
return None
9491

9592

9693
def _filter_mode(val):
@@ -405,15 +402,11 @@ def add_compile_configvars():
405402
# Anaconda on Windows has mingw-w64 packages including GCC, but it may not be on PATH.
406403
if rc != 0:
407404
if sys.platform == "win32":
408-
mingw_w64_gcc = os.path.join(
409-
os.path.dirname(sys.executable), "Library", "mingw-w64", "bin", "g++"
410-
)
405+
mingw_w64_gcc = Path(sys.executable).parent / "Library/mingw-w64/bin/g++"
411406
try:
412407
rc = call_subprocess_Popen([mingw_w64_gcc, "-v"])
413408
if rc == 0:
414-
maybe_add_to_os_environ_pathlist(
415-
"PATH", os.path.dirname(mingw_w64_gcc)
416-
)
409+
maybe_add_to_os_environ_pathlist("PATH", mingw_w64_gcc.parent)
417410
except OSError:
418411
rc = 1
419412
if rc != 0:
@@ -1221,27 +1214,27 @@ def add_numba_configvars():
12211214
)
12221215

12231216

1224-
def _default_compiledirname():
1217+
def _default_compiledirname() -> str:
12251218
formatted = config.compiledir_format % _compiledir_format_dict
12261219
safe = re.sub(r"[\(\)\s,]+", "_", formatted)
12271220
return safe
12281221

12291222

1230-
def _filter_base_compiledir(path):
1223+
def _filter_base_compiledir(path: Path) -> Path:
12311224
# Expand '~' in path
1232-
return os.path.expanduser(str(path))
1225+
return path.expanduser()
12331226

12341227

1235-
def _filter_compiledir(path):
1228+
def _filter_compiledir(path: Path) -> Path:
12361229
# Expand '~' in path
1237-
path = os.path.expanduser(path)
1230+
path = path.expanduser()
12381231
# Turn path into the 'real' path. This ensures that:
12391232
# 1. There is no relative path, which would fail e.g. when trying to
12401233
# import modules from the compile dir.
12411234
# 2. The path is stable w.r.t. e.g. symlinks (which makes it easier
12421235
# to re-use compiled modules).
1243-
path = os.path.realpath(path)
1244-
if os.access(path, os.F_OK): # Do it exist?
1236+
path = path.resolve()
1237+
if path.exists(): # Does it exist?
12451238
if not os.access(path, os.R_OK | os.W_OK | os.X_OK):
12461239
# If it exist we need read, write and listing access
12471240
raise ValueError(
@@ -1250,7 +1243,9 @@ def _filter_compiledir(path):
12501243
)
12511244
else:
12521245
try:
1253-
os.makedirs(path, 0o770) # read-write-execute for user and group
1246+
path.mkdir(
1247+
mode=0o770, parents=True
1248+
) # read-write-execute for user and group
12541249
except OSError as e:
12551250
# Maybe another parallel execution of pytensor was trying to create
12561251
# the same directory at the same time.
@@ -1264,36 +1259,38 @@ def _filter_compiledir(path):
12641259
# os.system('touch') returned -1 for an unknown reason; the
12651260
# alternate approach here worked in all cases... it was weird.
12661261
# No error should happen as we checked the permissions.
1267-
init_file = os.path.join(path, "__init__.py")
1268-
if not os.path.exists(init_file):
1262+
init_file = path / "__init__.py"
1263+
if not init_file.exists():
12691264
try:
12701265
with open(init_file, "w"):
12711266
pass
12721267
except OSError as e:
1273-
if os.path.exists(init_file):
1268+
if init_file.exists():
12741269
pass # has already been created
12751270
else:
1276-
e.args += (f"{path} exist? {os.path.exists(path)}",)
1271+
e.args += (f"{path} exist? {path.exists()}",)
12771272
raise
12781273
return path
12791274

12801275

1281-
def _get_home_dir():
1276+
def _get_home_dir() -> Path:
12821277
"""
12831278
Return location of the user's home directory.
12841279
12851280
"""
1286-
home = os.getenv("HOME")
1287-
if home is None:
1288-
# This expanduser usually works on Windows (see discussion on
1289-
# theano-users, July 13 2010).
1290-
home = os.path.expanduser("~")
1291-
if home == "~":
1292-
# This might happen when expanduser fails. Although the cause of
1293-
# failure is a mystery, it has been seen on some Windows system.
1294-
home = os.getenv("USERPROFILE")
1295-
assert home is not None
1296-
return home
1281+
if (env_home := os.getenv("HOME")) is not None:
1282+
return Path(env_home)
1283+
1284+
# This usually works on Windows (see discussion on theano-users, July 13 2010).
1285+
path_home = Path.home()
1286+
if str(path_home) != "~":
1287+
return path_home
1288+
1289+
# This might happen when expanduser fails. Although the cause of
1290+
# failure is a mystery, it has been seen on some Windows system.
1291+
windowsfail_home = os.getenv("USERPROFILE")
1292+
assert windowsfail_home is not None
1293+
return Path(windowsfail_home)
12971294

12981295

12991296
_compiledir_format_dict = {
@@ -1309,8 +1306,8 @@ def _get_home_dir():
13091306
}
13101307

13111308

1312-
def _default_compiledir():
1313-
return os.path.join(config.base_compiledir, _default_compiledirname())
1309+
def _default_compiledir() -> Path:
1310+
return Path(config.base_compiledir) / _default_compiledirname()
13141311

13151312

13161313
def add_caching_dir_configvars():
@@ -1343,9 +1340,9 @@ def add_caching_dir_configvars():
13431340
# part of the roaming part of the user profile. Instead we use the local part
13441341
# of the user profile, when available.
13451342
if sys.platform == "win32" and os.getenv("LOCALAPPDATA") is not None:
1346-
default_base_compiledir = os.path.join(os.getenv("LOCALAPPDATA"), "PyTensor")
1343+
default_base_compiledir = Path(os.getenv("LOCALAPPDATA")) / "PyTensor"
13471344
else:
1348-
default_base_compiledir = os.path.join(_get_home_dir(), ".pytensor")
1345+
default_base_compiledir = _get_home_dir() / ".pytensor"
13491346

13501347
config.add(
13511348
"base_compiledir",

0 commit comments

Comments
 (0)