77import spacy
88import numpy as np
99
10+ logging .basicConfig (level = logging .INFO )
11+
1012
1113class SemanticTextChunker :
1214 def __init__ (
1315 self ,
1416 num_surrounding_sentences : int = 1 ,
1517 similarity_threshold : float = 0.8 ,
1618 max_chunk_tokens : int = 200 ,
19+ min_chunk_tokens : int = 50 ,
1720 ):
1821 self .num_surrounding_sentences = num_surrounding_sentences
1922 self .similarity_threshold = similarity_threshold
2023 self .max_chunk_tokens = max_chunk_tokens
24+ self .min_chunk_tokens = min_chunk_tokens
2125 try :
2226 self ._nlp_model = spacy .load ("en_core_web_md" )
2327 except IOError as e :
2428 raise ValueError ("Spacy model 'en_core_web_md' not found." ) from e
2529
30+ def sentence_contains_figure_or_table_ending (self , text : str ):
31+ return "</figure>" in text or "</table>" in text
32+
2633 def sentence_contains_figure_or_table (self , text : str ):
27- return ("<figure" in text or "</figure>" in text ) or (
28- "<table>" in text or "</table>" in text
34+ return (
35+ ("<figure" in text or "</figure>" in text )
36+ or ("<table>" in text or "</table>" in text )
37+ or ("<th" in text or "th>" in text )
38+ or ("<td" in text or "td>" in text )
2939 )
3040
3141 def sentence_is_complete_figure_or_table (self , text : str ):
@@ -59,6 +69,7 @@ async def chunk(self, text: str) -> list[dict]:
5969 list(str): The list of chunks"""
6070
6171 sentences = self .split_into_sentences (text )
72+
6273 (
6374 grouped_sentences ,
6475 is_table_or_figure_map ,
@@ -68,15 +79,54 @@ async def chunk(self, text: str) -> list[dict]:
6879 grouped_sentences , is_table_or_figure_map
6980 )
7081
82+ logging .info (
83+ f"""Number of Forward pass chunks: {
84+ len (forward_pass_chunks )} """
85+ )
86+ logging .info (f"Forward pass chunks: { forward_pass_chunks } " )
87+
7188 backwards_pass_chunks , _ = self .merge_chunks (
7289 forward_pass_chunks , new_is_table_or_figure_map , forwards_direction = False
7390 )
7491
75- backwards_pass_chunks = list (
76- map (lambda x : x .strip (), reversed (backwards_pass_chunks ))
92+ reversed_backwards_pass_chunks = list (reversed (backwards_pass_chunks ))
93+
94+ logging .info (
95+ f"""Number of Backaward pass chunks: {
96+ len (reversed_backwards_pass_chunks )} """
7797 )
98+ logging .info (f"Backward pass chunks: { reversed_backwards_pass_chunks } " )
99+
100+ cleaned_final_chunks = []
101+ for chunk in reversed_backwards_pass_chunks :
102+ stripped_chunk = chunk .strip ()
103+ if len (stripped_chunk ) > 0 :
104+ cleaned_final_chunks .append (stripped_chunk )
105+
106+ logging .info (f"Number of final chunks: { len (cleaned_final_chunks )} " )
107+ logging .info (f"Chunks: { cleaned_final_chunks } " )
108+
109+ return cleaned_final_chunks
110+
111+ def filter_empty_figures (self , text ):
112+ # Regular expression to match <figure>...</figure> with only newlines or spaces in between
113+ pattern = r"<figure>\s*</figure>"
78114
79- return list (reversed (backwards_pass_chunks ))
115+ # Replace any matches of the pattern with an empty string
116+ filtered_text = re .sub (pattern , "" , text )
117+
118+ return filtered_text
119+
120+ def clean_new_lines (self , text ):
121+ # Remove single newlines surrounded by < and >
122+ cleaned_text = re .sub (r"(?<=>)(\n)(?=<)" , "" , text )
123+
124+ # Replace all other single newlines with space
125+ cleaned_text = re .sub (r"(?<!\n)\n(?!\n)" , " " , cleaned_text )
126+
127+ # Replace multiple consecutive newlines with a single space followed by \n\n
128+ cleaned_text = re .sub (r"\n{2,}" , " \n \n " , cleaned_text )
129+ return cleaned_text
80130
81131 def split_into_sentences (self , text : str ) -> list [str ]:
82132 """Splits a set of text into a list of sentences uses the Spacy NLP model.
@@ -88,22 +138,45 @@ def split_into_sentences(self, text: str) -> list[str]:
88138 list(str): The extracted sentences
89139 """
90140
91- def replace_newlines_outside_html (text ):
92- def replacement (match ):
93- # Only replace if \n is outside HTML tags
94- if "<" not in match .group (0 ) and ">" not in match .group (0 ):
95- return match .group (0 ).replace ("\n " , " " )
96- return match .group (0 )
141+ cleaned_text = self .clean_new_lines (text )
142+
143+ # Filter out empty <figure>...</figure> tags
144+ cleaned_text = self .filter_empty_figures (cleaned_text )
145+
146+ doc = self ._nlp_model (cleaned_text )
147+
148+ tag_split_sentences = []
149+ # Pattern to match the closing and opening tag junctions with whitespace in between
150+ split_pattern = r"(</table>\s*<table\b[^>]*>|</figure>\s*<figure\b[^>]*>)"
151+ for sent in doc .sents :
152+ split_result = re .split (split_pattern , sent .text )
153+ for part in split_result :
154+ # Match the junction and split it into two parts
155+ if re .match (split_pattern , part ):
156+ # Split at the first whitespace
157+ tag_split = part .split (" " , 1 )
158+ # Add the closing tag (e.g., </table>)
159+ tag_split_sentences .append (tag_split [0 ])
160+ if len (tag_split ) > 1 :
161+ # Add the rest of the string with leading space
162+ tag_split_sentences .append (" " + tag_split [1 ])
163+ else :
164+ tag_split_sentences .append (part )
97165
98- # Match sequences of non-whitespace characters with \n outside tags
99- return re . sub ( r"[^<>\s]+\n[^<>\s]+" , replacement , text )
166+ # Now apply a split pattern against markdown headings
167+ heading_split_sentences = []
100168
101- doc = self ._nlp_model (replace_newlines_outside_html (text ))
102- sentences = [sent .text for sent in doc .sents ]
169+ # Iterate through each sentence in tag_split_sentences
170+ for sent in tag_split_sentences :
171+ # Use re.split to split on \n\n and headings, but keep \n\n in the result
172+ split_result = re .split (r"(\n\n|#+ .*)" , sent )
103173
104- print (len (sentences ))
174+ # Extend the result with the correctly split parts, retaining \n\n before the heading
175+ for part in split_result :
176+ if part .strip (): # Only add non-empty parts
177+ heading_split_sentences .append (part )
105178
106- return sentences
179+ return heading_split_sentences
107180
108181 def group_figures_and_tables_into_sentences (self , sentences : list [str ]):
109182 grouped_sentences = []
@@ -125,7 +198,7 @@ def group_figures_and_tables_into_sentences(self, sentences: list[str]):
125198 is_table_or_figure_map .append (False )
126199 else :
127200 # check for ending case
128- if self .sentence_contains_figure_or_table (current_sentence ):
201+ if self .sentence_contains_figure_or_table_ending (current_sentence ):
129202 holding_sentences .append (current_sentence )
130203
131204 full_sentence = " " .join (holding_sentences )
@@ -137,6 +210,8 @@ def group_figures_and_tables_into_sentences(self, sentences: list[str]):
137210 else :
138211 holding_sentences .append (current_sentence )
139212
213+ assert len (holding_sentences ) == 0 , "Holding sentences should be empty"
214+
140215 return grouped_sentences , is_table_or_figure_map
141216
142217 def look_ahead_and_behind_sentences (
@@ -183,31 +258,50 @@ def look_ahead_and_behind_sentences(
183258
184259 def merge_similar_chunks (self , current_sentence , current_chunk , forwards_direction ):
185260 new_chunk = None
186- # Only compare when we have 2 or more chunks
187261
188- if forwards_direction is False :
189- directional_current_chunk = list (reversed (current_chunk ))
190- else :
191- directional_current_chunk = current_chunk
262+ def retrieve_current_chunk_up_to_n (n ):
263+ if forwards_direction :
264+ return " " .join (current_chunk [:- n ])
265+ else :
266+ return " " .join (reversed (current_chunk [:- n ]))
192267
193- if len (current_chunk ) >= 2 :
268+ def retrieve_current_chunks_from_n (n ):
269+ if forwards_direction :
270+ return " " .join (current_chunk [n :])
271+ else :
272+ return " " .join (reversed (current_chunk [:- n ]))
273+
274+ def retrive_current_chunk_at_n (n ):
275+ if forwards_direction :
276+ return current_chunk [n ]
277+ else :
278+ return current_chunk [n ]
279+
280+ current_chunk_tokens = self .num_tokens_from_string (" " .join (current_chunk ))
281+
282+ if len (current_chunk ) >= 2 and current_chunk_tokens >= self .min_chunk_tokens :
283+ logging .debug ("Comparing chunks" )
194284 cosine_sim = self .sentence_similarity (
195- " " . join ( directional_current_chunk [ - 2 :] ), current_sentence
285+ retrieve_current_chunks_from_n ( - 2 ), current_sentence
196286 )
197287 if (
198288 cosine_sim < self .similarity_threshold
199- or self .num_tokens_from_string (" " .join (directional_current_chunk ))
200- >= self .max_chunk_tokens
289+ or current_chunk_tokens >= self .max_chunk_tokens
201290 ):
202291 if len (current_chunk ) > 2 :
203- new_chunk = " " . join ( directional_current_chunk [: 1 ] )
204- current_chunk = [directional_current_chunk [ - 1 ] ]
292+ new_chunk = retrieve_current_chunk_up_to_n ( 1 )
293+ current_chunk = [retrive_current_chunk_at_n ( - 1 ) ]
205294 else :
206- new_chunk = current_chunk [0 ]
207- current_chunk = [current_chunk [1 ]]
295+ new_chunk = retrive_current_chunk_at_n (0 )
296+ current_chunk = [retrive_current_chunk_at_n (1 )]
297+ else :
298+ logging .debug ("Chunk too small to compare" )
208299
209300 return new_chunk , current_chunk
210301
302+ def is_markdown_heading (self , text ):
303+ return text .strip ().startswith ("#" )
304+
211305 def merge_chunks (self , sentences , is_table_or_figure_map , forwards_direction = True ):
212306 chunks = []
213307 current_chunk = []
@@ -230,6 +324,10 @@ def retrieve_current_chunk():
230324
231325 current_sentence = sentences [current_sentence_index ]
232326
327+ if len (current_sentence .strip ()) == 0 :
328+ index += 1
329+ continue
330+
233331 # Detect if table or figure
234332 if is_table_or_figure_map [current_sentence_index ]:
235333 if forwards_direction :
@@ -244,7 +342,7 @@ def retrieve_current_chunk():
244342 # On the backwards pass we don't want to add to the table chunk
245343 chunks .append (retrieve_current_chunk ())
246344 new_is_table_or_figure_map .append (True )
247- current_chunk . append ( current_sentence )
345+ current_chunk = [ current_sentence ]
248346
249347 index += 1
250348 continue
@@ -260,11 +358,18 @@ def retrieve_current_chunk():
260358 )
261359
262360 if is_table_or_figure_behind :
263- # Finish off
264- current_chunk .append (current_sentence )
265- chunks .append (retrieve_current_chunk ())
266- new_is_table_or_figure_map .append (False )
267- current_chunk = []
361+ # Check if Makrdown heading
362+ if self .is_markdown_heading (current_sentence ):
363+ # Start new chunk
364+ chunks .append (retrieve_current_chunk ())
365+ new_is_table_or_figure_map .append (False )
366+ current_chunk = [current_sentence ]
367+ else :
368+ # Finish off
369+ current_chunk .append (current_sentence )
370+ chunks .append (retrieve_current_chunk ())
371+ new_is_table_or_figure_map .append (False )
372+ current_chunk = []
268373
269374 index += 1
270375 continue
@@ -307,7 +412,7 @@ def retrieve_current_chunk():
307412 index += 1
308413
309414 if len (current_chunk ) > 0 :
310- final_chunk = " " . join ( current_chunk )
415+ final_chunk = retrieve_current_chunk ( )
311416 chunks .append (final_chunk )
312417
313418 new_is_table_or_figure_map .append (
@@ -322,7 +427,13 @@ def sentence_similarity(self, text_1, text_2):
322427
323428 dot_product = np .dot (vec1 , vec2 )
324429 magnitude = np .linalg .norm (vec1 ) * np .linalg .norm (vec2 )
325- return dot_product / magnitude if magnitude != 0 else 0.0
430+ similarity = dot_product / magnitude if magnitude != 0 else 0.0
431+
432+ logging .debug (
433+ f"""Similarity between '{ text_1 } ' and '{
434+ text_2 } ': { similarity } """
435+ )
436+ return similarity
326437
327438
328439async def process_semantic_text_chunker (record : dict , text_chunker ) -> dict :
0 commit comments