diff --git a/pathwaysutils/experimental/shared_pathways_service/validators.py b/pathwaysutils/experimental/shared_pathways_service/validators.py index bd3e7e6..40fd2fd 100644 --- a/pathwaysutils/experimental/shared_pathways_service/validators.py +++ b/pathwaysutils/experimental/shared_pathways_service/validators.py @@ -39,24 +39,19 @@ def validate_pathways_service(pathways_service: str) -> None: def _validate_tpu_supported(tpu_instance_with_topology: str) -> None: - """Checks if the given instance represents a valid single-host TPU. + """Checks if the given instance represents a valid TPU type. Args: tpu_instance_with_topology: The TPU instance string, e.g., "tpuv6e:4x8". - Raises ValueError if the instance is not a valid TPU host. + Raises ValueError if the instance is not a valid TPU type. """ - # Mapping from Cloud TPU type prefix to max chips per host. - single_host_max_chips = { - "tpuv6e": 8, # Cloud TPU v6e (2x4) - } - - # Regex to extract topology + # Regex to extract TPU type and topology. # Examples: - # ct5lp-hightpu-4t:4x8 -> ct5lp, 4x8 - # ct5p:2x2x1 -> ct5p, 2x2x1 + # tpuv6e:2x4 -> type='tpuv6e', topology='2x4' + # tpuv5p:2x2x1 -> type='tpuv5p', topology='2x2x1' match = re.match( - r"^(?Ptpuv6e):(?P\d+(?:x\d+)*)$", + r"^(?Ptpuv(?:5e|5p|6e)):(?P\d+(?:x\d+)*)$", tpu_instance_with_topology, ) @@ -85,13 +80,6 @@ def _validate_tpu_supported(tpu_instance_with_topology: str) -> None: f" '{tpu_instance_with_topology}'." ) from exc - if num_chips > single_host_max_chips[tpu_base_type]: - raise ValueError( - f"Topology '{tpu_instance_with_topology}' exceeds" - f" {single_host_max_chips[tpu_base_type]}, the maximum supported" - f" chips for {tpu_base_type}." - ) - return raise ValueError(