diff --git a/paddlenlp/transformers/model_utils.py b/paddlenlp/transformers/model_utils.py index b5756896c65a..58c8514432f0 100644 --- a/paddlenlp/transformers/model_utils.py +++ b/paddlenlp/transformers/model_utils.py @@ -1360,8 +1360,8 @@ def recompute_disable(self): """ def fn(layer): - if hasattr(layer, "enable_recompute") and (layer.enable_recompute is False or layer.enable_recompute == 0): - layer.enable_recompute = True + if hasattr(layer, "enable_recompute") and (layer.enable_recompute is True or layer.enable_recompute == 1): + layer.enable_recompute = False self.apply(fn)