Skip to content

Commit 38c473f

Browse files
[Pallas/TPU] tpu_info.get_tpu_info optionally takes a TPU chip version argument
* The motivation is to make TPU info for available for chips that not the current device * If None, matches the current behavior: return the TPU for the current device PiperOrigin-RevId: 834318595
1 parent e45b7b6 commit 38c473f

File tree

2 files changed

+79
-59
lines changed

2 files changed

+79
-59
lines changed

jax/_src/pallas/mosaic/tpu_info.py

Lines changed: 73 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020

2121
from jax import numpy as jnp
2222
from jax._src import dtypes
23-
from jax._src.pallas.mosaic import core
2423
from jax._src import util as jax_util
24+
from jax._src.pallas.mosaic import core
2525

2626

2727
class ChipVersionBase:
@@ -41,9 +41,26 @@ class ChipVersion(ChipVersionBase, enum.Enum):
4141
def __str__(self) -> str:
4242
return self.value
4343

44+
45+
DEVICE_KIND_TO_CHIP_VERSION = {
46+
"TPU v2": ChipVersion.TPU_V2,
47+
"TPU v3": ChipVersion.TPU_V3,
48+
"TPU v4": ChipVersion.TPU_V4,
49+
"TPU v4 lite": ChipVersion.TPU_V4I,
50+
"TPU v5e": ChipVersion.TPU_V5E,
51+
"TPU v5 lite": ChipVersion.TPU_V5E,
52+
"TPU v5": ChipVersion.TPU_V5P,
53+
"TPU v5p": ChipVersion.TPU_V5P,
54+
"TPU v6e": ChipVersion.TPU_V6E,
55+
"TPU v6 lite": ChipVersion.TPU_V6E,
56+
"TPU7x": ChipVersion.TPU_7X,
57+
}
58+
59+
4460
@dataclasses.dataclass(frozen=True, kw_only=True)
4561
class SparseCoreInfo:
4662
"""SparseCore-specific information."""
63+
4764
num_cores: int
4865
num_subcores: int
4966
num_lanes: int
@@ -122,10 +139,7 @@ def is_matmul_supported(
122139
or (lhs_dt in {U4, S4} and rhs_dt in {U4, S4})
123140
)
124141
case 7:
125-
return (
126-
lhs_dt in {F32, BF16}
127-
and rhs_dt in {F32, BF16}
128-
) or (
142+
return (lhs_dt in {F32, BF16} and rhs_dt in {F32, BF16}) or (
129143
lhs_dt in {F32, BF16, F8E5M2, F8E4M3FN}
130144
and rhs_dt in {F8E5M2, F8E4M3FN}
131145
)
@@ -154,46 +168,46 @@ def get_sublane_tiling(self, dtype: jnp.dtype) -> int:
154168

155169

156170
def is_tpu_device() -> bool:
157-
return core.get_device_kind() in {
158-
"TPU v2",
159-
"TPU v3",
160-
"TPU v4",
161-
"TPU v4 lite",
162-
"TPU v5e",
163-
"TPU v5 lite",
164-
"TPU v5",
165-
"TPU v5p",
166-
"TPU v6 lite",
167-
"TPU v6e",
168-
"TPU7x",
169-
}
171+
return core.get_device_kind() in DEVICE_KIND_TO_CHIP_VERSION.keys()
170172

171173

172174
registry: dict[str, Callable[[], TpuInfo]] = {}
173175

176+
174177
@jax_util.cache(trace_context_in_key=True)
175-
def get_tpu_info() -> TpuInfo:
176-
"""Returns the TPU hardware information for the current device.
178+
def get_tpu_info(chip_version: ChipVersion | None = None) -> TpuInfo:
179+
"""Returns the TPU hardware info for the current device or given TPU chip.
177180
178181
Note that all information is *per-TensorCore* so you would need to multiply by
179182
`num_cores` to obtain the total for the chip.
180183
184+
Args:
185+
chip_version: The TPU chip version to get the information for. If None, the
186+
information for the current device is returned.
187+
181188
Returns:
182-
A TpuInfo object containing the hardware information for the current device.
189+
A TpuInfo object containing the hardware information for the given TPU chip
190+
version.
183191
"""
184-
device_kind = core.get_device_kind()
192+
if chip_version is None:
193+
device_kind = core.get_device_kind()
194+
chip_version = DEVICE_KIND_TO_CHIP_VERSION.get(device_kind, None)
195+
if chip_version is None:
196+
if device_kind in registry:
197+
return registry[device_kind]()
198+
raise ValueError(f"Unsupported TPU device kind: {device_kind}")
185199

186200
# Common parameters for all TensorCores
187201
NUM_LANES = 128
188202
NUM_SUBLANES = 8
189203
MXU_COLUMN_SIZE_GEN_LT_6 = 128
190204
MXU_COLUMN_SIZE_GEN_GE_6 = 256
191205

192-
match device_kind:
193-
case "TPU v2": # 2 TensorCores per chip
206+
match chip_version:
207+
case ChipVersion.TPU_V2: # 2 TensorCores per chip
194208
num_chip_cores = 2
195209
return TpuInfo(
196-
chip_version=ChipVersion.TPU_V2,
210+
chip_version=chip_version,
197211
generation=2,
198212
num_cores=core.get_num_device_cores(),
199213
num_lanes=NUM_LANES,
@@ -209,10 +223,10 @@ def get_tpu_info() -> TpuInfo:
209223
fp8_ops_per_second=0, # Not Available
210224
int4_ops_per_second=0, # Not Available
211225
)
212-
case "TPU v3": # 2 TensorCores per chip
226+
case ChipVersion.TPU_V3: # 2 TensorCores per chip
213227
num_chip_cores = 2
214228
return TpuInfo(
215-
chip_version=ChipVersion.TPU_V3,
229+
chip_version=chip_version,
216230
generation=3,
217231
num_cores=core.get_num_device_cores(),
218232
num_lanes=NUM_LANES,
@@ -228,9 +242,9 @@ def get_tpu_info() -> TpuInfo:
228242
fp8_ops_per_second=0, # Not Available
229243
int4_ops_per_second=0, # Not Available
230244
)
231-
case "TPU v4 lite": # 1 TensorCore per chip
245+
case ChipVersion.TPU_V4I: # 1 TensorCore per chip
232246
return TpuInfo(
233-
chip_version=ChipVersion.TPU_V4I,
247+
chip_version=chip_version,
234248
generation=4,
235249
num_cores=core.get_num_device_cores(),
236250
num_lanes=NUM_LANES,
@@ -246,10 +260,10 @@ def get_tpu_info() -> TpuInfo:
246260
fp8_ops_per_second=0, # Not Available
247261
int4_ops_per_second=0, # Not Available
248262
)
249-
case "TPU v4": # 2 TensorCores per chip
263+
case ChipVersion.TPU_V4: # 2 TensorCores per chip
250264
num_chip_cores = 2
251265
return TpuInfo(
252-
chip_version=ChipVersion.TPU_V4,
266+
chip_version=chip_version,
253267
generation=4,
254268
num_cores=core.get_num_device_cores(),
255269
num_lanes=NUM_LANES,
@@ -265,9 +279,9 @@ def get_tpu_info() -> TpuInfo:
265279
fp8_ops_per_second=0, # Not Available
266280
int4_ops_per_second=0, # Not Available
267281
)
268-
case "TPU v5 lite" | "TPU v5e": # 1 TensorCore per chip
282+
case ChipVersion.TPU_V5E: # 1 TensorCore per chip
269283
return TpuInfo(
270-
chip_version=ChipVersion.TPU_V5E,
284+
chip_version=chip_version,
271285
generation=5,
272286
num_cores=core.get_num_device_cores(),
273287
num_lanes=NUM_LANES,
@@ -283,10 +297,10 @@ def get_tpu_info() -> TpuInfo:
283297
fp8_ops_per_second=0, # Not Available
284298
int4_ops_per_second=int(7.88e14),
285299
)
286-
case "TPU v5" | "TPU v5p": # 2 TensorCores per chip
300+
case ChipVersion.TPU_V5P: # 2 TensorCores per chip
287301
num_chip_cores = 2
288302
return TpuInfo(
289-
chip_version=ChipVersion.TPU_V5P,
303+
chip_version=chip_version,
290304
generation=5,
291305
num_cores=core.get_num_device_cores(),
292306
num_lanes=NUM_LANES,
@@ -303,9 +317,9 @@ def get_tpu_info() -> TpuInfo:
303317
int4_ops_per_second=int(1.84e15 // num_chip_cores),
304318
sparse_core=SparseCoreInfo(num_cores=4, num_subcores=16, num_lanes=8),
305319
)
306-
case "TPU v6 lite" | "TPU v6e": # 1 TensorCore per chip
320+
case ChipVersion.TPU_V6E: # 1 TensorCore per chip
307321
return TpuInfo(
308-
chip_version=ChipVersion.TPU_V6E,
322+
chip_version=chip_version,
309323
generation=6,
310324
num_cores=core.get_num_device_cores(),
311325
num_lanes=NUM_LANES,
@@ -322,28 +336,28 @@ def get_tpu_info() -> TpuInfo:
322336
int4_ops_per_second=int(3.68e15),
323337
sparse_core=SparseCoreInfo(num_cores=2, num_subcores=16, num_lanes=8),
324338
)
325-
case "TPU7x":
339+
case ChipVersion.TPU_7X:
326340
num_cores = core.get_num_device_cores()
327341
num_chip_cores = 2
328342
return TpuInfo(
329-
chip_version=ChipVersion.TPU_7X,
330-
generation=7,
331-
num_cores=num_cores,
332-
num_lanes=128,
333-
num_sublanes=8,
334-
mxu_column_size=256,
335-
vmem_capacity_bytes=64 * 1024 * 1024, # 64 MiB per core
336-
cmem_capacity_bytes=0,
337-
smem_capacity_bytes=1024 * 1024, # 1 MiB per core
338-
hbm_capacity_bytes=206_000_000_000 // num_chip_cores,
339-
mem_bw_bytes_per_second=int(7.40e12 // num_chip_cores),
340-
bf16_ops_per_second=int(2.31e15 // num_chip_cores),
341-
int8_ops_per_second=0, # Not Available
342-
fp8_ops_per_second=int(4.60e15 // num_chip_cores),
343-
int4_ops_per_second=0, # Not Available
344-
sparse_core=SparseCoreInfo(num_cores=4, num_subcores=16, num_lanes=16),
345-
)
346-
case _ as d:
347-
if d in registry:
348-
return registry[d]()
349-
raise ValueError(f"Unsupported TPU device kind: {device_kind}")
343+
chip_version=chip_version,
344+
generation=7,
345+
num_cores=num_cores,
346+
num_lanes=128,
347+
num_sublanes=8,
348+
mxu_column_size=256,
349+
vmem_capacity_bytes=64 * 1024 * 1024, # 64 MiB per core
350+
cmem_capacity_bytes=0,
351+
smem_capacity_bytes=1024 * 1024, # 1 MiB per core
352+
hbm_capacity_bytes=206_000_000_000 // num_chip_cores,
353+
mem_bw_bytes_per_second=int(7.40e12 // num_chip_cores),
354+
bf16_ops_per_second=int(2.31e15 // num_chip_cores),
355+
int8_ops_per_second=0, # Not Available
356+
fp8_ops_per_second=int(4.60e15 // num_chip_cores),
357+
int4_ops_per_second=0, # Not Available
358+
sparse_core=SparseCoreInfo(
359+
num_cores=4, num_subcores=16, num_lanes=16
360+
),
361+
)
362+
case _:
363+
raise ValueError(f"Unsupported TPU chip version: {chip_version}")

tests/pallas/tpu_info_test.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,12 @@ def test_get_tpu_info(self):
4646
case _:
4747
self.fail(f"Unexpected device kind: {device.device_kind}")
4848

49+
def test_get_tpu_info_given_chip_version(self):
50+
for chip_version in pltpu.ChipVersion:
51+
info = pltpu.get_tpu_info(chip_version=chip_version)
52+
self.assertIsInstance(info, pltpu.TpuInfo)
53+
self.assertEqual(info.chip_version, chip_version)
54+
4955

5056
if __name__ == "__main__":
5157
jax.config.parse_flags_with_absl()

0 commit comments

Comments
 (0)