From d3f131cfb75113df5ba264231d0869f6a40bca55 Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Wed, 8 Jan 2025 16:32:11 +0200 Subject: [PATCH 01/12] Improve text splitter to avoid cutting words in chunks --- .../approximate_fixed_size_splitter.py | 118 ++++++++++++++++++ 1 file changed, 118 insertions(+) create mode 100644 src/neo4j_graphrag/experimental/components/text_splitters/approximate_fixed_size_splitter.py diff --git a/src/neo4j_graphrag/experimental/components/text_splitters/approximate_fixed_size_splitter.py b/src/neo4j_graphrag/experimental/components/text_splitters/approximate_fixed_size_splitter.py new file mode 100644 index 000000000..e47eec946 --- /dev/null +++ b/src/neo4j_graphrag/experimental/components/text_splitters/approximate_fixed_size_splitter.py @@ -0,0 +1,118 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pydantic import validate_call + +from neo4j_graphrag.experimental.components.text_splitters.base import TextSplitter +from neo4j_graphrag.experimental.components.types import TextChunk, TextChunks + + +def _adjust_chunk_start(text: str, proposed_start: int) -> int: + """ + Shift the starting index backward if it lands in the middle of a word. + If no whitespace is found, use the proposed start. + + Args: + text (str): The text being split. + proposed_start (int): The initial starting index of the chunk. + + Returns: + int: The adjusted starting index, ensuring the chunk does not begin in the + middle of a word. + """ + start = proposed_start + if start > 0 and not text[start].isspace() and not text[start - 1].isspace(): + while start > 0 and not text[start - 1].isspace(): + start -= 1 + + # fallback if no whitespace is found + if start == 0 and not text[0].isspace(): + start = proposed_start + return start + + +def _adjust_chunk_end(text: str, start: int, approximate_end: int) -> int: + """ + Shift the ending index backward if it lands in the middle of a word. + If no whitespace is found, use 'approximate_end' to avoid an infinite loop. + + Args: + text (str): The full text being split. + start (int): The adjusted starting index for this chunk. + approximate_end (int): The initial end index. + + Returns: + int: The adjusted ending index, ensuring the chunk does not end in the middle of + a word if possible. + """ + end = approximate_end + if end < len(text): + while end > start and not text[end - 1].isspace(): + end -= 1 + + # fallback if no whitespace is found + if end == start: + end = approximate_end + return end + + +class ApproximateFixedSizeSplitter(TextSplitter): + """Text splitter which splits the input text into approximate fixed size chunks with + optional overlap, avoiding cutting words. + + Args: + chunk_size (int): The number of characters in each chunk. + chunk_overlap (int): The number of characters from the previous chunk to overlap + with each chunk. Must be less than `chunk_size`. + + """ + @validate_call + def __init__(self, chunk_size: int = 4000, chunk_overlap: int = 200) -> None: + if chunk_overlap >= chunk_size: + raise ValueError("chunk_overlap must be strictly less than chunk_size") + self.chunk_size = chunk_size + self.chunk_overlap = chunk_overlap + + @validate_call + async def run(self, text: str) -> TextChunks: + """Splits a piece of text into chunks without cutting words in half. + + Args: + text (str): The text to be split. + + Returns: + TextChunks: A list of TextChunk objects with chunked text. + """ + chunks = [] + index = 0 + + step = self.chunk_size - self.chunk_overlap + text_length = len(text) + + i = 0 + while i < text_length: + # adjust chunk start + start = _adjust_chunk_start(text, i) + + # adjust chunk end + approximate_end = min(start + self.chunk_size, text_length) + end = _adjust_chunk_end(text, start, approximate_end) + + chunk_text = text[start:end] + chunks.append(TextChunk(text=chunk_text, index=index)) + index += 1 + + i = max(start + step, end) + + return TextChunks(chunks=chunks) From 30ae18d9f101a9fd704cc00c7f573bc5c69ed499 Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Thu, 9 Jan 2025 13:02:18 +0200 Subject: [PATCH 02/12] Add a flag to fixed size text splitter to avoid cutting words at chunk boundaries --- .../approximate_fixed_size_splitter.py | 118 ------------------ .../text_splitters/fixed_size_splitter.py | 84 +++++++++++-- 2 files changed, 77 insertions(+), 125 deletions(-) delete mode 100644 src/neo4j_graphrag/experimental/components/text_splitters/approximate_fixed_size_splitter.py diff --git a/src/neo4j_graphrag/experimental/components/text_splitters/approximate_fixed_size_splitter.py b/src/neo4j_graphrag/experimental/components/text_splitters/approximate_fixed_size_splitter.py deleted file mode 100644 index e47eec946..000000000 --- a/src/neo4j_graphrag/experimental/components/text_splitters/approximate_fixed_size_splitter.py +++ /dev/null @@ -1,118 +0,0 @@ -# Copyright (c) "Neo4j" -# Neo4j Sweden AB [https://neo4j.com] -# # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# # -# https://www.apache.org/licenses/LICENSE-2.0 -# # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from pydantic import validate_call - -from neo4j_graphrag.experimental.components.text_splitters.base import TextSplitter -from neo4j_graphrag.experimental.components.types import TextChunk, TextChunks - - -def _adjust_chunk_start(text: str, proposed_start: int) -> int: - """ - Shift the starting index backward if it lands in the middle of a word. - If no whitespace is found, use the proposed start. - - Args: - text (str): The text being split. - proposed_start (int): The initial starting index of the chunk. - - Returns: - int: The adjusted starting index, ensuring the chunk does not begin in the - middle of a word. - """ - start = proposed_start - if start > 0 and not text[start].isspace() and not text[start - 1].isspace(): - while start > 0 and not text[start - 1].isspace(): - start -= 1 - - # fallback if no whitespace is found - if start == 0 and not text[0].isspace(): - start = proposed_start - return start - - -def _adjust_chunk_end(text: str, start: int, approximate_end: int) -> int: - """ - Shift the ending index backward if it lands in the middle of a word. - If no whitespace is found, use 'approximate_end' to avoid an infinite loop. - - Args: - text (str): The full text being split. - start (int): The adjusted starting index for this chunk. - approximate_end (int): The initial end index. - - Returns: - int: The adjusted ending index, ensuring the chunk does not end in the middle of - a word if possible. - """ - end = approximate_end - if end < len(text): - while end > start and not text[end - 1].isspace(): - end -= 1 - - # fallback if no whitespace is found - if end == start: - end = approximate_end - return end - - -class ApproximateFixedSizeSplitter(TextSplitter): - """Text splitter which splits the input text into approximate fixed size chunks with - optional overlap, avoiding cutting words. - - Args: - chunk_size (int): The number of characters in each chunk. - chunk_overlap (int): The number of characters from the previous chunk to overlap - with each chunk. Must be less than `chunk_size`. - - """ - @validate_call - def __init__(self, chunk_size: int = 4000, chunk_overlap: int = 200) -> None: - if chunk_overlap >= chunk_size: - raise ValueError("chunk_overlap must be strictly less than chunk_size") - self.chunk_size = chunk_size - self.chunk_overlap = chunk_overlap - - @validate_call - async def run(self, text: str) -> TextChunks: - """Splits a piece of text into chunks without cutting words in half. - - Args: - text (str): The text to be split. - - Returns: - TextChunks: A list of TextChunk objects with chunked text. - """ - chunks = [] - index = 0 - - step = self.chunk_size - self.chunk_overlap - text_length = len(text) - - i = 0 - while i < text_length: - # adjust chunk start - start = _adjust_chunk_start(text, i) - - # adjust chunk end - approximate_end = min(start + self.chunk_size, text_length) - end = _adjust_chunk_end(text, start, approximate_end) - - chunk_text = text[start:end] - chunks.append(TextChunk(text=chunk_text, index=index)) - index += 1 - - i = max(start + step, end) - - return TextChunks(chunks=chunks) diff --git a/src/neo4j_graphrag/experimental/components/text_splitters/fixed_size_splitter.py b/src/neo4j_graphrag/experimental/components/text_splitters/fixed_size_splitter.py index 6add30d82..5cab2ab51 100644 --- a/src/neo4j_graphrag/experimental/components/text_splitters/fixed_size_splitter.py +++ b/src/neo4j_graphrag/experimental/components/text_splitters/fixed_size_splitter.py @@ -18,12 +18,66 @@ from neo4j_graphrag.experimental.components.types import TextChunk, TextChunks +def _adjust_chunk_start(text: str, proposed_start: int) -> int: + """ + Shift the starting index backward if it lands in the middle of a word. + If no whitespace is found, use the proposed start. + + Args: + text (str): The text being split. + proposed_start (int): The initial starting index of the chunk. + + Returns: + int: The adjusted starting index, ensuring the chunk does not begin in the + middle of a word if possible. + """ + start = proposed_start + if start > 0 and not text[start].isspace() and not text[start - 1].isspace(): + while start > 0 and not text[start - 1].isspace(): + start -= 1 + + # fallback if no whitespace is found + if start == 0 and not text[0].isspace(): + start = proposed_start + return start + + +def _adjust_chunk_end(text: str, start: int, approximate_end: int) -> int: + """ + Shift the ending index backward if it lands in the middle of a word. + If no whitespace is found, use 'approximate_end'. + + Args: + text (str): The full text being split. + start (int): The adjusted starting index for this chunk. + approximate_end (int): The initial end index. + + Returns: + int: The adjusted ending index, ensuring the chunk does not end in the middle of + a word if possible. + """ + end = approximate_end + if end < len(text): + while end > start and not text[end - 1].isspace(): + end -= 1 + + # fallback if no whitespace is found + if end == start: + end = approximate_end + return end + + class FixedSizeSplitter(TextSplitter): - """Text splitter which splits the input text into fixed size chunks with optional overlap. + """Text splitter which splits the input text into fixed or approximate fixed size + chunks with optional overlap. Args: chunk_size (int): The number of characters in each chunk. - chunk_overlap (int): The number of characters from the previous chunk to overlap with each chunk. Must be less than `chunk_size`. + chunk_overlap (int): The number of characters from the previous chunk to overlap + with each chunk. Must be less than `chunk_size`. + approximate (bool): If True, avoids splitting words in the middle at chunk + boundaries. Defaults to True. + Example: @@ -33,16 +87,17 @@ class FixedSizeSplitter(TextSplitter): from neo4j_graphrag.experimental.pipeline import Pipeline pipeline = Pipeline() - text_splitter = FixedSizeSplitter(chunk_size=4000, chunk_overlap=200) + text_splitter = FixedSizeSplitter(chunk_size=4000, chunk_overlap=200, approximate=True) pipeline.add_component(text_splitter, "text_splitter") """ @validate_call - def __init__(self, chunk_size: int = 4000, chunk_overlap: int = 200) -> None: + def __init__(self, chunk_size: int = 4000, chunk_overlap: int = 200, approximate: bool = True) -> None: if chunk_overlap >= chunk_size: raise ValueError("chunk_overlap must be strictly less than chunk_size") self.chunk_size = chunk_size self.chunk_overlap = chunk_overlap + self.approximate = approximate @validate_call async def run(self, text: str) -> TextChunks: @@ -56,10 +111,25 @@ async def run(self, text: str) -> TextChunks: """ chunks = [] index = 0 - for i in range(0, len(text), self.chunk_size - self.chunk_overlap): - start = i - end = min(start + self.chunk_size, len(text)) + step = self.chunk_size - self.chunk_overlap + text_length = len(text) + + i = 0 + while i < text_length: + if self.approximate: + # adjust start and end to avoid cutting words in the middle + start = _adjust_chunk_start(text, i) + approximate_end = min(start + self.chunk_size, text_length) + end = _adjust_chunk_end(text, start, approximate_end) + else: + # fixed size splitting with possibly words cut in half at chunk boundaries + start = i + end = min(start + self.chunk_size, text_length) + chunk_text = text[start:end] chunks.append(TextChunk(text=chunk_text, index=index)) index += 1 + + i = max(start + step, end) + return TextChunks(chunks=chunks) From f622a44b80acde3f1cf1dc72ebabd8249d939b99 Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Fri, 10 Jan 2025 16:38:30 +0200 Subject: [PATCH 03/12] Fix existing unit tests --- .../text_splitters/test_fixed_size_splitter.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/unit/experimental/components/text_splitters/test_fixed_size_splitter.py b/tests/unit/experimental/components/text_splitters/test_fixed_size_splitter.py index 0467201fa..c51d6a428 100644 --- a/tests/unit/experimental/components/text_splitters/test_fixed_size_splitter.py +++ b/tests/unit/experimental/components/text_splitters/test_fixed_size_splitter.py @@ -26,7 +26,8 @@ async def test_split_text_no_overlap() -> None: text = "may thy knife chip and shatter" chunk_size = 5 chunk_overlap = 0 - splitter = FixedSizeSplitter(chunk_size, chunk_overlap) + approximate = False + splitter = FixedSizeSplitter(chunk_size, chunk_overlap, approximate) chunks = await splitter.run(text) expected_chunks = [ TextChunk(text="may t", index=0), @@ -47,7 +48,8 @@ async def test_split_text_with_overlap() -> None: text = "may thy knife chip and shatter" chunk_size = 10 chunk_overlap = 2 - splitter = FixedSizeSplitter(chunk_size, chunk_overlap) + approximate = False + splitter = FixedSizeSplitter(chunk_size, chunk_overlap, approximate) chunks = await splitter.run(text) expected_chunks = [ TextChunk(text="may thy kn", index=0), @@ -66,7 +68,8 @@ async def test_split_text_empty_string() -> None: text = "" chunk_size = 5 chunk_overlap = 1 - splitter = FixedSizeSplitter(chunk_size, chunk_overlap) + approximate = False + splitter = FixedSizeSplitter(chunk_size, chunk_overlap, approximate) chunks = await splitter.run(text) assert chunks.chunks == [] From a130054d34c6c2988c84987f54db48119ba6fd43 Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Fri, 10 Jan 2025 16:50:26 +0200 Subject: [PATCH 04/12] Bug Fixes --- .../text_splitters/fixed_size_splitter.py | 23 +++++++++++++------ 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/src/neo4j_graphrag/experimental/components/text_splitters/fixed_size_splitter.py b/src/neo4j_graphrag/experimental/components/text_splitters/fixed_size_splitter.py index 5cab2ab51..7418cde72 100644 --- a/src/neo4j_graphrag/experimental/components/text_splitters/fixed_size_splitter.py +++ b/src/neo4j_graphrag/experimental/components/text_splitters/fixed_size_splitter.py @@ -58,7 +58,7 @@ def _adjust_chunk_end(text: str, start: int, approximate_end: int) -> int: """ end = approximate_end if end < len(text): - while end > start and not text[end - 1].isspace(): + while end > start and not text[end].isspace() and not text[end-1].isspace(): end -= 1 # fallback if no whitespace is found @@ -114,22 +114,31 @@ async def run(self, text: str) -> TextChunks: step = self.chunk_size - self.chunk_overlap text_length = len(text) - i = 0 - while i < text_length: + approximate_start = 0 + skip_adjust_chunk_start = False + while approximate_start < text_length: if self.approximate: + if skip_adjust_chunk_start: + start = approximate_start + else: + start = _adjust_chunk_start(text, approximate_start) # adjust start and end to avoid cutting words in the middle - start = _adjust_chunk_start(text, i) approximate_end = min(start + self.chunk_size, text_length) end = _adjust_chunk_end(text, start, approximate_end) + # when avoiding splitting words in the middle is not possible, revert to initial chunk end and skip adjusting next chunk start + if end == approximate_end: + skip_adjust_chunk_start = True + else: + skip_adjust_chunk_start = False else: - # fixed size splitting with possibly words cut in half at chunk boundaries - start = i + # apply fixed size splitting with possibly words cut in half at chunk boundaries + start = approximate_start end = min(start + self.chunk_size, text_length) chunk_text = text[start:end] chunks.append(TextChunk(text=chunk_text, index=index)) index += 1 - i = max(start + step, end) + approximate_start = start + step return TextChunks(chunks=chunks) From 40c18b313af2f3ef9d5d183a71c3aeb5fb1e894b Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Mon, 13 Jan 2025 10:35:36 +0200 Subject: [PATCH 05/12] Code cleanup --- .../text_splitters/fixed_size_splitter.py | 30 +++++++++---------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/src/neo4j_graphrag/experimental/components/text_splitters/fixed_size_splitter.py b/src/neo4j_graphrag/experimental/components/text_splitters/fixed_size_splitter.py index 7418cde72..e90dca4ea 100644 --- a/src/neo4j_graphrag/experimental/components/text_splitters/fixed_size_splitter.py +++ b/src/neo4j_graphrag/experimental/components/text_splitters/fixed_size_splitter.py @@ -18,27 +18,27 @@ from neo4j_graphrag.experimental.components.types import TextChunk, TextChunks -def _adjust_chunk_start(text: str, proposed_start: int) -> int: +def _adjust_chunk_start(text: str, approximate_start: int) -> int: """ Shift the starting index backward if it lands in the middle of a word. If no whitespace is found, use the proposed start. Args: text (str): The text being split. - proposed_start (int): The initial starting index of the chunk. + approximate_start (int): The initial starting index of the chunk. Returns: int: The adjusted starting index, ensuring the chunk does not begin in the middle of a word if possible. """ - start = proposed_start + start = approximate_start if start > 0 and not text[start].isspace() and not text[start - 1].isspace(): while start > 0 and not text[start - 1].isspace(): start -= 1 # fallback if no whitespace is found if start == 0 and not text[0].isspace(): - start = proposed_start + start = approximate_start return start @@ -113,25 +113,25 @@ async def run(self, text: str) -> TextChunks: index = 0 step = self.chunk_size - self.chunk_overlap text_length = len(text) - approximate_start = 0 + skip_adjust_chunk_start = False while approximate_start < text_length: if self.approximate: - if skip_adjust_chunk_start: - start = approximate_start - else: - start = _adjust_chunk_start(text, approximate_start) + start = ( + approximate_start + if skip_adjust_chunk_start + else _adjust_chunk_start(text, approximate_start) + ) # adjust start and end to avoid cutting words in the middle approximate_end = min(start + self.chunk_size, text_length) end = _adjust_chunk_end(text, start, approximate_end) - # when avoiding splitting words in the middle is not possible, revert to initial chunk end and skip adjusting next chunk start - if end == approximate_end: - skip_adjust_chunk_start = True - else: - skip_adjust_chunk_start = False + # when avoiding splitting words in the middle is not possible, revert to + # initial chunk end and skip adjusting next chunk start + skip_adjust_chunk_start = (end == approximate_end) else: - # apply fixed size splitting with possibly words cut in half at chunk boundaries + # apply fixed size splitting with possibly words cut in half at chunk + # boundaries start = approximate_start end = min(start + self.chunk_size, text_length) From 48092e7fd1e64ea076ec90ce474de472d9654074 Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Mon, 13 Jan 2025 15:41:38 +0200 Subject: [PATCH 06/12] Avoid chunk_size 0 --- .../components/text_splitters/fixed_size_splitter.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/neo4j_graphrag/experimental/components/text_splitters/fixed_size_splitter.py b/src/neo4j_graphrag/experimental/components/text_splitters/fixed_size_splitter.py index e90dca4ea..94af5c27b 100644 --- a/src/neo4j_graphrag/experimental/components/text_splitters/fixed_size_splitter.py +++ b/src/neo4j_graphrag/experimental/components/text_splitters/fixed_size_splitter.py @@ -93,6 +93,8 @@ class FixedSizeSplitter(TextSplitter): @validate_call def __init__(self, chunk_size: int = 4000, chunk_overlap: int = 200, approximate: bool = True) -> None: + if chunk_size <= 0: + raise ValueError("chunk_size must be strictly greater than 0") if chunk_overlap >= chunk_size: raise ValueError("chunk_overlap must be strictly less than chunk_size") self.chunk_size = chunk_size From 55ff83ff4e65d4daa9f8b50579d86216a3849496 Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Mon, 13 Jan 2025 21:31:49 +0200 Subject: [PATCH 07/12] Fix edge case where chunk end greater than text length --- .../components/text_splitters/fixed_size_splitter.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/neo4j_graphrag/experimental/components/text_splitters/fixed_size_splitter.py b/src/neo4j_graphrag/experimental/components/text_splitters/fixed_size_splitter.py index 94af5c27b..40fdac767 100644 --- a/src/neo4j_graphrag/experimental/components/text_splitters/fixed_size_splitter.py +++ b/src/neo4j_graphrag/experimental/components/text_splitters/fixed_size_splitter.py @@ -116,9 +116,10 @@ async def run(self, text: str) -> TextChunks: step = self.chunk_size - self.chunk_overlap text_length = len(text) approximate_start = 0 - skip_adjust_chunk_start = False - while approximate_start < text_length: + end = 0 + + while end < text_length: if self.approximate: start = ( approximate_start From a5b2270098074b43cef6adaf09d4c3e524a023d2 Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Mon, 13 Jan 2025 21:31:57 +0200 Subject: [PATCH 08/12] Add unit tests --- .../test_fixed_size_splitter.py | 138 +++++++++++++++++- 1 file changed, 137 insertions(+), 1 deletion(-) diff --git a/tests/unit/experimental/components/text_splitters/test_fixed_size_splitter.py b/tests/unit/experimental/components/text_splitters/test_fixed_size_splitter.py index c51d6a428..6193d56a6 100644 --- a/tests/unit/experimental/components/text_splitters/test_fixed_size_splitter.py +++ b/tests/unit/experimental/components/text_splitters/test_fixed_size_splitter.py @@ -16,7 +16,7 @@ import pytest from neo4j_graphrag.experimental.components.text_splitters.fixed_size_splitter import ( - FixedSizeSplitter, + FixedSizeSplitter, _adjust_chunk_start, _adjust_chunk_end, ) from neo4j_graphrag.experimental.components.types import TextChunk @@ -78,3 +78,139 @@ def test_invalid_chunk_overlap() -> None: with pytest.raises(ValueError) as excinfo: FixedSizeSplitter(5, 5) assert "chunk_overlap must be strictly less than chunk_size" in str(excinfo) + + +def test_invalid_chunk_size() -> None: + with pytest.raises(ValueError) as excinfo: + FixedSizeSplitter(0, 0) + assert "chunk_size must be strictly greater than 0" in str(excinfo) + + +@pytest.mark.parametrize( + "text, approximate_start, expected_start", + [ + # Case: approximate_start is at word boundary already + ("Hello World", 6, 6), + # Case: approximate_start is at a whitespace already + ("Hello World", 5, 5), + # Case: approximate_start is at the middle of word and no whitespace is found + ("Hello World", 2, 2), + # Case: approximate_start is at the middle of a word + ("Hello World", 8, 6), + # Case: approximate_start = 0 + ("Hello World", 0, 0), + ], +) +def test_adjust_chunk_start(text, approximate_start, expected_start): + """ + Test that the _adjust_chunk_start function correctly shifts + the start index to avoid breaking words, unless no whitespace is found. + """ + result = _adjust_chunk_start(text, approximate_start) + assert result == expected_start + + +@pytest.mark.parametrize( + "text, start, approximate_end, expected_end", + [ + # Case: approximate_end is at word boundary already + ("Hello World", 0, 5, 5), + # Case: approximate_end is at the middle of a word + ("Hello World", 0, 8, 6), + # Case: approximate_end is at the middle of word and no whitespace is found + ("Hello World", 0, 3, 3), + # Case: adjusted_end == start => fallback to approximate_end + ("Hello World", 6, 7, 7), + # Case: end>=len(text) + ("Hello World", 6, 15, 15), + ], +) +def test_adjust_chunk_end(text, start, approximate_end, expected_end): + """ + Test that the _adjust_chunk_end function correctly shifts + the end index to avoid breaking words, unless no whitespace is found. + """ + result = _adjust_chunk_end(text, start, approximate_end) + assert result == expected_end + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "text, chunk_size, chunk_overlap, approximate, expected_chunks", + [ + # Case: approximate fixed size splitting + ( + "Hello World, this is a test message.", + 10, + 2, + True, + [ + "Hello ", + "World, ", + "this is a ", + "a test ", + "message." + ], + ), + # Case: fixed size splitting + ( + "Hello World, this is a test message.", + 10, + 2, + False, + [ + "Hello Worl", + "rld, this ", + "s is a tes", + "est messag", + "age." + ], + ), + # Case: short text => only one chunk + ( + "Short text", + 20, + 5, + True, + ["Short text"], + ), + # Case: short text => only one chunk + ( + "Short text", + 12, + 4, + True, + ["Short text"], + ), + # Case: text with no spaces + ( + "1234567890", + 5, + 1, + True, + ["12345", "56789", "90"], + ), + ], +) +async def test_fixed_size_splitter_run( + text, chunk_size, chunk_overlap, approximate, expected_chunks +): + """ + Test that 'FixedSizeSplitter.run' returns the expected chunks + for different configurations. + """ + splitter = FixedSizeSplitter( + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + approximate=approximate, + ) + text_chunks = await splitter.run(text) + + # Verify number of chunks + assert len(text_chunks.chunks) == len(expected_chunks) + + # Verify content of each chunk + for i, expected_text in enumerate(expected_chunks): + assert text_chunks.chunks[i].text == expected_text + assert isinstance(text_chunks.chunks[i], TextChunk) + assert text_chunks.chunks[i].index == i From 4fa1c9ac1cf7e6d1e74d4ffdddd76182d3ad772b Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Mon, 20 Jan 2025 13:46:48 +0200 Subject: [PATCH 09/12] Update CHANGELOG, examples and documentation --- CHANGELOG.md | 1 + docs/source/user_guide_kg_builder.rst | 5 ++++- .../build_graph/components/splitters/fixed_size_splitter.py | 3 ++- .../customize/build_graph/pipeline/kg_builder_from_pdf.py | 2 +- .../customize/build_graph/pipeline/kg_builder_from_text.py | 2 +- .../build_graph/pipeline/lexical_graph_builder_from_text.py | 2 +- .../text_to_lexical_graph_to_entity_graph_single_pipeline.py | 2 +- .../text_to_lexical_graph_to_entity_graph_two_pipelines.py | 2 +- 8 files changed, 12 insertions(+), 7 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3a4b5e897..3bf39dd60 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -28,6 +28,7 @@ ### Changed - Updated LLM implementations to handle message history consistently across providers. - The `id_prefix` parameter in the `LexicalGraphConfig` is deprecated. +- Changed the default behaviour of `FixedSizeSplitter` to avoid words cut-off in the chunks whenever it is possible. ### Fixed - IDs for the Document and Chunk nodes in the lexical graph are now randomly generated and unique across multiple runs, fixing issues in the lexical graph where relationships were created between chunks that were created by different pipeline runs. diff --git a/docs/source/user_guide_kg_builder.rst b/docs/source/user_guide_kg_builder.rst index c45590721..ea91ec2cc 100644 --- a/docs/source/user_guide_kg_builder.rst +++ b/docs/source/user_guide_kg_builder.rst @@ -581,9 +581,12 @@ that can be processed within the LLM token limits: from neo4j_graphrag.experimental.components.text_splitters.fixed_size_splitter import FixedSizeSplitter - splitter = FixedSizeSplitter(chunk_size=4000, chunk_overlap=200) + splitter = FixedSizeSplitter(chunk_size=4000, chunk_overlap=200, approximate=False) splitter.run(text="Hello World. Life is beautiful.") +.. note:: + + `approximate` flag is by default set to True to ensure clean chunk start and end (i.e. avoid words cut in the middle) whenever it is possible. Wrappers for LangChain and LlamaIndex text splitters are included in this package: diff --git a/examples/customize/build_graph/components/splitters/fixed_size_splitter.py b/examples/customize/build_graph/components/splitters/fixed_size_splitter.py index 8b2f2cc19..0b97f3938 100644 --- a/examples/customize/build_graph/components/splitters/fixed_size_splitter.py +++ b/examples/customize/build_graph/components/splitters/fixed_size_splitter.py @@ -6,9 +6,10 @@ async def main() -> TextChunks: splitter = FixedSizeSplitter( - # optionally, configure chunk_size and chunk_overlap + # optionally, configure chunk_size, chunk_overlap, and approximate flag # chunk_size=4000, # chunk_overlap=200, + # approximate = False ) chunks = await splitter.run(text="text to split") return chunks diff --git a/examples/customize/build_graph/pipeline/kg_builder_from_pdf.py b/examples/customize/build_graph/pipeline/kg_builder_from_pdf.py index e81d482cd..f418efd73 100644 --- a/examples/customize/build_graph/pipeline/kg_builder_from_pdf.py +++ b/examples/customize/build_graph/pipeline/kg_builder_from_pdf.py @@ -83,7 +83,7 @@ async def define_and_run_pipeline( pipe = Pipeline() pipe.add_component(PdfLoader(), "pdf_loader") pipe.add_component( - FixedSizeSplitter(chunk_size=4000, chunk_overlap=200), "splitter" + FixedSizeSplitter(chunk_size=4000, chunk_overlap=200, approximate=False), "splitter" ) pipe.add_component(SchemaBuilder(), "schema") pipe.add_component( diff --git a/examples/customize/build_graph/pipeline/kg_builder_from_text.py b/examples/customize/build_graph/pipeline/kg_builder_from_text.py index 2beed1248..e8f665170 100644 --- a/examples/customize/build_graph/pipeline/kg_builder_from_text.py +++ b/examples/customize/build_graph/pipeline/kg_builder_from_text.py @@ -58,7 +58,7 @@ async def define_and_run_pipeline( # define the components pipe.add_component( # chunk_size=50 for the sake of this demo - FixedSizeSplitter(chunk_size=4000, chunk_overlap=200), + FixedSizeSplitter(chunk_size=4000, chunk_overlap=200, approximate=False), "splitter", ) pipe.add_component(TextChunkEmbedder(embedder=OpenAIEmbeddings()), "chunk_embedder") diff --git a/examples/customize/build_graph/pipeline/lexical_graph_builder_from_text.py b/examples/customize/build_graph/pipeline/lexical_graph_builder_from_text.py index 4c332802c..62e0197a9 100644 --- a/examples/customize/build_graph/pipeline/lexical_graph_builder_from_text.py +++ b/examples/customize/build_graph/pipeline/lexical_graph_builder_from_text.py @@ -27,7 +27,7 @@ async def main(neo4j_driver: neo4j.Driver) -> PipelineResult: pipe = Pipeline() # define the components pipe.add_component( - FixedSizeSplitter(chunk_size=20, chunk_overlap=1), + FixedSizeSplitter(chunk_size=20, chunk_overlap=1, approximate=False), "splitter", ) pipe.add_component(TextChunkEmbedder(embedder=OpenAIEmbeddings()), "chunk_embedder") diff --git a/examples/customize/build_graph/pipeline/text_to_lexical_graph_to_entity_graph_single_pipeline.py b/examples/customize/build_graph/pipeline/text_to_lexical_graph_to_entity_graph_single_pipeline.py index b7ba60e1d..7d07af441 100644 --- a/examples/customize/build_graph/pipeline/text_to_lexical_graph_to_entity_graph_single_pipeline.py +++ b/examples/customize/build_graph/pipeline/text_to_lexical_graph_to_entity_graph_single_pipeline.py @@ -56,7 +56,7 @@ async def define_and_run_pipeline( pipe = Pipeline() # define the components pipe.add_component( - FixedSizeSplitter(chunk_size=200, chunk_overlap=50), + FixedSizeSplitter(chunk_size=200, chunk_overlap=50,approximate=False), "splitter", ) pipe.add_component(TextChunkEmbedder(embedder=OpenAIEmbeddings()), "chunk_embedder") diff --git a/examples/customize/build_graph/pipeline/text_to_lexical_graph_to_entity_graph_two_pipelines.py b/examples/customize/build_graph/pipeline/text_to_lexical_graph_to_entity_graph_two_pipelines.py index 7cfa4cbb5..db9a7f714 100644 --- a/examples/customize/build_graph/pipeline/text_to_lexical_graph_to_entity_graph_two_pipelines.py +++ b/examples/customize/build_graph/pipeline/text_to_lexical_graph_to_entity_graph_two_pipelines.py @@ -47,7 +47,7 @@ async def build_lexical_graph( pipe = Pipeline() # define the components pipe.add_component( - FixedSizeSplitter(chunk_size=200, chunk_overlap=50), + FixedSizeSplitter(chunk_size=200, chunk_overlap=50, approximate=False), "splitter", ) pipe.add_component(TextChunkEmbedder(embedder=OpenAIEmbeddings()), "chunk_embedder") From 643060b34c98da3b61ba980ab53406b59768b307 Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Mon, 20 Jan 2025 18:13:59 +0200 Subject: [PATCH 10/12] Ruff formatting --- .../pipeline/kg_builder_from_pdf.py | 19 +++------ .../pipeline/kg_builder_from_text.py | 15 +++---- .../lexical_graph_builder_from_text.py | 7 ++-- ...l_graph_to_entity_graph_single_pipeline.py | 13 +++--- ...cal_graph_to_entity_graph_two_pipelines.py | 11 ++--- .../text_splitters/fixed_size_splitter.py | 8 ++-- .../test_fixed_size_splitter.py | 41 ++++++++++--------- 7 files changed, 50 insertions(+), 64 deletions(-) diff --git a/examples/customize/build_graph/pipeline/kg_builder_from_pdf.py b/examples/customize/build_graph/pipeline/kg_builder_from_pdf.py index f418efd73..01672259c 100644 --- a/examples/customize/build_graph/pipeline/kg_builder_from_pdf.py +++ b/examples/customize/build_graph/pipeline/kg_builder_from_pdf.py @@ -17,7 +17,6 @@ import asyncio import logging -import neo4j from neo4j_graphrag.experimental.components.entity_relation_extractor import ( LLMEntityRelationExtractor, OnError, @@ -35,12 +34,12 @@ from neo4j_graphrag.experimental.pipeline.pipeline import PipelineResult from neo4j_graphrag.llm import LLMInterface, OpenAILLM +import neo4j + logging.basicConfig(level=logging.INFO) -async def define_and_run_pipeline( - neo4j_driver: neo4j.Driver, llm: LLMInterface -) -> PipelineResult: +async def define_and_run_pipeline(neo4j_driver: neo4j.Driver, llm: LLMInterface) -> PipelineResult: from neo4j_graphrag.experimental.pipeline import Pipeline # Instantiate Entity and Relation objects @@ -57,9 +56,7 @@ async def define_and_run_pipeline( ), ] relations = [ - SchemaRelation( - label="SITUATED_AT", description="Indicates the location of a person." - ), + SchemaRelation(label="SITUATED_AT", description="Indicates the location of a person."), SchemaRelation( label="LED_BY", description="Indicates the leader of an organization.", @@ -68,9 +65,7 @@ async def define_and_run_pipeline( label="OWNS", description="Indicates the ownership of an item such as a Horcrux.", ), - SchemaRelation( - label="INTERACTS", description="The interaction between two people." - ), + SchemaRelation(label="INTERACTS", description="The interaction between two people."), ] potential_schema = [ ("PERSON", "SITUATED_AT", "LOCATION"), @@ -131,9 +126,7 @@ async def main() -> PipelineResult: "response_format": {"type": "json_object"}, }, ) - driver = neo4j.GraphDatabase.driver( - "bolt://localhost:7687", auth=("neo4j", "password") - ) + driver = neo4j.GraphDatabase.driver("bolt://localhost:7687", auth=("neo4j", "password")) res = await define_and_run_pipeline(driver, llm) driver.close() await llm.async_client.close() diff --git a/examples/customize/build_graph/pipeline/kg_builder_from_text.py b/examples/customize/build_graph/pipeline/kg_builder_from_text.py index e8f665170..390280dda 100644 --- a/examples/customize/build_graph/pipeline/kg_builder_from_text.py +++ b/examples/customize/build_graph/pipeline/kg_builder_from_text.py @@ -16,7 +16,6 @@ import asyncio -import neo4j from neo4j_graphrag.embeddings.openai import OpenAIEmbeddings from neo4j_graphrag.experimental.components.embedder import TextChunkEmbedder from neo4j_graphrag.experimental.components.entity_relation_extractor import ( @@ -37,10 +36,10 @@ from neo4j_graphrag.experimental.pipeline.pipeline import PipelineResult from neo4j_graphrag.llm import LLMInterface, OpenAILLM +import neo4j + -async def define_and_run_pipeline( - neo4j_driver: neo4j.Driver, llm: LLMInterface -) -> PipelineResult: +async def define_and_run_pipeline(neo4j_driver: neo4j.Driver, llm: LLMInterface) -> PipelineResult: """This is where we define and run the KG builder pipeline, instantiating a few components: - Text Splitter: in this example we use the fixed size text splitter @@ -75,9 +74,7 @@ async def define_and_run_pipeline( # and how the output of previous components must be used pipe.connect("splitter", "chunk_embedder", input_config={"text_chunks": "splitter"}) pipe.connect("schema", "extractor", input_config={"schema": "schema"}) - pipe.connect( - "chunk_embedder", "extractor", input_config={"chunks": "chunk_embedder"} - ) + pipe.connect("chunk_embedder", "extractor", input_config={"chunks": "chunk_embedder"}) pipe.connect( "extractor", "writer", @@ -148,9 +145,7 @@ async def main() -> PipelineResult: "response_format": {"type": "json_object"}, }, ) - driver = neo4j.GraphDatabase.driver( - "bolt://localhost:7687", auth=("neo4j", "password") - ) + driver = neo4j.GraphDatabase.driver("bolt://localhost:7687", auth=("neo4j", "password")) res = await define_and_run_pipeline(driver, llm) driver.close() await llm.async_client.close() diff --git a/examples/customize/build_graph/pipeline/lexical_graph_builder_from_text.py b/examples/customize/build_graph/pipeline/lexical_graph_builder_from_text.py index 62e0197a9..069562cd6 100644 --- a/examples/customize/build_graph/pipeline/lexical_graph_builder_from_text.py +++ b/examples/customize/build_graph/pipeline/lexical_graph_builder_from_text.py @@ -2,7 +2,6 @@ import asyncio -import neo4j from neo4j_graphrag.embeddings.openai import OpenAIEmbeddings from neo4j_graphrag.experimental.components.embedder import TextChunkEmbedder from neo4j_graphrag.experimental.components.kg_writer import Neo4jWriter @@ -14,6 +13,8 @@ from neo4j_graphrag.experimental.pipeline import Pipeline from neo4j_graphrag.experimental.pipeline.pipeline import PipelineResult +import neo4j + async def main(neo4j_driver: neo4j.Driver) -> PipelineResult: """This is where we define and run the Lexical Graph builder pipeline, instantiating @@ -78,7 +79,5 @@ async def main(neo4j_driver: neo4j.Driver) -> PipelineResult: if __name__ == "__main__": - with neo4j.GraphDatabase.driver( - "bolt://localhost:7687", auth=("neo4j", "password") - ) as driver: + with neo4j.GraphDatabase.driver("bolt://localhost:7687", auth=("neo4j", "password")) as driver: print(asyncio.run(main(driver))) diff --git a/examples/customize/build_graph/pipeline/text_to_lexical_graph_to_entity_graph_single_pipeline.py b/examples/customize/build_graph/pipeline/text_to_lexical_graph_to_entity_graph_single_pipeline.py index 7d07af441..ce25a5d49 100644 --- a/examples/customize/build_graph/pipeline/text_to_lexical_graph_to_entity_graph_single_pipeline.py +++ b/examples/customize/build_graph/pipeline/text_to_lexical_graph_to_entity_graph_single_pipeline.py @@ -7,7 +7,6 @@ import asyncio -import neo4j from neo4j_graphrag.embeddings.openai import OpenAIEmbeddings from neo4j_graphrag.experimental.components.embedder import TextChunkEmbedder from neo4j_graphrag.experimental.components.entity_relation_extractor import ( @@ -29,6 +28,8 @@ from neo4j_graphrag.experimental.pipeline.pipeline import PipelineResult from neo4j_graphrag.llm import LLMInterface, OpenAILLM +import neo4j + async def define_and_run_pipeline( neo4j_driver: neo4j.Driver, @@ -56,7 +57,7 @@ async def define_and_run_pipeline( pipe = Pipeline() # define the components pipe.add_component( - FixedSizeSplitter(chunk_size=200, chunk_overlap=50,approximate=False), + FixedSizeSplitter(chunk_size=200, chunk_overlap=50, approximate=False), "splitter", ) pipe.add_component(TextChunkEmbedder(embedder=OpenAIEmbeddings()), "chunk_embedder") @@ -92,9 +93,7 @@ async def define_and_run_pipeline( ) # define the execution order of component # and how the output of previous components must be used - pipe.connect( - "chunk_embedder", "extractor", input_config={"chunks": "chunk_embedder"} - ) + pipe.connect("chunk_embedder", "extractor", input_config={"chunks": "chunk_embedder"}) pipe.connect("schema", "extractor", input_config={"schema": "schema"}) pipe.connect( "extractor", @@ -189,7 +188,5 @@ async def main(driver: neo4j.Driver) -> PipelineResult: if __name__ == "__main__": - with neo4j.GraphDatabase.driver( - "bolt://localhost:7687", auth=("neo4j", "password") - ) as driver: + with neo4j.GraphDatabase.driver("bolt://localhost:7687", auth=("neo4j", "password")) as driver: print(asyncio.run(main(driver))) diff --git a/examples/customize/build_graph/pipeline/text_to_lexical_graph_to_entity_graph_two_pipelines.py b/examples/customize/build_graph/pipeline/text_to_lexical_graph_to_entity_graph_two_pipelines.py index db9a7f714..fea205de5 100644 --- a/examples/customize/build_graph/pipeline/text_to_lexical_graph_to_entity_graph_two_pipelines.py +++ b/examples/customize/build_graph/pipeline/text_to_lexical_graph_to_entity_graph_two_pipelines.py @@ -8,7 +8,6 @@ import asyncio -import neo4j from neo4j_graphrag.embeddings.openai import OpenAIEmbeddings from neo4j_graphrag.experimental.components.embedder import TextChunkEmbedder from neo4j_graphrag.experimental.components.entity_relation_extractor import ( @@ -31,6 +30,8 @@ from neo4j_graphrag.experimental.pipeline.pipeline import PipelineResult from neo4j_graphrag.llm import LLMInterface, OpenAILLM +import neo4j + async def build_lexical_graph( neo4j_driver: neo4j.Driver, @@ -200,15 +201,11 @@ async def main(driver: neo4j.Driver) -> PipelineResult: }, ) await build_lexical_graph(driver, lexical_graph_config, text=text) - res = await read_chunk_and_perform_entity_extraction( - driver, llm, lexical_graph_config - ) + res = await read_chunk_and_perform_entity_extraction(driver, llm, lexical_graph_config) await llm.async_client.close() return res if __name__ == "__main__": - with neo4j.GraphDatabase.driver( - "bolt://localhost:7687", auth=("neo4j", "password") - ) as driver: + with neo4j.GraphDatabase.driver("bolt://localhost:7687", auth=("neo4j", "password")) as driver: print(asyncio.run(main(driver))) diff --git a/src/neo4j_graphrag/experimental/components/text_splitters/fixed_size_splitter.py b/src/neo4j_graphrag/experimental/components/text_splitters/fixed_size_splitter.py index 40fdac767..515387f43 100644 --- a/src/neo4j_graphrag/experimental/components/text_splitters/fixed_size_splitter.py +++ b/src/neo4j_graphrag/experimental/components/text_splitters/fixed_size_splitter.py @@ -58,7 +58,7 @@ def _adjust_chunk_end(text: str, start: int, approximate_end: int) -> int: """ end = approximate_end if end < len(text): - while end > start and not text[end].isspace() and not text[end-1].isspace(): + while end > start and not text[end].isspace() and not text[end - 1].isspace(): end -= 1 # fallback if no whitespace is found @@ -92,7 +92,9 @@ class FixedSizeSplitter(TextSplitter): """ @validate_call - def __init__(self, chunk_size: int = 4000, chunk_overlap: int = 200, approximate: bool = True) -> None: + def __init__( + self, chunk_size: int = 4000, chunk_overlap: int = 200, approximate: bool = True + ) -> None: if chunk_size <= 0: raise ValueError("chunk_size must be strictly greater than 0") if chunk_overlap >= chunk_size: @@ -131,7 +133,7 @@ async def run(self, text: str) -> TextChunks: end = _adjust_chunk_end(text, start, approximate_end) # when avoiding splitting words in the middle is not possible, revert to # initial chunk end and skip adjusting next chunk start - skip_adjust_chunk_start = (end == approximate_end) + skip_adjust_chunk_start = end == approximate_end else: # apply fixed size splitting with possibly words cut in half at chunk # boundaries diff --git a/tests/unit/experimental/components/text_splitters/test_fixed_size_splitter.py b/tests/unit/experimental/components/text_splitters/test_fixed_size_splitter.py index 6193d56a6..f505777ae 100644 --- a/tests/unit/experimental/components/text_splitters/test_fixed_size_splitter.py +++ b/tests/unit/experimental/components/text_splitters/test_fixed_size_splitter.py @@ -16,7 +16,9 @@ import pytest from neo4j_graphrag.experimental.components.text_splitters.fixed_size_splitter import ( - FixedSizeSplitter, _adjust_chunk_start, _adjust_chunk_end, + FixedSizeSplitter, + _adjust_chunk_end, + _adjust_chunk_start, ) from neo4j_graphrag.experimental.components.types import TextChunk @@ -101,7 +103,11 @@ def test_invalid_chunk_size() -> None: ("Hello World", 0, 0), ], ) -def test_adjust_chunk_start(text, approximate_start, expected_start): +def test_adjust_chunk_start( + text: str, + approximate_start: int, + expected_start: int +) -> None: """ Test that the _adjust_chunk_start function correctly shifts the start index to avoid breaking words, unless no whitespace is found. @@ -125,7 +131,12 @@ def test_adjust_chunk_start(text, approximate_start, expected_start): ("Hello World", 6, 15, 15), ], ) -def test_adjust_chunk_end(text, start, approximate_end, expected_end): +def test_adjust_chunk_end( + text: str, + start: int, + approximate_end: int, + expected_end: int +) -> None: """ Test that the _adjust_chunk_end function correctly shifts the end index to avoid breaking words, unless no whitespace is found. @@ -144,13 +155,7 @@ def test_adjust_chunk_end(text, start, approximate_end, expected_end): 10, 2, True, - [ - "Hello ", - "World, ", - "this is a ", - "a test ", - "message." - ], + ["Hello ", "World, ", "this is a ", "a test ", "message."], ), # Case: fixed size splitting ( @@ -158,13 +163,7 @@ def test_adjust_chunk_end(text, start, approximate_end, expected_end): 10, 2, False, - [ - "Hello Worl", - "rld, this ", - "s is a tes", - "est messag", - "age." - ], + ["Hello Worl", "rld, this ", "s is a tes", "est messag", "age."], ), # Case: short text => only one chunk ( @@ -193,8 +192,12 @@ def test_adjust_chunk_end(text, start, approximate_end, expected_end): ], ) async def test_fixed_size_splitter_run( - text, chunk_size, chunk_overlap, approximate, expected_chunks -): + text: str, + chunk_size: int, + chunk_overlap: int, + approximate: bool, + expected_chunks: list[str] +) -> None: """ Test that 'FixedSizeSplitter.run' returns the expected chunks for different configurations. From fbbd2c434f6c30ab45fe7ed2f92f6a7b18783bff Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Mon, 20 Jan 2025 18:51:38 +0200 Subject: [PATCH 11/12] More ruff formatting --- .../test_fixed_size_splitter.py | 19 +++---------------- 1 file changed, 3 insertions(+), 16 deletions(-) diff --git a/tests/unit/experimental/components/text_splitters/test_fixed_size_splitter.py b/tests/unit/experimental/components/text_splitters/test_fixed_size_splitter.py index f505777ae..ca575b65d 100644 --- a/tests/unit/experimental/components/text_splitters/test_fixed_size_splitter.py +++ b/tests/unit/experimental/components/text_splitters/test_fixed_size_splitter.py @@ -103,11 +103,7 @@ def test_invalid_chunk_size() -> None: ("Hello World", 0, 0), ], ) -def test_adjust_chunk_start( - text: str, - approximate_start: int, - expected_start: int -) -> None: +def test_adjust_chunk_start(text: str, approximate_start: int, expected_start: int) -> None: """ Test that the _adjust_chunk_start function correctly shifts the start index to avoid breaking words, unless no whitespace is found. @@ -131,12 +127,7 @@ def test_adjust_chunk_start( ("Hello World", 6, 15, 15), ], ) -def test_adjust_chunk_end( - text: str, - start: int, - approximate_end: int, - expected_end: int -) -> None: +def test_adjust_chunk_end(text: str, start: int, approximate_end: int, expected_end: int) -> None: """ Test that the _adjust_chunk_end function correctly shifts the end index to avoid breaking words, unless no whitespace is found. @@ -192,11 +183,7 @@ def test_adjust_chunk_end( ], ) async def test_fixed_size_splitter_run( - text: str, - chunk_size: int, - chunk_overlap: int, - approximate: bool, - expected_chunks: list[str] + text: str, chunk_size: int, chunk_overlap: int, approximate: bool, expected_chunks: list[str] ) -> None: """ Test that 'FixedSizeSplitter.run' returns the expected chunks From 5a81f874904604a64648ae9e87e55c31cbea9b31 Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Tue, 21 Jan 2025 12:10:05 +0200 Subject: [PATCH 12/12] More ruff formatting --- .../pipeline/kg_builder_from_pdf.py | 19 ++++++++++++++----- .../pipeline/kg_builder_from_text.py | 12 +++++++++--- .../lexical_graph_builder_from_text.py | 4 +++- ...l_graph_to_entity_graph_single_pipeline.py | 8 ++++++-- ...cal_graph_to_entity_graph_two_pipelines.py | 8 ++++++-- .../test_fixed_size_splitter.py | 14 +++++++++++--- 6 files changed, 49 insertions(+), 16 deletions(-) diff --git a/examples/customize/build_graph/pipeline/kg_builder_from_pdf.py b/examples/customize/build_graph/pipeline/kg_builder_from_pdf.py index 01672259c..ab11206da 100644 --- a/examples/customize/build_graph/pipeline/kg_builder_from_pdf.py +++ b/examples/customize/build_graph/pipeline/kg_builder_from_pdf.py @@ -39,7 +39,9 @@ logging.basicConfig(level=logging.INFO) -async def define_and_run_pipeline(neo4j_driver: neo4j.Driver, llm: LLMInterface) -> PipelineResult: +async def define_and_run_pipeline( + neo4j_driver: neo4j.Driver, llm: LLMInterface +) -> PipelineResult: from neo4j_graphrag.experimental.pipeline import Pipeline # Instantiate Entity and Relation objects @@ -56,7 +58,9 @@ async def define_and_run_pipeline(neo4j_driver: neo4j.Driver, llm: LLMInterface) ), ] relations = [ - SchemaRelation(label="SITUATED_AT", description="Indicates the location of a person."), + SchemaRelation( + label="SITUATED_AT", description="Indicates the location of a person." + ), SchemaRelation( label="LED_BY", description="Indicates the leader of an organization.", @@ -65,7 +69,9 @@ async def define_and_run_pipeline(neo4j_driver: neo4j.Driver, llm: LLMInterface) label="OWNS", description="Indicates the ownership of an item such as a Horcrux.", ), - SchemaRelation(label="INTERACTS", description="The interaction between two people."), + SchemaRelation( + label="INTERACTS", description="The interaction between two people." + ), ] potential_schema = [ ("PERSON", "SITUATED_AT", "LOCATION"), @@ -78,7 +84,8 @@ async def define_and_run_pipeline(neo4j_driver: neo4j.Driver, llm: LLMInterface) pipe = Pipeline() pipe.add_component(PdfLoader(), "pdf_loader") pipe.add_component( - FixedSizeSplitter(chunk_size=4000, chunk_overlap=200, approximate=False), "splitter" + FixedSizeSplitter(chunk_size=4000, chunk_overlap=200, approximate=False), + "splitter", ) pipe.add_component(SchemaBuilder(), "schema") pipe.add_component( @@ -126,7 +133,9 @@ async def main() -> PipelineResult: "response_format": {"type": "json_object"}, }, ) - driver = neo4j.GraphDatabase.driver("bolt://localhost:7687", auth=("neo4j", "password")) + driver = neo4j.GraphDatabase.driver( + "bolt://localhost:7687", auth=("neo4j", "password") + ) res = await define_and_run_pipeline(driver, llm) driver.close() await llm.async_client.close() diff --git a/examples/customize/build_graph/pipeline/kg_builder_from_text.py b/examples/customize/build_graph/pipeline/kg_builder_from_text.py index 390280dda..907a02825 100644 --- a/examples/customize/build_graph/pipeline/kg_builder_from_text.py +++ b/examples/customize/build_graph/pipeline/kg_builder_from_text.py @@ -39,7 +39,9 @@ import neo4j -async def define_and_run_pipeline(neo4j_driver: neo4j.Driver, llm: LLMInterface) -> PipelineResult: +async def define_and_run_pipeline( + neo4j_driver: neo4j.Driver, llm: LLMInterface +) -> PipelineResult: """This is where we define and run the KG builder pipeline, instantiating a few components: - Text Splitter: in this example we use the fixed size text splitter @@ -74,7 +76,9 @@ async def define_and_run_pipeline(neo4j_driver: neo4j.Driver, llm: LLMInterface) # and how the output of previous components must be used pipe.connect("splitter", "chunk_embedder", input_config={"text_chunks": "splitter"}) pipe.connect("schema", "extractor", input_config={"schema": "schema"}) - pipe.connect("chunk_embedder", "extractor", input_config={"chunks": "chunk_embedder"}) + pipe.connect( + "chunk_embedder", "extractor", input_config={"chunks": "chunk_embedder"} + ) pipe.connect( "extractor", "writer", @@ -145,7 +149,9 @@ async def main() -> PipelineResult: "response_format": {"type": "json_object"}, }, ) - driver = neo4j.GraphDatabase.driver("bolt://localhost:7687", auth=("neo4j", "password")) + driver = neo4j.GraphDatabase.driver( + "bolt://localhost:7687", auth=("neo4j", "password") + ) res = await define_and_run_pipeline(driver, llm) driver.close() await llm.async_client.close() diff --git a/examples/customize/build_graph/pipeline/lexical_graph_builder_from_text.py b/examples/customize/build_graph/pipeline/lexical_graph_builder_from_text.py index 069562cd6..c2fbdec4e 100644 --- a/examples/customize/build_graph/pipeline/lexical_graph_builder_from_text.py +++ b/examples/customize/build_graph/pipeline/lexical_graph_builder_from_text.py @@ -79,5 +79,7 @@ async def main(neo4j_driver: neo4j.Driver) -> PipelineResult: if __name__ == "__main__": - with neo4j.GraphDatabase.driver("bolt://localhost:7687", auth=("neo4j", "password")) as driver: + with neo4j.GraphDatabase.driver( + "bolt://localhost:7687", auth=("neo4j", "password") + ) as driver: print(asyncio.run(main(driver))) diff --git a/examples/customize/build_graph/pipeline/text_to_lexical_graph_to_entity_graph_single_pipeline.py b/examples/customize/build_graph/pipeline/text_to_lexical_graph_to_entity_graph_single_pipeline.py index ce25a5d49..6867d9068 100644 --- a/examples/customize/build_graph/pipeline/text_to_lexical_graph_to_entity_graph_single_pipeline.py +++ b/examples/customize/build_graph/pipeline/text_to_lexical_graph_to_entity_graph_single_pipeline.py @@ -93,7 +93,9 @@ async def define_and_run_pipeline( ) # define the execution order of component # and how the output of previous components must be used - pipe.connect("chunk_embedder", "extractor", input_config={"chunks": "chunk_embedder"}) + pipe.connect( + "chunk_embedder", "extractor", input_config={"chunks": "chunk_embedder"} + ) pipe.connect("schema", "extractor", input_config={"schema": "schema"}) pipe.connect( "extractor", @@ -188,5 +190,7 @@ async def main(driver: neo4j.Driver) -> PipelineResult: if __name__ == "__main__": - with neo4j.GraphDatabase.driver("bolt://localhost:7687", auth=("neo4j", "password")) as driver: + with neo4j.GraphDatabase.driver( + "bolt://localhost:7687", auth=("neo4j", "password") + ) as driver: print(asyncio.run(main(driver))) diff --git a/examples/customize/build_graph/pipeline/text_to_lexical_graph_to_entity_graph_two_pipelines.py b/examples/customize/build_graph/pipeline/text_to_lexical_graph_to_entity_graph_two_pipelines.py index fea205de5..0fd354db4 100644 --- a/examples/customize/build_graph/pipeline/text_to_lexical_graph_to_entity_graph_two_pipelines.py +++ b/examples/customize/build_graph/pipeline/text_to_lexical_graph_to_entity_graph_two_pipelines.py @@ -201,11 +201,15 @@ async def main(driver: neo4j.Driver) -> PipelineResult: }, ) await build_lexical_graph(driver, lexical_graph_config, text=text) - res = await read_chunk_and_perform_entity_extraction(driver, llm, lexical_graph_config) + res = await read_chunk_and_perform_entity_extraction( + driver, llm, lexical_graph_config + ) await llm.async_client.close() return res if __name__ == "__main__": - with neo4j.GraphDatabase.driver("bolt://localhost:7687", auth=("neo4j", "password")) as driver: + with neo4j.GraphDatabase.driver( + "bolt://localhost:7687", auth=("neo4j", "password") + ) as driver: print(asyncio.run(main(driver))) diff --git a/tests/unit/experimental/components/text_splitters/test_fixed_size_splitter.py b/tests/unit/experimental/components/text_splitters/test_fixed_size_splitter.py index ca575b65d..d03006ddb 100644 --- a/tests/unit/experimental/components/text_splitters/test_fixed_size_splitter.py +++ b/tests/unit/experimental/components/text_splitters/test_fixed_size_splitter.py @@ -103,7 +103,9 @@ def test_invalid_chunk_size() -> None: ("Hello World", 0, 0), ], ) -def test_adjust_chunk_start(text: str, approximate_start: int, expected_start: int) -> None: +def test_adjust_chunk_start( + text: str, approximate_start: int, expected_start: int +) -> None: """ Test that the _adjust_chunk_start function correctly shifts the start index to avoid breaking words, unless no whitespace is found. @@ -127,7 +129,9 @@ def test_adjust_chunk_start(text: str, approximate_start: int, expected_start: i ("Hello World", 6, 15, 15), ], ) -def test_adjust_chunk_end(text: str, start: int, approximate_end: int, expected_end: int) -> None: +def test_adjust_chunk_end( + text: str, start: int, approximate_end: int, expected_end: int +) -> None: """ Test that the _adjust_chunk_end function correctly shifts the end index to avoid breaking words, unless no whitespace is found. @@ -183,7 +187,11 @@ def test_adjust_chunk_end(text: str, start: int, approximate_end: int, expected_ ], ) async def test_fixed_size_splitter_run( - text: str, chunk_size: int, chunk_overlap: int, approximate: bool, expected_chunks: list[str] + text: str, + chunk_size: int, + chunk_overlap: int, + approximate: bool, + expected_chunks: list[str], ) -> None: """ Test that 'FixedSizeSplitter.run' returns the expected chunks