@@ -39,59 +39,33 @@ def validate_pathways_service(pathways_service: str) -> None:
3939
4040
4141def _validate_tpu_supported (tpu_instance_with_topology : str ) -> None :
42- """Checks if the given instance represents a valid single-host TPU.
42+ """Checks if the given instance represents a valid TPU type .
4343
4444 Args:
4545 tpu_instance_with_topology: The TPU instance string, e.g., "tpuv6e:4x8".
4646
47- Raises ValueError if the instance is not a valid TPU host .
47+ Raises ValueError if the instance is not a valid TPU type .
4848 """
49- # Mapping from Cloud TPU type prefix to max chips per host.
50- single_host_max_chips = {
51- "tpuv6e" : 8 , # Cloud TPU v6e (2x4)
52- }
53-
54- # Regex to extract topology
49+ # Regex to extract TPU type and topology.
5550 # Examples:
56- # ct5lp-hightpu-4t:4x8 -> ct5lp, 4x8
57- # ct5p :2x2x1 -> ct5p, 2x2x1
51+ # tpuv6e:2x4 -> type='tpuv6e', topology='2x4'
52+ # tpuv5p :2x2x1 -> type='tpuv5p', topology=' 2x2x1'
5853 match = re .match (
59- r"^(?P<type>tpuv6e) :(?P<topology>\d+(?:x\d+)* )$" ,
54+ r"^(?:tpuv(?:5e|5p|6e)) :(?P<topology>\d+(?:x\d+){1,2} )$" ,
6055 tpu_instance_with_topology ,
6156 )
6257
6358 if match :
64- tpu_base_type = match .group ("type" )
6559 topology_str = match .group ("topology" )
6660
67- if not tpu_base_type :
68- raise ValueError (
69- f"Unknown TPU type '{ type } ' from '{ tpu_instance_with_topology } '."
70- )
71-
7261 try :
73- dims = [int (d ) for d in topology_str .split ("x" )]
74- if len (dims ) < 2 or len (dims ) > 3 :
75- raise ValueError (
76- f"Error: Invalid topology format '{ topology_str } ', Expected either"
77- " 2 or 3 dimensions."
78- )
79- num_chips = 1
80- for dim in dims :
81- num_chips *= dim
62+ _ = [int (d ) for d in topology_str .split ("x" )]
8263 except ValueError as exc :
8364 raise ValueError (
8465 f"Error: Invalid topology format '{ topology_str } ' in"
85- f" '{ tpu_instance_with_topology } '."
66+ f" '{ tpu_instance_with_topology } '. Expected all numbers, e.g., 2x4. "
8667 ) from exc
8768
88- if num_chips > single_host_max_chips [tpu_base_type ]:
89- raise ValueError (
90- f"Topology '{ tpu_instance_with_topology } ' exceeds"
91- f" { single_host_max_chips [tpu_base_type ]} , the maximum supported"
92- f" chips for { tpu_base_type } ."
93- )
94-
9569 return
9670
9771 raise ValueError (
0 commit comments