Skip to content

Commit 6f9db5b

Browse files
author
nullptr
committed
fix: rtmdet infer
1 parent 2d49d9b commit 6f9db5b

File tree

2 files changed

+10
-6
lines changed

2 files changed

+10
-6
lines changed

sscma/deploy/backend/tflite_infer.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ def infer(self, input_data):
3939
data = (data / scale + zero_point).astype(
4040
input["dtype"]
4141
) # de-scale
42-
4342
self.interpreter.set_tensor(input["index"], data)
4443
self.interpreter.invoke()
4544
y = []
@@ -49,7 +48,7 @@ def infer(self, input_data):
4948
scale, zero_point = output["quantization"]
5049
x = (x.astype(np.float32) - zero_point) * scale # re-scale
5150
# numpy x convert NHWC to NCWH
52-
y.append(np.transpose(x, [0, 3, 1, 2]))
51+
y.append(np.transpose(x, [0, 3, 1, 2]) if len(x.shape) == 4 else x)
5352

5453
results.append(y)
5554
return results

sscma/deploy/models/rtmdet_infer.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import copy
2+
import math
23
import warnings
34
from typing import Dict, List, Tuple, Union
45

@@ -104,13 +105,17 @@ def _predict(
104105
for dt in data_tmp:
105106
tmp = [None for _ in range(6)]
106107
for d in dt:
107-
if d.shape[2:] in featmap_size:
108-
if d.shape[1] == 4:
109-
tmp[3 + featmap_size.index(d.shape[2:])] = d
108+
fs = int(math.sqrt(d.shape[1]))
109+
ts = (fs, fs)
110+
if ts in featmap_size:
111+
if d.shape[2] == 4:
112+
tmp[3 + featmap_size.index(ts)] = d
110113
else:
111-
tmp[featmap_size.index(d.shape[2:])] = d
114+
tmp[featmap_size.index(ts)] = d
112115

113116
data.append(tmp)
117+
118+
114119
for result, data_sample in zip(data, batch_data_samples):
115120
# check item in result is tensor or numpy
116121

0 commit comments

Comments
 (0)