Skip to content

Commit b583f11

Browse files
greycookerJunnYu
authored andcommitted
Trainer support simultaneously parse JSON files and cmd arguments. (#7768)
* add parse_json_file_and_cmd_lines * change unit test file path * Change the way the JSON file is determined * Merge parameter parsing judgment branches and add comments. * remove the special handling of output_dir * Add remaining_args warning
1 parent 1743466 commit b583f11

File tree

4 files changed

+174
-6
lines changed

4 files changed

+174
-6
lines changed

llm/finetune_generation.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,10 @@ def read_local_dataset(path):
5858
def main():
5959
# Arguments
6060
parser = PdArgumentParser((GenerateArgument, QuantArgument, ModelArgument, DataArgument, TrainingArguments))
61-
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
62-
gen_args, quant_args, model_args, data_args, training_args = parser.parse_json_file(
63-
json_file=os.path.abspath(sys.argv[1])
64-
)
61+
# Support format as "args.json --arg1 value1 --arg2 value2.”
62+
# In case of conflict, command line arguments take precedence.
63+
if len(sys.argv) >= 2 and sys.argv[1].endswith(".json"):
64+
gen_args, quant_args, model_args, data_args, training_args = parser.parse_json_file_and_cmd_lines()
6565
else:
6666
gen_args, quant_args, model_args, data_args, training_args = parser.parse_args_into_dataclasses()
6767
training_args.print_config(model_args, "Model")

llm/llama/tests/test_argparser.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import json
15+
import os
16+
import sys
17+
import tempfile
18+
import unittest
19+
from unittest.mock import patch
20+
21+
from llm.run_pretrain import PreTrainingArguments
22+
from paddlenlp.trainer.argparser import PdArgumentParser
23+
24+
25+
def parse_args():
26+
parser = PdArgumentParser((PreTrainingArguments,))
27+
# Support format as "args.json --arg1 value1 --arg2 value2.”
28+
# In case of conflict, command line arguments take precedence.
29+
if len(sys.argv) >= 2 and sys.argv[1].endswith(".json"):
30+
model_args = parser.parse_json_file_and_cmd_lines()
31+
else:
32+
model_args = parser.parse_args_into_dataclasses()
33+
return model_args
34+
35+
36+
def create_json_from_dict(data_dict, file_path):
37+
with open(file_path, "w") as f:
38+
json.dump(data_dict, f)
39+
40+
41+
class ArgparserTest(unittest.TestCase):
42+
script_name = "test_argparser.py"
43+
args_dict = {
44+
"max_steps": 3000,
45+
"amp_master_grad": False,
46+
"adam_beta1": 0.9,
47+
"adam_beta2": 0.999,
48+
"adam_epsilon": 1e-08,
49+
"bf16": False,
50+
"enable_linear_fused_grad_add": False,
51+
"eval_steps": 3216,
52+
"flatten_param_grads": False,
53+
"fp16": 1,
54+
"log_on_each_node": True,
55+
"logging_dir": "./checkpoints/llama2_pretrain_ckpts/runs/Dec27_04-28-35_instance-047hzlt0-4",
56+
"logging_first_step": False,
57+
"logging_steps": 1,
58+
"lr_end": 1e-07,
59+
"max_evaluate_steps": -1,
60+
"max_grad_norm": 1.0,
61+
"min_learning_rate": 3e-06,
62+
"no_cuda": False,
63+
"num_cycles": 0.5,
64+
"num_train_epochs": 3.0,
65+
"output_dir": "./checkpoints/llama2_pretrain_ckpts",
66+
}
67+
68+
def test_parse_cmd_lines(self):
69+
cmd_line_args = [ArgparserTest.script_name]
70+
for key, value in ArgparserTest.args_dict.items():
71+
cmd_line_args.extend([f"--{key}", str(value)])
72+
with patch("sys.argv", cmd_line_args):
73+
model_args = vars(parse_args()[0])
74+
for key, value in ArgparserTest.args_dict.items():
75+
self.assertEqual(model_args.get(key), value)
76+
77+
def test_parse_json_file(self):
78+
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as tmpfile:
79+
create_json_from_dict(ArgparserTest.args_dict, tmpfile.name)
80+
tmpfile_path = tmpfile.name
81+
with patch("sys.argv", [ArgparserTest.script_name, tmpfile_path]):
82+
model_args = vars(parse_args()[0])
83+
for key, value in ArgparserTest.args_dict.items():
84+
self.assertEqual(model_args.get(key), value)
85+
os.remove(tmpfile_path)
86+
87+
def test_parse_json_file_and_cmd_lines(self):
88+
half_size = len(ArgparserTest.args_dict) // 2
89+
json_part = {k: ArgparserTest.args_dict[k] for k in list(ArgparserTest.args_dict)[:half_size]}
90+
cmd_line_part = {k: ArgparserTest.args_dict[k] for k in list(ArgparserTest.args_dict)[half_size:]}
91+
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as tmpfile:
92+
create_json_from_dict(json_part, tmpfile.name)
93+
tmpfile_path = tmpfile.name
94+
cmd_line_args = [ArgparserTest.script_name, tmpfile_path]
95+
for key, value in cmd_line_part.items():
96+
cmd_line_args.extend([f"--{key}", str(value)])
97+
with patch("sys.argv", cmd_line_args):
98+
model_args = vars(parse_args()[0])
99+
for key, value in ArgparserTest.args_dict.items():
100+
self.assertEqual(model_args.get(key), value)
101+
os.remove(tmpfile_path)
102+
103+
def test_parse_json_file_and_cmd_lines_with_conflict(self):
104+
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as tmpfile:
105+
json.dump(ArgparserTest.args_dict, tmpfile)
106+
tmpfile_path = tmpfile.name
107+
cmd_line_args = [
108+
ArgparserTest.script_name,
109+
tmpfile_path,
110+
"--min_learning_rate",
111+
"2e-5",
112+
"--max_steps",
113+
"3000",
114+
"--log_on_each_node",
115+
"False",
116+
]
117+
with patch("sys.argv", cmd_line_args):
118+
model_args = vars(parse_args()[0])
119+
self.assertEqual(model_args.get("min_learning_rate"), 2e-5)
120+
self.assertEqual(model_args.get("max_steps"), 3000)
121+
self.assertEqual(model_args.get("log_on_each_node"), False)
122+
for key, value in ArgparserTest.args_dict.items():
123+
if key not in ["min_learning_rate", "max_steps", "log_on_each_node"]:
124+
self.assertEqual(model_args.get(key), value)
125+
os.remove(tmpfile_path)

llm/run_pretrain.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -372,8 +372,10 @@ def _get_train_sampler(self) -> Optional[paddle.io.Sampler]:
372372

373373
def main():
374374
parser = PdArgumentParser((ModelArguments, DataArguments, PreTrainingArguments))
375-
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
376-
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
375+
# Support format as "args.json --arg1 value1 --arg2 value2.”
376+
# In case of conflict, command line arguments take precedence.
377+
if len(sys.argv) >= 2 and sys.argv[1].endswith(".json"):
378+
model_args, data_args, training_args = parser.parse_json_file_and_cmd_lines()
377379
else:
378380
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
379381

paddlenlp/trainer/argparser.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import dataclasses
2020
import json
2121
import sys
22+
import warnings
2223
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser, ArgumentTypeError
2324
from copy import copy
2425
from enum import Enum
@@ -247,6 +248,46 @@ def parse_json_file(self, json_file: str) -> Tuple[DataClass, ...]:
247248
outputs.append(obj)
248249
return (*outputs,)
249250

251+
def parse_json_file_and_cmd_lines(self) -> Tuple[DataClass, ...]:
252+
"""
253+
Extend the functionality of `parse_json_file` to handle command line arguments in addition to loading a JSON
254+
file.
255+
256+
When there is a conflict between the command line arguments and the JSON file configuration,
257+
the command line arguments will take precedence.
258+
259+
Returns:
260+
Tuple consisting of:
261+
262+
- the dataclass instances in the same order as they were passed to the initializer.abspath
263+
"""
264+
if not sys.argv[1].endswith(".json"):
265+
raise ValueError(f"The first argument should be a JSON file, but it is {sys.argv[1]}")
266+
json_file = Path(sys.argv[1])
267+
if json_file.exists():
268+
with open(json_file, "r") as file:
269+
data = json.load(file)
270+
json_args = []
271+
for key, value in data.items():
272+
json_args.extend([f"--{key}", str(value)])
273+
else:
274+
raise FileNotFoundError(f"The argument file {json_file} does not exist.")
275+
# In case of conflict, command line arguments take precedence
276+
args = json_args + sys.argv[2:]
277+
namespace, remaining_args = self.parse_known_args(args=args)
278+
outputs = []
279+
for dtype in self.dataclass_types:
280+
keys = {f.name for f in dataclasses.fields(dtype) if f.init}
281+
inputs = {k: v for k, v in vars(namespace).items() if k in keys}
282+
for k in keys:
283+
delattr(namespace, k)
284+
obj = dtype(**inputs)
285+
outputs.append(obj)
286+
if remaining_args:
287+
warnings.warn(f"Some specified arguments are not used by the PdArgumentParser: {remaining_args}")
288+
289+
return (*outputs,)
290+
250291
def parse_dict(self, args: dict) -> Tuple[DataClass, ...]:
251292
"""
252293
Alternative helper method that does not use `argparse` at all, instead uses a dict and populating the dataclass

0 commit comments

Comments
 (0)