1010
1111from pipeline .common .datasets import Statistics
1212from pipeline .common .downloads import read_lines , write_lines
13+ from pipeline .common .logging import get_logger
1314
15+ logger = get_logger (__file__ )
1416
1517CJK_LANGS = ["zh" , "ja" , "ko" ]
1618
@@ -26,10 +28,13 @@ class ConversionStep(Statistics):
2628 When converting data, count how many sentences were converted, and how many were visited.
2729 """
2830
29- def __init__ (self , description : str , converted = 0 , dataset_path : Optional [Path ] = None ) -> None :
31+ def __init__ (
32+ self , description : str , converted = 0 , filtered = 0 , dataset_path : Optional [Path ] = None
33+ ) -> None :
3034 super ().__init__ (dataset_path )
3135 self .description = description
3236 self .converted = converted
37+ self .filtered = filtered
3338 self .visited = 0
3439
3540
@@ -38,7 +43,7 @@ def __init__(self, dataset_path: Path, script: ChineseType) -> None:
3843 super ().__init__ (dataset_path )
3944 self .script = script
4045 self .script_conversion = ConversionStep (
41- f"How many sentences in the dataset were converted to { script .name } " ,
46+ f"How many sentences in the dataset were converted to { script .name } or filtered " ,
4247 )
4348
4449
@@ -50,6 +55,9 @@ def __init__(self):
5055 def convert_file (
5156 self , input_path : Path , output_path : Path , to : ChineseType
5257 ) -> DatasetStatistics :
58+ """
59+ Convert all lines to one variant of Chinese
60+ """
5361 stats = DatasetStatistics (output_path , to )
5462 with write_lines (output_path ) as out_file , read_lines (input_path ) as lines :
5563 for line in lines :
@@ -63,6 +71,51 @@ def convert_file(
6371 out_file .write (new_line )
6472 return stats
6573
74+ def filter_file (self , input_path : Path , output_path : Path , variant : ChineseType ):
75+ """
76+ Filter everything except the specified variant of Chinese
77+ """
78+ stats = DatasetStatistics (output_path , variant )
79+ with write_lines (output_path ) as out_file , read_lines (input_path ) as lines :
80+ for line in lines :
81+ stats .script_conversion .visited += 1
82+ ch_type = self ._detect (line )
83+ if ch_type == variant :
84+ out_file .write (line )
85+ else :
86+ stats .script_conversion .filtered += 1
87+
88+ return stats
89+
90+ def filter_parallel_corpus (
91+ self ,
92+ zh_path : Path ,
93+ other_path : Path ,
94+ zh_output_path : Path ,
95+ other_output_path : Path ,
96+ variant : ChineseType ,
97+ ):
98+ """
99+ Filter everything except the specified variant of Chinese in a parallel corpus
100+ """
101+ stats = DatasetStatistics (zh_output_path , variant )
102+ with (
103+ write_lines (zh_output_path ) as zh_out_file ,
104+ write_lines (other_output_path ) as other_out_file ,
105+ read_lines (zh_path ) as zh_lines ,
106+ read_lines (other_path ) as other_lines ,
107+ ):
108+ for zh_line , other_line in zip (zh_lines , other_lines ):
109+ stats .script_conversion .visited += 1
110+ ch_type = self ._detect (zh_line )
111+ if ch_type == variant :
112+ zh_out_file .write (zh_line )
113+ other_out_file .write (other_line )
114+ else :
115+ stats .script_conversion .filtered += 1
116+
117+ return stats
118+
66119 @staticmethod
67120 def _detect (text ) -> ChineseType :
68121 res = hanzidentifier .identify (text )
@@ -80,3 +133,56 @@ def _convert_line(self, text: str, to: ChineseType) -> str:
80133 elif to == ChineseType .traditional :
81134 return self .s2t .convert (text )
82135 raise ValueError (f"Unsupported type: { to } " )
136+
137+
138+ def handle_chinese_mono (file_destination : Path , is_src : bool , variant : ChineseType ):
139+ converted_path = file_destination .with_suffix (".converted.zst" )
140+ chinese_converter = ChineseConverter ()
141+ if is_src :
142+ logger .info (f"Converting the output file to { variant } " )
143+ stats = chinese_converter .convert_file (file_destination , converted_path , variant )
144+ else :
145+ logger .info (f"Filtering out everything except { variant } in the output file" )
146+ stats = chinese_converter .filter_file (file_destination , converted_path , variant )
147+ converted_path .replace (file_destination )
148+ print (
149+ f"Converted { stats .script_conversion .converted } , Filtered: { stats .script_conversion .filtered } Visited: { stats .script_conversion .visited } "
150+ )
151+ stats .save_json ()
152+
153+
154+ def handle_chinese_parallel (output_prefix : str , src : str , trg : str , variant : ChineseType ):
155+ if "zh" not in (src , trg ):
156+ raise ValueError ("Run only for Chinese" )
157+
158+ chinese_converter = ChineseConverter ()
159+ is_src = src == "zh"
160+ if is_src :
161+ logger .info (f"Converting the output file to { variant } " )
162+ input_path = Path (f"{ output_prefix } .{ src } .zst" )
163+ converted_path = Path (f"{ output_prefix } .converted.{ src } .zst" )
164+ stats = chinese_converter .convert_file (
165+ input_path = input_path ,
166+ output_path = converted_path ,
167+ to = variant ,
168+ )
169+ converted_path .replace (input_path )
170+ else :
171+ logger .info (f"Filtering out everything except { variant } from a parallel corpus" )
172+ trg_path = Path (f"{ output_prefix } .{ trg } .zst" )
173+ src_path = Path (f"{ output_prefix } .{ src } .zst" )
174+ trg_filtered_path = Path (f"{ output_prefix } .filtered.{ trg } .zst" )
175+ src_filtered_path = Path (f"{ output_prefix } .filtered.{ src } .zst" )
176+ stats = chinese_converter .filter_parallel_corpus (
177+ zh_path = trg_path ,
178+ other_path = src_path ,
179+ zh_output_path = trg_filtered_path ,
180+ other_output_path = src_filtered_path ,
181+ variant = variant ,
182+ )
183+ src_filtered_path .replace (src_path )
184+ trg_filtered_path .replace (trg_path )
185+ print (
186+ f"Converted { stats .script_conversion .converted } , Filtered: { stats .script_conversion .filtered } Visited: { stats .script_conversion .visited } "
187+ )
188+ stats .save_json ()
0 commit comments