Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 73 additions & 59 deletions jax/_src/pallas/mosaic/tpu_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@

from jax import numpy as jnp
from jax._src import dtypes
from jax._src.pallas.mosaic import core
from jax._src import util as jax_util
from jax._src.pallas.mosaic import core


class ChipVersionBase:
Expand All @@ -41,9 +41,26 @@ class ChipVersion(ChipVersionBase, enum.Enum):
def __str__(self) -> str:
return self.value


DEVICE_KIND_TO_CHIP_VERSION = {
"TPU v2": ChipVersion.TPU_V2,
"TPU v3": ChipVersion.TPU_V3,
"TPU v4": ChipVersion.TPU_V4,
"TPU v4 lite": ChipVersion.TPU_V4I,
"TPU v5e": ChipVersion.TPU_V5E,
"TPU v5 lite": ChipVersion.TPU_V5E,
"TPU v5": ChipVersion.TPU_V5P,
"TPU v5p": ChipVersion.TPU_V5P,
"TPU v6e": ChipVersion.TPU_V6E,
"TPU v6 lite": ChipVersion.TPU_V6E,
"TPU7x": ChipVersion.TPU_7X,
}


@dataclasses.dataclass(frozen=True, kw_only=True)
class SparseCoreInfo:
"""SparseCore-specific information."""

