11import glob
2- import json
32import os
43import uuid
54from datetime import datetime
109import pandas as pd
1110import torch
1211from datasets import concatenate_datasets , load_from_disk
13- from torch .profiler import ProfilerActivity , profile
1412from torch .utils .data import DataLoader
1513from tqdm import tqdm
1614from transformers .utils import is_flash_attn_2_available , logging
@@ -222,9 +220,7 @@ def main():
222220 }
223221
224222 data_loaders = [("train" , train_loader ), ("test" , test_dataloader )]
225- training_metrics_file = Path (training_args .output_dir ) / "training_metrics.json"
226- start_time : datetime = datetime .now ()
227- total_gflops = 0
223+
228224 for split , data_loader in data_loaders :
229225
230226 # Ensure prediction folder exists
@@ -251,17 +247,8 @@ def main():
251247 labels = np .asarray ([labels ])
252248
253249 batch = {k : v .to (device ) for k , v in batch .items ()}
254- with profile (
255- activities = [ProfilerActivity .CPU , ProfilerActivity .CUDA ],
256- with_flops = True ,
257- ) as prof :
258- # Forward pass
259- cehrbert_output = cehrbert_model (** batch , output_attentions = False , output_hidden_states = False )
260-
261- for event in prof .key_averages ():
262- if hasattr (event , "flops" ) and event .flops > 0 :
263- # Convert to GFLOPs
264- total_gflops += event .flops / 1e9
250+ # Forward pass
251+ cehrbert_output = cehrbert_model (** batch , output_attentions = False , output_hidden_states = False )
265252
266253 cls_token_indices = batch ["input_ids" ] == cehrgpt_tokenizer .cls_token_index
267254 if cehrbert_args .sample_packing :
@@ -315,14 +302,6 @@ def main():
315302 features_pd ["gender_concept_id" ] = gender_concept_ids
316303 features_pd .to_parquet (feature_output_folder / f"{ uuid .uuid4 ()} .parquet" )
317304
318- # Save the training metrics to the output file
319- with open (training_metrics_file , "w" ) as output_file :
320- training_metrics = {
321- "duration_in_seconds" : (datetime .now () - start_time ).total_seconds (),
322- "total_flops" : total_gflops ,
323- }
324- json .dump (training_metrics , output_file )
325-
326305
327306if __name__ == "__main__" :
328307 main ()
0 commit comments