@@ -50,6 +50,12 @@ def load_jsonl(path):
5050 default = 8 ,
5151 help = "set this if you want to change the number of tokens generated per sequence (1 prefill + max_new_tokens-1 decodes). Note: If this value is larger than 64, this may result in switching decode programs mid generation" ,
5252)
53+ parser .add_argument (
54+ "--max_workers" ,
55+ type = int ,
56+ default = 8 ,
57+ help = "max workers to run in parallel" ,
58+ )
5359parser .add_argument (
5460 "--dataset_path" ,
5561 type = str ,
@@ -90,12 +96,12 @@ def process_row(row):
9096 "validation" : cpu_validation_info
9197 }
9298
93- validation_info = {}
94- # results = []
95- for row in dataset :
96- result = process_row (row )
97- # results.append(result)
99+ with ThreadPoolExecutor (max_workers = args .max_workers ) as executor :
100+ results = list (executor .map (process_row , dataset ))
98101
102+ # save the results
103+ validation_info = {}
104+ for result in results :
99105 tokens = result ["validation" ].get_info ("tokens" )
100106 generated_tokens_tensor = tokens [0 ][- max_new_tokens :]
101107 generated_tokens = [token .item () for token in generated_tokens_tensor ]
@@ -104,6 +110,8 @@ def process_row(row):
104110 for step_num , logits_for_step in enumerate (logits [0 ]):
105111 logprob_for_step = torch .nn .functional .log_softmax (logits_for_step , dim = - 1 )
106112 values , indices = torch .topk (logprob_for_step , k = 100 )
113+ # in case we want to save a new tensor?
114+ # but this will also take memory
107115 # top_logprobs = torch.full_like(logprobs, float('-inf'))
108116 # top_logprobs.scatter_(1, indices, values)
109117 top_logprob_dict = {
@@ -116,43 +124,8 @@ def process_row(row):
116124 "tokens" : generated_tokens ,
117125 "text" : tokenizer .decode (generated_tokens )
118126 }
119- with open (f"{ result ["id" ]} - cpu_validation_info123.json" , "w" ) as f :
120- json .dump (validation_info , f , indent = 4 )
121-
122-
123- # torch.save(validation_info, f"{result["id"]} - cpu_validation_info_top_dict.pt")
124- # torch.save(result, f"{result["id"]} - cpu_validation_info1.pt")
125- # print("saved cpu validation info for id: ", result["id"])
126- # with ThreadPoolExecutor(max_workers=5) as executor:
127- # results = list(executor.map(process_row, dataset))
128-
129- # save the results
130- # validation_info = {}
131- # for result in results:
132- # tokens = result["validation"].get_info("tokens")
133- # generated_tokens = tokens[0][-max_new_tokens:]
134- # logits = result["validation"].get_info("logits")
135- # logprobs = []
136- # top_logprob_dict_list = []
137- # for step_num, logits_for_step in enumerate(logits[0]):
138- # logprobs.append(torch.nn.functional.log_softmax(logits_for_step, dim=-1))
139- # values, indices = torch.topk(logprobs, k=100)
140- # # top_logprobs = torch.full_like(logprobs, float('-inf'))
141- # # top_logprobs.scatter_(1, indices, values)
142- # top_logprob_dict = {
143- # int(idx): float(val)
144- # for idx, val in zip(indices[0], values[0])
145- # }
146- # validation_info[result["id"]] = {
147- # "logprobs": top_logprob_dict,
148- # "tokens": generated_tokens,
149- # "text": "".join([tokenizer.decode(tensor.tolist(), \
150- # skip_special_tokens=True) for \
151- # tensor in generated_tokens])
152- # }
153-
154127
155- # torch. save(validation_info, "cpu_validation_info_top_dict.pt")
128+ # save the final result
156129with open ("cpu_validation_info.json" , "w" ) as f :
157130 json .dump (validation_info , f , indent = 4 )
158131print ("all done!" )
0 commit comments