diff --git a/paddlenlp/trainer/utils/offload_optimizer.py b/paddlenlp/trainer/utils/offload_optimizer.py index f20066f1e29b..a38a8d81e093 100644 --- a/paddlenlp/trainer/utils/offload_optimizer.py +++ b/paddlenlp/trainer/utils/offload_optimizer.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os import paddle from paddle import _C_ops @@ -37,6 +38,10 @@ def reload(tensor): assert new_tensor is tensor, "to_device must be inplace operation" +def is_offload_opt_cache_master_weight(): + return os.getenv("FLAGS_offload_opt_master_weight_cache", "0").lower() in ["true", "1"] + + def hack_offload_optimizer(): # Step 1: mock _add_accumulator origin_add_accumulator = getattr(Optimizer, "_add_accumulator") @@ -60,6 +65,10 @@ def new_opt_op(*args): ret = origin_op(*args) is_offload_opt = getattr(args[0], "is_offload_opt", False) for i, arg in enumerate(args): + # need_offload_arg = i >= 2 and isinstance(arg, paddle.Tensor) and is_offload_opt + # if is_offload_opt_cache_master_weight(): + # need_offload_arg = need_offload_arg and i != 8 + # if need_offload_arg: # do not offload parameter and gradient if ( i >= 2 and isinstance(arg, paddle.Tensor) and is_offload_opt ): # do not offload parameter and gradient