num_cores: int
num_subcores: int
num_lanes: int
Expand Down Expand Up @@ -122,10 +139,7 @@ def is_matmul_supported(
or (lhs_dt in {U4, S4} and rhs_dt in {U4, S4})
)
case 7:
return (
lhs_dt in {F32, BF16}
and rhs_dt in {F32, BF16}
) or (
return (lhs_dt in {F32, BF16} and rhs_dt in {F32, BF16}) or (
lhs_dt in {F32, BF16, F8E5M2, F8E4M3FN}
and rhs_dt in {F8E5M2, F8E4M3FN}
)
Expand Down Expand Up @@ -154,46 +168,46 @@ def get_sublane_tiling(self, dtype: jnp.dtype) -> int:


def is_tpu_device() -> bool:
return core.get_device_kind() in {
"TPU v2",
"TPU v3",
"TPU v4",
"TPU v4 lite",
"TPU v5e",
"TPU v5 lite",
"TPU v5",
"TPU v5p",
"TPU v6 lite",
"TPU v6e",
"TPU7x",
}
return core.get_device_kind() in DEVICE_KIND_TO_CHIP_VERSION.keys()


registry: dict[str, Callable[[], TpuInfo]] = {}


@jax_util.cache(trace_context_in_key=True)
def get_tpu_info() -> TpuInfo:
"""Returns the TPU hardware information for the current device.
def get_tpu_info(chip_version: ChipVersion | None = None) -> TpuInfo:
"""Returns the TPU hardware info for the current device or given TPU chip.

Note that all information is *per-TensorCore* so you would need to multiply by
`num_cores` to obtain the total for the chip.

Args:
chip_version: The TPU chip version to get the information for. If None, the
information for the current device is returned.

Returns:
A TpuInfo object containing the hardware information for the current device.
A TpuInfo object containing the hardware information for the given TPU chip
version.
"""
device_kind = core.get_device_kind()
if chip_version is None:
device_kind = core.get_device_kind()
chip_version = DEVICE_KIND_TO_CHIP_VERSION.get(device_kind, None)
if chip_version is None:
if device_kind in registry:
return registry[device_kind]()
raise ValueError(f"Unsupported TPU device kind: {device_kind}")

# Common parameters for all TensorCores
NUM_LANES = 128
NUM_SUBLANES = 8
MXU_COLUMN_SIZE_GEN_LT_6 = 128
MXU_COLUMN_SIZE_GEN_GE_6 = 256

match device_kind:
case "TPU v2": # 2 TensorCores per chip
match chip_version:
case ChipVersion.TPU_V2: # 2 TensorCores per chip
num_chip_cores = 2
return TpuInfo(
chip_version=ChipVersion.TPU_V2,
chip_version=chip_version,
generation=2,
num_cores=core.get_num_device_cores(),
num_lanes=NUM_LANES,
Expand All @@ -209,10 +223,10 @@ def get_tpu_info() -> TpuInfo:
fp8_ops_per_second=0, # Not Available
int4_ops_per_second=0, # Not Available
)
case "TPU v3": # 2 TensorCores per chip
case ChipVersion.TPU_V3: # 2 TensorCores per chip
num_chip_cores = 2
return TpuInfo(
chip_version=ChipVersion.TPU_V3,
chip_version=chip_version,
generation=3,
num_cores=core.get_num_device_cores(),
num_lanes=NUM_LANES,
Expand All @@ -228,9 +242,9 @@ def get_tpu_info() -> TpuInfo:
fp8_ops_per_second=0, # Not Available
int4_ops_per_second=0, # Not Available
)
case "TPU v4 lite": # 1 TensorCore per chip
case ChipVersion.TPU_V4I: # 1 TensorCore per chip
return TpuInfo(
chip_version=ChipVersion.TPU_V4I,
chip_version=chip_version,
generation=4,
num_cores=core.get_num_device_cores(),
num_lanes=NUM_LANES,
Expand All @@ -246,10 +260,10 @@ def get_tpu_info() -> TpuInfo:
fp8_ops_per_second=0, # Not Available
int4_ops_per_second=0, # Not Available
)
case "TPU v4": # 2 TensorCores per chip
case ChipVersion.TPU_V4: # 2 TensorCores per chip
num_chip_cores = 2
return TpuInfo(
chip_version=ChipVersion.TPU_V4,
chip_version=chip_version,
generation=4,
num_cores=core.get_num_device_cores(),
num_lanes=NUM_LANES,
Expand All @@ -265,9 +279,9 @@ def get_tpu_info() -> TpuInfo:
fp8_ops_per_second=0, # Not Available
int4_ops_per_second=0, # Not Available
)
case "TPU v5 lite" | "TPU v5e": # 1 TensorCore per chip
case ChipVersion.TPU_V5E: # 1 TensorCore per chip
return TpuInfo(
chip_version=ChipVersion.TPU_V5E,
chip_version=chip_version,
generation=5,
num_cores=core.get_num_device_cores(),
num_lanes=NUM_LANES,
Expand All @@ -283,10 +297,10 @@ def get_tpu_info() -> TpuInfo:
fp8_ops_per_second=0, # Not Available
int4_ops_per_second=int(7.88e14),
)
case "TPU v5" | "TPU v5p": # 2 TensorCores per chip
case ChipVersion.TPU_V5P: # 2 TensorCores per chip
num_chip_cores = 2
return TpuInfo(
chip_version=ChipVersion.TPU_V5P,
chip_version=chip_version,
generation=5,
num_cores=core.get_num_device_cores(),
num_lanes=NUM_LANES,
Expand All @@ -303,9 +317,9 @@ def get_tpu_info() -> TpuInfo:
int4_ops_per_second=int(1.84e15 // num_chip_cores),
sparse_core=SparseCoreInfo(num_cores=4, num_subcores=16, num_lanes=8),
)
case "TPU v6 lite" | "TPU v6e": # 1 TensorCore per chip
case ChipVersion.TPU_V6E: # 1 TensorCore per chip
return TpuInfo(
chip_version=ChipVersion.TPU_V6E,
chip_version=chip_version,
generation=6,
num_cores=core.get_num_device_cores(),
num_lanes=NUM_LANES,
Expand All @@ -322,28 +336,28 @@ def get_tpu_info() -> TpuInfo:
int4_ops_per_second=int(3.68e15),
sparse_core=SparseCoreInfo(num_cores=2, num_subcores=16, num_lanes=8),
)
case "TPU7x":
case ChipVersion.TPU_7X:
num_cores = core.get_num_device_cores()
num_chip_cores = 2
return TpuInfo(
chip_version=ChipVersion.TPU_7X,
generation=7,
num_cores=num_cores,
num_lanes=128,
num_sublanes=8,
mxu_column_size=256,
vmem_capacity_bytes=64 * 1024 * 1024, # 64 MiB per core
cmem_capacity_bytes=0,
smem_capacity_bytes=1024 * 1024, # 1 MiB per core
hbm_capacity_bytes=206_000_000_000 // num_chip_cores,
mem_bw_bytes_per_second=int(7.40e12 // num_chip_cores),
bf16_ops_per_second=int(2.31e15 // num_chip_cores),
int8_ops_per_second=0, # Not Available
fp8_ops_per_second=int(4.60e15 // num_chip_cores),
int4_ops_per_second=0, # Not Available
sparse_core=SparseCoreInfo(num_cores=4, num_subcores=16, num_lanes=16),
)
case _ as d:
if d in registry:
return registry[d]()
raise ValueError(f"Unsupported TPU device kind: {device_kind}")
chip_version=chip_version,
generation=7,
num_cores=num_cores,
num_lanes=128,
num_sublanes=8,
mxu_column_size=256,
vmem_capacity_bytes=64 * 1024 * 1024, # 64 MiB per core
cmem_capacity_bytes=0,
smem_capacity_bytes=1024 * 1024, # 1 MiB per core
hbm_capacity_bytes=206_000_000_000 // num_chip_cores,
mem_bw_bytes_per_second=int(7.40e12 // num_chip_cores),
bf16_ops_per_second=int(2.31e15 // num_chip_cores),
int8_ops_per_second=0, # Not Available
fp8_ops_per_second=int(4.60e15 // num_chip_cores),
int4_ops_per_second=0, # Not Available
sparse_core=SparseCoreInfo(
num_cores=4, num_subcores=16, num_lanes=16
),
)
case _:
raise ValueError(f"Unsupported TPU chip version: {chip_version}")
6 changes: 6 additions & 0 deletions tests/pallas/tpu_info_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,12 @@ def test_get_tpu_info(self):
case _:
self.fail(f"Unexpected device kind: {device.device_kind}")

def test_get_tpu_info_given_chip_version(self):
for chip_version in pltpu.ChipVersion:
info = pltpu.get_tpu_info(chip_version=chip_version)
self.assertIsInstance(info, pltpu.TpuInfo)
self.assertEqual(info.chip_version, chip_version)


if __name__ == "__main__":
jax.config.parse_flags_with_absl()
Expand Down
Loading