Skip to content

Commit b446b08

Browse files
authored
Fallback to ACCELERATOR_TYPE for TPU flops (#19314)
1 parent 7cc79fe commit b446b08

File tree

2 files changed

+8
-2
lines changed

2 files changed

+8
-2
lines changed

src/lightning/fabric/utilities/throughput.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -597,7 +597,9 @@ def get_available_flops(device: torch.device, dtype: Union[torch.dtype, str]) ->
597597
else:
598598
from torch_xla.experimental import tpu
599599

600-
device_name = tpu.get_tpu_env()["TYPE"]
600+
tpu_env = tpu.get_tpu_env()
601+
# not all TPU generations define the "TYPE" envar. example: TYPE="V4", ACCELERATOR_TYPE="v4-8"
602+
device_name = tpu_env.get("TYPE") or tpu_env["ACCELERATOR_TYPE"].split("-")[0]
601603
chip = device_name.lower()
602604
assert isinstance(device_name, str)
603605
if chip not in _TPU_FLOPS:

tests/tests_fabric/utilities/test_throughput.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,15 +49,19 @@ def test_get_available_flops(xla_available):
4949
from torch_xla.experimental import tpu
5050

5151
assert isinstance(tpu, Mock)
52-
tpu.get_tpu_env.return_value = {"TYPE": "V4"}
5352

53+
tpu.get_tpu_env.return_value = {"TYPE": "V4"}
5454
flops = get_available_flops(torch.device("xla"), torch.bfloat16)
5555
assert flops == 275e12
5656

5757
tpu.get_tpu_env.return_value = {"TYPE": "V1"}
5858
with pytest.warns(match="not found for TPU 'V1'"):
5959
assert get_available_flops(torch.device("xla"), torch.bfloat16) is None
6060

61+
tpu.get_tpu_env.return_value = {"ACCELERATOR_TYPE": "v3-8"}
62+
flops = get_available_flops(torch.device("xla"), torch.bfloat16)
63+
assert flops == 123e12
64+
6165
tpu.reset_mock()
6266

6367

0 commit comments

Comments
 (0)