|
11 | 11 | from shapely.prepared import prep
|
12 | 12 | from shapely.ops import unary_union, nearest_points
|
13 | 13 | from shapely.validation import explain_validity
|
| 14 | +from shapely import set_precision |
14 | 15 |
|
15 | 16 | from ocrd_modelfactory import page_from_file
|
16 | 17 | from ocrd_models.ocrd_page import (
|
@@ -931,37 +932,105 @@ def join_polygons(polygons, loc='', scale=20):
|
931 | 932 |
|
932 | 933 | def join_baselines(baselines, loc=''):
|
933 | 934 | LOG = getLogger('processor.OcropyResegment')
|
934 |
| - result = [] |
935 |
| - def add_baseline(baseline): |
936 |
| - nonlocal result |
937 |
| - base_x = [pt[0] for pt in result] |
938 |
| - base_left = min(base_x, default=0) |
939 |
| - base_right = max(base_x, default=0) |
940 |
| - left = baseline.bounds[0] |
941 |
| - right = baseline.bounds[2] |
942 |
| - if baseline.coords[0][0] > baseline.coords[-1][0]: |
943 |
| - baseline.coords = list(baseline.coords[::-1]) |
944 |
| - if left > base_right: |
945 |
| - result.extend(baseline.coords) |
946 |
| - elif right < base_left: |
947 |
| - result = list(baseline.coords) + result |
948 |
| - else: |
949 |
| - LOG.warning("baseline part crosses existing x in %s", loc) |
950 |
| - return |
951 |
| - assert all(p1[0] < p2[0] for p1, p2 in zip(result[:-1], result[1:])), result |
| 935 | + lines = [] |
952 | 936 | for baseline in baselines:
|
953 | 937 | if (baseline.is_empty or
|
954 | 938 | baseline.geom_type in ['Point', 'MultiPoint']):
|
955 | 939 | continue
|
956 |
| - if (baseline.geom_type == 'GeometryCollection' or |
957 |
| - baseline.geom_type.startswith('Multi')): |
| 940 | + elif baseline.geom_type == 'MultiLineString': |
| 941 | + lines.extend(baseline.geoms) |
| 942 | + elif baseline.geom_type == 'LineString': |
| 943 | + lines.append(baseline) |
| 944 | + elif baseline.geom_type == 'GeometryCollection': |
958 | 945 | for geom in baseline.geoms:
|
959 |
| - add_baseline(geom) |
960 |
| - continue |
961 |
| - add_baseline(baseline) |
962 |
| - if len(result) < 2: |
| 946 | + if geom.geom_type == 'LineString': |
| 947 | + lines.append(geom) |
| 948 | + elif geom.geom_type == 'MultiLineString': |
| 949 | + lines.extend(geom) |
| 950 | + else: |
| 951 | + LOG.warning("ignoring baseline subtype %s in %s", geom.geom_type, loc) |
| 952 | + else: |
| 953 | + LOG.warning("ignoring baseline type %s in %s", baseline.geom_type, loc) |
| 954 | + nlines = len(lines) |
| 955 | + if nlines == 0: |
| 956 | + return None |
| 957 | + elif nlines == 1: |
| 958 | + return lines[0] |
| 959 | + # Shapely cannot reorder: |
| 960 | + #result = line_merge(MultiLineString([line.normalize() for line in lines])) |
| 961 | + # find min-dist path through all lines (travelling salesman) |
| 962 | + pairs = itertools.combinations(range(nlines), 2) |
| 963 | + dists = np.eye(nlines, dtype=float) |
| 964 | + for i, j in pairs: |
| 965 | + dist = lines[i].distance(lines[j]) |
| 966 | + if dist < 1e-5: |
| 967 | + dist = 1e-5 # if pair merely touches, we still need to get an edge |
| 968 | + dists[i, j] = dist |
| 969 | + dists[j, i] = dist |
| 970 | + dists = minimum_spanning_tree(dists, overwrite=True) |
| 971 | + assert dists.nonzero()[0].size, dists |
| 972 | + # get path |
| 973 | + chains = [] |
| 974 | + for prevl, nextl in zip(*dists.nonzero()): |
| 975 | + foundchains = [] |
| 976 | + for chain in chains: |
| 977 | + if chain[0] == prevl: |
| 978 | + found = chain, 0, nextl |
| 979 | + elif chain[0] == nextl: |
| 980 | + found = chain, 0, prevl |
| 981 | + elif chain[-1] == prevl: |
| 982 | + found = chain, -1, nextl |
| 983 | + elif chain[-1] == nextl: |
| 984 | + found = chain, -1, prevl |
| 985 | + else: |
| 986 | + continue |
| 987 | + foundchains.append(found) |
| 988 | + if len(foundchains): |
| 989 | + assert len(foundchains) <= 2, foundchains |
| 990 | + chain, pos, node = foundchains.pop() |
| 991 | + if len(foundchains): |
| 992 | + otherchain, otherpos, othernode = foundchains.pop() |
| 993 | + assert node != othernode |
| 994 | + assert chain[pos] == othernode |
| 995 | + assert otherchain[otherpos] == node |
| 996 | + if pos < 0 and otherpos < 0: |
| 997 | + chain.extend(reversed(otherchain)) |
| 998 | + chains.remove(otherchain) |
| 999 | + elif pos < 0 and otherpos == 0: |
| 1000 | + chain.extend(otherchain) |
| 1001 | + chains.remove(otherchain) |
| 1002 | + elif pos == 0 and otherpos == 0: |
| 1003 | + otherchain.extend(reversed(chain)) |
| 1004 | + chains.remove(chain) |
| 1005 | + elif pos == 0 and otherpos < 0: |
| 1006 | + otherchain.extend(chain) |
| 1007 | + chains.remove(chain) |
| 1008 | + elif pos < 0: |
| 1009 | + chain.append(node) |
| 1010 | + else: |
| 1011 | + chain.insert(0, node) |
| 1012 | + else: |
| 1013 | + chains.append([prevl, nextl]) |
| 1014 | + if len(chains) > 1: |
| 1015 | + LOG.warning("baseline merge impossible (no spanning tree) in %s", loc) |
| 1016 | + return None |
| 1017 | + assert len(chains) == 1, chains |
| 1018 | + assert len(chains[0]) == nlines, chains[0] |
| 1019 | + path = chains[0] |
| 1020 | + # get points |
| 1021 | + coords = [] |
| 1022 | + for node in path: |
| 1023 | + line = lines[node] |
| 1024 | + coords.extend(line.normalize().coords) |
| 1025 | + result = LineString(coords) |
| 1026 | + if result.is_empty: |
| 1027 | + LOG.warning("baseline merge is empty in %s", loc) |
963 | 1028 | return None
|
964 |
| - return LineString(result) |
| 1029 | + assert result.geom_type == 'LineString', result.wkt |
| 1030 | + result = set_precision(result, 1.0) |
| 1031 | + if result.geom_type != 'LineString' or not result.is_valid: |
| 1032 | + result = LineString(np.round(line.coords)) |
| 1033 | + return result |
965 | 1034 |
|
966 | 1035 | def page_get_reading_order(ro, rogroup):
|
967 | 1036 | """Add all elements from the given reading order group to the given dictionary.
|
|
0 commit comments