Skip to content

Commit 09b671b

Browse files
committed
Update the chunker
1 parent 010fe03 commit 09b671b

File tree

2 files changed

+160
-42
lines changed

2 files changed

+160
-42
lines changed

adi_function_app/adi_2_ai_search.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,19 @@ async def build_and_clean_markdown_for_response(
3838
str: The cleaned Markdown text.
3939
"""
4040

41-
comment_patterns = r"<!-- PageNumber=\"[^\"]*\" -->|<!-- PageHeader=\"[^\"]*\" -->|<!-- PageFooter=\"[^\"]*\" -->|<!-- PageBreak -->|<!-- Footnote=\"[^\"]*\" -->"
42-
cleaned_text = re.sub(comment_patterns, "", markdown_text, flags=re.DOTALL)
41+
# Pattern to match the comment start `<!--` and comment end `-->`
42+
# Matches opening `<!--` up to the first occurrence of a non-hyphen character
43+
comment_start_pattern = r"<!--[^<]*"
44+
comment_end_pattern = r"(-->|\<)"
45+
46+
# Using re.sub to remove comments
47+
cleaned_text = re.sub(
48+
f"{comment_start_pattern}.*?{comment_end_pattern}", "", markdown_text
49+
)
4350

4451
# Remove irrelevant figures
4552
if remove_irrelevant_figures:
46-
irrelevant_figure_pattern = r"<!-- FigureContent=\"Irrelevant Image\" -->\s*"
53+
irrelevant_figure_pattern = r"<figure[^>]*>.*?Irrelevant Image.*?</figure>"
4754
cleaned_text = re.sub(
4855
irrelevant_figure_pattern, "", cleaned_text, flags=re.DOTALL
4956
)

adi_function_app/semantic_text_chunker.py

Lines changed: 150 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -7,25 +7,35 @@
77
import spacy
88
import numpy as np
99

10+
logging.basicConfig(level=logging.INFO)
11+
1012

1113
class SemanticTextChunker:
1214
def __init__(
1315
self,
1416
num_surrounding_sentences: int = 1,
1517
similarity_threshold: float = 0.8,
1618
max_chunk_tokens: int = 200,
19+
min_chunk_tokens: int = 50,
1720
):
1821
self.num_surrounding_sentences = num_surrounding_sentences
1922
self.similarity_threshold = similarity_threshold
2023
self.max_chunk_tokens = max_chunk_tokens
24+
self.min_chunk_tokens = min_chunk_tokens
2125
try:
2226
self._nlp_model = spacy.load("en_core_web_md")
2327
except IOError as e:
2428
raise ValueError("Spacy model 'en_core_web_md' not found.") from e
2529

30+
def sentence_contains_figure_or_table_ending(self, text: str):
31+
return "</figure>" in text or "</table>" in text
32+
2633
def sentence_contains_figure_or_table(self, text: str):
27-
return ("<figure" in text or "</figure>" in text) or (
28-
"<table>" in text or "</table>" in text
34+
return (
35+
("<figure" in text or "</figure>" in text)
36+
or ("<table>" in text or "</table>" in text)
37+
or ("<th" in text or "th>" in text)
38+
or ("<td" in text or "td>" in text)
2939
)
3040

3141
def sentence_is_complete_figure_or_table(self, text: str):
@@ -59,6 +69,7 @@ async def chunk(self, text: str) -> list[dict]:
5969
list(str): The list of chunks"""
6070

6171
sentences = self.split_into_sentences(text)
72+
6273
(
6374
grouped_sentences,
6475
is_table_or_figure_map,
@@ -68,15 +79,54 @@ async def chunk(self, text: str) -> list[dict]:
6879
grouped_sentences, is_table_or_figure_map
6980
)
7081

82+
logging.info(
83+
f"""Number of Forward pass chunks: {
84+
len(forward_pass_chunks)}"""
85+
)
86+
logging.info(f"Forward pass chunks: {forward_pass_chunks}")
87+
7188
backwards_pass_chunks, _ = self.merge_chunks(
7289
forward_pass_chunks, new_is_table_or_figure_map, forwards_direction=False
7390
)
7491

75-
backwards_pass_chunks = list(
76-
map(lambda x: x.strip(), reversed(backwards_pass_chunks))
92+
reversed_backwards_pass_chunks = list(reversed(backwards_pass_chunks))
93+
94+
logging.info(
95+
f"""Number of Backaward pass chunks: {
96+
len(reversed_backwards_pass_chunks)}"""
7797
)
98+
logging.info(f"Backward pass chunks: {reversed_backwards_pass_chunks}")
99+
100+
cleaned_final_chunks = []
101+
for chunk in reversed_backwards_pass_chunks:
102+
stripped_chunk = chunk.strip()
103+
if len(stripped_chunk) > 0:
104+
cleaned_final_chunks.append(stripped_chunk)
105+
106+
logging.info(f"Number of final chunks: {len(cleaned_final_chunks)}")
107+
logging.info(f"Chunks: {cleaned_final_chunks}")
108+
109+
return cleaned_final_chunks
110+
111+
def filter_empty_figures(self, text):
112+
# Regular expression to match <figure>...</figure> with only newlines or spaces in between
113+
pattern = r"<figure>\s*</figure>"
78114

79-
return list(reversed(backwards_pass_chunks))
115+
# Replace any matches of the pattern with an empty string
116+
filtered_text = re.sub(pattern, "", text)
117+
118+
return filtered_text
119+
120+
def clean_new_lines(self, text):
121+
# Remove single newlines surrounded by < and >
122+
cleaned_text = re.sub(r"(?<=>)(\n)(?=<)", "", text)
123+
124+
# Replace all other single newlines with space
125+
cleaned_text = re.sub(r"(?<!\n)\n(?!\n)", " ", cleaned_text)
126+
127+
# Replace multiple consecutive newlines with a single space followed by \n\n
128+
cleaned_text = re.sub(r"\n{2,}", " \n\n", cleaned_text)
129+
return cleaned_text
80130

81131
def split_into_sentences(self, text: str) -> list[str]:
82132
"""Splits a set of text into a list of sentences uses the Spacy NLP model.
@@ -88,22 +138,45 @@ def split_into_sentences(self, text: str) -> list[str]:
88138
list(str): The extracted sentences
89139
"""
90140

91-
def replace_newlines_outside_html(text):
92-
def replacement(match):
93-
# Only replace if \n is outside HTML tags
94-
if "<" not in match.group(0) and ">" not in match.group(0):
95-
return match.group(0).replace("\n", " ")
96-
return match.group(0)
141+
cleaned_text = self.clean_new_lines(text)
142+
143+
# Filter out empty <figure>...</figure> tags
144+
cleaned_text = self.filter_empty_figures(cleaned_text)
145+
146+
doc = self._nlp_model(cleaned_text)
147+
148+
tag_split_sentences = []
149+
# Pattern to match the closing and opening tag junctions with whitespace in between
150+
split_pattern = r"(</table>\s*<table\b[^>]*>|</figure>\s*<figure\b[^>]*>)"
151+
for sent in doc.sents:
152+
split_result = re.split(split_pattern, sent.text)
153+
for part in split_result:
154+
# Match the junction and split it into two parts
155+
if re.match(split_pattern, part):
156+
# Split at the first whitespace
157+
tag_split = part.split(" ", 1)
158+
# Add the closing tag (e.g., </table>)
159+
tag_split_sentences.append(tag_split[0])
160+
if len(tag_split) > 1:
161+
# Add the rest of the string with leading space
162+
tag_split_sentences.append(" " + tag_split[1])
163+
else:
164+
tag_split_sentences.append(part)
97165

98-
# Match sequences of non-whitespace characters with \n outside tags
99-
return re.sub(r"[^<>\s]+\n[^<>\s]+", replacement, text)
166+
# Now apply a split pattern against markdown headings
167+
heading_split_sentences = []
100168

101-
doc = self._nlp_model(replace_newlines_outside_html(text))
102-
sentences = [sent.text for sent in doc.sents]
169+
# Iterate through each sentence in tag_split_sentences
170+
for sent in tag_split_sentences:
171+
# Use re.split to split on \n\n and headings, but keep \n\n in the result
172+
split_result = re.split(r"(\n\n|#+ .*)", sent)
103173

104-
print(len(sentences))
174+
# Extend the result with the correctly split parts, retaining \n\n before the heading
175+
for part in split_result:
176+
if part.strip(): # Only add non-empty parts
177+
heading_split_sentences.append(part)
105178

106-
return sentences
179+
return heading_split_sentences
107180

108181
def group_figures_and_tables_into_sentences(self, sentences: list[str]):
109182
grouped_sentences = []
@@ -125,7 +198,7 @@ def group_figures_and_tables_into_sentences(self, sentences: list[str]):
125198
is_table_or_figure_map.append(False)
126199
else:
127200
# check for ending case
128-
if self.sentence_contains_figure_or_table(current_sentence):
201+
if self.sentence_contains_figure_or_table_ending(current_sentence):
129202
holding_sentences.append(current_sentence)
130203

131204
full_sentence = " ".join(holding_sentences)
@@ -137,6 +210,8 @@ def group_figures_and_tables_into_sentences(self, sentences: list[str]):
137210
else:
138211
holding_sentences.append(current_sentence)
139212

213+
assert len(holding_sentences) == 0, "Holding sentences should be empty"
214+
140215
return grouped_sentences, is_table_or_figure_map
141216

142217
def look_ahead_and_behind_sentences(
@@ -183,31 +258,50 @@ def look_ahead_and_behind_sentences(
183258

184259
def merge_similar_chunks(self, current_sentence, current_chunk, forwards_direction):
185260
new_chunk = None
186-
# Only compare when we have 2 or more chunks
187261

188-
if forwards_direction is False:
189-
directional_current_chunk = list(reversed(current_chunk))
190-
else:
191-
directional_current_chunk = current_chunk
262+
def retrieve_current_chunk_up_to_n(n):
263+
if forwards_direction:
264+
return " ".join(current_chunk[:-n])
265+
else:
266+
return " ".join(reversed(current_chunk[:-n]))
192267

193-
if len(current_chunk) >= 2:
268+
def retrieve_current_chunks_from_n(n):
269+
if forwards_direction:
270+
return " ".join(current_chunk[n:])
271+
else:
272+
return " ".join(reversed(current_chunk[:-n]))
273+
274+
def retrive_current_chunk_at_n(n):
275+
if forwards_direction:
276+
return current_chunk[n]
277+
else:
278+
return current_chunk[n]
279+
280+
current_chunk_tokens = self.num_tokens_from_string(" ".join(current_chunk))
281+
282+
if len(current_chunk) >= 2 and current_chunk_tokens >= self.min_chunk_tokens:
283+
logging.debug("Comparing chunks")
194284
cosine_sim = self.sentence_similarity(
195-
" ".join(directional_current_chunk[-2:]), current_sentence
285+
retrieve_current_chunks_from_n(-2), current_sentence
196286
)
197287
if (
198288
cosine_sim < self.similarity_threshold
199-
or self.num_tokens_from_string(" ".join(directional_current_chunk))
200-
>= self.max_chunk_tokens
289+
or current_chunk_tokens >= self.max_chunk_tokens
201290
):
202291
if len(current_chunk) > 2:
203-
new_chunk = " ".join(directional_current_chunk[:1])
204-
current_chunk = [directional_current_chunk[-1]]
292+
new_chunk = retrieve_current_chunk_up_to_n(1)
293+
current_chunk = [retrive_current_chunk_at_n(-1)]
205294
else:
206-
new_chunk = current_chunk[0]
207-
current_chunk = [current_chunk[1]]
295+
new_chunk = retrive_current_chunk_at_n(0)
296+
current_chunk = [retrive_current_chunk_at_n(1)]
297+
else:
298+
logging.debug("Chunk too small to compare")
208299

209300
return new_chunk, current_chunk
210301

302+
def is_markdown_heading(self, text):
303+
return text.strip().startswith("#")
304+
211305
def merge_chunks(self, sentences, is_table_or_figure_map, forwards_direction=True):
212306
chunks = []
213307
current_chunk = []
@@ -230,6 +324,10 @@ def retrieve_current_chunk():
230324

231325
current_sentence = sentences[current_sentence_index]
232326

327+
if len(current_sentence.strip()) == 0:
328+
index += 1
329+
continue
330+
233331
# Detect if table or figure
234332
if is_table_or_figure_map[current_sentence_index]:
235333
if forwards_direction:
@@ -244,7 +342,7 @@ def retrieve_current_chunk():
244342
# On the backwards pass we don't want to add to the table chunk
245343
chunks.append(retrieve_current_chunk())
246344
new_is_table_or_figure_map.append(True)
247-
current_chunk.append(current_sentence)
345+
current_chunk = [current_sentence]
248346

249347
index += 1
250348
continue
@@ -260,11 +358,18 @@ def retrieve_current_chunk():
260358
)
261359

262360
if is_table_or_figure_behind:
263-
# Finish off
264-
current_chunk.append(current_sentence)
265-
chunks.append(retrieve_current_chunk())
266-
new_is_table_or_figure_map.append(False)
267-
current_chunk = []
361+
# Check if Makrdown heading
362+
if self.is_markdown_heading(current_sentence):
363+
# Start new chunk
364+
chunks.append(retrieve_current_chunk())
365+
new_is_table_or_figure_map.append(False)
366+
current_chunk = [current_sentence]
367+
else:
368+
# Finish off
369+
current_chunk.append(current_sentence)
370+
chunks.append(retrieve_current_chunk())
371+
new_is_table_or_figure_map.append(False)
372+
current_chunk = []
268373

269374
index += 1
270375
continue
@@ -307,7 +412,7 @@ def retrieve_current_chunk():
307412
index += 1
308413

309414
if len(current_chunk) > 0:
310-
final_chunk = " ".join(current_chunk)
415+
final_chunk = retrieve_current_chunk()
311416
chunks.append(final_chunk)
312417

313418
new_is_table_or_figure_map.append(
@@ -322,7 +427,13 @@ def sentence_similarity(self, text_1, text_2):
322427

323428
dot_product = np.dot(vec1, vec2)
324429
magnitude = np.linalg.norm(vec1) * np.linalg.norm(vec2)
325-
return dot_product / magnitude if magnitude != 0 else 0.0
430+
similarity = dot_product / magnitude if magnitude != 0 else 0.0
431+
432+
logging.debug(
433+
f"""Similarity between '{text_1}' and '{
434+
text_2}': {similarity}"""
435+
)
436+
return similarity
326437

327438

328439
async def process_semantic_text_chunker(record: dict, text_chunker) -> dict:

0 commit comments

Comments
 (0)