diff --git a/src/lightning/fabric/strategies/deepspeed.py b/src/lightning/fabric/strategies/deepspeed.py index 0461414d2b975..bb381a49c7782 100644 --- a/src/lightning/fabric/strategies/deepspeed.py +++ b/src/lightning/fabric/strategies/deepspeed.py @@ -484,8 +484,9 @@ def load_checkpoint( torch.cuda.empty_cache() with suppress(AttributeError): - torch.xpu.empty_cache() - + if _LIGHTNING_XPU_AVAILABLE: + XPUAccelerator.teardown() + _, client_state = engine.load_checkpoint( path, tag="checkpoint",