Skip to content

Commit 45959d7

Browse files
authored
fix FSDP2 test case failure on XPU (#3771)
* fix FSDP2 test case failure on XPU Signed-off-by: YAO Matrix <matrix.yao@intel.com> * fix style Signed-off-by: YAO Matrix <matrix.yao@intel.com> --------- Signed-off-by: YAO Matrix <matrix.yao@intel.com>
1 parent 8b49352 commit 45959d7

File tree

1 file changed

+6
-13
lines changed

1 file changed

+6
-13
lines changed

src/accelerate/accelerator.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -568,25 +568,18 @@ def __init__(
568568
and self.distributed_type not in (DistributedType.DEEPSPEED, DistributedType.MEGATRON_LM)
569569
):
570570
self.native_amp = True
571-
if self.device.type not in (
572-
"xpu",
573-
"cuda",
574-
"npu",
575-
"xla",
576-
"mlu",
577-
"musa",
578-
"hpu",
579-
"sdaa",
580-
"mps",
581-
) or is_torch_xla_available(check_is_tpu=True):
582-
raise ValueError(f"fp16 mixed precision requires a GPU or MPS device (not {self.device.type!r}).")
571+
supported_device = ("xpu", "cuda", "npu", "xla", "mlu", "musa", "hpu", "sdaa", "mps")
572+
if self.device.type not in supported_device or is_torch_xla_available(check_is_tpu=True):
573+
raise ValueError(
574+
f"fp16 mixed precision requires a device in {supported_device} (not {self.device.type!r})."
575+
)
583576
if self.device.type == "mps" and not is_torch_version(">=", "2.5.0"):
584577
raise ValueError("fp16 mixed precision with MPS device requires a Pytorch >= 2.5.0")
585578
kwargs = self.scaler_handler.to_kwargs() if self.scaler_handler is not None else {}
586579

587580
# FSDP2 doesn't use ShardedGradScaler, don't want to modify `get_grad_scaler`, rather create a simple utility
588581
if self.is_fsdp2:
589-
self.scaler = get_fsdp2_grad_scaler(**kwargs)
582+
self.scaler = get_fsdp2_grad_scaler(device=self.device.type, **kwargs)
590583
else:
591584
self.scaler = get_grad_scaler(self.distributed_type, **kwargs)
592585

0 commit comments

Comments
 (0)