@@ -45,6 +45,7 @@ class ArgparserTest(unittest.TestCase):
45
45
"amp_master_grad" : False ,
46
46
"adam_beta1" : 0.9 ,
47
47
"adam_beta2" : 0.999 ,
48
+ "amp_custom_black_list" : ["reduce_sum" , "sin" , "cos" ],
48
49
"adam_epsilon" : 1e-08 ,
49
50
"bf16" : False ,
50
51
"enable_linear_fused_grad_add" : False ,
@@ -68,7 +69,10 @@ class ArgparserTest(unittest.TestCase):
68
69
def test_parse_cmd_lines (self ):
69
70
cmd_line_args = [ArgparserTest .script_name ]
70
71
for key , value in ArgparserTest .args_dict .items ():
71
- cmd_line_args .extend ([f"--{ key } " , str (value )])
72
+ if isinstance (value , list ):
73
+ cmd_line_args .extend ([f"--{ key } " , * [str (v ) for v in value ]])
74
+ else :
75
+ cmd_line_args .extend ([f"--{ key } " , str (value )])
72
76
with patch ("sys.argv" , cmd_line_args ):
73
77
model_args = vars (parse_args ()[0 ])
74
78
for key , value in ArgparserTest .args_dict .items ():
@@ -93,7 +97,10 @@ def test_parse_json_file_and_cmd_lines(self):
93
97
tmpfile_path = tmpfile .name
94
98
cmd_line_args = [ArgparserTest .script_name , tmpfile_path ]
95
99
for key , value in cmd_line_part .items ():
96
- cmd_line_args .extend ([f"--{ key } " , str (value )])
100
+ if isinstance (value , list ):
101
+ cmd_line_args .extend ([f"--{ key } " , * [str (v ) for v in value ]])
102
+ else :
103
+ cmd_line_args .extend ([f"--{ key } " , str (value )])
97
104
with patch ("sys.argv" , cmd_line_args ):
98
105
model_args = vars (parse_args ()[0 ])
99
106
for key , value in ArgparserTest .args_dict .items ():
0 commit comments