Skip to content

Commit 21b9c74

Browse files
Fix CI
1 parent b64126b commit 21b9c74

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

Tests/TestCCTPretrained.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,7 @@ def deepQuantTestCCT():
296296
# FBRANCASI: Load original floating point model
297297
originalModel = cct_2_3x2_32()
298298
checkpointPath = "./Tests/Data/checkpoint_epoch_200_cct2_cifar10.pth"
299-
checkpoint = torch.load(checkpointPath, map_location="cpu")
299+
checkpoint = torch.load(checkpointPath, map_location="cpu", weights_only=False)
300300
originalModel.load_state_dict(checkpoint["model_state_dict"])
301301
originalModel = originalModel.eval().to(device)
302302
print("Original CCT-2 loaded from checkpoint.")
@@ -316,7 +316,7 @@ def deepQuantTestCCT():
316316
FQAccuracy = evaluateModel(FQModel, valLoader, device, "FQ CCT-2")
317317

318318
sampleInput = torch.randn(1, 3, 32, 32).to("cpu")
319-
319+
320320
# FBRANCASI: Override the injectCustomForwards function in the module before DeepQuant.Export imports it
321321
import DeepQuant.Pipeline.Injection as injection_module
322322

0 commit comments

Comments
 (0)