Skip to content

Commit b01cbe1

Browse files
authored
Refined sanity_check() to handle tuple outputs from the model evaluation. (#71) (#72)
1 parent 484a4b6 commit b01cbe1

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

mipcandy/sanity_check.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,5 +37,6 @@ def sanity_check(model: nn.Module, input_shape: Sequence[int], *, device: Device
3737
num_macs, num_params, layer_stats = model_complexity_info(model, input_shape)
3838
if num_macs is None or num_params is None:
3939
raise RuntimeError("Failed to validate model")
40-
output = model.to(device).eval()(torch.randn(1, *input_shape, device=device)).squeeze(0)
41-
return SanityCheckResult(num_macs, num_params, layer_stats, output)
40+
outputs = model.to(device).eval()(torch.randn(1, *input_shape, device=device))
41+
return SanityCheckResult(num_macs, num_params, layer_stats, (
42+
outputs[0] if isinstance(outputs, tuple) else outputs).squeeze(0))

0 commit comments

Comments
 (0)