Skip to content

Commit cccd8d7

Browse files
committed
Remove os.path in link/c/op.py
1 parent 53f52a3 commit cccd8d7

File tree

1 file changed

+26
-24
lines changed

1 file changed

+26
-24
lines changed

pytensor/link/c/op.py

Lines changed: 26 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import inspect
2-
import os
32
import re
43
import warnings
5-
from collections.abc import Callable, Collection
4+
from collections.abc import Callable, Collection, Iterable
5+
from pathlib import Path
66
from re import Pattern
77
from typing import TYPE_CHECKING, Any, ClassVar, cast
88

@@ -288,28 +288,32 @@ class ExternalCOp(COp):
288288
_cop_num_outputs: int | None = None
289289

290290
@classmethod
291-
def get_path(cls, f: str) -> str:
291+
def get_path(cls, f: Path) -> Path:
292292
"""Convert a path relative to the location of the class file into an absolute path.
293293
294294
Paths that are already absolute are passed through unchanged.
295295
296296
"""
297-
if not os.path.isabs(f):
297+
if not f.is_absolute():
298298
class_file = inspect.getfile(cls)
299-
class_dir = os.path.dirname(class_file)
300-
f = os.path.realpath(os.path.join(class_dir, f))
299+
class_dir = Path(class_file).parent
300+
f = (class_dir / f).resolve()
301301
return f
302302

303-
def __init__(self, func_files: str | list[str], func_name: str | None = None):
303+
def __init__(
304+
self,
305+
func_files: str | Path | list[str] | list[Path],
306+
func_name: str | None = None,
307+
):
304308
"""
305309
Sections are loaded from files in order with sections in later
306310
files overriding sections in previous files.
307311
308312
"""
309313
if not isinstance(func_files, list):
310-
self.func_files = [func_files]
314+
self.func_files = [Path(func_files)]
311315
else:
312-
self.func_files = func_files
316+
self.func_files = [Path(func_file) for func_file in func_files]
313317

314318
self.func_codes: list[str] = []
315319
# Keep the original name. If we reload old pickle, we want to
@@ -334,22 +338,20 @@ def __init__(self, func_files: str | list[str], func_name: str | None = None):
334338
"Cannot have an `op_code_cleanup` section and specify `func_name`"
335339
)
336340

337-
def load_c_code(self, func_files: list[str]) -> None:
341+
def load_c_code(self, func_files: Iterable[Path]) -> None:
338342
"""Loads the C code to perform the `Op`."""
339-
func_files = [self.get_path(f) for f in func_files]
340343
for func_file in func_files:
341-
with open(func_file) as f:
342-
self.func_codes.append(f.read())
344+
func_file = self.get_path(func_file)
345+
self.func_codes.append(func_file.read_text(encoding="utf-8"))
343346

344347
# If both the old section markers and the new section markers are
345348
# present, raise an error because we don't know which ones to follow.
346-
old_markers_present = False
347-
new_markers_present = False
348-
for code in self.func_codes:
349-
if self.backward_re.search(code):
350-
old_markers_present = True
351-
if self.section_re.search(code):
352-
new_markers_present = True
349+
old_markers_present = any(
350+
self.backward_re.search(code) for code in self.func_codes
351+
)
352+
new_markers_present = any(
353+
self.section_re.search(code) for code in self.func_codes
354+
)
353355

354356
if old_markers_present and new_markers_present:
355357
raise ValueError(
@@ -359,7 +361,7 @@ def load_c_code(self, func_files: list[str]) -> None:
359361
"be used at the same time."
360362
)
361363

362-
for i, code in enumerate(self.func_codes):
364+
for func_file, code in zip(func_files, self.func_codes):
363365
if self.backward_re.search(code):
364366
# This is backward compat code that will go away in a while
365367

@@ -380,15 +382,15 @@ def load_c_code(self, func_files: list[str]) -> None:
380382
if split[0].strip() != "":
381383
raise ValueError(
382384
"Stray code before first #section "
383-
f"statement (in file {func_files[i]}): {split[0]}"
385+
f"statement (in file {func_file}): {split[0]}"
384386
)
385387

386388
# Separate the code into the proper sections
387389
n = 1
388390
while n < len(split):
389391
if split[n] not in self.SECTIONS:
390392
raise ValueError(
391-
f"Unknown section type (in file {func_files[i]}): {split[n]}"
393+
f"Unknown section type (in file {func_file}): {split[n]}"
392394
)
393395
if split[n] not in self.code_sections:
394396
self.code_sections[split[n]] = ""
@@ -397,7 +399,7 @@ def load_c_code(self, func_files: list[str]) -> None:
397399

398400
else:
399401
raise ValueError(
400-
f"No valid section marker was found in file {func_files[i]}"
402+
f"No valid section marker was found in file {func_file}"
401403
)
402404

403405
def __get_op_params(self) -> list[tuple[str, Any]]:

0 commit comments

Comments
 (0)