@@ -83,10 +83,15 @@ class TPUInterpretParams:
8383 replaced with arrays all of `jnp.inf`. Additionaly any floating point
8484 operands to any operation will be replaced with (arrays of) `jnp.inf`.
8585 Default: False.
86+ uninitialized_memory: If "nan", allocated buffers are initialized to
87+ to contain all NaNs (or to their maximum possible value for integers).
88+ If "zero", allocated buffers are initialized to all zeros.
89+ Default: "nan".
8690 """
8791 dma_execution_mode : Literal ["eager" , "on_wait" ] = "on_wait"
8892 detect_races : bool = False
8993 skip_floating_point_ops : bool = False
94+ uninitialized_memory : Literal ["nan" , "zero" ] = "nan"
9095
9196
9297VectorClock = np .ndarray
@@ -1114,7 +1119,8 @@ def f(*args, jaxpr):
11141119 jax .ShapeDtypeStruct ((), jnp .int16 ),
11151120 device_id ,
11161121 TPU_MEMORY_SPACE_IDXS [v .aval .memory_space ],
1117- primitives .uninitialized_value (v .aval .shape , v .aval .dtype ),
1122+ _uninitialized_value (
1123+ v .aval .shape , v .aval .dtype , interpret_params ),
11181124 ordered = True ))
11191125
11201126 out = _interpret (eqn .params ['jaxpr' ], * deferred_invals (), * allocs )
@@ -1279,16 +1285,19 @@ def f(*args, jaxpr):
12791285
12801286def _initialize_output_vals (
12811287 block_mappings_output : Iterable [BlockMapping ],
1282- input_args , input_output_aliases ) -> Sequence [jax .Array ]:
1288+ input_args , input_output_aliases ,
1289+ interpret_params : TPUInterpretParams ,
1290+ ) -> Sequence [jax .Array ]:
12831291 oi_map = {v : k for k , v in input_output_aliases }
12841292 output_vals = []
12851293 for i , bm in enumerate (block_mappings_output ):
12861294 if i in oi_map :
12871295 output_vals .append (input_args [oi_map [i ]])
12881296 else :
1289- output_vals .append (primitives . uninitialized_value (
1297+ output_vals .append (_uninitialized_value (
12901298 bm .array_shape_dtype .shape ,
1291- bm .array_shape_dtype .dtype ))
1299+ bm .array_shape_dtype .dtype ,
1300+ interpret_params ))
12921301 return output_vals
12931302
12941303def _compute_start_indices (block_mapping , loop_idx , * args ):
@@ -1319,7 +1328,20 @@ def _maybe_dynamic_slice(start_idx, block_shape, value, is_indexing):
13191328 dtype = np .bool_ )])
13201329 return lax .squeeze (output , squeeze_dims )
13211330
1322- def _pad_to_block_dimension (value , block_shape ):
1331+ def _uninitialized_value (shape , dtype , interpret_params ):
1332+ if interpret_params .uninitialized_memory == 'nan' :
1333+ if jnp .issubdtype (dtype , jnp .floating ):
1334+ return jnp .full (shape , jnp .nan , dtype )
1335+ elif jnp .issubdtype (dtype , jnp .integer ):
1336+ return jnp .full (shape , jnp .iinfo (dtype ).max , dtype )
1337+ elif jnp .issubdtype (dtype , jnp .bool ):
1338+ return jnp .full (shape , False , dtype )
1339+ if interpret_params .uninitialized_memory == 'zero' :
1340+ return jnp .full (shape , 0 , dtype )
1341+ raise NotImplementedError (
1342+ interpret_params .uninitialized_memory + ' + ' + str (dtype ))
1343+
1344+ def _pad_to_block_dimension (value , block_shape , interpret_params ):
13231345 """Pads values so the shape evenly divides into block dimensions.
13241346
13251347 For example, if values has a shape of (33, 2, 5) with a block_shape of
@@ -1338,7 +1360,7 @@ def _pad_to_block_dimension(value, block_shape):
13381360 )
13391361 if padded_shape != value .shape :
13401362 pad_width = tuple ((0 , a - b ) for a , b in zip (padded_shape , value .shape ))
1341- pad_value = primitives . uninitialized_value ( shape = (), dtype = value .dtype )
1363+ pad_value = _uninitialized_value ( (), value .dtype , interpret_params )
13421364 value = jnp .pad (value , pad_width , constant_values = pad_value )
13431365 return value
13441366
@@ -1397,7 +1419,7 @@ def interpret_pallas_call(
13971419 ]
13981420 num_inputs = grid_mapping .num_inputs
13991421 input_args = [
1400- _pad_to_block_dimension (a , bs )
1422+ _pad_to_block_dimension (a , bs , interpret_params )
14011423 for a , bs in zip (input_args , block_shapes [:num_inputs ])
14021424 ]
14031425
@@ -1407,11 +1429,12 @@ def interpret_pallas_call(
14071429 output_vals = _initialize_output_vals (
14081430 grid_mapping .block_mappings_output ,
14091431 scalars + input_args ,
1410- input_output_aliases )
1432+ input_output_aliases ,
1433+ interpret_params )
14111434 num_outputs = grid_mapping .num_outputs
14121435 output_block_shapes = block_shapes [num_inputs : num_inputs + num_outputs ]
14131436 for out_val , bs in zip (output_vals , output_block_shapes ):
1414- padded_val = _pad_to_block_dimension (out_val , bs )
1437+ padded_val = _pad_to_block_dimension (out_val , bs , interpret_params )
14151438 output_buffer_shapes .append (padded_val .shape )
14161439 output_buffer_ids .append (callback .io_callback (
14171440 _allocate_buffer ,
@@ -1466,7 +1489,8 @@ def interpret_pallas_call(
14661489 jax .ShapeDtypeStruct ((), jnp .int16 ),
14671490 device_id ,
14681491 TPU_MEMORY_SPACE_IDXS [var .aval .memory_space ],
1469- primitives .uninitialized_value (var .aval .shape , var .aval .dtype ),
1492+ _uninitialized_value (
1493+ var .aval .shape , var .aval .dtype , interpret_params ),
14701494 ordered = True ))
14711495
14721496 _ , input_ids , kernel_output_ids , _ = split_list (
0 commit comments