File tree Expand file tree Collapse file tree 2 files changed +11
-14
lines changed Expand file tree Collapse file tree 2 files changed +11
-14
lines changed Original file line number Diff line number Diff line change @@ -447,7 +447,16 @@ def module_is_offloaded(module):
447447 )
448448
449449 # Enable generic support for Intel Gaudi accelerator using GPU/HPU migration
450- if kwargs .pop ("hpu_migration" , True ) and is_hpu_available ():
450+ if device_type == "hpu" and kwargs .pop ("hpu_migration" , True ) and is_hpu_available ():
451+ os .environ ["PT_HPU_GPU_MIGRATION" ] = "1"
452+ logger .debug ("Environment variable set: PT_HPU_GPU_MIGRATION=1" )
453+
454+ import habana_frameworks .torch # noqa: F401
455+
456+ # HPU hardware check
457+ if not (hasattr (torch , "hpu" ) and torch .hpu .is_available ()):
458+ raise ValueError ("You are trying to call `.to('hpu')` but HPU device is unavailable." )
459+
451460 os .environ ["PT_HPU_MAX_COMPOUND_OP_SIZE" ] = "1"
452461 logger .debug ("Environment variable set: PT_HPU_MAX_COMPOUND_OP_SIZE=1" )
453462
Original file line number Diff line number Diff line change @@ -337,19 +337,7 @@ def is_timm_available():
337337
338338
339339def is_hpu_available ():
340- if (
341- importlib .util .find_spec ("habana_frameworks" ) is None
342- or importlib .util .find_spec ("habana_frameworks.torch" ) is None
343- ):
344- return False
345-
346- os .environ ["PT_HPU_GPU_MIGRATION" ] = "1"
347- logger .debug ("Environment variable set: PT_HPU_GPU_MIGRATION=1" )
348-
349- import habana_frameworks .torch # noqa: F401
350- import torch
351-
352- return hasattr (torch , "hpu" ) and torch .hpu .is_available ()
340+ return all (importlib .util .find_spec (lib ) for lib in ("habana_frameworks" , "habana_frameworks.torch" ))
353341
354342
355343# docstyle-ignore
You can’t perform that action at this time.
0 commit comments