1717import os
1818
1919import numpy as np
20+ import paddle
2021import pandas as pd
2122from categories import categories , subcategories
2223from evaluator import ModelEvaluator
@@ -29,9 +30,9 @@ def main(args, evaluator):
2930 [f .split ("_test.csv" )[0 ] for f in os .listdir (os .path .join (args .data_dir , "test" )) if "_test.csv" in f ]
3031 )
3132 if not os .path .exists (args .output_dir ):
32- os .makedirs (args .output_dir )
33+ os .makedirs (args .output_dir , exist_ok = True )
3334 if not os .path .exists (os .path .join (args .output_dir , "results_{}" .format (args .model_name_or_path ))):
34- os .makedirs (os .path .join (args .output_dir , "results_{}" .format (args .model_name_or_path )))
35+ os .makedirs (os .path .join (args .output_dir , "results_{}" .format (args .model_name_or_path )), exist_ok = True )
3536
3637 all_cors = []
3738 subcat_cors = {subcat : [] for subcat_lists in subcategories .values () for subcat in subcat_lists }
@@ -95,15 +96,25 @@ def main(args, evaluator):
9596 parser .add_argument ("--data_dir" , "-d" , type = str , default = "data" )
9697 parser .add_argument ("--output_dir" , type = str , default = "results" )
9798 parser .add_argument ("--dtype" , default = "float32" , type = str )
99+ parser .add_argument ("--tensor_parallel_degree" , default = 1 , type = int )
98100
99101 args = parser .parse_args ()
100102 print (args )
101103
104+ if args .tensor_parallel_degree > 1 :
105+ strategy = paddle .distributed .fleet .DistributedStrategy ()
106+ strategy .hybrid_configs = {
107+ "mp_degree" : args .tensor_parallel_degree ,
108+ }
109+ # Set control in tensor parallel
110+ strategy .tensor_parallel_configs = {"tensor_init_seed" : 1234 }
111+ paddle .distributed .fleet .init (is_collective = True , strategy = strategy )
102112 evaluator = ModelEvaluator (
103113 model_name_or_path = args .model_name_or_path ,
104114 ntrain = args .ntrain ,
105115 temperature = args .temperature ,
106116 dtype = args .dtype ,
117+ tensor_parallel_degree = args .tensor_parallel_degree ,
107118 )
108119
109120 main (args , evaluator = evaluator )
0 commit comments