@@ -757,89 +757,34 @@ def _has_effects(effects) -> bool:
757757 return bool ({e for e in effects if not isinstance (e , core .NamedAxisEffect )})
758758
759759
760- def remat_expansion (* args , jaxpr : core . Jaxpr , prevent_cse : bool ,
761- differentiated : bool , is_gpu_platform : bool = False ,
762- ** _ ):
760+ def remat_expansion (
761+ * args , jaxpr : core . Jaxpr , prevent_cse : bool , differentiated : bool , ** _
762+ ):
763763 assert not jaxpr .constvars
764764
765765 if differentiated and prevent_cse :
766- if config .remat_opt_barrier .value :
767- translation_rule = _remat_translation_using_opt_barrier
768- elif is_gpu_platform :
769- translation_rule = _remat_translation_using_while
770- else :
771- translation_rule = _remat_translation_using_cond
766+ translation_rule = _remat_translation_using_opt_barrier
772767 else :
773768 translation_rule = lambda * args , jaxpr : core .eval_jaxpr (jaxpr , (), * args )
774769
775770 return api .named_call (translation_rule , name = "checkpoint" )(* args , jaxpr = jaxpr )
776771
772+
777773def _remat_translation_using_opt_barrier (* args , jaxpr : core .Jaxpr ):
778774 args = lax_internal .optimization_barrier (args )
779775 return core .eval_jaxpr (jaxpr , (), * args )
780776
781- # TODO(mattjj): add core utility for 'create dummy value for this type'?
782- def _dummy_like (aval : core .AbstractValue ) -> Any :
783- if aval is core .abstract_token :
784- return lax_internal .create_token ()
785- elif isinstance (aval , (core .ShapedArray , core .DShapedArray )):
786- return lax_internal .broadcast (lax_internal .empty (aval .dtype ), aval .shape ) # type: ignore
787- else :
788- raise ValueError (aval )
789-
790- def _remat_translation_using_while (* args , jaxpr : core .Jaxpr ):
791- # Implements:
792- # for(counter=0, result=0; counter < rng(1, 2); counter ++) {
793- # result = eval_jaxpr(*args)
794- # }
795- # The loop carry is a tuple: (counter, result, args)
796- from jax ._src .lax import control_flow as lax_control_flow
797-
798- avals_out = tuple (v .aval for v in jaxpr .outvars )
799- carry_init = (np .int32 (0 ), tuple (map (_dummy_like , avals_out )), args )
800- def cond (carry ):
801- counter , _ , _ = carry
802- unif = lax_internal .rng_uniform (np .int32 (1 ), np .int32 (2 ), shape = ())
803- return counter < unif
804-
805- def body (carry ):
806- counter , _ , args = carry
807- results = core .eval_jaxpr (jaxpr , (), * args )
808- return (counter + 1 , tuple (results ), args )
809-
810- carry_res = lax_control_flow .while_loop (cond , body , carry_init )
811- return carry_res [1 ]
812-
813- def _remat_translation_using_cond (* args , jaxpr : core .Jaxpr ):
814- # Implements:
815- # if(rng(0, 1) < 2)
816- # return eval_jaxpr(*args)
817- # else:
818- # return 0
819- from jax ._src .lax import control_flow as lax_control_flow
820-
821- avals_out = tuple (v .aval for v in jaxpr .outvars )
822-
823- def remat_comp (* args ):
824- return tuple (core .eval_jaxpr (jaxpr , (), * args ))
825- def dummy_comp (* args ):
826- return tuple (map (_dummy_like , avals_out ))
827-
828- unif = lax_internal .rng_uniform (np .float32 (0 ), np .float32 (1 ), shape = ())
829- return lax_control_flow .cond (unif < np .float32 (2 ), remat_comp , dummy_comp , * args )
830-
831- def _remat_lowering (ctx , * args , jaxpr : core .Jaxpr , prevent_cse : bool ,
832- differentiated : bool , policy , is_gpu_platform = False ):
777+
778+ def _remat_lowering (
779+ ctx ,
780+ * args ,
781+ jaxpr : core .Jaxpr ,
782+ prevent_cse : bool ,
783+ differentiated : bool ,
784+ policy ,
785+ ):
833786 jaxpr_args : Sequence [mlir .IrValues ]
834787 if differentiated and prevent_cse :
835- # If we're using the loop or cond lowerings, use the slower lower_fun
836- # based path.
837- if not config .remat_opt_barrier .value :
838- return mlir .lower_fun (remat_expansion , multiple_results = True )(
839- ctx , * args , jaxpr = jaxpr , prevent_cse = prevent_cse ,
840- differentiated = differentiated , policy = policy ,
841- is_gpu_platform = is_gpu_platform )
842-
843788 arg_types = map (mlir .aval_to_ir_type , ctx .avals_in )
844789 flat_args = mlir .flatten_ir_values (args )
845790 barrier_op = hlo .OptimizationBarrierOp (flat_args )
@@ -853,9 +798,8 @@ def _remat_lowering(ctx, *args, jaxpr: core.Jaxpr, prevent_cse: bool,
853798 ctx .set_tokens_out (tokens_out )
854799 return outs
855800
801+
856802mlir .register_lowering (remat_p , _remat_lowering )
857- mlir .register_lowering (remat_p , partial (_remat_lowering , is_gpu_platform = True ),
858- platform = "gpu" )
859803
860804
861805def checkpoint_name (x , name ):
0 commit comments