@@ -28,9 +28,7 @@ def _get_cache_dir() -> str | None:
2828 """Returns the kernels cache directory."""
2929 cache_dir = os .environ .get ("HF_KERNELS_CACHE" , None )
3030 if cache_dir is not None :
31- logging .warning (
32- "HF_KERNELS_CACHE will be removed in the future, use KERNELS_CACHE instead"
33- )
31+ logging .warning ("HF_KERNELS_CACHE will be removed in the future, use KERNELS_CACHE instead" )
3432 return cache_dir
3533
3634 return os .environ .get ("KERNELS_CACHE" , None )
@@ -50,7 +48,11 @@ def _get_privateuse_backend_name() -> str | None:
5048def backend () -> str :
5149 import torch
5250
53- if torch .version .cuda is not None :
51+ if hasattr (torch , "neuron" ):
52+ # Needs to be sorted before specific Torch builds, since Neuron
53+ # extension can be loaded into e.g. CUDA Torch builds.
54+ return "neuron"
55+ elif torch .version .cuda is not None :
5456 return "cuda"
5557 elif torch .version .hip is not None :
5658 return "hip"
@@ -104,7 +106,11 @@ def build_variant() -> str:
104106def build_variant_noarch () -> str :
105107 import torch
106108
107- if torch .version .cuda is not None :
109+ if hasattr (torch , "neuron" ):
110+ # Needs to be sorted before specific Torch builds, since Neuron
111+ # extension can be loaded into e.g. CUDA Torch builds.
112+ return "torch-neuron"
113+ elif torch .version .cuda is not None :
108114 return "torch-cuda"
109115 elif torch .version .hip is not None :
110116 return "torch-rocm"
@@ -197,9 +203,7 @@ def install_kernel(
197203 try :
198204 return _find_kernel_in_repo_path (repo_path , package_name , variant_locks )
199205 except FileNotFoundError :
200- raise FileNotFoundError (
201- f"Cannot install kernel from repo { repo_id } (revision: { revision } )"
202- )
206+ raise FileNotFoundError (f"Cannot install kernel from repo { repo_id } (revision: { revision } )" )
203207
204208
205209def _find_kernel_in_repo_path (
@@ -264,9 +268,7 @@ def install_kernel_all_variants(
264268 if variant_lock is None :
265269 raise ValueError (f"No lock found for build variant: { variant } " )
266270
267- validate_kernel (
268- repo_path = repo_path , variant = variant , hash = variant_lock .hash
269- )
271+ validate_kernel (repo_path = repo_path , variant = variant , hash = variant_lock .hash )
270272
271273 return repo_path / "build"
272274
@@ -309,9 +311,7 @@ def get_kernel(
309311 ```
310312 """
311313 revision = select_revision_or_version (repo_id , revision = revision , version = version )
312- package_name , variant_path = install_kernel (
313- repo_id , revision = revision , user_agent = user_agent
314- )
314+ package_name , variant_path = install_kernel (repo_id , revision = revision , user_agent = user_agent )
315315 return _import_from_path (package_name , variant_path )
316316
317317
@@ -344,9 +344,7 @@ def get_local_kernel(repo_path: Path, package_name: str) -> ModuleType:
344344 raise FileNotFoundError (f"Could not find package '{ package_name } ' in { repo_path } " )
345345
346346
347- def has_kernel (
348- repo_id : str , revision : str | None = None , version : int | str | None = None
349- ) -> bool :
347+ def has_kernel (repo_id : str , revision : str | None = None , version : int | str | None = None ) -> bool :
350348 """
351349 Check whether a kernel build exists for the current environment (Torch version and compute framework).
352350
@@ -419,9 +417,7 @@ def load_kernel(repo_id: str, *, lockfile: Path | None) -> ModuleType:
419417 )
420418
421419 try :
422- package_name , variant_path = _find_kernel_in_repo_path (
423- repo_path , package_name , variant_locks = None
424- )
420+ package_name , variant_path = _find_kernel_in_repo_path (repo_path , package_name , variant_locks = None )
425421 return _import_from_path (package_name , variant_path )
426422 except FileNotFoundError :
427423 raise FileNotFoundError (
@@ -447,9 +443,7 @@ def get_locked_kernel(repo_id: str, local_files_only: bool = False) -> ModuleTyp
447443 if locked_sha is None :
448444 raise ValueError (f"Kernel `{ repo_id } ` is not locked" )
449445
450- package_name , variant_path = install_kernel (
451- repo_id , locked_sha , local_files_only = local_files_only
452- )
446+ package_name , variant_path = install_kernel (repo_id , locked_sha , local_files_only = local_files_only )
453447
454448 return _import_from_path (package_name , variant_path )
455449
0 commit comments