1
1
import inspect
2
- import os
3
2
import re
4
3
import warnings
5
- from collections .abc import Callable , Collection
4
+ from collections .abc import Callable , Collection , Iterable
5
+ from pathlib import Path
6
6
from re import Pattern
7
7
from typing import TYPE_CHECKING , Any , ClassVar , cast
8
8
@@ -279,28 +279,32 @@ class ExternalCOp(COp):
279
279
_cop_num_outputs : int | None = None
280
280
281
281
@classmethod
282
- def get_path (cls , f : str ) -> str :
282
+ def get_path (cls , f : Path ) -> Path :
283
283
"""Convert a path relative to the location of the class file into an absolute path.
284
284
285
285
Paths that are already absolute are passed through unchanged.
286
286
287
287
"""
288
- if not os . path . isabs ( f ):
288
+ if not f . is_absolute ( ):
289
289
class_file = inspect .getfile (cls )
290
- class_dir = os . path . dirname (class_file )
291
- f = os . path . realpath ( os . path . join ( class_dir , f ) )
290
+ class_dir = Path (class_file ). parent
291
+ f = ( class_dir / f ). resolve ( )
292
292
return f
293
293
294
- def __init__ (self , func_files : str | list [str ], func_name : str | None = None ):
294
+ def __init__ (
295
+ self ,
296
+ func_files : str | Path | list [str ] | list [Path ],
297
+ func_name : str | None = None ,
298
+ ):
295
299
"""
296
300
Sections are loaded from files in order with sections in later
297
301
files overriding sections in previous files.
298
302
299
303
"""
300
304
if not isinstance (func_files , list ):
301
- self .func_files = [func_files ]
305
+ self .func_files = [Path ( func_files ) ]
302
306
else :
303
- self .func_files = func_files
307
+ self .func_files = [ Path ( func_file ) for func_file in func_files ]
304
308
305
309
self .func_codes : list [str ] = []
306
310
# Keep the original name. If we reload old pickle, we want to
@@ -325,22 +329,20 @@ def __init__(self, func_files: str | list[str], func_name: str | None = None):
325
329
"Cannot have an `op_code_cleanup` section and specify `func_name`"
326
330
)
327
331
328
- def load_c_code (self , func_files : list [ str ]) -> None :
332
+ def load_c_code (self , func_files : Iterable [ Path ]) -> None :
329
333
"""Loads the C code to perform the `Op`."""
330
- func_files = [self .get_path (f ) for f in func_files ]
331
334
for func_file in func_files :
332
- with open (func_file ) as f :
333
- self .func_codes .append (f . read ( ))
335
+ func_file = self . get_path (func_file )
336
+ self .func_codes .append (func_file . read_text ( encoding = "utf-8" ))
334
337
335
338
# If both the old section markers and the new section markers are
336
339
# present, raise an error because we don't know which ones to follow.
337
- old_markers_present = False
338
- new_markers_present = False
339
- for code in self .func_codes :
340
- if self .backward_re .search (code ):
341
- old_markers_present = True
342
- if self .section_re .search (code ):
343
- new_markers_present = True
340
+ old_markers_present = any (
341
+ self .backward_re .search (code ) for code in self .func_codes
342
+ )
343
+ new_markers_present = any (
344
+ self .section_re .search (code ) for code in self .func_codes
345
+ )
344
346
345
347
if old_markers_present and new_markers_present :
346
348
raise ValueError (
@@ -350,7 +352,7 @@ def load_c_code(self, func_files: list[str]) -> None:
350
352
"be used at the same time."
351
353
)
352
354
353
- for i , code in enumerate ( self .func_codes ):
355
+ for func_file , code in zip ( func_files , self .func_codes ):
354
356
if self .backward_re .search (code ):
355
357
# This is backward compat code that will go away in a while
356
358
@@ -371,15 +373,15 @@ def load_c_code(self, func_files: list[str]) -> None:
371
373
if split [0 ].strip () != "" :
372
374
raise ValueError (
373
375
"Stray code before first #section "
374
- f"statement (in file { func_files [ i ] } ): { split [0 ]} "
376
+ f"statement (in file { func_file } ): { split [0 ]} "
375
377
)
376
378
377
379
# Separate the code into the proper sections
378
380
n = 1
379
381
while n < len (split ):
380
382
if split [n ] not in self .SECTIONS :
381
383
raise ValueError (
382
- f"Unknown section type (in file { func_files [ i ] } ): { split [n ]} "
384
+ f"Unknown section type (in file { func_file } ): { split [n ]} "
383
385
)
384
386
if split [n ] not in self .code_sections :
385
387
self .code_sections [split [n ]] = ""
@@ -388,7 +390,7 @@ def load_c_code(self, func_files: list[str]) -> None:
388
390
389
391
else :
390
392
raise ValueError (
391
- f"No valid section marker was found in file { func_files [ i ] } "
393
+ f"No valid section marker was found in file { func_file } "
392
394
)
393
395
394
396
def __get_op_params (self ) -> list [tuple [str , Any ]]:
0 commit comments