Skip to content

Commit e75900f

Browse files
committed
allow GPU inference
1 parent 884b2bb commit e75900f

File tree

1 file changed

+13
-3
lines changed

1 file changed

+13
-3
lines changed

evaluation/model_eval.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -111,13 +111,15 @@ def __getitem__(self,idx):
111111
return self.inp_dicts[idx]
112112

113113

114-
def test(checkpoint, testing_csv_file):
114+
def test(checkpoint, testing_csv_file, device = 'cpu'):
115115
tokenizer = AutoTokenizer.from_pretrained("microsoft/graphcodebert-base")
116116
model = AutoModelForMaskedLM.from_pretrained("microsoft/graphcodebert-base")
117117
base_model = AutoModelForMaskedLM.from_pretrained("microsoft/graphcodebert-base")
118118
model.load_state_dict(torch.load(checkpoint))
119119
model.eval()
120120
base_model.eval()
121+
base_model.to(device)
122+
model.to(device)
121123
myDs=MyDataset(testing_csv_file,tokenizer)
122124
train_loader=DataLoader(myDs,batch_size=1,shuffle=False)
123125

@@ -169,7 +171,7 @@ def test(checkpoint, testing_csv_file):
169171
m_y = random.choice(var_list[num_sub_tokens_label-1])
170172
m_ty = tokenizer.encode(m_y)[1:-1]
171173
print("Mock truth:", m_y)
172-
# input_ids, att_mask = input_ids.to(device),att_mask.to(device)
174+
input_ids, att_mask = input_ids.to(device),att_mask.to(device)
173175
outputs = model(input_ids, attention_mask = att_mask)
174176
base_outputs = base_model(input_ids, attention_mask = att_mask)
175177
last_hidden_state = outputs[0].squeeze()
@@ -270,6 +272,7 @@ def test(checkpoint, testing_csv_file):
270272
def parse_arguments():
271273
parser = argparse.ArgumentParser(description="Testing the language model that was trained for identifier renaming.")
272274
parser.add_argument("--testing_csv_file", help="Path to CSV file containing testing data")
275+
parser.add_argument("--device", help="Device to train the model on (default: cpu)", choices=["cuda", "cpu"] , default="cpu")
273276
parser.add_argument("--checkpoint", help="Model checkpoint")
274277
return parser.parse_args()
275278

@@ -278,7 +281,14 @@ def main():
278281
args = parse_arguments()
279282
testing_csv_file = args.testing_csv_file
280283
checkpoint = args.checkpoint
281-
test(checkpoint, testing_csv_file)
284+
device = args.device
285+
if device == "cuda" and not torch.cuda.is_available():
286+
print("CUDA is not available on this device. Using CPU instead.")
287+
device = "cpu"
288+
else:
289+
print(f"Using {device} for training.")
290+
device = torch.device(device)
291+
test(checkpoint, testing_csv_file, device)
282292

283293

284294

0 commit comments

Comments
 (0)