Skip to content

Commit dc65c0e

Browse files
author
Tavian Barnes
committed
infer: Use grapheme clusters rather than down-weighting combining marks
This makes it perform better on things like Zalgo text with a great deal of combining marks.
1 parent 67ce972 commit dc65c0e

File tree

2 files changed

+40
-31
lines changed

2 files changed

+40
-31
lines changed

python/bistring/_infer.py

Lines changed: 36 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -9,37 +9,43 @@
99

1010
from ._alignment import Alignment
1111
from ._bistr import bistr
12+
from ._token import CharacterTokenizer
1213

1314

1415
@dataclass(frozen=True)
1516
class AugmentedChar:
1617
"""
17-
A single character (code point) augmented with extra information.
18+
A single character (grapheme cluster) augmented with extra information.
1819
"""
1920

20-
folded: str
21+
top_category: str
2122
"""
22-
The case-folded form of the char.
23+
The top-level Unicode category of the char (L, P, Z, etc.).
2324
"""
2425

25-
normalized: str
26+
category: str
2627
"""
27-
The Unicode compatibility normalized form of the char.
28+
The specific Unicode category of the char (Lu, Po, Zs, etc.).
2829
"""
2930

30-
original: str
31+
root: str
3132
"""
32-
The original form of the char.
33+
The root code point of the grapheme cluster.
3334
"""
3435

35-
top_category: str
36+
folded: str
3637
"""
37-
The top-level Unicode category of the char (L, P, Z, etc.).
38+
The case-folded form of the char.
3839
"""
3940

40-
category: str
41+
normalized: str
4142
"""
42-
The specific Unicode category of the char (Lu, Po, Zs, etc.).
43+
The Unicode compatibility normalized form of the char.
44+
"""
45+
46+
original: str
47+
"""
48+
The original form of the char.
4349
"""
4450

4551
@classmethod
@@ -49,29 +55,22 @@ def cost_fn(cls, a: Optional[AugmentedChar], b: Optional[AugmentedChar]) -> int:
4955
"""
5056

5157
if a is None or b is None:
52-
if a:
53-
top_category = a.top_category
54-
elif b:
55-
top_category = b.top_category
56-
else:
57-
assert False, 'Unreachable'
58-
59-
if top_category == 'M':
60-
# Less penalty for combining marks
61-
return 1
62-
else:
63-
# cost(insert) + cost(delete) (3 + 3) should be more than cost(substitute) (5)
64-
return 3
58+
# cost(insert) + cost(delete) (4 + 4) should be more than cost(substitute) (6)
59+
return 4
6560

6661
result = 0
62+
result += int(a.top_category != b.top_category)
63+
result += int(a.category != b.category)
64+
result += int(a.root != b.root)
6765
result += int(a.folded != b.folded)
6866
result += int(a.normalized != b.normalized)
6967
result += int(a.original != b.original)
70-
result += int(a.top_category != b.top_category)
71-
result += int(a.category != b.category)
7268
return result
7369

7470

71+
TOKENIZER = CharacterTokenizer('root')
72+
73+
7574
@dataclass(frozen=True)
7675
class AugmentedString:
7776
"""
@@ -97,21 +96,27 @@ class AugmentedString:
9796
def augment(cls, original: str) -> AugmentedString:
9897
normalized = bistr(original).normalize('NFKD')
9998
folded = bistr(normalized.modified).casefold()
99+
glyphs = TOKENIZER.tokenize(folded)
100100

101101
chars = []
102-
for i, fold_c in enumerate(folded):
103-
norm_slice = folded.alignment.original_slice(i, i + 1)
102+
for glyph in glyphs:
103+
fold_c = glyph.text.modified
104+
root = fold_c[0]
105+
106+
norm_slice = folded.alignment.original_slice(glyph.start, glyph.end)
104107
norm_c = folded.original[norm_slice]
105108

106109
orig_slice = normalized.alignment.original_slice(norm_slice)
107110
orig_c = normalized.original[orig_slice]
108111

109-
cat = unicodedata.category(fold_c)
112+
cat = unicodedata.category(root)
110113
top_cat = cat[0]
111114

112-
chars.append(AugmentedChar(fold_c, norm_c, orig_c, top_cat, cat))
115+
chars.append(AugmentedChar(top_cat, cat, root, fold_c, norm_c, orig_c))
113116

114-
alignment = normalized.alignment.compose(folded.alignment)
117+
alignment = normalized.alignment
118+
alignment = alignment.compose(folded.alignment)
119+
alignment = alignment.compose(glyphs.alignment)
115120
return cls(original, chars, alignment)
116121

117122

python/tests/test_bistr.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,10 @@ def test_infer():
6565
assert bs[40:43].original == '🐶'
6666
assert bs[40:43].modified == 'dog'
6767

68+
bs = bistr.infer('Z̴̡̪̫̖̥̔̿̃̈̏̎͠͝á̸̪̠̖̻̬̖̪̞͙͇̮̠͎̆͋́̐͌̒͆̓l̶͉̭̳̤̬̮̩͎̟̯̜͇̥̠̘͑͐̌͂̄́̀̂̌̈͛̊̄̚͜ģ̸̬̼̞̙͇͕͎̌̾̒̐̿̎̆̿̌̃̏̌́̾̈͘͜o̶̢̭͕͔̩͐ ̴̡̡̜̥̗͔̘̦͉̣̲͚͙̐̈́t̵͈̰̉̀͒̎̈̿̔̄̽͑͝͠ẹ̵̫̲̫̄͜͜x̵͕̳͈̝̤̭̼̼̻͓̿̌̽̂̆̀̀̍̒͐́̈̀̚͝t̸̡̨̥̺̣̟͎̝̬̘̪͔͆́̄̅̚', 'Zalgo text')
69+
for i, c in enumerate(bs):
70+
assert bs[i:i+1].original.startswith(c)
71+
6872

6973
def test_concat():
7074
bs = bistr(' ', '')

0 commit comments

Comments
 (0)