Skip to content

Commit 2f0b2ad

Browse files
committed
Add some types to pyopencl.cache
1 parent 28c2ac1 commit 2f0b2ad

File tree

1 file changed

+29
-11
lines changed

1 file changed

+29
-11
lines changed

pyopencl/cache.py

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,25 +24,28 @@
2424
THE SOFTWARE.
2525
"""
2626

27+
import hashlib
2728
import logging
2829
import os
2930
import re
3031
import sys
3132
from dataclasses import dataclass
33+
from typing import TYPE_CHECKING, Literal
3234

3335
import pyopencl._cl as _cl
3436

3537

3638
logger = logging.getLogger(__name__)
3739

3840

39-
import hashlib
41+
if TYPE_CHECKING:
42+
from collections.abc import Sequence
4043

4144

4245
new_hash = hashlib.md5
4346

4447

45-
def _erase_dir(directory):
48+
def _erase_dir(directory: str):
4649
from os import listdir, rmdir, unlink
4750
from os.path import join
4851

@@ -343,19 +346,25 @@ class _SourceInfo:
343346
log: str | None
344347

345348

346-
def _create_built_program_from_source_cached(ctx, src, options_bytes,
347-
devices, cache_dir, include_path):
349+
def _create_built_program_from_source_cached(
350+
ctx: _cl.Context,
351+
src: str | bytes,
352+
options_bytes: bytes,
353+
devices: Sequence[_cl.Device] | None,
354+
cache_dir: str | None,
355+
include_path: Sequence[str] | None):
348356
from os.path import join
349357

350358
if cache_dir is None:
351359
import platformdirs
352360

353361
# Determine the cache directory in the same way as pytools.PersistentDict,
354362
# which PyOpenCL uses for invoker caches.
355-
if sys.platform == "darwin" and os.getenv("XDG_CACHE_HOME") is not None:
363+
xdg_cache_home = os.getenv("XDG_CACHE_HOME")
364+
if sys.platform == "darwin" and xdg_cache_home is not None:
356365
# platformdirs does not handle XDG_CACHE_HOME on macOS
357366
# https://github.com/platformdirs/platformdirs/issues/269
358-
cache_dir = join(os.getenv("XDG_CACHE_HOME"), "pyopencl")
367+
cache_dir = join(xdg_cache_home, "pyopencl")
359368
else:
360369
cache_dir = platformdirs.user_cache_dir("pyopencl", "pyopencl")
361370

@@ -371,7 +380,7 @@ def _create_built_program_from_source_cached(ctx, src, options_bytes,
371380
cache_keys = [get_cache_key(device, options_bytes, src) for device in devices]
372381

373382
binaries = []
374-
to_be_built_indices = []
383+
to_be_built_indices: list[int] = []
375384
logs = []
376385
for i, (_device, cache_key) in enumerate(zip(devices, cache_keys, strict=True)):
377386
cache_result = retrieve_from_cache(cache_dir, cache_key)
@@ -406,10 +415,13 @@ def _create_built_program_from_source_cached(ctx, src, options_bytes,
406415
already_built = False
407416
was_cached = not to_be_built_indices
408417

418+
if isinstance(src, str):
419+
src = src.encode()
420+
409421
if to_be_built_indices:
410422
# defeat implementation caches:
411423
from uuid import uuid4
412-
src = src + "\n\n__constant int pyopencl_defeat_cache_%s = 0;" % (
424+
src = src + b"\n\n__constant int pyopencl_defeat_cache_%s = 0;" % (
413425
uuid4().hex)
414426

415427
logger.debug(
@@ -462,7 +474,7 @@ def _create_built_program_from_source_cached(ctx, src, options_bytes,
462474
binary_path = mod_cache_dir_m.sub("binary")
463475
source_path = mod_cache_dir_m.sub("source.cl")
464476

465-
with open(source_path, "w") as outf:
477+
with open(source_path, "wb") as outf:
466478
outf.write(src)
467479

468480
with open(binary_path, "wb") as outf:
@@ -486,8 +498,14 @@ def _create_built_program_from_source_cached(ctx, src, options_bytes,
486498
return result, already_built, was_cached
487499

488500

489-
def create_built_program_from_source_cached(ctx, src, options_bytes, devices=None,
490-
cache_dir=None, include_path=None):
501+
def create_built_program_from_source_cached(
502+
ctx: _cl.Context,
503+
src: str | bytes,
504+
options_bytes: bytes,
505+
devices: Sequence[_cl.Device] | None = None,
506+
cache_dir: str | Literal[False] | None = None,
507+
include_path: Sequence[str] | None = None
508+
):
491509
try:
492510
was_cached = False
493511
already_built = False

0 commit comments

Comments
 (0)