Skip to content

Commit 338b840

Browse files
author
Robert Sachunsky
committed
re/segment join_baselines: adapt to Shapely, improve
1 parent 4673d9b commit 338b840

File tree

1 file changed

+94
-25
lines changed

1 file changed

+94
-25
lines changed

ocrd_cis/ocropy/segment.py

Lines changed: 94 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from shapely.prepared import prep
1212
from shapely.ops import unary_union, nearest_points
1313
from shapely.validation import explain_validity
14+
from shapely import set_precision
1415

1516
from ocrd_modelfactory import page_from_file
1617
from ocrd_models.ocrd_page import (
@@ -931,37 +932,105 @@ def join_polygons(polygons, loc='', scale=20):
931932

932933
def join_baselines(baselines, loc=''):
933934
LOG = getLogger('processor.OcropyResegment')
934-
result = []
935-
def add_baseline(baseline):
936-
nonlocal result
937-
base_x = [pt[0] for pt in result]
938-
base_left = min(base_x, default=0)
939-
base_right = max(base_x, default=0)
940-
left = baseline.bounds[0]
941-
right = baseline.bounds[2]
942-
if baseline.coords[0][0] > baseline.coords[-1][0]:
943-
baseline.coords = list(baseline.coords[::-1])
944-
if left > base_right:
945-
result.extend(baseline.coords)
946-
elif right < base_left:
947-
result = list(baseline.coords) + result
948-
else:
949-
LOG.warning("baseline part crosses existing x in %s", loc)
950-
return
951-
assert all(p1[0] < p2[0] for p1, p2 in zip(result[:-1], result[1:])), result
935+
lines = []
952936
for baseline in baselines:
953937
if (baseline.is_empty or
954938
baseline.geom_type in ['Point', 'MultiPoint']):
955939
continue
956-
if (baseline.geom_type == 'GeometryCollection' or
957-
baseline.geom_type.startswith('Multi')):
940+
elif baseline.geom_type == 'MultiLineString':
941+
lines.extend(baseline.geoms)
942+
elif baseline.geom_type == 'LineString':
943+
lines.append(baseline)
944+
elif baseline.geom_type == 'GeometryCollection':
958945
for geom in baseline.geoms:
959-
add_baseline(geom)
960-
continue
961-
add_baseline(baseline)
962-
if len(result) < 2:
946+
if geom.geom_type == 'LineString':
947+
lines.append(geom)
948+
elif geom.geom_type == 'MultiLineString':
949+
lines.extend(geom)
950+
else:
951+
LOG.warning("ignoring baseline subtype %s in %s", geom.geom_type, loc)
952+
else:
953+
LOG.warning("ignoring baseline type %s in %s", baseline.geom_type, loc)
954+
nlines = len(lines)
955+
if nlines == 0:
956+
return None
957+
elif nlines == 1:
958+
return lines[0]
959+
# Shapely cannot reorder:
960+
#result = line_merge(MultiLineString([line.normalize() for line in lines]))
961+
# find min-dist path through all lines (travelling salesman)
962+
pairs = itertools.combinations(range(nlines), 2)
963+
dists = np.eye(nlines, dtype=float)
964+
for i, j in pairs:
965+
dist = lines[i].distance(lines[j])
966+
if dist < 1e-5:
967+
dist = 1e-5 # if pair merely touches, we still need to get an edge
968+
dists[i, j] = dist
969+
dists[j, i] = dist
970+
dists = minimum_spanning_tree(dists, overwrite=True)
971+
assert dists.nonzero()[0].size, dists
972+
# get path
973+
chains = []
974+
for prevl, nextl in zip(*dists.nonzero()):
975+
foundchains = []
976+
for chain in chains:
977+
if chain[0] == prevl:
978+
found = chain, 0, nextl
979+
elif chain[0] == nextl:
980+
found = chain, 0, prevl
981+
elif chain[-1] == prevl:
982+
found = chain, -1, nextl
983+
elif chain[-1] == nextl:
984+
found = chain, -1, prevl
985+
else:
986+
continue
987+
foundchains.append(found)
988+
if len(foundchains):
989+
assert len(foundchains) <= 2, foundchains
990+
chain, pos, node = foundchains.pop()
991+
if len(foundchains):
992+
otherchain, otherpos, othernode = foundchains.pop()
993+
assert node != othernode
994+
assert chain[pos] == othernode
995+
assert otherchain[otherpos] == node
996+
if pos < 0 and otherpos < 0:
997+
chain.extend(reversed(otherchain))
998+
chains.remove(otherchain)
999+
elif pos < 0 and otherpos == 0:
1000+
chain.extend(otherchain)
1001+
chains.remove(otherchain)
1002+
elif pos == 0 and otherpos == 0:
1003+
otherchain.extend(reversed(chain))
1004+
chains.remove(chain)
1005+
elif pos == 0 and otherpos < 0:
1006+
otherchain.extend(chain)
1007+
chains.remove(chain)
1008+
elif pos < 0:
1009+
chain.append(node)
1010+
else:
1011+
chain.insert(0, node)
1012+
else:
1013+
chains.append([prevl, nextl])
1014+
if len(chains) > 1:
1015+
LOG.warning("baseline merge impossible (no spanning tree) in %s", loc)
1016+
return None
1017+
assert len(chains) == 1, chains
1018+
assert len(chains[0]) == nlines, chains[0]
1019+
path = chains[0]
1020+
# get points
1021+
coords = []
1022+
for node in path:
1023+
line = lines[node]
1024+
coords.extend(line.normalize().coords)
1025+
result = LineString(coords)
1026+
if result.is_empty:
1027+
LOG.warning("baseline merge is empty in %s", loc)
9631028
return None
964-
return LineString(result)
1029+
assert result.geom_type == 'LineString', result.wkt
1030+
result = set_precision(result, 1.0)
1031+
if result.geom_type != 'LineString' or not result.is_valid:
1032+
result = LineString(np.round(line.coords))
1033+
return result
9651034

9661035
def page_get_reading_order(ro, rogroup):
9671036
"""Add all elements from the given reading order group to the given dictionary.

0 commit comments

Comments
 (0)