2424THE SOFTWARE.
2525"""
2626
27+ import hashlib
2728import logging
2829import os
2930import re
3031import sys
3132from dataclasses import dataclass
33+ from typing import TYPE_CHECKING , Literal
3234
3335import pyopencl ._cl as _cl
3436
3537
3638logger = logging .getLogger (__name__ )
3739
3840
39- import hashlib
41+ if TYPE_CHECKING :
42+ from collections .abc import Sequence
4043
4144
4245new_hash = hashlib .md5
4346
4447
45- def _erase_dir (directory ):
48+ def _erase_dir (directory : str ):
4649 from os import listdir , rmdir , unlink
4750 from os .path import join
4851
@@ -343,19 +346,25 @@ class _SourceInfo:
343346 log : str | None
344347
345348
346- def _create_built_program_from_source_cached (ctx , src , options_bytes ,
347- devices , cache_dir , include_path ):
349+ def _create_built_program_from_source_cached (
350+ ctx : _cl .Context ,
351+ src : str | bytes ,
352+ options_bytes : bytes ,
353+ devices : Sequence [_cl .Device ] | None ,
354+ cache_dir : str | None ,
355+ include_path : Sequence [str ] | None ):
348356 from os .path import join
349357
350358 if cache_dir is None :
351359 import platformdirs
352360
353361 # Determine the cache directory in the same way as pytools.PersistentDict,
354362 # which PyOpenCL uses for invoker caches.
355- if sys .platform == "darwin" and os .getenv ("XDG_CACHE_HOME" ) is not None :
363+ xdg_cache_home = os .getenv ("XDG_CACHE_HOME" )
364+ if sys .platform == "darwin" and xdg_cache_home is not None :
356365 # platformdirs does not handle XDG_CACHE_HOME on macOS
357366 # https://github.com/platformdirs/platformdirs/issues/269
358- cache_dir = join (os . getenv ( "XDG_CACHE_HOME" ) , "pyopencl" )
367+ cache_dir = join (xdg_cache_home , "pyopencl" )
359368 else :
360369 cache_dir = platformdirs .user_cache_dir ("pyopencl" , "pyopencl" )
361370
@@ -371,7 +380,7 @@ def _create_built_program_from_source_cached(ctx, src, options_bytes,
371380 cache_keys = [get_cache_key (device , options_bytes , src ) for device in devices ]
372381
373382 binaries = []
374- to_be_built_indices = []
383+ to_be_built_indices : list [ int ] = []
375384 logs = []
376385 for i , (_device , cache_key ) in enumerate (zip (devices , cache_keys , strict = True )):
377386 cache_result = retrieve_from_cache (cache_dir , cache_key )
@@ -406,10 +415,13 @@ def _create_built_program_from_source_cached(ctx, src, options_bytes,
406415 already_built = False
407416 was_cached = not to_be_built_indices
408417
418+ if isinstance (src , str ):
419+ src = src .encode ()
420+
409421 if to_be_built_indices :
410422 # defeat implementation caches:
411423 from uuid import uuid4
412- src = src + "\n \n __constant int pyopencl_defeat_cache_%s = 0;" % (
424+ src = src + b "\n \n __constant int pyopencl_defeat_cache_%s = 0;" % (
413425 uuid4 ().hex )
414426
415427 logger .debug (
@@ -462,7 +474,7 @@ def _create_built_program_from_source_cached(ctx, src, options_bytes,
462474 binary_path = mod_cache_dir_m .sub ("binary" )
463475 source_path = mod_cache_dir_m .sub ("source.cl" )
464476
465- with open (source_path , "w " ) as outf :
477+ with open (source_path , "wb " ) as outf :
466478 outf .write (src )
467479
468480 with open (binary_path , "wb" ) as outf :
@@ -486,8 +498,14 @@ def _create_built_program_from_source_cached(ctx, src, options_bytes,
486498 return result , already_built , was_cached
487499
488500
489- def create_built_program_from_source_cached (ctx , src , options_bytes , devices = None ,
490- cache_dir = None , include_path = None ):
501+ def create_built_program_from_source_cached (
502+ ctx : _cl .Context ,
503+ src : str | bytes ,
504+ options_bytes : bytes ,
505+ devices : Sequence [_cl .Device ] | None = None ,
506+ cache_dir : str | Literal [False ] | None = None ,
507+ include_path : Sequence [str ] | None = None
508+ ):
491509 try :
492510 was_cached = False
493511 already_built = False
0 commit comments