Skip to content

Commit 7cdb17a

Browse files
committed
Remove os.path in link/c/op.py
1 parent 21fdfa8 commit 7cdb17a

File tree

1 file changed

+26
-24
lines changed

1 file changed

+26
-24
lines changed

pytensor/link/c/op.py

Lines changed: 26 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import inspect
2-
import os
32
import re
43
import warnings
5-
from collections.abc import Callable, Collection
4+
from collections.abc import Callable, Collection, Iterable
5+
from pathlib import Path
66
from re import Pattern
77
from typing import TYPE_CHECKING, Any, ClassVar, cast
88

@@ -279,28 +279,32 @@ class ExternalCOp(COp):
279279
_cop_num_outputs: int | None = None
280280

281281
@classmethod
282-
def get_path(cls, f: str) -> str:
282+
def get_path(cls, f: Path) -> Path:
283283
"""Convert a path relative to the location of the class file into an absolute path.
284284
285285
Paths that are already absolute are passed through unchanged.
286286
287287
"""
288-
if not os.path.isabs(f):
288+
if not f.is_absolute():
289289
class_file = inspect.getfile(cls)
290-
class_dir = os.path.dirname(class_file)
291-
f = os.path.realpath(os.path.join(class_dir, f))
290+
class_dir = Path(class_file).parent
291+
f = (class_dir / f).resolve()
292292
return f
293293

294-
def __init__(self, func_files: str | list[str], func_name: str | None = None):
294+
def __init__(
295+
self,
296+
func_files: str | Path | list[str] | list[Path],
297+
func_name: str | None = None,
298+
):
295299
"""
296300
Sections are loaded from files in order with sections in later
297301
files overriding sections in previous files.
298302
299303
"""
300304
if not isinstance(func_files, list):
301-
self.func_files = [func_files]
305+
self.func_files = [Path(func_files)]
302306
else:
303-
self.func_files = func_files
307+
self.func_files = [Path(func_file) for func_file in func_files]
304308

305309
self.func_codes: list[str] = []
306310
# Keep the original name. If we reload old pickle, we want to
@@ -325,22 +329,20 @@ def __init__(self, func_files: str | list[str], func_name: str | None = None):
325329
"Cannot have an `op_code_cleanup` section and specify `func_name`"
326330
)
327331

328-
def load_c_code(self, func_files: list[str]) -> None:
332+
def load_c_code(self, func_files: Iterable[Path]) -> None:
329333
"""Loads the C code to perform the `Op`."""
330-
func_files = [self.get_path(f) for f in func_files]
331334
for func_file in func_files:
332-
with open(func_file) as f:
333-
self.func_codes.append(f.read())
335+
func_file = self.get_path(func_file)
336+
self.func_codes.append(func_file.read_text(encoding="utf-8"))
334337

335338
# If both the old section markers and the new section markers are
336339
# present, raise an error because we don't know which ones to follow.
337-
old_markers_present = False
338-
new_markers_present = False
339-
for code in self.func_codes:
340-
if self.backward_re.search(code):
341-
old_markers_present = True
342-
if self.section_re.search(code):
343-
new_markers_present = True
340+
old_markers_present = any(
341+
self.backward_re.search(code) for code in self.func_codes
342+
)
343+
new_markers_present = any(
344+
self.section_re.search(code) for code in self.func_codes
345+
)
344346

345347
if old_markers_present and new_markers_present:
346348
raise ValueError(
@@ -350,7 +352,7 @@ def load_c_code(self, func_files: list[str]) -> None:
350352
"be used at the same time."
351353
)
352354

353-
for i, code in enumerate(self.func_codes):
355+
for func_file, code in zip(func_files, self.func_codes):
354356
if self.backward_re.search(code):
355357
# This is backward compat code that will go away in a while
356358

@@ -371,15 +373,15 @@ def load_c_code(self, func_files: list[str]) -> None:
371373
if split[0].strip() != "":
372374
raise ValueError(
373375
"Stray code before first #section "
374-
f"statement (in file {func_files[i]}): {split[0]}"
376+
f"statement (in file {func_file}): {split[0]}"
375377
)
376378

377379
# Separate the code into the proper sections
378380
n = 1
379381
while n < len(split):
380382
if split[n] not in self.SECTIONS:
381383
raise ValueError(
382-
f"Unknown section type (in file {func_files[i]}): {split[n]}"
384+
f"Unknown section type (in file {func_file}): {split[n]}"
383385
)
384386
if split[n] not in self.code_sections:
385387
self.code_sections[split[n]] = ""
@@ -388,7 +390,7 @@ def load_c_code(self, func_files: list[str]) -> None:
388390

389391
else:
390392
raise ValueError(
391-
f"No valid section marker was found in file {func_files[i]}"
393+
f"No valid section marker was found in file {func_file}"
392394
)
393395

394396
def __get_op_params(self) -> list[tuple[str, Any]]:

0 commit comments

Comments
 (0)