Skip to content

Commit c518bcb

Browse files
✨ use multi-threading
Co-authored-by: Rafael Vasquez <[email protected]> Signed-off-by: Prashant Gupta <[email protected]>
1 parent f4d0754 commit c518bcb

File tree

1 file changed

+14
-41
lines changed

1 file changed

+14
-41
lines changed

aiu_fms_testing_utils/scripts/save_cpu_data.py

Lines changed: 14 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,12 @@ def load_jsonl(path):
5050
default=8,
5151
help="set this if you want to change the number of tokens generated per sequence (1 prefill + max_new_tokens-1 decodes). Note: If this value is larger than 64, this may result in switching decode programs mid generation",
5252
)
53+
parser.add_argument(
54+
"--max_workers",
55+
type=int,
56+
default=8,
57+
help="max workers to run in parallel",
58+
)
5359
parser.add_argument(
5460
"--dataset_path",
5561
type=str,
@@ -90,12 +96,12 @@ def process_row(row):
9096
"validation": cpu_validation_info
9197
}
9298

93-
validation_info = {}
94-
# results = []
95-
for row in dataset:
96-
result = process_row(row)
97-
# results.append(result)
99+
with ThreadPoolExecutor(max_workers=args.max_workers) as executor:
100+
results = list(executor.map(process_row, dataset))
98101

102+
# save the results
103+
validation_info = {}
104+
for result in results:
99105
tokens = result["validation"].get_info("tokens")
100106
generated_tokens_tensor = tokens[0][-max_new_tokens:]
101107
generated_tokens = [token.item() for token in generated_tokens_tensor]
@@ -104,6 +110,8 @@ def process_row(row):
104110
for step_num, logits_for_step in enumerate(logits[0]):
105111
logprob_for_step = torch.nn.functional.log_softmax(logits_for_step, dim=-1)
106112
values, indices = torch.topk(logprob_for_step, k=100)
113+
# in case we want to save a new tensor?
114+
# but this will also take memory
107115
# top_logprobs = torch.full_like(logprobs, float('-inf'))
108116
# top_logprobs.scatter_(1, indices, values)
109117
top_logprob_dict = {
@@ -116,43 +124,8 @@ def process_row(row):
116124
"tokens": generated_tokens,
117125
"text": tokenizer.decode(generated_tokens)
118126
}
119-
with open(f"{result["id"]} - cpu_validation_info123.json", "w") as f:
120-
json.dump(validation_info, f, indent=4)
121-
122-
123-
# torch.save(validation_info, f"{result["id"]} - cpu_validation_info_top_dict.pt")
124-
# torch.save(result, f"{result["id"]} - cpu_validation_info1.pt")
125-
# print("saved cpu validation info for id: ", result["id"])
126-
# with ThreadPoolExecutor(max_workers=5) as executor:
127-
# results = list(executor.map(process_row, dataset))
128-
129-
# save the results
130-
# validation_info = {}
131-
# for result in results:
132-
# tokens = result["validation"].get_info("tokens")
133-
# generated_tokens = tokens[0][-max_new_tokens:]
134-
# logits = result["validation"].get_info("logits")
135-
# logprobs = []
136-
# top_logprob_dict_list = []
137-
# for step_num, logits_for_step in enumerate(logits[0]):
138-
# logprobs.append(torch.nn.functional.log_softmax(logits_for_step, dim=-1))
139-
# values, indices = torch.topk(logprobs, k=100)
140-
# # top_logprobs = torch.full_like(logprobs, float('-inf'))
141-
# # top_logprobs.scatter_(1, indices, values)
142-
# top_logprob_dict = {
143-
# int(idx): float(val)
144-
# for idx, val in zip(indices[0], values[0])
145-
# }
146-
# validation_info[result["id"]] = {
147-
# "logprobs": top_logprob_dict,
148-
# "tokens": generated_tokens,
149-
# "text": "".join([tokenizer.decode(tensor.tolist(), \
150-
# skip_special_tokens=True) for \
151-
# tensor in generated_tokens])
152-
# }
153-
154127

155-
# torch.save(validation_info, "cpu_validation_info_top_dict.pt")
128+
# save the final result
156129
with open("cpu_validation_info.json", "w") as f:
157130
json.dump(validation_info, f, indent=4)
158131
print("all done!")

0 commit comments

Comments
 (0)