2020
2121from jax import numpy as jnp
2222from jax ._src import dtypes
23- from jax ._src .pallas .mosaic import core
2423from jax ._src import util as jax_util
24+ from jax ._src .pallas .mosaic import core
2525
2626
2727class ChipVersionBase :
@@ -41,9 +41,26 @@ class ChipVersion(ChipVersionBase, enum.Enum):
4141 def __str__ (self ) -> str :
4242 return self .value
4343
44+
45+ DEVICE_KIND_TO_CHIP_VERSION = {
46+ "TPU v2" : ChipVersion .TPU_V2 ,
47+ "TPU v3" : ChipVersion .TPU_V3 ,
48+ "TPU v4" : ChipVersion .TPU_V4 ,
49+ "TPU v4 lite" : ChipVersion .TPU_V4I ,
50+ "TPU v5e" : ChipVersion .TPU_V5E ,
51+ "TPU v5 lite" : ChipVersion .TPU_V5E ,
52+ "TPU v5" : ChipVersion .TPU_V5P ,
53+ "TPU v5p" : ChipVersion .TPU_V5P ,
54+ "TPU v6e" : ChipVersion .TPU_V6E ,
55+ "TPU v6 lite" : ChipVersion .TPU_V6E ,
56+ "TPU7x" : ChipVersion .TPU_7X ,
57+ }
58+
59+
4460@dataclasses .dataclass (frozen = True , kw_only = True )
4561class SparseCoreInfo :
4662 """SparseCore-specific information."""
63+
4764 num_cores : int
4865 num_subcores : int
4966 num_lanes : int
@@ -122,10 +139,7 @@ def is_matmul_supported(
122139 or (lhs_dt in {U4 , S4 } and rhs_dt in {U4 , S4 })
123140 )
124141 case 7 :
125- return (
126- lhs_dt in {F32 , BF16 }
127- and rhs_dt in {F32 , BF16 }
128- ) or (
142+ return (lhs_dt in {F32 , BF16 } and rhs_dt in {F32 , BF16 }) or (
129143 lhs_dt in {F32 , BF16 , F8E5M2 , F8E4M3FN }
130144 and rhs_dt in {F8E5M2 , F8E4M3FN }
131145 )
@@ -154,46 +168,46 @@ def get_sublane_tiling(self, dtype: jnp.dtype) -> int:
154168
155169
156170def is_tpu_device () -> bool :
157- return core .get_device_kind () in {
158- "TPU v2" ,
159- "TPU v3" ,
160- "TPU v4" ,
161- "TPU v4 lite" ,
162- "TPU v5e" ,
163- "TPU v5 lite" ,
164- "TPU v5" ,
165- "TPU v5p" ,
166- "TPU v6 lite" ,
167- "TPU v6e" ,
168- "TPU7x" ,
169- }
171+ return core .get_device_kind () in DEVICE_KIND_TO_CHIP_VERSION .keys ()
170172
171173
172174registry : dict [str , Callable [[], TpuInfo ]] = {}
173175
176+
174177@jax_util .cache (trace_context_in_key = True )
175- def get_tpu_info () -> TpuInfo :
176- """Returns the TPU hardware information for the current device.
178+ def get_tpu_info (chip_version : ChipVersion | None = None ) -> TpuInfo :
179+ """Returns the TPU hardware info for the current device or given TPU chip .
177180
178181 Note that all information is *per-TensorCore* so you would need to multiply by
179182 `num_cores` to obtain the total for the chip.
180183
184+ Args:
185+ chip_version: The TPU chip version to get the information for. If None, the
186+ information for the current device is returned.
187+
181188 Returns:
182- A TpuInfo object containing the hardware information for the current device.
189+ A TpuInfo object containing the hardware information for the given TPU chip
190+ version.
183191 """
184- device_kind = core .get_device_kind ()
192+ if chip_version is None :
193+ device_kind = core .get_device_kind ()
194+ chip_version = DEVICE_KIND_TO_CHIP_VERSION .get (device_kind , None )
195+ if chip_version is None :
196+ if device_kind in registry :
197+ return registry [device_kind ]()
198+ raise ValueError (f"Unsupported TPU device kind: { device_kind } " )
185199
186200 # Common parameters for all TensorCores
187201 NUM_LANES = 128
188202 NUM_SUBLANES = 8
189203 MXU_COLUMN_SIZE_GEN_LT_6 = 128
190204 MXU_COLUMN_SIZE_GEN_GE_6 = 256
191205
192- match device_kind :
193- case "TPU v2" : # 2 TensorCores per chip
206+ match chip_version :
207+ case ChipVersion . TPU_V2 : # 2 TensorCores per chip
194208 num_chip_cores = 2
195209 return TpuInfo (
196- chip_version = ChipVersion . TPU_V2 ,
210+ chip_version = chip_version ,
197211 generation = 2 ,
198212 num_cores = core .get_num_device_cores (),
199213 num_lanes = NUM_LANES ,
@@ -209,10 +223,10 @@ def get_tpu_info() -> TpuInfo:
209223 fp8_ops_per_second = 0 , # Not Available
210224 int4_ops_per_second = 0 , # Not Available
211225 )
212- case "TPU v3" : # 2 TensorCores per chip
226+ case ChipVersion . TPU_V3 : # 2 TensorCores per chip
213227 num_chip_cores = 2
214228 return TpuInfo (
215- chip_version = ChipVersion . TPU_V3 ,
229+ chip_version = chip_version ,
216230 generation = 3 ,
217231 num_cores = core .get_num_device_cores (),
218232 num_lanes = NUM_LANES ,
@@ -228,9 +242,9 @@ def get_tpu_info() -> TpuInfo:
228242 fp8_ops_per_second = 0 , # Not Available
229243 int4_ops_per_second = 0 , # Not Available
230244 )
231- case "TPU v4 lite" : # 1 TensorCore per chip
245+ case ChipVersion . TPU_V4I : # 1 TensorCore per chip
232246 return TpuInfo (
233- chip_version = ChipVersion . TPU_V4I ,
247+ chip_version = chip_version ,
234248 generation = 4 ,
235249 num_cores = core .get_num_device_cores (),
236250 num_lanes = NUM_LANES ,
@@ -246,10 +260,10 @@ def get_tpu_info() -> TpuInfo:
246260 fp8_ops_per_second = 0 , # Not Available
247261 int4_ops_per_second = 0 , # Not Available
248262 )
249- case "TPU v4" : # 2 TensorCores per chip
263+ case ChipVersion . TPU_V4 : # 2 TensorCores per chip
250264 num_chip_cores = 2
251265 return TpuInfo (
252- chip_version = ChipVersion . TPU_V4 ,
266+ chip_version = chip_version ,
253267 generation = 4 ,
254268 num_cores = core .get_num_device_cores (),
255269 num_lanes = NUM_LANES ,
@@ -265,9 +279,9 @@ def get_tpu_info() -> TpuInfo:
265279 fp8_ops_per_second = 0 , # Not Available
266280 int4_ops_per_second = 0 , # Not Available
267281 )
268- case "TPU v5 lite" | "TPU v5e" : # 1 TensorCore per chip
282+ case ChipVersion . TPU_V5E : # 1 TensorCore per chip
269283 return TpuInfo (
270- chip_version = ChipVersion . TPU_V5E ,
284+ chip_version = chip_version ,
271285 generation = 5 ,
272286 num_cores = core .get_num_device_cores (),
273287 num_lanes = NUM_LANES ,
@@ -283,10 +297,10 @@ def get_tpu_info() -> TpuInfo:
283297 fp8_ops_per_second = 0 , # Not Available
284298 int4_ops_per_second = int (7.88e14 ),
285299 )
286- case "TPU v5" | "TPU v5p" : # 2 TensorCores per chip
300+ case ChipVersion . TPU_V5P : # 2 TensorCores per chip
287301 num_chip_cores = 2
288302 return TpuInfo (
289- chip_version = ChipVersion . TPU_V5P ,
303+ chip_version = chip_version ,
290304 generation = 5 ,
291305 num_cores = core .get_num_device_cores (),
292306 num_lanes = NUM_LANES ,
@@ -303,9 +317,9 @@ def get_tpu_info() -> TpuInfo:
303317 int4_ops_per_second = int (1.84e15 // num_chip_cores ),
304318 sparse_core = SparseCoreInfo (num_cores = 4 , num_subcores = 16 , num_lanes = 8 ),
305319 )
306- case "TPU v6 lite" | "TPU v6e" : # 1 TensorCore per chip
320+ case ChipVersion . TPU_V6E : # 1 TensorCore per chip
307321 return TpuInfo (
308- chip_version = ChipVersion . TPU_V6E ,
322+ chip_version = chip_version ,
309323 generation = 6 ,
310324 num_cores = core .get_num_device_cores (),
311325 num_lanes = NUM_LANES ,
@@ -322,28 +336,28 @@ def get_tpu_info() -> TpuInfo:
322336 int4_ops_per_second = int (3.68e15 ),
323337 sparse_core = SparseCoreInfo (num_cores = 2 , num_subcores = 16 , num_lanes = 8 ),
324338 )
325- case "TPU7x" :
339+ case ChipVersion . TPU_7X :
326340 num_cores = core .get_num_device_cores ()
327341 num_chip_cores = 2
328342 return TpuInfo (
329- chip_version = ChipVersion . TPU_7X ,
330- generation = 7 ,
331- num_cores = num_cores ,
332- num_lanes = 128 ,
333- num_sublanes = 8 ,
334- mxu_column_size = 256 ,
335- vmem_capacity_bytes = 64 * 1024 * 1024 , # 64 MiB per core
336- cmem_capacity_bytes = 0 ,
337- smem_capacity_bytes = 1024 * 1024 , # 1 MiB per core
338- hbm_capacity_bytes = 206_000_000_000 // num_chip_cores ,
339- mem_bw_bytes_per_second = int (7.40e12 // num_chip_cores ),
340- bf16_ops_per_second = int (2.31e15 // num_chip_cores ),
341- int8_ops_per_second = 0 , # Not Available
342- fp8_ops_per_second = int (4.60e15 // num_chip_cores ),
343- int4_ops_per_second = 0 , # Not Available
344- sparse_core = SparseCoreInfo (num_cores = 4 , num_subcores = 16 , num_lanes = 16 ),
345- )
346- case _ as d :
347- if d in registry :
348- return registry [ d ]()
349- raise ValueError (f"Unsupported TPU device kind : { device_kind } " )
343+ chip_version = chip_version ,
344+ generation = 7 ,
345+ num_cores = num_cores ,
346+ num_lanes = 128 ,
347+ num_sublanes = 8 ,
348+ mxu_column_size = 256 ,
349+ vmem_capacity_bytes = 64 * 1024 * 1024 , # 64 MiB per core
350+ cmem_capacity_bytes = 0 ,
351+ smem_capacity_bytes = 1024 * 1024 , # 1 MiB per core
352+ hbm_capacity_bytes = 206_000_000_000 // num_chip_cores ,
353+ mem_bw_bytes_per_second = int (7.40e12 // num_chip_cores ),
354+ bf16_ops_per_second = int (2.31e15 // num_chip_cores ),
355+ int8_ops_per_second = 0 , # Not Available
356+ fp8_ops_per_second = int (4.60e15 // num_chip_cores ),
357+ int4_ops_per_second = 0 , # Not Available
358+ sparse_core = SparseCoreInfo (
359+ num_cores = 4 , num_subcores = 16 , num_lanes = 16
360+ ),
361+ )
362+ case _:
363+ raise ValueError (f"Unsupported TPU chip version : { chip_version } " )
0 commit comments