Skip to content

Commit 2d8afee

Browse files
committed
Put student in eval
1 parent dd76b26 commit 2d8afee

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

caduceus_distill/nt_eval.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,11 @@ def load_caduceus(
189189
else:
190190
model = StudentCaduceus.load_from_checkpoint(model_to_load)
191191

192+
# NOTE: this is really necessary just for the model loaded from a checkpoint,
193+
# but it doesn't hurt to do it for all models.
194+
model.eval()
195+
assert model.training is False, "Model should be in inference mode"
196+
192197
num_gpus = torch.cuda.device_count()
193198
# Sequential warm-up to prevent a race condition in the Triton kernel autotuner.
194199
# This is necessary when using nn.DataParallel with models that use Triton JIT,

0 commit comments

Comments
 (0)