Skip to content

Commit 8ab3366

Browse files
hawkinspGoogle-ML-Automation
authored andcommitted
Add a variant of safe_map() that has no return value, named foreach().
This avoids a bunch of list bookkeeping in cases where we are iterating only for a side effect and do not care about the results. I would have named this iter() to match OCaml's list.iter(), but unfortunately iter() is a Python builtin. PiperOrigin-RevId: 736859418
1 parent 074216e commit 8ab3366

File tree

20 files changed

+184
-70
lines changed

20 files changed

+184
-70
lines changed

jax/_src/checkify.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
from jax._src.tree_util import tree_unflatten
5252
from jax._src.typing import Array
5353
from jax._src.util import (as_hashable_function, split_list, safe_map, safe_zip,
54-
unzip3, weakref_lru_cache, HashableWrapper)
54+
unzip3, weakref_lru_cache, HashableWrapper, foreach)
5555

5656
source_info_util.register_exclusion(__file__)
5757
traceback_util.register_exclusion(__file__)
@@ -413,8 +413,8 @@ def read_env(var: core.Atom):
413413
def write_env(var: core.Var, val: Any):
414414
env[var] = val
415415

416-
map(write_env, jaxpr.constvars, consts)
417-
map(write_env, jaxpr.invars, in_args)
416+
foreach(write_env, jaxpr.constvars, consts)
417+
foreach(write_env, jaxpr.invars, in_args)
418418

419419
# interpreter loop
420420
for eqn in jaxpr.eqns:
@@ -427,7 +427,7 @@ def write_env(var: core.Var, val: Any):
427427
error, outvals = checkify_rule(error, enabled_errors,
428428
*invals, **eqn.params)
429429
if eqn.primitive.multiple_results:
430-
map(write_env, eqn.outvars, outvals)
430+
foreach(write_env, eqn.outvars, outvals)
431431
else:
432432
write_env(eqn.outvars[0], outvals)
433433
core.clean_up_dead_vars(eqn, env, last_used)

jax/_src/core.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
from jax._src.util import (safe_zip, safe_map, curry, tuple_insert,
5151
tuple_delete, cache,
5252
HashableFunction, HashableWrapper, weakref_lru_cache,
53-
partition_list, StrictABCMeta)
53+
partition_list, StrictABCMeta, foreach)
5454
import jax._src.pretty_printer as pp
5555
from jax._src.named_sharding import NamedSharding
5656
from jax._src.lib import jax_jit
@@ -578,8 +578,8 @@ def write(v: Var, val: Any) -> None:
578578
env[v] = val
579579

580580
env: dict[Var, Any] = {}
581-
map(write, jaxpr.constvars, consts)
582-
map(write, jaxpr.invars, args)
581+
foreach(write, jaxpr.constvars, consts)
582+
foreach(write, jaxpr.invars, args)
583583
lu = last_used(jaxpr)
584584
for eqn in jaxpr.eqns:
585585
subfuns, bind_params = eqn.primitive.get_bind_params(eqn.params)
@@ -589,7 +589,7 @@ def write(v: Var, val: Any) -> None:
589589
traceback, name_stack=name_stack), eqn.ctx.manager:
590590
ans = eqn.primitive.bind(*subfuns, *map(read, eqn.invars), **bind_params)
591591
if eqn.primitive.multiple_results:
592-
map(write, eqn.outvars, ans)
592+
foreach(write, eqn.outvars, ans)
593593
else:
594594
write(eqn.outvars[0], ans)
595595
clean_up_dead_vars(eqn, env, lu)
@@ -2837,7 +2837,7 @@ def write(v: Var, a: AbstractValue) -> None:
28372837

28382838
# Check out_type matches the let-binders' annotation (after substitution).
28392839
out_type = substitute_vars_in_output_ty(out_type, eqn.invars, eqn.outvars)
2840-
map(write, eqn.outvars, out_type)
2840+
foreach(write, eqn.outvars, out_type)
28412841

28422842
except JaxprTypeError as e:
28432843
ctx, settings = ctx_factory()
@@ -2848,7 +2848,7 @@ def write(v: Var, a: AbstractValue) -> None:
28482848
raise JaxprTypeError(msg, eqn_idx) from None
28492849

28502850
# TODO(mattjj): include output type annotation on jaxpr and check it here
2851-
map(read, jaxpr.outvars)
2851+
foreach(read, jaxpr.outvars)
28522852

