Skip to content

Commit 3d7d918

Browse files
committed
avoid gradients tracking, for faster inference and less memory consumption
1 parent b355319 commit 3d7d918

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

chebifier/prediction_models/nn_predictor.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ def __init__(
2222

2323
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
2424
self.model = self.init_model(ckpt_path=ckpt_path)
25+
self.model.eval()
26+
2527
self.target_labels = [
2628
line.strip() for line in open(target_labels_path, encoding="utf-8")
2729
]
@@ -32,6 +34,7 @@ def init_model(self, ckpt_path: str, **kwargs):
3234
"Model initialization must be implemented in subclasses."
3335
)
3436

37+
@torch.inference_mode()
3538
def calculate_results(self, batch):
3639
collator = self.reader_cls.COLLATOR()
3740
dat = self.model._process_batch(collator(batch).to(self.device), 0)

0 commit comments

Comments
 (0)