1
1
import inspect
2
- import os
3
2
import re
4
3
import warnings
5
- from collections .abc import Callable , Collection
4
+ from collections .abc import Callable , Collection , Iterable
5
+ from pathlib import Path
6
6
from re import Pattern
7
7
from typing import TYPE_CHECKING , Any , ClassVar , cast
8
8
@@ -288,28 +288,32 @@ class ExternalCOp(COp):
288
288
_cop_num_outputs : int | None = None
289
289
290
290
@classmethod
291
- def get_path (cls , f : str ) -> str :
291
+ def get_path (cls , f : Path ) -> Path :
292
292
"""Convert a path relative to the location of the class file into an absolute path.
293
293
294
294
Paths that are already absolute are passed through unchanged.
295
295
296
296
"""
297
- if not os . path . isabs ( f ):
297
+ if not f . is_absolute ( ):
298
298
class_file = inspect .getfile (cls )
299
- class_dir = os . path . dirname (class_file )
300
- f = os . path . realpath ( os . path . join ( class_dir , f ) )
299
+ class_dir = Path (class_file ). parent
300
+ f = ( class_dir / f ). resolve ( )
301
301
return f
302
302
303
- def __init__ (self , func_files : str | list [str ], func_name : str | None = None ):
303
+ def __init__ (
304
+ self ,
305
+ func_files : str | Path | list [str ] | list [Path ],
306
+ func_name : str | None = None ,
307
+ ):
304
308
"""
305
309
Sections are loaded from files in order with sections in later
306
310
files overriding sections in previous files.
307
311
308
312
"""
309
313
if not isinstance (func_files , list ):
310
- self .func_files = [func_files ]
314
+ self .func_files = [Path ( func_files ) ]
311
315
else :
312
- self .func_files = func_files
316
+ self .func_files = [ Path ( func_file ) for func_file in func_files ]
313
317
314
318
self .func_codes : list [str ] = []
315
319
# Keep the original name. If we reload old pickle, we want to
@@ -334,22 +338,20 @@ def __init__(self, func_files: str | list[str], func_name: str | None = None):
334
338
"Cannot have an `op_code_cleanup` section and specify `func_name`"
335
339
)
336
340
337
- def load_c_code (self , func_files : list [ str ]) -> None :
341
+ def load_c_code (self , func_files : Iterable [ Path ]) -> None :
338
342
"""Loads the C code to perform the `Op`."""
339
- func_files = [self .get_path (f ) for f in func_files ]
340
343
for func_file in func_files :
341
- with open (func_file ) as f :
342
- self .func_codes .append (f . read ( ))
344
+ func_file = self . get_path (func_file )
345
+ self .func_codes .append (func_file . read_text ( encoding = "utf-8" ))
343
346
344
347
# If both the old section markers and the new section markers are
345
348
# present, raise an error because we don't know which ones to follow.
346
- old_markers_present = False
347
- new_markers_present = False
348
- for code in self .func_codes :
349
- if self .backward_re .search (code ):
350
- old_markers_present = True
351
- if self .section_re .search (code ):
352
- new_markers_present = True
349
+ old_markers_present = any (
350
+ self .backward_re .search (code ) for code in self .func_codes
351
+ )
352
+ new_markers_present = any (
353
+ self .section_re .search (code ) for code in self .func_codes
354
+ )
353
355
354
356
if old_markers_present and new_markers_present :
355
357
raise ValueError (
@@ -359,7 +361,7 @@ def load_c_code(self, func_files: list[str]) -> None:
359
361
"be used at the same time."
360
362
)
361
363
362
- for i , code in enumerate ( self .func_codes ):
364
+ for func_file , code in zip ( func_files , self .func_codes ):
363
365
if self .backward_re .search (code ):
364
366
# This is backward compat code that will go away in a while
365
367
@@ -380,15 +382,15 @@ def load_c_code(self, func_files: list[str]) -> None:
380
382
if split [0 ].strip () != "" :
381
383
raise ValueError (
382
384
"Stray code before first #section "
383
- f"statement (in file { func_files [ i ] } ): { split [0 ]} "
385
+ f"statement (in file { func_file } ): { split [0 ]} "
384
386
)
385
387
386
388
# Separate the code into the proper sections
387
389
n = 1
388
390
while n < len (split ):
389
391
if split [n ] not in self .SECTIONS :
390
392
raise ValueError (
391
- f"Unknown section type (in file { func_files [ i ] } ): { split [n ]} "
393
+ f"Unknown section type (in file { func_file } ): { split [n ]} "
392
394
)
393
395
if split [n ] not in self .code_sections :
394
396
self .code_sections [split [n ]] = ""
@@ -397,7 +399,7 @@ def load_c_code(self, func_files: list[str]) -> None:
397
399
398
400
else :
399
401
raise ValueError (
400
- f"No valid section marker was found in file { func_files [ i ] } "
402
+ f"No valid section marker was found in file { func_file } "
401
403
)
402
404
403
405
def __get_op_params (self ) -> list [tuple [str , Any ]]:
0 commit comments