Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 6 additions & 18 deletions pathwaysutils/experimental/shared_pathways_service/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,24 +39,19 @@ def validate_pathways_service(pathways_service: str) -> None:


def _validate_tpu_supported(tpu_instance_with_topology: str) -> None:
"""Checks if the given instance represents a valid single-host TPU.
"""Checks if the given instance represents a valid TPU type.

Args:
tpu_instance_with_topology: The TPU instance string, e.g., "tpuv6e:4x8".

Raises ValueError if the instance is not a valid TPU host.
Raises ValueError if the instance is not a valid TPU type.
"""
# Mapping from Cloud TPU type prefix to max chips per host.
single_host_max_chips = {
"tpuv6e": 8, # Cloud TPU v6e (2x4)
}

# Regex to extract topology
# Regex to extract TPU type and topology.
# Examples:
# ct5lp-hightpu-4t:4x8 -> ct5lp, 4x8
# ct5p:2x2x1 -> ct5p, 2x2x1
# tpuv6e:2x4 -> type='tpuv6e', topology='2x4'
# tpuv5p:2x2x1 -> type='tpuv5p', topology='2x2x1'
match = re.match(
r"^(?P<type>tpuv6e):(?P<topology>\d+(?:x\d+)*)$",
r"^(?P<type>tpuv(?:5e|5p|6e)):(?P<topology>\d+(?:x\d+)*)$",
tpu_instance_with_topology,
)

Expand Down Expand Up @@ -85,13 +80,6 @@ def _validate_tpu_supported(tpu_instance_with_topology: str) -> None:
f" '{tpu_instance_with_topology}'."
) from exc

if num_chips > single_host_max_chips[tpu_base_type]:
raise ValueError(
f"Topology '{tpu_instance_with_topology}' exceeds"
f" {single_host_max_chips[tpu_base_type]}, the maximum supported"
f" chips for {tpu_base_type}."
)

return

raise ValueError(
Expand Down