Skip to content

Commit 06a3303

Browse files
committed
update amp
1 parent 5cdd9e7 commit 06a3303

File tree

2 files changed

+2
-2
lines changed
  • src/lightning
    • fabric/plugins/precision
    • pytorch/plugins/precision

2 files changed

+2
-2
lines changed

src/lightning/fabric/plugins/precision/amp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def __init__(
5656
if _TORCH_GREATER_EQUAL_2_4
5757
else getattr(
5858
torch,
59-
"cuda" if not isinstance(device, str) or device.split(":")[0] == "cpu" else device.split(":")[0],
59+
"cuda" if device.split(":")[0] == "cpu" else device.split(":")[0],
6060
).amp.GradScaler()
6161
)
6262
if scaler is not None and self.precision == "bf16-mixed":

src/lightning/pytorch/plugins/precision/amp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def __init__(
5656
if _TORCH_GREATER_EQUAL_2_4
5757
else getattr(
5858
torch,
59-
"cuda" if not isinstance(device, str) or device.split(":")[0] == "cpu" else device.split(":")[0],
59+
"cuda" if device.split(":")[0] == "cpu" else device.split(":")[0],
6060
).amp.GradScaler()
6161
)
6262
if scaler is not None and self.precision == "bf16-mixed":

0 commit comments

Comments
 (0)