|
| 1 | +import os |
1 | 2 | from abc import abstractmethod |
| 3 | +from functools import cache |
2 | 4 | from io import TextIOWrapper |
3 | | -from typing import Generator |
| 5 | +from typing import Generator, Optional |
| 6 | + |
| 7 | +from pygments.lexer import Lexer |
| 8 | +from pygments.lexers import guess_lexer_for_filename |
| 9 | +from pygments.util import ClassNotFound |
| 10 | +from tree_sitter import Node |
| 11 | +from tree_sitter_language_pack import get_parser |
4 | 12 |
|
5 | 13 |
|
6 | 14 | class ChunkerBase: |
@@ -59,3 +67,68 @@ def chunk(self, data: TextIOWrapper) -> Generator[str, None, None]: |
59 | 67 | yield output |
60 | 68 | if len(new_chars) < step_size: |
61 | 69 | return |
| 70 | + |
| 71 | + |
| 72 | +class TreeSitterChunker(ChunkerBase): |
| 73 | + def __init__(self, chunk_size: int = -1, overlap_ratio: float = 0.2): |
| 74 | + super().__init__() |
| 75 | + assert isinstance(chunk_size, int), "chunk_size parameter must be an integer" |
| 76 | + assert 0 <= overlap_ratio < 1, ( |
| 77 | + "Overlap ratio has to be a float between 0 (inclusive) and 1 (exclusive)." |
| 78 | + ) |
| 79 | + self.__chunk_size = chunk_size |
| 80 | + self.__overlap_ratio = overlap_ratio |
| 81 | + |
| 82 | + def __chunk_node(self, node: Node, text: str) -> Generator[str, None, None]: |
| 83 | + current_chunk = "" |
| 84 | + for child in node.children: |
| 85 | + child_length = child.end_byte - child.start_byte |
| 86 | + if child_length > self.__chunk_size: |
| 87 | + if current_chunk: |
| 88 | + yield current_chunk |
| 89 | + current_chunk = "" |
| 90 | + yield from self.__chunk_node(child, text) |
| 91 | + elif len(current_chunk) + child_length > self.__chunk_size: |
| 92 | + yield current_chunk |
| 93 | + current_chunk = text[child.start_byte : child.end_byte] |
| 94 | + else: |
| 95 | + current_chunk += text[child.start_byte : child.end_byte] |
| 96 | + if current_chunk: |
| 97 | + yield current_chunk |
| 98 | + |
| 99 | + @cache |
| 100 | + def __guess_type(self, path: str, content: str) -> Optional[Lexer]: |
| 101 | + try: |
| 102 | + return guess_lexer_for_filename(path, content) |
| 103 | + |
| 104 | + except ClassNotFound: |
| 105 | + return None |
| 106 | + |
| 107 | + def chunk(self, data: str) -> Generator[str, None, None]: |
| 108 | + """ |
| 109 | + data: path to the file |
| 110 | + """ |
| 111 | + assert os.path.isfile(data) |
| 112 | + with open(data) as fin: |
| 113 | + content = fin.read() |
| 114 | + parser = None |
| 115 | + lexer = self.__guess_type(data, content) |
| 116 | + if lexer is not None: |
| 117 | + lang_names = [lexer.name] |
| 118 | + lang_names.extend(lexer.aliases) |
| 119 | + for name in lang_names: |
| 120 | + try: |
| 121 | + parser = get_parser(name.lower()) |
| 122 | + break |
| 123 | + except LookupError: |
| 124 | + pass |
| 125 | + |
| 126 | + if parser is None: |
| 127 | + # fall back to naive chunking |
| 128 | + yield from StringChunker(self.__chunk_size, self.__overlap_ratio).chunk( |
| 129 | + content |
| 130 | + ) |
| 131 | + else: |
| 132 | + content_bytes = content.encode() |
| 133 | + tree = parser.parse(content_bytes) |
| 134 | + yield from self.__chunk_node(tree.root_node, content) |
0 commit comments