4747from jax ._src .util import (unzip2 , safe_zip , safe_map , toposort , split_list ,
4848 merge_lists , partition_list , OrderedSet ,
4949 as_hashable_function , weakref_lru_cache , subs_list ,
50- HashableFunction )
50+ HashableFunction , foreach )
5151
5252
5353map , unsafe_map = safe_map , map
@@ -1085,8 +1085,8 @@ def has_effects(effects) -> bool:
10851085
10861086 newvar = core .gensym (suffix = '_offload' )
10871087 known_eqns , staged_eqns = [], []
1088- map (write , in_unknowns , in_inst , jaxpr .invars )
1089- map (partial (write , False , True ), jaxpr .constvars )
1088+ foreach (write , in_unknowns , in_inst , jaxpr .invars )
1089+ foreach (partial (write , False , True ), jaxpr .constvars )
10901090 for eqn in jaxpr .eqns :
10911091 unks_in , inst_in = unzip2 (map (read , eqn .invars ))
10921092 rule = partial_eval_jaxpr_custom_rules .get (eqn .primitive )
@@ -1098,18 +1098,18 @@ def has_effects(effects) -> bool:
10981098 residual_refs .add (r )
10991099 else :
11001100 residuals .add (r )
1101- map (write , unks_out , inst_out , eqn .outvars )
1101+ foreach (write , unks_out , inst_out , eqn .outvars )
11021102 elif any (unks_in ):
11031103 inputs = map (ensure_instantiated , inst_in , eqn .invars )
11041104 staged_eqns .append (eqn .replace (invars = inputs ))
1105- map (partial (write , True , True ), eqn .outvars )
1105+ foreach (partial (write , True , True ), eqn .outvars )
11061106 else :
11071107 known_eqns .append (eqn )
11081108 # If it's an effectful primitive, we always to run and avoid staging it.
11091109 policy = ensure_enum (saveable (
11101110 eqn .primitive , * [x .aval for x in eqn .invars ], ** eqn .params ))
11111111 if has_effects (eqn .effects ) or isinstance (policy , SaveableType ):
1112- map (partial (write , False , False ), eqn .outvars )
1112+ foreach (partial (write , False , False ), eqn .outvars )
11131113 elif isinstance (policy , Offloadable ):
11141114 # TODO(slebedev): This is a legit error which requires a BUILD fix.
11151115 from jax ._src .dispatch import device_put_p , TransferToMemoryKind , CopySemantics # pytype: disable=import-error
@@ -1124,7 +1124,7 @@ def has_effects(effects) -> bool:
11241124 JaxprEqnContext (None , False ))
11251125 known_eqns .append (offload_eqn )
11261126 # resvars are known and available in the backward jaxpr.
1127- map (partial (write , False , True ), resvars )
1127+ foreach (partial (write , False , True ), resvars )
11281128 residuals .update (resvars )
11291129 reload_eqn = core .JaxprEqn (
11301130 resvars , eqn .outvars , device_put_p ,
@@ -1135,12 +1135,12 @@ def has_effects(effects) -> bool:
11351135 JaxprEqnContext (None , False ))
11361136 staged_eqns .append (reload_eqn )
11371137 # outvars are known and available in the backward jaxpr.
1138- map (partial (write , False , True ), eqn .outvars )
1138+ foreach (partial (write , False , True ), eqn .outvars )
11391139 else :
11401140 assert isinstance (policy , RecomputeType )
11411141 inputs = map (ensure_instantiated , inst_in , eqn .invars )
11421142 staged_eqns .append (eqn .replace (invars = inputs ))
1143- map (partial (write , False , True ), eqn .outvars )
1143+ foreach (partial (write , False , True ), eqn .outvars )
11441144 unzipped = unzip2 (map (read , jaxpr .outvars ))
11451145 out_unknowns , out_inst = list (unzipped [0 ]), list (unzipped [1 ])
11461146 assert all (type (v ) is Var for v in residuals ), residuals
@@ -1441,14 +1441,14 @@ def write(x: Atom, b: bool) -> None:
14411441 env [x ] = read (x ) or b
14421442
14431443 new_eqns = []
1444- map (write , jaxpr .outvars , used_outputs )
1444+ foreach (write , jaxpr .outvars , used_outputs )
14451445 for eqn in jaxpr .eqns [::- 1 ]:
14461446 used_outs = map (read , eqn .outvars )
14471447 rule = dce_rules .get (eqn .primitive , _default_dce_rule )
14481448 used_ins , new_eqn = rule (used_outs , eqn )
14491449 if new_eqn is not None :
14501450 new_eqns .append (new_eqn )
1451- map (write , eqn .invars , used_ins )
1451+ foreach (write , eqn .invars , used_ins )
14521452 used_inputs = map (read , jaxpr .invars )
14531453 used_inputs = map (op .or_ , instantiate , used_inputs )
14541454
@@ -2495,15 +2495,15 @@ def read(x):
24952495 def write (v , val ) -> None :
24962496 env [v ] = val
24972497
2498- map (write , jaxpr .constvars , consts )
2499- map (write , jaxpr .invars , args )
2498+ foreach (write , jaxpr .constvars , consts )
2499+ foreach (write , jaxpr .invars , args )
25002500 last_used = core .last_used (jaxpr )
25012501 for eqn in jaxpr .eqns :
25022502 in_avals = [_substitute_axis_sizes (env , v .aval ) for v in eqn .invars ]
25032503 out_avals = [_substitute_axis_sizes (env , v .aval ) for v in eqn .outvars ]
25042504 rule = padding_rules [eqn .primitive ]
25052505 outs = rule (in_avals , out_avals , * map (read , eqn .invars ), ** eqn .params )
2506- map (write , eqn .outvars , outs )
2506+ foreach (write , eqn .outvars , outs )
25072507 core .clean_up_dead_vars (eqn , env , last_used )
25082508 return map (read , jaxpr .outvars )
25092509
@@ -2580,7 +2580,7 @@ def inline_jaxpr_into_trace(
25802580 src_ = (src if not eqn .source_info .name_stack else
25812581 src .replace (name_stack = src .name_stack + eqn .source_info .name_stack ))
25822582 trace .frame .add_eqn (eqn .replace (invars , outvars , source_info = src_ ))
2583- map (env .setdefault , eqn .outvars , outvars )
2583+ foreach (env .setdefault , eqn .outvars , outvars )
25842584
25852585 tracer_env : dict [Var , Any ] = dict (zip ([* jaxpr .constvars , * jaxpr .invars ],
25862586 [* consts , * arg_tracers ]))
0 commit comments