Skip to content

Commit ff00fa9

Browse files
superbobryGoogle-ML-Automation
authored andcommitted
Removed unused jax_remat_opt_barrier config option
It defaults to True and is not flipped to False by any internal JAX users. PiperOrigin-RevId: 744754343
1 parent 5a3fc60 commit ff00fa9

File tree

4 files changed

+21
-89
lines changed

4 files changed

+21
-89
lines changed

jax/_src/ad_checkpoint.py

Lines changed: 15 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -757,89 +757,34 @@ def _has_effects(effects) -> bool:
757757
return bool({e for e in effects if not isinstance(e, core.NamedAxisEffect)})
758758

759759

760-
def remat_expansion(*args, jaxpr: core.Jaxpr, prevent_cse: bool,
761-
differentiated: bool, is_gpu_platform: bool = False,
762-
**_):
760+
def remat_expansion(
761+
*args, jaxpr: core.Jaxpr, prevent_cse: bool, differentiated: bool, **_
762+
):
763763
assert not jaxpr.constvars
764764

765765
if differentiated and prevent_cse:
766-
if config.remat_opt_barrier.value:
767-
translation_rule = _remat_translation_using_opt_barrier
768-
elif is_gpu_platform:
769-
translation_rule = _remat_translation_using_while
770-
else:
771-
translation_rule = _remat_translation_using_cond
766+
translation_rule = _remat_translation_using_opt_barrier
772767
else:
773768
translation_rule = lambda *args, jaxpr: core.eval_jaxpr(jaxpr, (), *args)
774769

775770
return api.named_call(translation_rule, name="checkpoint")(*args, jaxpr=jaxpr)
776771

772+
777773
def _remat_translation_using_opt_barrier(*args, jaxpr: core.Jaxpr):
778774
args = lax_internal.optimization_barrier(args)
779775
return core.eval_jaxpr(jaxpr, (), *args)
780776

781-
# TODO(mattjj): add core utility for 'create dummy value for this type'?
782-
def _dummy_like(aval: core.AbstractValue) -> Any:
783-
if aval is core.abstract_token:
784-
return lax_internal.create_token()
785-
elif isinstance(aval, (core.ShapedArray, core.DShapedArray)):
786-
return lax_internal.broadcast(lax_internal.empty(aval.dtype), aval.shape) # type: ignore
787-
else:
788-
raise ValueError(aval)
789-
790-
def _remat_translation_using_while(*args, jaxpr: core.Jaxpr):
791-
# Implements:
792-
# for(counter=0, result=0; counter < rng(1, 2); counter ++) {
793-
# result = eval_jaxpr(*args)
794-
# }
795-
# The loop carry is a tuple: (counter, result, args)
796-
from jax._src.lax import control_flow as lax_control_flow
797-
798-
avals_out = tuple(v.aval for v in jaxpr.outvars)
799-
carry_init = (np.int32(0), tuple(map(_dummy_like, avals_out)), args)
800-
def cond(carry):
801-
counter, _, _ = carry
802-
unif = lax_internal.rng_uniform(np.int32(1), np.int32(2), shape=())
803-
return counter < unif
804-
805-
def body(carry):
806-
counter, _, args = carry
807-
results = core.eval_jaxpr(jaxpr, (), *args)
808-
return (counter + 1, tuple(results), args)
809-
810-
carry_res = lax_control_flow.while_loop(cond, body, carry_init)
811-
return carry_res[1]
812-
813-
def _remat_translation_using_cond(*args, jaxpr: core.Jaxpr):
814-
# Implements:
815-
# if(rng(0, 1) < 2)
816-
# return eval_jaxpr(*args)
817-
# else:
818-
# return 0
819-
from jax._src.lax import control_flow as lax_control_flow
820-
821-
avals_out = tuple(v.aval for v in jaxpr.outvars)
822-
823-
def remat_comp(*args):
824-
return tuple(core.eval_jaxpr(jaxpr, (), *args))
825-
def dummy_comp(*args):
826-
return tuple(map(_dummy_like, avals_out))
827-
828-
unif = lax_internal.rng_uniform(np.float32(0), np.float32(1), shape=())
829-
return lax_control_flow.cond(unif < np.float32(2), remat_comp, dummy_comp, *args)
830-
831-
def _remat_lowering(ctx, *args, jaxpr: core.Jaxpr, prevent_cse: bool,
832-
differentiated: bool, policy, is_gpu_platform=False):
777+
778+
def _remat_lowering(
779+
ctx,
780+
*args,
781+
jaxpr: core.Jaxpr,
782+
prevent_cse: bool,
783+
differentiated: bool,
784+
policy,
785+
):
833786
jaxpr_args: Sequence[mlir.IrValues]
834787
if differentiated and prevent_cse:
835-
# If we're using the loop or cond lowerings, use the slower lower_fun
836-
# based path.
837-
if not config.remat_opt_barrier.value:
838-
return mlir.lower_fun(remat_expansion, multiple_results=True)(
839-
ctx, *args, jaxpr=jaxpr, prevent_cse=prevent_cse,
840-
differentiated=differentiated, policy=policy,
841-
is_gpu_platform=is_gpu_platform)
842-
843788
arg_types = map(mlir.aval_to_ir_type, ctx.avals_in)
844789
flat_args = mlir.flatten_ir_values(args)
845790
barrier_op = hlo.OptimizationBarrierOp(flat_args)
@@ -853,9 +798,8 @@ def _remat_lowering(ctx, *args, jaxpr: core.Jaxpr, prevent_cse: bool,
853798
ctx.set_tokens_out(tokens_out)
854799
return outs
855800

801+
856802
mlir.register_lowering(remat_p, _remat_lowering)
857-
mlir.register_lowering(remat_p, partial(_remat_lowering, is_gpu_platform=True),
858-
platform="gpu")
859803

860804

861805
def checkpoint_name(x, name):

jax/_src/config.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1512,13 +1512,6 @@ def _update_disable_jit_thread_local(val):
15121512
help=('Attempt constant folding during staging.'),
15131513
include_in_jit_key=True)
15141514

