Skip to content

Commit 320b39c

Browse files
committed
Fix eval and GH setup
1 parent 7d3d178 commit 320b39c

File tree

4 files changed

+10
-5
lines changed

4 files changed

+10
-5
lines changed

requirements.txt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,7 @@ onnx
1212
onnxruntime
1313
strenum
1414
tabulate
15-
pytest
15+
pytest
16+
scikit-learn
17+
torchmetrics
18+
pytorch-lightning

scripts/evaluate_pose_network.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -193,11 +193,13 @@ def compute_pred_keys(loader: dtr.SampleBySampleLoader, net: eval.InferenceNetwo
193193
def report(net_filename, data_name, roi_config: RoiConfig, args: argparse.Namespace, builder: TableBuilder):
194194
alignment: AlignmentScheme = args.alignment_scheme
195195

196-
loader = trackertraincode.pipelines.make_validation_loader(data_name, use_head_roi=roi_config.use_head_roi)
196+
loader = trackertraincode.pipelines.make_validation_loader(
197+
data_name, use_head_roi=roi_config.use_head_roi, return_single_samples=True
198+
)
197199
net = load_pose_network(net_filename, args.device)
198200

199201
pred_keys = compute_pred_keys(loader, net)
200-
predictor = eval.Predictor(net, roi_config.expansion_factor, keep_keys=pred_keys)
202+
predictor = eval.Predictor(net, roi_config.expansion_factor)
201203

202204
metrics = torchmetrics.MetricCollection({'pose_errs': eval.NormalizedXYSError()})
203205
if alignment == 'none':
@@ -250,7 +252,7 @@ def report(net_filename, data_name, roi_config: RoiConfig, args: argparse.Namesp
250252
return []
251253

252254
order = np.ascontiguousarray(np.argsort(quantity)[::-1])
253-
loader = trackertraincode.pipelines.make_validation_loader(data_name, order=order)
255+
loader = trackertraincode.pipelines.make_validation_loader(data_name, order=order, return_single_samples=True)
254256

255257
def iter_gt_and_preds():
256258
for sample in loader:

trackertraincode/eval.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ def normalize_crop_transform(self):
179179
@torch.no_grad()
180180
def predict_batch(self, images: List[Tensor], rois: Tensor):
181181
B = len(images)
182-
assert rois.shape == (B, 4), f"Bad roi shape: {rois.shape}"
182+
assert rois.shape == (B, 4), f"Bad roi shape: {rois.shape}, expected {(B,4)}"
183183
device = images[-1].device
184184
batch = [self._create_sample(i, r) for i, r in zip(images, rois)]
185185
batch = Batch.collate(batch)

0 commit comments

Comments
 (0)