Skip to content

Commit 03f05a6

Browse files
authored
Revert "added the number of flops calculation to feature extraction (#108)"
This reverts commit 1340406.
1 parent 1340406 commit 03f05a6

File tree

1 file changed

+3
-24
lines changed

1 file changed

+3
-24
lines changed

src/cehrbert/linear_prob/compute_cehrbert_features.py

Lines changed: 3 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import glob
2-
import json
32
import os
43
import uuid
54
from datetime import datetime
@@ -10,7 +9,6 @@
109
import pandas as pd
1110
import torch
1211
from datasets import concatenate_datasets, load_from_disk
13-
from torch.profiler import ProfilerActivity, profile
1412
from torch.utils.data import DataLoader
1513
from tqdm import tqdm
1614
from transformers.utils import is_flash_attn_2_available, logging
@@ -222,9 +220,7 @@ def main():
222220
}
223221

224222
data_loaders = [("train", train_loader), ("test", test_dataloader)]
225-
training_metrics_file = Path(training_args.output_dir) / "training_metrics.json"
226-
start_time: datetime = datetime.now()
227-
total_gflops = 0
223+
228224
for split, data_loader in data_loaders:
229225

230226
# Ensure prediction folder exists
@@ -251,17 +247,8 @@ def main():
251247
labels = np.asarray([labels])
252248

253249
batch = {k: v.to(device) for k, v in batch.items()}
254-
with profile(
255-
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
256-
with_flops=True,
257-
) as prof:
258-
# Forward pass
259-
cehrbert_output = cehrbert_model(**batch, output_attentions=False, output_hidden_states=False)
260-
261-
for event in prof.key_averages():
262-
if hasattr(event, "flops") and event.flops > 0:
263-
# Convert to GFLOPs
264-
total_gflops += event.flops / 1e9
250+
# Forward pass
251+
cehrbert_output = cehrbert_model(**batch, output_attentions=False, output_hidden_states=False)
265252

266253
cls_token_indices = batch["input_ids"] == cehrgpt_tokenizer.cls_token_index
267254
if cehrbert_args.sample_packing:
@@ -315,14 +302,6 @@ def main():
315302
features_pd["gender_concept_id"] = gender_concept_ids
316303
features_pd.to_parquet(feature_output_folder / f"{uuid.uuid4()}.parquet")
317304

318-
# Save the training metrics to the output file
319-
with open(training_metrics_file, "w") as output_file:
320-
training_metrics = {
321-
"duration_in_seconds": (datetime.now() - start_time).total_seconds(),
322-
"total_flops": total_gflops,
323-
}
324-
json.dump(training_metrics, output_file)
325-
326305

327306
if __name__ == "__main__":
328307
main()

0 commit comments

Comments
 (0)