Skip to content

Commit d60f77b

Browse files
carmoccaBorda
authored andcommitted
Update Habana integration to 1.2 (#18877)
Co-authored-by: Jirka Borovec <[email protected]> (cherry picked from commit 182c30b)
1 parent 03101ce commit d60f77b

File tree

8 files changed

+129
-76
lines changed

8 files changed

+129
-76
lines changed

docs/source-pytorch/conf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def _load_py_module(name: str, location: str) -> ModuleType:
9292
assist_local.AssistantCLI.pull_docs_files(
9393
gh_user_repo="Lightning-AI/lightning-Habana",
9494
target_dir="docs/source-pytorch/integrations/hpu",
95-
checkout="tags/1.1.0",
95+
checkout="tags/1.2.0",
9696
)
9797
assist_local.AssistantCLI.pull_docs_files(
9898
gh_user_repo="Lightning-AI/lightning-Graphcore",
Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
# validation HPU connectors
2-
lightning-habana >=1.0.0
3-
lightning-graphcore >=0.1.0.rc4
1+
# validation accelerator connectors
2+
lightning-habana >=1.2.0, <1.3.0
3+
lightning-graphcore >=0.1.0, <0.2.0

src/lightning/pytorch/trainer/configuration_validator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from lightning.fabric.utilities.warnings import PossibleUserWarning
1717
from lightning.pytorch.trainer.states import TrainerFn
1818
from lightning.pytorch.utilities.exceptions import MisconfigurationException
19-
from lightning.pytorch.utilities.imports import _lightning_graphcore_available
19+
from lightning.pytorch.utilities.imports import _graphcore_available_and_importable
2020
from lightning.pytorch.utilities.model_helpers import is_overridden
2121
from lightning.pytorch.utilities.rank_zero import rank_zero_deprecation, rank_zero_warn
2222
from lightning.pytorch.utilities.signature_utils import is_param_in_hook_signature
@@ -125,7 +125,7 @@ def __verify_batch_transfer_support(trainer: "pl.Trainer") -> None:
125125
datahook_selector = trainer._data_connector._datahook_selector
126126
assert datahook_selector is not None
127127
for hook in batch_transfer_hooks:
128-
if _lightning_graphcore_available():
128+
if _graphcore_available_and_importable():
129129
from lightning_graphcore import IPUAccelerator
130130

131131
# TODO: This code could be done in a hook in the IPUAccelerator as it's a simple error check

src/lightning/pytorch/trainer/connectors/accelerator_connector.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,8 @@
6464
from lightning.pytorch.utilities.imports import (
6565
_LIGHTNING_BAGUA_AVAILABLE,
6666
_LIGHTNING_COLOSSALAI_AVAILABLE,
67-
_lightning_graphcore_available,
68-
_lightning_habana_available,
67+
_graphcore_available_and_importable,
68+
_habana_available_and_importable,
6969
)
7070
from lightning.pytorch.utilities.rank_zero import rank_zero_info, rank_zero_warn
7171

@@ -347,12 +347,12 @@ def _choose_auto_accelerator(self) -> str:
347347
"""Choose the accelerator type (str) based on availability."""
348348
if XLAAccelerator.is_available():
349349
return "tpu"
350-
if _lightning_graphcore_available():
350+
if _graphcore_available_and_importable():
351351
from lightning_graphcore import IPUAccelerator
352352

353353
if IPUAccelerator.is_available():
354354
return "ipu"
355-
if _lightning_habana_available():
355+
if _habana_available_and_importable():
356356
from lightning_habana import HPUAccelerator
357357

358358
if HPUAccelerator.is_available():
@@ -435,7 +435,7 @@ def _choose_and_init_cluster_environment(self) -> ClusterEnvironment:
435435

436436
def _choose_strategy(self) -> Union[Strategy, str]:
437437
if self._accelerator_flag == "ipu":
438-
if not _lightning_graphcore_available():
438+
if not _graphcore_available_and_importable():
439439
raise ImportError(
440440
"You have passed `accelerator='ipu'` but the IPU integration is not installed."
441441
" Please run `pip install lightning-graphcore` or check out"
@@ -445,7 +445,7 @@ def _choose_strategy(self) -> Union[Strategy, str]:
445445

446446
return IPUStrategy.strategy_name
447447
if self._accelerator_flag == "hpu":
448-
if not _lightning_habana_available():
448+
if not _habana_available_and_importable():
449449
raise ImportError(
450450
"You have asked for HPU but you miss install related integration."
451451
" Please run `pip install lightning-habana` or see for further instructions"
@@ -514,7 +514,7 @@ def _check_and_init_precision(self) -> PrecisionPlugin:
514514
if isinstance(self._precision_plugin_flag, PrecisionPlugin):
515515
return self._precision_plugin_flag
516516

517-
if _lightning_graphcore_available():
517+
if _graphcore_available_and_importable():
518518
from lightning_graphcore import IPUAccelerator, IPUPrecision
519519

520520
# TODO: For the strategies that have a fixed precision class, we don't really need this logic
@@ -524,7 +524,7 @@ def _check_and_init_precision(self) -> PrecisionPlugin:
524524
if isinstance(self.accelerator, IPUAccelerator):
525525
return IPUPrecision(self._precision_flag)
526526

527-
if _lightning_habana_available():
527+
if _habana_available_and_importable():
528528
from lightning_habana import HPUAccelerator, HPUPrecisionPlugin
529529

530530
if isinstance(self.accelerator, HPUAccelerator):
@@ -571,7 +571,7 @@ def _check_and_init_precision(self) -> PrecisionPlugin:
571571

572572
def _validate_precision_choice(self) -> None:
573573
"""Validate the combination of choices for precision, AMP type, and accelerator."""
574-
if _lightning_habana_available():
574+
if _habana_available_and_importable():
575575
from lightning_habana import HPUAccelerator
576576

577577
if isinstance(self.accelerator, HPUAccelerator) and self._precision_flag not in (
@@ -626,7 +626,7 @@ def _lazy_init_strategy(self) -> None:
626626
f" found {self.strategy.__class__.__name__}."
627627
)
628628

629-
if _lightning_habana_available():
629+
if _habana_available_and_importable():
630630
from lightning_habana import HPUAccelerator, HPUParallelStrategy, SingleHPUStrategy
631631

632632
if isinstance(self.accelerator, HPUAccelerator) and not isinstance(
@@ -645,7 +645,7 @@ def is_distributed(self) -> bool:
645645
DeepSpeedStrategy,
646646
XLAStrategy,
647647
]
648-
if _lightning_habana_available():
648+
if _habana_available_and_importable():
649649
from lightning_habana import HPUParallelStrategy
650650

651651
distributed_strategies.append(HPUParallelStrategy)
@@ -698,7 +698,7 @@ def _register_external_accelerators_and_strategies() -> None:
698698
if "bagua" not in StrategyRegistry:
699699
BaguaStrategy.register_strategies(StrategyRegistry)
700700

701-
if _lightning_habana_available():
701+
if _habana_available_and_importable():
702702
from lightning_habana import HPUAccelerator, HPUParallelStrategy, SingleHPUStrategy
703703

704704
# TODO: Prevent registering multiple times
@@ -709,7 +709,7 @@ def _register_external_accelerators_and_strategies() -> None:
709709
if "hpu_single" not in StrategyRegistry:
710710
SingleHPUStrategy.register_strategies(StrategyRegistry)
711711

712-
if _lightning_graphcore_available():
712+
if _graphcore_available_and_importable():
713713
from lightning_graphcore import IPUAccelerator, IPUStrategy
714714

715715
# TODO: Prevent registering multiple times

src/lightning/pytorch/trainer/connectors/data_connector.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from lightning.pytorch.utilities.combined_loader import CombinedLoader
3535
from lightning.pytorch.utilities.data import _is_dataloader_shuffled, _update_dataloader
3636
from lightning.pytorch.utilities.exceptions import MisconfigurationException
37-
from lightning.pytorch.utilities.imports import _lightning_graphcore_available
37+
from lightning.pytorch.utilities.imports import _graphcore_available_and_importable
3838
from lightning.pytorch.utilities.model_helpers import is_overridden
3939
from lightning.pytorch.utilities.rank_zero import WarningCache, rank_zero_warn
4040
from lightning.pytorch.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS
@@ -165,7 +165,7 @@ def attach_datamodule(
165165
datamodule.trainer = trainer
166166

167167
def _requires_distributed_sampler(self, dataloader: DataLoader) -> bool:
168-
if _lightning_graphcore_available():
168+
if _graphcore_available_and_importable():
169169
from lightning_graphcore import IPUAccelerator
170170

171171
# `DistributedSampler` is never used with `poptorch.DataLoader`
@@ -191,7 +191,7 @@ def _prepare_dataloader(self, dataloader: object, shuffle: bool, mode: RunningSt
191191
if not isinstance(dataloader, DataLoader):
192192
return dataloader
193193

194-
if _lightning_graphcore_available():
194+
if _graphcore_available_and_importable():
195195
from lightning_graphcore import IPUAccelerator
196196

197197
# IPUs use a custom `poptorch.DataLoader` which we might need to convert to

src/lightning/pytorch/trainer/setup.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
XLAProfiler,
2929
)
3030
from lightning.pytorch.utilities.exceptions import MisconfigurationException
31-
from lightning.pytorch.utilities.imports import _lightning_graphcore_available, _lightning_habana_available
31+
from lightning.pytorch.utilities.imports import _graphcore_available_and_importable, _habana_available_and_importable
3232
from lightning.pytorch.utilities.rank_zero import rank_zero_info, rank_zero_warn
3333

3434

@@ -158,7 +158,7 @@ def _log_device_info(trainer: "pl.Trainer") -> None:
158158
num_tpu_cores = trainer.num_devices if isinstance(trainer.accelerator, XLAAccelerator) else 0
159159
rank_zero_info(f"TPU available: {XLAAccelerator.is_available()}, using: {num_tpu_cores} TPU cores")
160160

161-
if _lightning_graphcore_available():
161+
if _graphcore_available_and_importable():
162162
from lightning_graphcore import IPUAccelerator
163163

164164
num_ipus = trainer.num_devices if isinstance(trainer.accelerator, IPUAccelerator) else 0
@@ -168,7 +168,7 @@ def _log_device_info(trainer: "pl.Trainer") -> None:
168168
ipu_available = False
169169
rank_zero_info(f"IPU available: {ipu_available}, using: {num_ipus} IPUs")
170170

171-
if _lightning_habana_available():
171+
if _habana_available_and_importable():
172172
from lightning_habana import HPUAccelerator
173173

174174
num_hpus = trainer.num_devices if isinstance(trainer.accelerator, HPUAccelerator) else 0
@@ -192,13 +192,13 @@ def _log_device_info(trainer: "pl.Trainer") -> None:
192192
if XLAAccelerator.is_available() and not isinstance(trainer.accelerator, XLAAccelerator):
193193
rank_zero_warn("TPU available but not used. You can set it by doing `Trainer(accelerator='tpu')`.")
194194

195-
if _lightning_graphcore_available():
195+
if _graphcore_available_and_importable():
196196
from lightning_graphcore import IPUAccelerator
197197

198198
if IPUAccelerator.is_available() and not isinstance(trainer.accelerator, IPUAccelerator):
199199
rank_zero_warn("IPU available but not used. You can set it by doing `Trainer(accelerator='ipu')`.")
200200

201-
if _lightning_habana_available():
201+
if _habana_available_and_importable():
202202
from lightning_habana import HPUAccelerator
203203

204204
if HPUAccelerator.is_available() and not isinstance(trainer.accelerator, HPUAccelerator):

src/lightning/pytorch/utilities/imports.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,21 +35,25 @@ def _try_import_module(module_name: str) -> bool:
3535
try:
3636
__import__(module_name)
3737
return True
38-
# added also AttributeError fro case of impoerts like pl.LightningModule
38+
# Also on AttributeError for failed imports like pl.LightningModule
3939
except (ImportError, AttributeError) as err:
40-
rank_zero_warn(f"Import of {module_name} package failed for some compatibility issues: \n{err}")
40+
rank_zero_warn(f"Import of {module_name} package failed for some compatibility issues:\n{err}")
4141
return False
4242

4343

44-
@functools.lru_cache(maxsize=1)
45-
def _lightning_graphcore_available() -> bool:
44+
_LIGHTNING_GRAPHCORE_AVAILABLE = RequirementCache("lightning-graphcore>=0.1.0")
45+
46+
47+
def _graphcore_available_and_importable() -> bool:
4648
# This is defined as a function instead of a constant to avoid circular imports, because `lightning_graphcore`
4749
# also imports Lightning
48-
return bool(RequirementCache("lightning-graphcore")) and _try_import_module("lightning_graphcore")
50+
return bool(_LIGHTNING_GRAPHCORE_AVAILABLE) and _try_import_module("lightning_graphcore")
51+
52+
53+
_LIGHTNING_HABANA_AVAILABLE = RequirementCache("lightning-habana>=1.2.0")
4954

5055

51-
@functools.lru_cache(maxsize=1)
52-
def _lightning_habana_available() -> bool:
56+
def _habana_available_and_importable() -> bool:
5357
# This is defined as a function instead of a constant to avoid circular imports, because `lightning_habana`
5458
# also imports Lightning
55-
return bool(RequirementCache("lightning-habana")) and _try_import_module("lightning_habana")
59+
return bool(_LIGHTNING_HABANA_AVAILABLE) and _try_import_module("lightning_habana")

0 commit comments

Comments
 (0)