28532853
def check_type(
28542854
ctx_factory: Callable[[], tuple[JaxprPpContext, JaxprPpSettings]],

jax/_src/interpreters/ad.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
from jax._src.dtypes import dtype, float0
4040
from jax._src.util import (unzip2, safe_map, safe_zip, split_list, wrap_name,
4141
as_hashable_function, weakref_lru_cache,
42-
partition_list, subs_list2)
42+
partition_list, subs_list2, foreach)
4343

4444
zip = safe_zip
4545
map = safe_map
@@ -344,10 +344,10 @@ def write_primal(v, val):
344344
primal_env[v] = val
345345

346346
primal_env: dict[Any, Any] = {}
347-
map(write_primal, jaxpr.constvars, consts)
347+
foreach(write_primal, jaxpr.constvars, consts)
348348
# FIXME: invars can contain both primal and tangent values, and this line
349349
# forces primal_in to contain UndefinedPrimals for tangent values!
350-
map(write_primal, jaxpr.invars, primals_in)
350+
foreach(write_primal, jaxpr.invars, primals_in)
351351

352352
# Start with a forward pass to evaluate any side-effect-free JaxprEqns that
353353
# only operate on primals. This is required to support primitives with
@@ -367,15 +367,15 @@ def write_primal(v, val):
367367
traceback, name_stack=name_stack), eqn.ctx.manager:
368368
ans = eqn.primitive.bind(*subfuns, *map(read_primal, eqn.invars), **bind_params)
369369
if eqn.primitive.multiple_results:
370-
map(write_primal, eqn.outvars, ans)
370+
foreach(write_primal, eqn.outvars, ans)
371371
else:
372372
write_primal(eqn.outvars[0], ans)
373373

374374
ct_env: dict[Any, Any] = {}
375375
ctx = (source_info_util.transform_name_stack('transpose') if transform_stack
376376
else contextlib.nullcontext())
377377
with ctx:
378-
map(partial(write_cotangent, 'outvars'), jaxpr.outvars, cotangents_in)
378+
foreach(partial(write_cotangent, 'outvars'), jaxpr.outvars, cotangents_in)
379379
for eqn in lin_eqns[::-1]:
380380
if eqn.primitive.ref_primitive:
381381
if eqn.primitive is core.mutable_array_p:
@@ -417,7 +417,7 @@ def write_primal(v, val):
417417
raise e from None
418418
cts_out = [Zero(v.aval) for v in eqn.invars] if cts_out is Zero else cts_out
419419
# FIXME: Some invars correspond to primals!
420-
map(partial(write_cotangent, eqn.primitive), eqn.invars, cts_out)
420+
foreach(partial(write_cotangent, eqn.primitive), eqn.invars, cts_out)
421421

422422
cotangents_out = map(read_cotangent, jaxpr.invars)
423423
return cotangents_out

jax/_src/interpreters/mlir.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
from jax._src.sharding_impls import (AUTO, NamedSharding,
5454
modify_sdy_sharding_wrt_axis_types,
5555
SdyArraySharding, SdyArrayShardingList)
56+
from jax._src.util import foreach
5657
from jax._src.lib import xla_client as xc
5758
from jax._src.lib import xla_extension, xla_extension_version
5859
from jax._src.lib.mlir import dialects, ir, passmanager
@@ -1941,8 +1942,8 @@ def get_override_lowering_rule(primitive: core.Primitive) -> LoweringRule | None
19411942
assert len(args) == len(jaxpr.invars), (jaxpr, args)
19421943
assert len(consts) == len(jaxpr.constvars), (jaxpr, consts)
19431944
assert len(ctx.shape_poly_state.dim_vars) == len(dim_var_values), (ctx.shape_poly_state.dim_vars, dim_var_values)
1944-
map(write, jaxpr.constvars, consts)
1945-
map(write, jaxpr.invars, args)
1945+
foreach(write, jaxpr.constvars, consts)
1946+
foreach(write, jaxpr.invars, args)
19461947
last_used = core.last_used(jaxpr)
19471948
for eqn in jaxpr.eqns:
19481949
in_nodes = map(read, eqn.invars)
@@ -2009,7 +2010,7 @@ def get_override_lowering_rule(primitive: core.Primitive) -> LoweringRule | None
20092010
f"{eqn}, got output {ans}") from e
20102011

