@@ -69,18 +69,24 @@ class TPUInterpretParams:
6969
7070 Attributes:
7171 dma_execution_mode: If "eager", DMAs are executed as soon as they are
72- issued. If "on_wait", DMA reads or writes are only executed when a
73- device is waiting on a DMA semaphore that will be signaled when the read
74- or write is complete.
72+ issued. If "on_wait", DMA reads or writes are only executed when a device
73+ is waiting on a DMA semaphore that will be signaled when the read or write
74+ is complete.
7575 Default: "on_wait".
7676 detect_races: If True, a dynamic, happens-before race detector will be
7777 used to detect data races during kernel interpretation. If any races are
7878 detected, a message will be printed and `races.races_found` will be set
7979 to True.
8080 Default: False.
81+ skip_floating_point_ops: If True, operations that produce only floating
82+ point values will not be interpreted; instead, their results will be
83+ replaced with arrays all of `jnp.inf`. Additionaly any floating point
84+ operands to any operation will be replaced with (arrays of) `jnp.inf`.
85+ Default: False.
8186 """
8287 dma_execution_mode : Literal ["eager" , "on_wait" ] = "on_wait"
8388 detect_races : bool = False
89+ skip_floating_point_ops : bool = False
8490
8591
8692VectorClock = np .ndarray
@@ -954,16 +960,32 @@ def _is_any(memory_space):
954960 return ((memory_space == mosaic_core .TPUMemorySpace .ANY ) or
955961 (memory_space == pallas_core .MemorySpace .ANY ))
956962
963+ def _is_float (dtype ):
964+ return jnp .issubdtype (dtype , jnp .floating )
965+
966+ _SENTINEL = jnp .inf
967+
968+ @dataclasses .dataclass (frozen = True )
969+ class Placeholder :
970+ """Placeholder for use in `_interpret_jaxpr` below instead of putting a concrete value into `env`."""
971+ shape : tuple [int , ...]
972+ dtype : jnp .dtype
973+
957974def _interpret_jaxpr (jaxpr , * args , compiler_params , interpret_params ):
958975 env = {}
959976
960977 def read (var ):
961978 if isinstance (var , jax_core .Literal ):
962- return var .val
979+ result = var .val
963980 else :
964- return env [var ]
981+ result = env [var ]
982+ if isinstance (result , Placeholder ):
983+ result = jax .lax .full (result .shape , _SENTINEL , result .dtype )
984+ return result
965985
966986 def write (var , value ):
987+ if interpret_params .skip_floating_point_ops and _is_float (value .dtype ):
988+ value = Placeholder (value .shape , value .dtype )
967989 env [var ] = value
968990
969991 jax .util .safe_map (write , jaxpr .constvars + jaxpr .invars , args )
@@ -987,11 +1009,16 @@ def write(var, value):
9871009 with source_info_util .user_context (
9881010 eqn .source_info .traceback , name_stack = eqn .source_info .name_stack ):
9891011 prim = eqn .primitive
990- invals = jax .util .safe_map (read , eqn .invars )
1012+ # We defer reading the values for `eqn.invars` into each of the branches
1013+ # of the if-elif-else statement below. This is because the else branch may
1014+ # not need to do any reads if `interpret_params.skip_floating_point_ops`
1015+ # is True. If this is the case, we want to avoid materializing the read
1016+ # array into the jaxpr when this function is traced.
1017+ deferred_invals = functools .partial (jax .util .safe_map , read , eqn .invars )
9911018
9921019 if prim is primitives .load_p :
9931020 (ref , transforms , mask , _ ) = jax .tree .unflatten (
994- eqn .params ['args_tree' ], invals )
1021+ eqn .params ['args_tree' ], deferred_invals () )
9951022 if mask is not None :
9961023 raise NotImplementedError ('masked load_p' )
9971024 out = callback .io_callback (
@@ -1005,7 +1032,7 @@ def write(var, value):
10051032
10061033 elif prim is primitives .swap_p :
10071034 (ref , transforms , val , mask ) = jax .tree .unflatten (
1008- eqn .params ['args_tree' ], invals )
1035+ eqn .params ['args_tree' ], deferred_invals () )
10091036 out = callback .io_callback (
10101037 functools .partial (swap , source_info = eqn .source_info ),
10111038 eqn .outvars [0 ].aval ,
@@ -1023,6 +1050,7 @@ def write(var, value):
10231050 elif prim is lax .cond_p :
10241051 def _make_branch (jaxpr ):
10251052 return lambda * args : _interpret (jaxpr , * args )
1053+ invals = deferred_invals ()
10261054 out = lax .switch (
10271055 invals [0 ],
10281056 [_make_branch (branch_jaxpr .jaxpr )
@@ -1031,7 +1059,9 @@ def _make_branch(jaxpr):
10311059
10321060 elif prim is lax .scan_p :
10331061 consts , init_carry , xs = split_list (
1034- invals , [eqn .params ['num_consts' ], eqn .params ['num_carry' ]])
1062+ deferred_invals (),
1063+ [eqn .params ['num_consts' ], eqn .params ['num_carry' ]],
1064+ )
10351065 def _scan_body (c , a ):
10361066 return split_list (
10371067 _interpret (eqn .params ['jaxpr' ].jaxpr , * consts , * c , * a ),
@@ -1041,8 +1071,10 @@ def _scan_body(c, a):
10411071 out = carry + out
10421072
10431073 elif prim is lax .while_p :
1044- cond_consts , body_consts , init_vals = split_list (
1045- invals , [eqn .params ['cond_nconsts' ], eqn .params ['body_nconsts' ]])
1074+ cond_consts , body_consts , init_vals = split_list (
1075+ deferred_invals (),
1076+ [eqn .params ['cond_nconsts' ], eqn .params ['body_nconsts' ]],
1077+ )
10461078 out = lax .while_loop (
10471079 lambda args : _interpret (
10481080 eqn .params ['cond_jaxpr' ].jaxpr , * cond_consts , * args )[0 ],
@@ -1056,6 +1088,7 @@ def _scan_body(c, a):
10561088 elif prim is pjit .pjit_p :
10571089 def f (* args , jaxpr ):
10581090 return _interpret (jaxpr .jaxpr , * jaxpr .consts , * args )
1091+ invals = deferred_invals ()
10591092 in_avals = tuple (jax_core .shaped_abstractify (i ) for i in invals )
10601093 new_jaxpr = _to_jaxpr (
10611094 lu .wrap_init (functools .partial (f , jaxpr = eqn .params ['jaxpr' ]),
@@ -1084,7 +1117,7 @@ def f(*args, jaxpr):
10841117 primitives .uninitialized_value (v .aval .shape , v .aval .dtype ),
10851118 ordered = True ))
10861119
1087- out = _interpret (eqn .params ['jaxpr' ], * invals , * allocs )
1120+ out = _interpret (eqn .params ['jaxpr' ], * deferred_invals () , * allocs )
10881121
10891122 for a in allocs :
10901123 if isinstance (a , tuple ):
@@ -1106,6 +1139,7 @@ def f(*args, jaxpr):
11061139 pass
11071140
11081141 elif prim is state_primitives .get_p :
1142+ invals = deferred_invals ()
11091143 out = callback .io_callback (
11101144 functools .partial (get , source_info = eqn .source_info ),
11111145 eqn .outvars [0 ].aval ,
@@ -1116,6 +1150,7 @@ def f(*args, jaxpr):
11161150 ordered = True )
11171151
11181152 elif prim is state_primitives .swap_p :
1153+ invals = deferred_invals ()
11191154 out = callback .io_callback (
11201155 functools .partial (swap , source_info = eqn .source_info ),
11211156 eqn .outvars [0 ].aval ,
@@ -1128,11 +1163,17 @@ def f(*args, jaxpr):
11281163 ordered = True )
11291164
11301165 elif prim is mosaic_primitives .dma_start_p :
1131- (src , src_transforms ,
1132- dst , dst_transforms ,
1133- dst_sem , dst_sem_transforms ,
1134- src_sem , src_sem_transforms ,
1135- target_device_id ) = jax .tree .unflatten (eqn .params ['tree' ], invals )
1166+ (
1167+ src ,
1168+ src_transforms ,
1169+ dst ,
1170+ dst_transforms ,
1171+ dst_sem ,
1172+ dst_sem_transforms ,
1173+ src_sem ,
1174+ src_sem_transforms ,
1175+ target_device_id ,
1176+ ) = jax .tree .unflatten (eqn .params ['tree' ], deferred_invals ())
11361177 target_device_id = _device_id_to_logical (
11371178 target_device_id , eqn .params ['device_id_type' ], axis_sizes )
11381179 (orig_src_ref , _ , orig_dst_ref , * _
@@ -1152,11 +1193,17 @@ def f(*args, jaxpr):
11521193 out = []
11531194
11541195 elif prim is mosaic_primitives .dma_wait_p :
1155- (src , src_transforms ,
1156- dst , dst_transforms ,
1157- dst_sem , dst_sem_transforms ,
1158- src_sem , src_sem_transforms ,
1159- target_device_id ) = jax .tree .unflatten (eqn .params ['tree' ], invals )
1196+ (
1197+ src ,
1198+ src_transforms ,
1199+ dst ,
1200+ dst_transforms ,
1201+ dst_sem ,
1202+ dst_sem_transforms ,
1203+ src_sem ,
1204+ src_sem_transforms ,
1205+ target_device_id ,
1206+ ) = jax .tree .unflatten (eqn .params ['tree' ], deferred_invals ())
11601207 read_shape , read_dtype = _compute_transformed_shape_and_dtype (
11611208 eqn .invars [0 ].aval .shape , eqn .invars [0 ].aval .dtype , src_transforms )
11621209 callback .io_callback (
@@ -1178,7 +1225,7 @@ def f(*args, jaxpr):
11781225
11791226 elif prim is mosaic_primitives .semaphore_signal_p :
11801227 sem , sem_transforms , inc , target_device_id , core_index = (
1181- jax .tree .unflatten (eqn .params ['args_tree' ], invals ))
1228+ jax .tree .unflatten (eqn .params ['args_tree' ], deferred_invals () ))
11821229 target_device_id = _device_id_to_logical (
11831230 target_device_id , eqn .params ['device_id_type' ], axis_sizes )
11841231 callback .io_callback (
@@ -1194,7 +1241,7 @@ def f(*args, jaxpr):
11941241
11951242 elif prim is mosaic_primitives .semaphore_wait_p :
11961243 sem , sem_transforms , value = (
1197- jax .tree .unflatten (eqn .params ['args_tree' ], invals ))
1244+ jax .tree .unflatten (eqn .params ['args_tree' ], deferred_invals () ))
11981245 callback .io_callback (
11991246 semaphore_wait ,
12001247 (),
@@ -1211,8 +1258,19 @@ def f(*args, jaxpr):
12111258 raise NotImplementedError ('atomic_cas_p' )
12121259
12131260 else :
1214- subfuns , bind_params = eqn .primitive .get_bind_params (eqn .params )
1215- out = prim .bind (* subfuns , * invals , ** bind_params )
1261+ if interpret_params .skip_floating_point_ops and all (
1262+ _is_float (ovar .aval .dtype ) for ovar in eqn .outvars
1263+ ):
1264+ # Skip `prim.bind` since `prim` only produces floating-point values.
1265+ # It is safe to populate `out` with avals since mapping `write` over
1266+ # `out` below only relies on the shape and dtype (for writing
1267+ # `Placeholder`s).
1268+ out = [ovar .aval for ovar in eqn .outvars ]
1269+ if not prim .multiple_results :
1270+ out = out [0 ]
1271+ else :
1272+ subfuns , bind_params = eqn .primitive .get_bind_params (eqn .params )
1273+ out = prim .bind (* subfuns , * deferred_invals (), ** bind_params )
12161274
12171275 out = out if prim .multiple_results else [out ]
12181276 jax .util .safe_map (write , eqn .outvars , out )
0 commit comments