Skip to content

Commit fecd8ed

Browse files
authored
fix pse_ctw1500 training error (#806)
* fix pse_ctw1500 training error * fix pse_ctw1500 training error * fix pse_ctw1500 training error * ci fix
1 parent 8e81c14 commit fecd8ed

File tree

4 files changed

+30
-19
lines changed

4 files changed

+30
-19
lines changed

mindocr/data/transforms/det_transforms.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
import pyclipper
1414
from shapely.geometry import Polygon, box
1515

16+
from ...utils.misc import is_uneven_nested_list
17+
1618
__all__ = [
1719
"DetLabelEncode",
1820
"BorderMap",
@@ -595,9 +597,14 @@ def _shrink(self, text_polys, rate, max_shr=20):
595597
if not shrinked_bbox:
596598
shrinked_text_polys.append(bbox)
597599
continue
598-
599-
shrinked_bbox = np.array(shrinked_bbox)[0]
600-
shrinked_bbox = np.array(shrinked_bbox)
600+
if is_uneven_nested_list(shrinked_bbox):
601+
shrinked_bbox = np.array(shrinked_bbox, dtype=object)[0]
602+
else:
603+
shrinked_bbox = np.array(shrinked_bbox)[0]
604+
if is_uneven_nested_list(shrinked_bbox):
605+
shrinked_bbox = np.array(shrinked_bbox, dtype=object)
606+
else:
607+
shrinked_bbox = np.array(shrinked_bbox)
601608
if shrinked_bbox.shape[0] <= 2:
602609
shrinked_text_polys.append(bbox)
603610
continue

mindocr/losses/det_loss.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ def __init__(self, alpha=0.7, ohem_ratio=3):
221221
self.zeros_like = ops.ZerosLike()
222222
self.add = ops.Add()
223223
self.gather = ops.Gather()
224-
self.upsample = nn.ResizeBilinear()
224+
self.upsample = ops.interpolate
225225

226226
def ohem_batch(self, scores, gt_texts, training_masks):
227227
"""
@@ -334,8 +334,10 @@ def construct(self, model_predict, gt_texts, gt_kernels, training_masks):
334334
Tensor: The computed loss value.
335335
"""
336336
batch_size = model_predict.shape[0]
337-
model_predict = self.upsample(model_predict, scale_factor=4)
338-
h, w = model_predict.shape[2:]
337+
scale_factor = 4
338+
origin_h, origin_w = model_predict.shape[2:]
339+
h, w = origin_h * scale_factor, origin_w * scale_factor
340+
model_predict = self.upsample(model_predict, size=(h, w), mode="bilinear")
339341
texts = self.slice(model_predict, (0, 0, 0, 0), (batch_size, 1, h, w))
340342
texts = self.reshape(texts, (batch_size, h, w))
341343
selected_masks_text = self.ohem_batch(texts, gt_texts, training_masks)

mindocr/postprocess/det_db_postprocess.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from mindspore import Tensor
88

99
from ..data.transforms.det_transforms import expand_poly
10+
from ..utils.misc import is_uneven_nested_list
1011
from .det_base_postprocess import DetBasePostprocess
1112

1213
__all__ = ["DBPostprocess"]
@@ -111,7 +112,7 @@ def _extract_preds(self, pred: np.ndarray, bitmap: np.ndarray):
111112

112113
poly = Polygon(points)
113114
poly_list = expand_poly(points, distance=poly.area * self._expand_ratio / poly.length)
114-
if self._is_uneven_nested_list(poly_list):
115+
if is_uneven_nested_list(poly_list):
115116
poly = np.array(poly_list, dtype=object)
116117
else:
117118
poly = np.array(poly_list)
@@ -138,18 +139,6 @@ def _extract_preds(self, pred: np.ndarray, bitmap: np.ndarray):
138139
return polys, scores
139140
return np.array(polys), np.array(scores).astype(np.float32)
140141

141-
def _is_uneven_nested_list(self, arr_list):
142-
if not isinstance(arr_list, list):
143-
return False
144-
145-
first_length = len(arr_list[0]) if isinstance(arr_list[0], list) else None
146-
147-
for sublist in arr_list:
148-
if not isinstance(sublist, list) or len(sublist) != first_length:
149-
return True
150-
151-
return False
152-
153142
@staticmethod
154143
def _fit_box(contour):
155144
"""

mindocr/utils/misc.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,3 +73,16 @@ def is_ms_version_2():
7373
make compatibilities in differenct Mindspore version
7474
"""
7575
return version.parse(ms.__version__) >= version.parse("2.0.0rc")
76+
77+
78+
def is_uneven_nested_list(arr_list):
79+
if not isinstance(arr_list, list):
80+
return False
81+
82+
first_length = len(arr_list[0]) if isinstance(arr_list[0], list) else None
83+
84+
for sublist in arr_list:
85+
if not isinstance(sublist, list) or len(sublist) != first_length:
86+
return True
87+
88+
return False

0 commit comments

Comments
 (0)