20112012
assert len(ans) == len(eqn.outvars), (ans, eqn)
2012-
map(write, eqn.outvars, out_nodes)
2013+
foreach(write, eqn.outvars, out_nodes)
20132014
core.clean_up_dead_vars(eqn, env, last_used)
20142015
return tuple(read(v) for v in jaxpr.outvars), tokens
20152016

@@ -2101,11 +2102,11 @@ def lower_per_platform(ctx: LoweringRuleContext,
21012102
# If there is a single rule left just apply the rule, without conditionals.
21022103
if len(kept_rules) == 1:
21032104
output = kept_rules[0](ctx, *rule_args, **rule_kwargs)
2104-
map(
2105+
foreach(
21052106
lambda o: wrap_compute_type_in_place(ctx, o.owner),
21062107
filter(_is_not_block_argument, flatten_ir_values(output)),
21072108
)
2108-
map(
2109+
foreach(
21092110
lambda o: wrap_xla_metadata_in_place(ctx, o.owner),
21102111
flatten_ir_values(output),
21112112
)
@@ -2146,11 +2147,11 @@ def lower_per_platform(ctx: LoweringRuleContext,
21462147
except TypeError as e:
21472148
raise ValueError("Output of translation rule must be iterable: "
21482149
f"{description}, got output {output}") from e
2149-
map(
2150+
foreach(
21502151
lambda o: wrap_compute_type_in_place(ctx, o.owner),
21512152
filter(_is_not_block_argument, out_nodes),
21522153
)
2153-
map(
2154+
foreach(
21542155
lambda o: wrap_xla_metadata_in_place(ctx, o.owner),
21552156
out_nodes,
21562157
)

jax/_src/interpreters/partial_eval.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
from jax._src.util import (unzip2, safe_zip, safe_map, toposort, split_list,
4848
merge_lists, partition_list, OrderedSet,
4949
as_hashable_function, weakref_lru_cache, subs_list,
50-
HashableFunction)
50+
HashableFunction, foreach)
5151

5252

5353
map, unsafe_map = safe_map, map
@@ -1085,8 +1085,8 @@ def has_effects(effects) -> bool:
10851085

10861086
newvar = core.gensym(suffix='_offload')
10871087
known_eqns, staged_eqns = [], []
1088-
map(write, in_unknowns, in_inst, jaxpr.invars)
1089-
map(partial(write, False, True), jaxpr.constvars)
1088+
foreach(write, in_unknowns, in_inst, jaxpr.invars)
1089+
foreach(partial(write, False, True), jaxpr.constvars)
10901090
for eqn in jaxpr.eqns:
10911091
unks_in, inst_in = unzip2(map(read, eqn.invars))
10921092
rule = partial_eval_jaxpr_custom_rules.get(eqn.primitive)
@@ -1098,18 +1098,18 @@ def has_effects(effects) -> bool:
10981098
residual_refs.add(r)
10991099
else:
11001100
residuals.add(r)
1101-
map(write, unks_out, inst_out, eqn.outvars)
1101+
foreach(write, unks_out, inst_out, eqn.outvars)
11021102
elif any(unks_in):
11031103
inputs = map(ensure_instantiated, inst_in, eqn.invars)
11041104
staged_eqns.append(eqn.replace(invars=inputs))
1105-
map(partial(write, True, True), eqn.outvars)
1105+
foreach(partial(write, True, True), eqn.outvars)
11061106
else:
11071107
known_eqns.append(eqn)
11081108
# If it's an effectful primitive, we always to run and avoid staging it.
11091109
policy = ensure_enum(saveable(
11101110
eqn.primitive, *[x.aval for x in eqn.invars], **eqn.params))
11111111
if has_effects(eqn.effects) or isinstance(policy, SaveableType):
1112-
map(partial(write, False, False), eqn.outvars)
1112+
foreach(partial(write, False, False), eqn.outvars)
11131113
elif isinstance(policy, Offloadable):
11141114
# TODO(slebedev): This is a legit error which requires a BUILD fix.
11151115
from jax._src.dispatch import device_put_p, TransferToMemoryKind, CopySemantics # pytype: disable=import-error
@@ -1124,7 +1124,7 @@ def has_effects(effects) -> bool:
11241124
JaxprEqnContext(None, False))
11251125
known_eqns.append(offload_eqn)
11261126
# resvars are known and available in the backward jaxpr.
1127-
map(partial(write, False, True), resvars)
1127+
foreach(partial(write, False, True), resvars)
11281128
residuals.update(resvars)
11291129
reload_eqn = core.JaxprEqn(
11301130
resvars, eqn.outvars, device_put_p,
@@ -1135,12 +1135,12 @@ def has_effects(effects) -> bool:
11351135
JaxprEqnContext(None, False))
11361136
staged_eqns.append(reload_eqn)
11371137
# outvars are known and available in the backward jaxpr.
1138-
map(partial(write, False, True), eqn.outvars)
1138+
foreach(partial(write, False, True), eqn.outvars)
11391139
else:
11401140
assert isinstance(policy, RecomputeType)
11411141
inputs = map(ensure_instantiated, inst_in, eqn.invars)
11421142
staged_eqns.append(eqn.replace(invars=inputs))
1143-
map(partial(write, False, True), eqn.outvars)
1143+
foreach(partial(write, False, True), eqn.outvars)
11441144
unzipped = unzip2(map(read, jaxpr.outvars))
11451145
out_unknowns, out_inst = list(unzipped[0]), list(unzipped[1])
11461146
assert all(type(v) is Var for v in residuals), residuals
@@ -1441,14 +1441,14 @@ def write(x: Atom, b: bool) -> None:
14411441
env[x] = read(x) or b
14421442

14431443
new_eqns = []
1444-
map(write, jaxpr.outvars, used_outputs)
1444+
foreach(write, jaxpr.outvars, used_outputs)
14451445
for eqn in jaxpr.eqns[::-1]:
14461446
used_outs = map(read, eqn.outvars)
14471447
rule = dce_rules.get(eqn.primitive, _default_dce_rule)
14481448
used_ins, new_eqn = rule(used_outs, eqn)
14491449
if new_eqn is not None:
14501450
new_eqns.append(new_eqn)
1451-
map(write, eqn.invars, used_ins)
1451+
foreach(write, eqn.invars, used_ins)
14521452
used_inputs = map(read, jaxpr.invars)
14531453
used_inputs = map(op.or_, instantiate, used_inputs)
14541454

@@ -2495,15 +2495,15 @@ def read(x):
24952495
def write(v, val) -> None:
24962496
env[v] = val
24972497

2498-
map(write, jaxpr.constvars, consts)
2499-
map(write, jaxpr.invars, args)
2498+
foreach(write, jaxpr.constvars, consts)
2499+
foreach(write, jaxpr.invars, args)
25002500
last_used = core.last_used(jaxpr)
25012501
for eqn in jaxpr.eqns:
25022502
in_avals = [_substitute_axis_sizes(env, v.aval) for v in eqn.invars]
25032503
out_avals = [_substitute_axis_sizes(env, v.aval) for v in eqn.outvars]
25042504
rule = padding_rules[eqn.primitive]
25052505
outs = rule(in_avals, out_avals, *map(read, eqn.invars), **eqn.params)
2506-
map(write, eqn.outvars, outs)
2506+
foreach(write, eqn.outvars, outs)
25072507
core.clean_up_dead_vars(eqn, env, last_used)
25082508
return map(read, jaxpr.outvars)
25092509

@@ -2580,7 +2580,7 @@ def inline_jaxpr_into_trace(
25802580
src_ = (src if not eqn.source_info.name_stack else
25812581
src.replace(name_stack=src.name_stack + eqn.source_info.name_stack))
25822582
trace.frame.add_eqn(eqn.replace(invars, outvars, source_info=src_))
2583-
map(env.setdefault, eqn.outvars, outvars)
2583+
foreach(env.setdefault, eqn.outvars, outvars)
25842584

25852585
tracer_env: dict[Var, Any] = dict(zip([*jaxpr.constvars, *jaxpr.invars],
25862586
[*consts, *arg_tracers]))

jax/_src/lax/lax.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,8 @@
7171
PartitionSpec as P, canonicalize_sharding)
7272
from jax._src.typing import Array, ArrayLike, DimSize, DuckTypedArray, DTypeLike, Shape
7373
from jax._src.util import (NumpyComplexWarning, cache, canonicalize_axis,
74-
safe_map, safe_zip, split_list, weakref_lru_cache)
74+
safe_map, safe_zip, split_list, weakref_lru_cache,
75+
foreach)
7576

7677
_max = builtins.max
7778
_min = builtins.min
@@ -106,7 +107,7 @@ def _check_static_shape(shape: Shape):
106107
# pass dynamic shapes through unchecked
107108
return
108109
else:
109-
map(_check_static_shape, shapes)
110+
foreach(_check_static_shape, shapes)
110111

111112
def _try_broadcast_shapes(*shapes: tuple[int, ...], name: str) -> tuple[int, ...]:
112113
"""

jax/_src/pallas/fuser/fusable_dtype.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from jax._src.pallas.fuser.fusable import fusable_p
3838
from jax._src.state import discharge as state_discharge
3939
from jax._src.state import primitives as state_primitives
40+
from jax._src.util import foreach
4041

4142
# TODO(sharadmv): Enable type checking.
4243
# mypy: ignore-errors
@@ -216,11 +217,11 @@ def read_env(var: core.Atom):
216217
def write_env(var: core.Var, val: Any):
217218
env[var] = val
218219

219-
map(write_env, jaxpr.constvars, consts)
220+
foreach(write_env, jaxpr.constvars, consts)
220221
assert len(jaxpr.invars) == len(
221222
args
222223
), f"Length mismatch: {jaxpr.invars} != {args}"
223-
map(write_env, jaxpr.invars, args)
224+
foreach(write_env, jaxpr.invars, args)
224225

225226
for eqn in jaxpr.eqns:
226227
invals = list(map(read_env, eqn.invars))
@@ -248,7 +249,7 @@ def write_env(var: core.Var, val: Any):
248249

249250
if eqn.primitive.multiple_results:
250251
assert len(outvals) == len(eqn.outvars)
251-
map(write_env, eqn.outvars, outvals)
252+
foreach(write_env, eqn.outvars, outvals)
252253
else:
253254
write_env(eqn.outvars[0], outvals)
254255

jax/_src/pallas/hlo_interpreter.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from jax._src.state import discharge as state_discharge
4141
from jax._src import util
4242
from jax._src.util import (
43+
foreach,
4344
safe_map,
4445
safe_zip,
4546
split_list,
@@ -199,8 +200,8 @@ def write(v: jax_core.Var, val: Any) -> None:
199200
env[v] = val
200201

201202
env: dict[jax_core.Var, Any] = {}
202-
map(write, jaxpr.constvars, consts)
203-
map(write, jaxpr.invars, args)
203+
foreach(write, jaxpr.constvars, consts)
204+
foreach(write, jaxpr.invars, args)
204205
lu = jax_core.last_used(jaxpr)
205206
for eqn in jaxpr.eqns:
206207
in_vals = map(read, eqn.invars)
@@ -216,7 +217,7 @@ def write(v: jax_core.Var, val: Any) -> None:
216217
subfuns, bind_params = eqn.primitive.get_bind_params(eqn.params)
217218
ans = eqn.primitive.bind(*subfuns, *in_vals, **bind_params)
218219
if eqn.primitive.multiple_results:
219-
map(write, eqn.outvars, ans)
220+
foreach(write, eqn.outvars, ans)
220221
else:
221222
write(eqn.outvars[0], ans)
222223
jax_core.clean_up_dead_vars(eqn, env, lu)

jax/_src/pallas/mosaic/lowering.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
from jax._src.state.types import RefBitcaster, RefReshaper
6666
from jax._src.state.utils import dtype_bitwidth
6767
from jax._src.typing import Array, DTypeLike
68+
from jax._src.util import foreach
6869
from jax._src.util import safe_map
6970
from jax._src.util import safe_zip
7071
from jax._src.util import split_list
@@ -950,7 +951,7 @@ def write_env(var: jax_core.Var, val):
950951

951952
for invar, bs in zip(jaxpr.invars, ctx.block_shapes):
952953
block_shape_env[invar] = bs
953-
map(write_env, jaxpr.invars, args)
954+
foreach(write_env, jaxpr.invars, args)
954955

955956
initial_name_stack = [scope.name for scope in ctx.name_stack.stack]
956957
current_name_stack: list[str] = []
@@ -1011,7 +1012,7 @@ def write_env(var: jax_core.Var, val):
10111012
f"{eqn.primitive.name}. "
10121013
"Please file an issue on https://github.com/jax-ml/jax/issues.")
10131014
if eqn.primitive.multiple_results:
1014-
map(write_env, eqn.outvars, ans)
1015+
foreach(write_env, eqn.outvars, ans)
10151016
else:
10161017
write_env(eqn.outvars[0], ans)
10171018

0 commit comments

Comments
 (0)