diff --git a/jax/_src/pallas/mosaic/tpu_info.py b/jax/_src/pallas/mosaic/tpu_info.py index 40006dc94880..e0e8127c86da 100644 --- a/jax/_src/pallas/mosaic/tpu_info.py +++ b/jax/_src/pallas/mosaic/tpu_info.py @@ -20,8 +20,8 @@ from jax import numpy as jnp from jax._src import dtypes -from jax._src.pallas.mosaic import core from jax._src import util as jax_util +from jax._src.pallas.mosaic import core class ChipVersionBase: @@ -41,9 +41,26 @@ class ChipVersion(ChipVersionBase, enum.Enum): def __str__(self) -> str: return self.value + +DEVICE_KIND_TO_CHIP_VERSION = { + "TPU v2": ChipVersion.TPU_V2, + "TPU v3": ChipVersion.TPU_V3, + "TPU v4": ChipVersion.TPU_V4, + "TPU v4 lite": ChipVersion.TPU_V4I, + "TPU v5e": ChipVersion.TPU_V5E, + "TPU v5 lite": ChipVersion.TPU_V5E, + "TPU v5": ChipVersion.TPU_V5P, + "TPU v5p": ChipVersion.TPU_V5P, + "TPU v6e": ChipVersion.TPU_V6E, + "TPU v6 lite": ChipVersion.TPU_V6E, + "TPU7x": ChipVersion.TPU_7X, +} + + @dataclasses.dataclass(frozen=True, kw_only=True) class SparseCoreInfo: """SparseCore-specific information.""" + num_cores: int num_subcores: int num_lanes: int @@ -122,10 +139,7 @@ def is_matmul_supported( or (lhs_dt in {U4, S4} and rhs_dt in {U4, S4}) ) case 7: - return ( - lhs_dt in {F32, BF16} - and rhs_dt in {F32, BF16} - ) or ( + return (lhs_dt in {F32, BF16} and rhs_dt in {F32, BF16}) or ( lhs_dt in {F32, BF16, F8E5M2, F8E4M3FN} and rhs_dt in {F8E5M2, F8E4M3FN} ) @@ -154,34 +168,34 @@ def get_sublane_tiling(self, dtype: jnp.dtype) -> int: def is_tpu_device() -> bool: - return core.get_device_kind() in { - "TPU v2", - "TPU v3", - "TPU v4", - "TPU v4 lite", - "TPU v5e", - "TPU v5 lite", - "TPU v5", - "TPU v5p", - "TPU v6 lite", - "TPU v6e", - "TPU7x", - } + return core.get_device_kind() in DEVICE_KIND_TO_CHIP_VERSION.keys() registry: dict[str, Callable[[], TpuInfo]] = {} + @jax_util.cache(trace_context_in_key=True) -def get_tpu_info() -> TpuInfo: - """Returns the TPU hardware information for the current device. +def get_tpu_info(chip_version: ChipVersion | None = None) -> TpuInfo: + """Returns the TPU hardware info for the current device or given TPU chip. Note that all information is *per-TensorCore* so you would need to multiply by `num_cores` to obtain the total for the chip. + Args: + chip_version: The TPU chip version to get the information for. If None, the + information for the current device is returned. + Returns: - A TpuInfo object containing the hardware information for the current device. + A TpuInfo object containing the hardware information for the given TPU chip + version. """ - device_kind = core.get_device_kind() + if chip_version is None: + device_kind = core.get_device_kind() + chip_version = DEVICE_KIND_TO_CHIP_VERSION.get(device_kind, None) + if chip_version is None: + if device_kind in registry: + return registry[device_kind]() + raise ValueError(f"Unsupported TPU device kind: {device_kind}") # Common parameters for all TensorCores NUM_LANES = 128 @@ -189,11 +203,11 @@ def get_tpu_info() -> TpuInfo: MXU_COLUMN_SIZE_GEN_LT_6 = 128 MXU_COLUMN_SIZE_GEN_GE_6 = 256 - match device_kind: - case "TPU v2": # 2 TensorCores per chip + match chip_version: + case ChipVersion.TPU_V2: # 2 TensorCores per chip num_chip_cores = 2 return TpuInfo( - chip_version=ChipVersion.TPU_V2, + chip_version=chip_version, generation=2, num_cores=core.get_num_device_cores(), num_lanes=NUM_LANES, @@ -209,10 +223,10 @@ def get_tpu_info() -> TpuInfo: fp8_ops_per_second=0, # Not Available int4_ops_per_second=0, # Not Available ) - case "TPU v3": # 2 TensorCores per chip + case ChipVersion.TPU_V3: # 2 TensorCores per chip num_chip_cores = 2 return TpuInfo( - chip_version=ChipVersion.TPU_V3, + chip_version=chip_version, generation=3, num_cores=core.get_num_device_cores(), num_lanes=NUM_LANES, @@ -228,9 +242,9 @@ def get_tpu_info() -> TpuInfo: fp8_ops_per_second=0, # Not Available int4_ops_per_second=0, # Not Available ) - case "TPU v4 lite": # 1 TensorCore per chip + case ChipVersion.TPU_V4I: # 1 TensorCore per chip return TpuInfo( - chip_version=ChipVersion.TPU_V4I, + chip_version=chip_version, generation=4, num_cores=core.get_num_device_cores(), num_lanes=NUM_LANES, @@ -246,10 +260,10 @@ def get_tpu_info() -> TpuInfo: fp8_ops_per_second=0, # Not Available int4_ops_per_second=0, # Not Available ) - case "TPU v4": # 2 TensorCores per chip + case ChipVersion.TPU_V4: # 2 TensorCores per chip num_chip_cores = 2 return TpuInfo( - chip_version=ChipVersion.TPU_V4, + chip_version=chip_version, generation=4, num_cores=core.get_num_device_cores(), num_lanes=NUM_LANES, @@ -265,9 +279,9 @@ def get_tpu_info() -> TpuInfo: fp8_ops_per_second=0, # Not Available int4_ops_per_second=0, # Not Available ) - case "TPU v5 lite" | "TPU v5e": # 1 TensorCore per chip + case ChipVersion.TPU_V5E: # 1 TensorCore per chip return TpuInfo( - chip_version=ChipVersion.TPU_V5E, + chip_version=chip_version, generation=5, num_cores=core.get_num_device_cores(), num_lanes=NUM_LANES, @@ -283,10 +297,10 @@ def get_tpu_info() -> TpuInfo: fp8_ops_per_second=0, # Not Available int4_ops_per_second=int(7.88e14), ) - case "TPU v5" | "TPU v5p": # 2 TensorCores per chip + case ChipVersion.TPU_V5P: # 2 TensorCores per chip num_chip_cores = 2 return TpuInfo( - chip_version=ChipVersion.TPU_V5P, + chip_version=chip_version, generation=5, num_cores=core.get_num_device_cores(), num_lanes=NUM_LANES, @@ -303,9 +317,9 @@ def get_tpu_info() -> TpuInfo: int4_ops_per_second=int(1.84e15 // num_chip_cores), sparse_core=SparseCoreInfo(num_cores=4, num_subcores=16, num_lanes=8), ) - case "TPU v6 lite" | "TPU v6e": # 1 TensorCore per chip + case ChipVersion.TPU_V6E: # 1 TensorCore per chip return TpuInfo( - chip_version=ChipVersion.TPU_V6E, + chip_version=chip_version, generation=6, num_cores=core.get_num_device_cores(), num_lanes=NUM_LANES, @@ -322,28 +336,28 @@ def get_tpu_info() -> TpuInfo: int4_ops_per_second=int(3.68e15), sparse_core=SparseCoreInfo(num_cores=2, num_subcores=16, num_lanes=8), ) - case "TPU7x": + case ChipVersion.TPU_7X: num_cores = core.get_num_device_cores() num_chip_cores = 2 return TpuInfo( - chip_version=ChipVersion.TPU_7X, - generation=7, - num_cores=num_cores, - num_lanes=128, - num_sublanes=8, - mxu_column_size=256, - vmem_capacity_bytes=64 * 1024 * 1024, # 64 MiB per core - cmem_capacity_bytes=0, - smem_capacity_bytes=1024 * 1024, # 1 MiB per core - hbm_capacity_bytes=206_000_000_000 // num_chip_cores, - mem_bw_bytes_per_second=int(7.40e12 // num_chip_cores), - bf16_ops_per_second=int(2.31e15 // num_chip_cores), - int8_ops_per_second=0, # Not Available - fp8_ops_per_second=int(4.60e15 // num_chip_cores), - int4_ops_per_second=0, # Not Available - sparse_core=SparseCoreInfo(num_cores=4, num_subcores=16, num_lanes=16), - ) - case _ as d: - if d in registry: - return registry[d]() - raise ValueError(f"Unsupported TPU device kind: {device_kind}") + chip_version=chip_version, + generation=7, + num_cores=num_cores, + num_lanes=128, + num_sublanes=8, + mxu_column_size=256, + vmem_capacity_bytes=64 * 1024 * 1024, # 64 MiB per core + cmem_capacity_bytes=0, + smem_capacity_bytes=1024 * 1024, # 1 MiB per core + hbm_capacity_bytes=206_000_000_000 // num_chip_cores, + mem_bw_bytes_per_second=int(7.40e12 // num_chip_cores), + bf16_ops_per_second=int(2.31e15 // num_chip_cores), + int8_ops_per_second=0, # Not Available + fp8_ops_per_second=int(4.60e15 // num_chip_cores), + int4_ops_per_second=0, # Not Available + sparse_core=SparseCoreInfo( + num_cores=4, num_subcores=16, num_lanes=16 + ), + ) + case _: + raise ValueError(f"Unsupported TPU chip version: {chip_version}") diff --git a/tests/pallas/tpu_info_test.py b/tests/pallas/tpu_info_test.py index 9107a1f0d26b..6851a6a5ac3e 100644 --- a/tests/pallas/tpu_info_test.py +++ b/tests/pallas/tpu_info_test.py @@ -46,6 +46,12 @@ def test_get_tpu_info(self): case _: self.fail(f"Unexpected device kind: {device.device_kind}") + def test_get_tpu_info_given_chip_version(self): + for chip_version in pltpu.ChipVersion: + info = pltpu.get_tpu_info(chip_version=chip_version) + self.assertIsInstance(info, pltpu.TpuInfo) + self.assertEqual(info.chip_version, chip_version) + if __name__ == "__main__": jax.config.parse_flags_with_absl()