1515-
# This flag is temporary during rollout of the remat barrier.
1516-
# TODO(parkers): Remove if there are no complaints.
1517-
remat_opt_barrier = bool_state(
1518-
name='jax_remat_opt_barrier',
1519-
default=True,
1520-
help=('Enables using optimization-barrier op for lowering remat.'))
1521-
15221515
enable_remat_opt_pass = bool_state(
15231516
name='jax_compiler_enable_remat_pass',
15241517
default=True,

jax/experimental/jax2tf/jax2tf.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3173,12 +3173,11 @@ def select_one_carry(new_c: TfVal, c: TfVal, c_aval: core.ShapedArray) -> TfVal:
31733173
lax_control_flow._scan_impl,
31743174
extra_name_stack="scan")
31753175

3176-
tf_impl_with_avals[ad_checkpoint.remat_p] = \
3177-
_convert_jax_impl(partial(ad_checkpoint.remat_expansion,
3178-
# TODO: jax2tf cannot discriminate by platform
3179-
is_gpu_platform=False),
3180-
multiple_results=True,
3181-
extra_name_stack="checkpoint")
3176+
tf_impl_with_avals[ad_checkpoint.remat_p] = _convert_jax_impl(
3177+
ad_checkpoint.remat_expansion,
3178+
multiple_results=True,
3179+
extra_name_stack="checkpoint",
3180+
)
31823181

31833182
tf_impl[ad_checkpoint.name_p] = lambda x, *, name: x
31843183

jax/experimental/jax2tf/tests/jax2tf_test.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -832,11 +832,7 @@ def f(x1):
832832
arg = np.array(3.)
833833
f_tf = jax2tf.convert(jax.grad(remat_f))
834834
f_tf_hlo = self.TfToHlo(f_tf, arg)
835-
if config.remat_opt_barrier.value:
836-
self.assertRegex(f_tf_hlo, r"opt-barrier")
837-
else:
838-
self.assertRegex(f_tf_hlo,
839-
r'transpose/jax2tf_f_/jvp/checkpoint/cond/branch_1_fun/Sin')
835+
self.assertRegex(f_tf_hlo, r"opt-barrier")
840836

841837
def test_remat_free_var(self):
842838
def f(x):

0 commit comments

Comments
 (0)