|
19 | 19 | from pathlib import Path |
20 | 20 |
|
21 | 21 | import motmetrics as mm |
| 22 | +# Andreu |
| 23 | +from motmetrics.utils import is_in_region |
22 | 24 |
|
23 | 25 |
|
24 | 26 | def parse_args(): |
@@ -67,13 +69,25 @@ def parse_args(): |
67 | 69 | return parser.parse_args() |
68 | 70 |
|
69 | 71 |
|
| 72 | +# Andreu |
| 73 | +def filter_dets_by_zone(gt_df, test_df): |
| 74 | + ig_region = gt_df[['X_reg', 'Y_reg', 'W_reg', 'H_reg']].dropna().values |
| 75 | + dets = test_df[['X', 'Y', 'Width', 'Height']] |
| 76 | + for idx, row in dets.iterrows(): |
| 77 | + bbox = [dets.at[idx, 'X'], dets.at[idx, 'Y'], dets.at[idx, 'Width'], dets.at[idx, 'Height']] |
| 78 | + for reg in ig_region: |
| 79 | + if is_in_region(bbox, reg): |
| 80 | + test_df.drop([idx], axis=0, inplace=True) |
| 81 | + break |
| 82 | + |
70 | 83 | def compare_dataframes(gts, ts): |
71 | 84 | """Builds accumulator for each sequence.""" |
72 | 85 | accs = [] |
73 | 86 | names = [] |
74 | 87 | for k, tsacc in ts.items(): |
75 | 88 | if k in gts: |
76 | 89 | logging.info('Comparing %s...', k) |
| 90 | + filter_dets_by_zone(gts[k], tsacc) |
77 | 91 | accs.append(mm.utils.compare_to_groundtruth(gts[k], tsacc, 'iou', distth=0.5)) |
78 | 92 | names.append(k) |
79 | 93 | else: |
@@ -107,6 +121,17 @@ def main(): |
107 | 121 | gt = OrderedDict([(os.path.splitext(Path(f).parts[-1])[0], mm.io.loadtxt(f, fmt=args.gtfmt)) for f in gtfiles]) |
108 | 122 | ts = OrderedDict([(os.path.splitext(Path(f).parts[-1])[0], mm.io.loadtxt(f, fmt=args.tsfmt)) for f in tsfiles]) |
109 | 123 |
|
| 124 | + # # Debug |
| 125 | + # # f = gtfiles[1] |
| 126 | + # # seq = os.path.splitext(Path(f).parts[-1])[0] |
| 127 | + # seq = 'MVI_39511' |
| 128 | + # f_gt = gtfiles[0].split('/')[:-1] |
| 129 | + # f = '/'.join(f_gt) + '/' + seq + '.xml' |
| 130 | + # gt = OrderedDict([(seq, mm.io.loadtxt(f, fmt=args.gtfmt))]) |
| 131 | + # f_test = tsfiles[0].split('/')[:-1] |
| 132 | + # f = '/'.join(f_test) + '/' + seq + '.txt' |
| 133 | + # ts = OrderedDict([(seq, mm.io.loadtxt(f, fmt=args.tsfmt))]) |
| 134 | + |
110 | 135 | mh = mm.metrics.create() |
111 | 136 | accs, names = compare_dataframes(gt, ts) |
112 | 137 |
|
|
0 commit comments