File tree Expand file tree Collapse file tree 1 file changed +10
-0
lines changed Expand file tree Collapse file tree 1 file changed +10
-0
lines changed Original file line number Diff line number Diff line change 64
64
init_chat_template ,
65
65
)
66
66
from paddlenlp .utils .log import logger
67
+ from paddlenlp .utils .tools import get_env_device
67
68
68
69
# Fine-tune Environment Variables to support sharding stage1 overlap optimization.
69
70
os .environ ["USE_CASUAL_MASK" ] = "False"
@@ -105,6 +106,15 @@ def main():
105
106
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
106
107
)
107
108
109
+ if get_env_device () == "xpu" and training_args .gradient_accumulation_steps > 1 :
110
+ try :
111
+ from paddle_xpu .layers .nn .linear import LinearConfig # noqa: F401
112
+ LinearConfig .enable_accumulate_steps_opt ()
113
+ LinearConfig .set_accumulate_steps (training_args .gradient_accumulation_steps )
114
+ except ImportError :
115
+ # It's OK, not use accumulate_steps optimization
116
+ pass
117
+
108
118
# Load model
109
119
if training_args .fp16_opt_level == "O2" :
110
120
if training_args .fp16 :
You can’t perform that action at this time.
0 commit comments