Skip to content

Commit fff730e

Browse files
greycookerJunnYu
authored andcommitted
fix_trainer_argparser (#7860)
1 parent b583f11 commit fff730e

File tree

2 files changed

+13
-3
lines changed

2 files changed

+13
-3
lines changed

llm/llama/tests/test_argparser.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ class ArgparserTest(unittest.TestCase):
4545
"amp_master_grad": False,
4646
"adam_beta1": 0.9,
4747
"adam_beta2": 0.999,
48+
"amp_custom_black_list": ["reduce_sum", "sin", "cos"],
4849
"adam_epsilon": 1e-08,
4950
"bf16": False,
5051
"enable_linear_fused_grad_add": False,
@@ -68,7 +69,10 @@ class ArgparserTest(unittest.TestCase):
6869
def test_parse_cmd_lines(self):
6970
cmd_line_args = [ArgparserTest.script_name]
7071
for key, value in ArgparserTest.args_dict.items():
71-
cmd_line_args.extend([f"--{key}", str(value)])
72+
if isinstance(value, list):
73+
cmd_line_args.extend([f"--{key}", *[str(v) for v in value]])
74+
else:
75+
cmd_line_args.extend([f"--{key}", str(value)])
7276
with patch("sys.argv", cmd_line_args):
7377
model_args = vars(parse_args()[0])
7478
for key, value in ArgparserTest.args_dict.items():
@@ -93,7 +97,10 @@ def test_parse_json_file_and_cmd_lines(self):
9397
tmpfile_path = tmpfile.name
9498
cmd_line_args = [ArgparserTest.script_name, tmpfile_path]
9599
for key, value in cmd_line_part.items():
96-
cmd_line_args.extend([f"--{key}", str(value)])
100+
if isinstance(value, list):
101+
cmd_line_args.extend([f"--{key}", *[str(v) for v in value]])
102+
else:
103+
cmd_line_args.extend([f"--{key}", str(value)])
97104
with patch("sys.argv", cmd_line_args):
98105
model_args = vars(parse_args()[0])
99106
for key, value in ArgparserTest.args_dict.items():

paddlenlp/trainer/argparser.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,10 @@ def parse_json_file_and_cmd_lines(self) -> Tuple[DataClass, ...]:
269269
data = json.load(file)
270270
json_args = []
271271
for key, value in data.items():
272-
json_args.extend([f"--{key}", str(value)])
272+
if isinstance(value, list):
273+
json_args.extend([f"--{key}", *[str(v) for v in value]])
274+
else:
275+
json_args.extend([f"--{key}", str(value)])
273276
else:
274277
raise FileNotFoundError(f"The argument file {json_file} does not exist.")
275278
# In case of conflict, command line arguments take precedence

0 commit comments

Comments
 (0)