We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent dd76b26 commit 2d8afeeCopy full SHA for 2d8afee
caduceus_distill/nt_eval.py
@@ -189,6 +189,11 @@ def load_caduceus(
189
else:
190
model = StudentCaduceus.load_from_checkpoint(model_to_load)
191
192
+ # NOTE: this is really necessary just for the model loaded from a checkpoint,
193
+ # but it doesn't hurt to do it for all models.
194
+ model.eval()
195
+ assert model.training is False, "Model should be in inference mode"
196
+
197
num_gpus = torch.cuda.device_count()
198
# Sequential warm-up to prevent a race condition in the Triton kernel autotuner.
199
# This is necessary when using nn.DataParallel with models that use Triton JIT,
0 commit comments