99
1010from ._alignment import Alignment
1111from ._bistr import bistr
12+ from ._token import CharacterTokenizer
1213
1314
1415@dataclass (frozen = True )
1516class AugmentedChar :
1617 """
17- A single character (code point ) augmented with extra information.
18+ A single character (grapheme cluster ) augmented with extra information.
1819 """
1920
20- folded : str
21+ top_category : str
2122 """
22- The case-folded form of the char.
23+ The top-level Unicode category of the char (L, P, Z, etc.) .
2324 """
2425
25- normalized : str
26+ category : str
2627 """
27- The Unicode compatibility normalized form of the char.
28+ The specific Unicode category of the char (Lu, Po, Zs, etc.) .
2829 """
2930
30- original : str
31+ root : str
3132 """
32- The original form of the char .
33+ The root code point of the grapheme cluster .
3334 """
3435
35- top_category : str
36+ folded : str
3637 """
37- The top-level Unicode category of the char (L, P, Z, etc.) .
38+ The case-folded form of the char.
3839 """
3940
40- category : str
41+ normalized : str
4142 """
42- The specific Unicode category of the char (Lu, Po, Zs, etc.).
43+ The Unicode compatibility normalized form of the char.
44+ """
45+
46+ original : str
47+ """
48+ The original form of the char.
4349 """
4450
4551 @classmethod
@@ -49,29 +55,22 @@ def cost_fn(cls, a: Optional[AugmentedChar], b: Optional[AugmentedChar]) -> int:
4955 """
5056
5157 if a is None or b is None :
52- if a :
53- top_category = a .top_category
54- elif b :
55- top_category = b .top_category
56- else :
57- assert False , 'Unreachable'
58-
59- if top_category == 'M' :
60- # Less penalty for combining marks
61- return 1
62- else :
63- # cost(insert) + cost(delete) (3 + 3) should be more than cost(substitute) (5)
64- return 3
58+ # cost(insert) + cost(delete) (4 + 4) should be more than cost(substitute) (6)
59+ return 4
6560
6661 result = 0
62+ result += int (a .top_category != b .top_category )
63+ result += int (a .category != b .category )
64+ result += int (a .root != b .root )
6765 result += int (a .folded != b .folded )
6866 result += int (a .normalized != b .normalized )
6967 result += int (a .original != b .original )
70- result += int (a .top_category != b .top_category )
71- result += int (a .category != b .category )
7268 return result
7369
7470
71+ TOKENIZER = CharacterTokenizer ('root' )
72+
73+
7574@dataclass (frozen = True )
7675class AugmentedString :
7776 """
@@ -97,21 +96,27 @@ class AugmentedString:
9796 def augment (cls , original : str ) -> AugmentedString :
9897 normalized = bistr (original ).normalize ('NFKD' )
9998 folded = bistr (normalized .modified ).casefold ()
99+ glyphs = TOKENIZER .tokenize (folded )
100100
101101 chars = []
102- for i , fold_c in enumerate (folded ):
103- norm_slice = folded .alignment .original_slice (i , i + 1 )
102+ for glyph in glyphs :
103+ fold_c = glyph .text .modified
104+ root = fold_c [0 ]
105+
106+ norm_slice = folded .alignment .original_slice (glyph .start , glyph .end )
104107 norm_c = folded .original [norm_slice ]
105108
106109 orig_slice = normalized .alignment .original_slice (norm_slice )
107110 orig_c = normalized .original [orig_slice ]
108111
109- cat = unicodedata .category (fold_c )
112+ cat = unicodedata .category (root )
110113 top_cat = cat [0 ]
111114
112- chars .append (AugmentedChar (fold_c , norm_c , orig_c , top_cat , cat ))
115+ chars .append (AugmentedChar (top_cat , cat , root , fold_c , norm_c , orig_c ))
113116
114- alignment = normalized .alignment .compose (folded .alignment )
117+ alignment = normalized .alignment
118+ alignment = alignment .compose (folded .alignment )
119+ alignment = alignment .compose (glyphs .alignment )
115120 return cls (original , chars , alignment )
116121
117122
0 commit comments