@@ -111,13 +111,15 @@ def __getitem__(self,idx):
111111 return self .inp_dicts [idx ]
112112
113113
114- def test (checkpoint , testing_csv_file ):
114+ def test (checkpoint , testing_csv_file , device = 'cpu' ):
115115 tokenizer = AutoTokenizer .from_pretrained ("microsoft/graphcodebert-base" )
116116 model = AutoModelForMaskedLM .from_pretrained ("microsoft/graphcodebert-base" )
117117 base_model = AutoModelForMaskedLM .from_pretrained ("microsoft/graphcodebert-base" )
118118 model .load_state_dict (torch .load (checkpoint ))
119119 model .eval ()
120120 base_model .eval ()
121+ base_model .to (device )
122+ model .to (device )
121123 myDs = MyDataset (testing_csv_file ,tokenizer )
122124 train_loader = DataLoader (myDs ,batch_size = 1 ,shuffle = False )
123125
@@ -169,7 +171,7 @@ def test(checkpoint, testing_csv_file):
169171 m_y = random .choice (var_list [num_sub_tokens_label - 1 ])
170172 m_ty = tokenizer .encode (m_y )[1 :- 1 ]
171173 print ("Mock truth:" , m_y )
172- # input_ids, att_mask = input_ids.to(device),att_mask.to(device)
174+ input_ids , att_mask = input_ids .to (device ),att_mask .to (device )
173175 outputs = model (input_ids , attention_mask = att_mask )
174176 base_outputs = base_model (input_ids , attention_mask = att_mask )
175177 last_hidden_state = outputs [0 ].squeeze ()
@@ -270,6 +272,7 @@ def test(checkpoint, testing_csv_file):
270272def parse_arguments ():
271273 parser = argparse .ArgumentParser (description = "Testing the language model that was trained for identifier renaming." )
272274 parser .add_argument ("--testing_csv_file" , help = "Path to CSV file containing testing data" )
275+ parser .add_argument ("--device" , help = "Device to train the model on (default: cpu)" , choices = ["cuda" , "cpu" ] , default = "cpu" )
273276 parser .add_argument ("--checkpoint" , help = "Model checkpoint" )
274277 return parser .parse_args ()
275278
@@ -278,7 +281,14 @@ def main():
278281 args = parse_arguments ()
279282 testing_csv_file = args .testing_csv_file
280283 checkpoint = args .checkpoint
281- test (checkpoint , testing_csv_file )
284+ device = args .device
285+ if device == "cuda" and not torch .cuda .is_available ():
286+ print ("CUDA is not available on this device. Using CPU instead." )
287+ device = "cpu"
288+ else :
289+ print (f"Using { device } for training." )
290+ device = torch .device (device )
291+ test (checkpoint , testing_csv_file , device )
282292
283293
284294
0 commit comments