Skip to content

Commit bad1731

Browse files
committed
fix: YAML escape gaps, encoding fallback, simplify tokens & logging
1 parent 06acb7f commit bad1731

File tree

9 files changed

+87
-129
lines changed

9 files changed

+87
-129
lines changed

pyproject.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,12 +53,13 @@ classifiers = [
5353
dynamic = [ "version" ] # Version is still managed in version.py
5454
dependencies = [
5555
"pathspec>=0.11,<2.0",
56-
"tiktoken>=0.7,<1.0",
56+
"tiktoken>=0.8,<1.0",
5757
]
5858
optional-dependencies.dev = [
5959
"black>=23.0.0,<27.0",
6060
# Build and release
6161
"build>=0.10,<2.0",
62+
"charset-normalizer>=3.0,<4.0",
6263
"coverage>=7.0,<8.0",
6364
"hypothesis>=6.0,<7.0",
6465
"import-linter>=2.0,<3.0",
@@ -87,6 +88,7 @@ optional-dependencies.embeddings = [
8788
"sentence-transformers>=3.0,<6.0",
8889
]
8990
optional-dependencies.full = [
91+
"charset-normalizer>=3.0,<4.0",
9092
"lxml>=5.0,<7.0",
9193
"mistune>=3.0,<4.0",
9294
"pysbd>=0.3,<1.0",

src/treemapper/cli.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ class ParsedArgs:
143143
whitelist_file: Path | None
144144
output_file: Path | None
145145
no_default_ignores: bool
146-
verbosity: int
146+
verbosity: int | str
147147
output_format: str
148148
max_depth: int | None
149149
no_content: bool
@@ -315,11 +315,7 @@ def parse_args() -> ParsedArgs:
315315
ignore_file = _resolve_ignore_file(args.ignore, root_dir)
316316
whitelist_file = _resolve_whitelist_file(args.whitelist, root_dir)
317317

318-
log_level_map = {"error": 0, "warning": 1, "info": 2, "debug": 3}
319-
verbosity = log_level_map[args.log_level]
320-
321-
if args.quiet:
322-
verbosity = 0
318+
verbosity = "error" if args.quiet else args.log_level
323319

324320
return ParsedArgs(
325321
root_dir=root_dir,

src/treemapper/logger.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,20 @@
22

33
PACKAGE_LOGGER_NAME = "treemapper"
44

5+
_LOG_LEVEL_MAP = {
6+
"error": logging.ERROR,
7+
"warning": logging.WARNING,
8+
"info": logging.INFO,
9+
"debug": logging.DEBUG,
10+
}
511

6-
def setup_logging(verbosity: int) -> None:
7-
level_map = {
8-
0: logging.ERROR,
9-
1: logging.WARNING,
10-
2: logging.INFO,
11-
3: logging.DEBUG,
12-
}
13-
level = level_map.get(verbosity, logging.INFO)
12+
13+
def setup_logging(verbosity: int | str) -> None:
14+
if isinstance(verbosity, str):
15+
level = _LOG_LEVEL_MAP.get(verbosity, logging.INFO)
16+
else:
17+
int_to_level = {0: logging.ERROR, 1: logging.WARNING, 2: logging.INFO, 3: logging.DEBUG}
18+
level = int_to_level.get(verbosity, logging.INFO)
1419

1520
pkg_logger = logging.getLogger(PACKAGE_LOGGER_NAME)
1621
pkg_logger.setLevel(level)

src/treemapper/tokens.py

Lines changed: 1 addition & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,6 @@
77

88
logger = logging.getLogger(__name__)
99

10-
CHUNK_SIZE = 500_000
11-
CHUNK_THRESHOLD = 1_000_000
12-
SAMPLE_CHAR_THRESHOLD = 50_000_000 # 50M characters - use sampling above this
13-
SAMPLE_COUNT = 5
14-
1510

1611
@dataclass
1712
class TokenCountResult:
@@ -42,43 +37,7 @@ def count_tokens(text: str, encoding: str = "o200k_base") -> TokenCountResult:
4237
logger.debug("tiktoken unavailable, using char/4 approximation")
4338
return TokenCountResult(len(text) // 4, False, "approximation")
4439

45-
text_len = len(text)
46-
if text_len <= CHUNK_THRESHOLD:
47-
logger.debug("Token counting: exact mode (%d chars)", text_len)
48-
return TokenCountResult(len(encoder.encode(text)), True, encoding)
49-
50-
if text_len > SAMPLE_CHAR_THRESHOLD:
51-
logger.debug("Token counting: sampled mode (%d chars, %d samples)", text_len, SAMPLE_COUNT)
52-
return _count_tokens_sampled(text, text_len, encoder, encoding)
53-
54-
logger.debug("Token counting: chunked mode (%d chars)", text_len)
55-
total = 0
56-
for i in range(0, text_len, CHUNK_SIZE):
57-
chunk = text[i : i + CHUNK_SIZE]
58-
total += len(encoder.encode(chunk))
59-
return TokenCountResult(total, False, encoding)
60-
61-
62-
def _count_tokens_sampled(text: str, text_len: int, encoder: Any, encoding: str) -> TokenCountResult:
63-
num_chunks = text_len // CHUNK_SIZE
64-
step = max(1, num_chunks // SAMPLE_COUNT)
65-
sampled_tokens = 0
66-
sampled_chars = 0
67-
68-
for i in range(0, num_chunks, step):
69-
start = i * CHUNK_SIZE
70-
chunk = text[start : start + CHUNK_SIZE]
71-
sampled_tokens += len(encoder.encode(chunk))
72-
sampled_chars += len(chunk)
73-
if sampled_chars >= SAMPLE_COUNT * CHUNK_SIZE:
74-
break
75-
76-
if sampled_chars == 0:
77-
return TokenCountResult(text_len // 4, False, "approximation")
78-
79-
tokens_per_char = sampled_tokens / sampled_chars
80-
estimated_total = int(tokens_per_char * text_len)
81-
return TokenCountResult(estimated_total, False, encoding)
40+
return TokenCountResult(len(encoder.encode(text)), True, encoding)
8241

8342

8443
def _format_size(byte_size: int) -> str:

src/treemapper/tree.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,12 +212,35 @@ def _detect_binary_in_sample(file_path: Path, file_size: int) -> tuple[bytes | N
212212
return raw_bytes, None
213213

214214

215+
def _try_charset_normalizer(raw_bytes: bytes, file_path: Path) -> str | None:
216+
try:
217+
from charset_normalizer import from_bytes
218+
219+
matches = from_bytes(raw_bytes)
220+
best = matches.best()
221+
if best is not None:
222+
logger.info("Decoded %s as %s via charset-normalizer", file_path.name, best.encoding)
223+
return str(best)
224+
except ImportError:
225+
pass
226+
except Exception:
227+
pass
228+
return None
229+
230+
215231
def _decode_file_content(raw_bytes: bytes, file_path: Path, file_size: int) -> str:
216232
if b"\x00" in raw_bytes[BINARY_DETECTION_SAMPLE_SIZE:]:
217233
logger.debug("Detected binary file %s (null in remainder)", file_path.name)
218234
return _format_binary_placeholder(file_size)
219235

220-
content = raw_bytes.decode("utf-8")
236+
try:
237+
content = raw_bytes.decode("utf-8")
238+
except UnicodeDecodeError:
239+
fallback = _try_charset_normalizer(raw_bytes, file_path)
240+
if fallback is None:
241+
raise
242+
content = fallback
243+
221244
content = content.replace("\r\n", "\n").replace("\r", "\n")
222245
if not content:
223246
return ""

src/treemapper/writer.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,26 +15,30 @@
1515

1616
_YAML_PROBLEMATIC_RE = re.compile(r"[\r\x00\x85\u2028\u2029]")
1717

18-
_YAML_STRING_ESCAPE_PATTERN = re.compile(r'[\\"\n\r\x00\x85\u2028\u2029]')
18+
_YAML_STRING_ESCAPE_PATTERN = re.compile(r'[\\"\n\r\x00\x08\x0c\x85\u2028\u2029]')
1919
_YAML_STRING_ESCAPE_MAP = {
2020
"\\": "\\\\",
2121
'"': '\\"',
2222
"\n": "\\n",
2323
"\r": "\\r",
2424
"\x00": "\\0",
25+
"\x08": "\\b",
26+
"\x0c": "\\f",
2527
"\x85": "\\x85",
2628
"\u2028": "\\u2028",
2729
"\u2029": "\\u2029",
2830
}
2931

30-
_YAML_CONTENT_ESCAPE_PATTERN = re.compile(r'[\\"\n\t\r\x00\x85\u2028\u2029]')
32+
_YAML_CONTENT_ESCAPE_PATTERN = re.compile(r'[\\"\n\t\r\x00\x08\x0c\x85\u2028\u2029]')
3133
_YAML_CONTENT_ESCAPE_MAP = {
3234
"\\": "\\\\",
3335
'"': '\\"',
3436
"\n": "\\n",
3537
"\t": "\\t",
3638
"\r": "\\r",
3739
"\x00": "\\0",
40+
"\x08": "\\b",
41+
"\x0c": "\\f",
3842
"\x85": "\\x85",
3943
"\u2028": "\\u2028",
4044
"\u2029": "\\u2029",

tests/test_basic.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -260,10 +260,20 @@ def test_unicode_content_and_encoding_errors(temp_project, run_mapper, caplog):
260260

261261
assert cp1251_node is not None, "'cp1251.txt' not found"
262262
cp1251_content = cp1251_node.get("content", "")
263-
assert "unreadable" in cp1251_content, f"CP1251 file should be marked unreadable, got: {cp1251_content!r}"
264-
assert any(
265-
"cp1251.txt" in record.message for record in caplog.records if record.levelno >= logging.WARNING
266-
), "Expected WARNING about cp1251.txt not found in logs"
263+
try:
264+
from charset_normalizer import from_bytes # noqa: F401
265+
266+
has_charset_normalizer = True
267+
except ImportError:
268+
has_charset_normalizer = False
269+
if has_charset_normalizer:
270+
assert (
271+
cp1251_content and "unreadable" not in cp1251_content
272+
), f"charset-normalizer should decode CP1251, got: {cp1251_content!r}"
273+
else:
274+
assert (
275+
"unreadable" in cp1251_content
276+
), f"CP1251 file should be marked unreadable without charset-normalizer, got: {cp1251_content!r}"
267277

268278
assert binary_node is not None, "'binary.bin' not found"
269279
binary_content = binary_node.get("content", "")

tests/test_complete_coverage.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -472,7 +472,12 @@ def test_non_utf8_placeholder(self, tmp_path):
472472
node = find_node_by_path(tree, ["non_utf8.txt"])
473473
assert node is not None
474474
content = node.get("content", "")
475-
assert "<unreadable content" in content
475+
try:
476+
from charset_normalizer import from_bytes # noqa: F401
477+
478+
assert content and "<unreadable content" not in content
479+
except ImportError:
480+
assert "<unreadable content" in content
476481

477482
@pytest.mark.skipif(
478483
sys.platform == "win32",

tests/test_tokens.py

Lines changed: 18 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,24 @@ def test_newlines_tabs(self):
7575
result = count_tokens("line1\nline2\tline3\r\nline4")
7676
assert result.count > 0
7777

78+
def test_large_text_exact(self):
79+
large_text = "word " * 500_000
80+
result = count_tokens(large_text)
81+
assert result.count > 0
82+
assert result.is_exact is True
83+
84+
def test_exact_count_matches_direct_encode(self):
85+
from treemapper.tokens import _get_encoder
86+
87+
encoder = _get_encoder("o200k_base")
88+
if encoder is None:
89+
return
90+
91+
text = "word " * 5_000
92+
exact_count = len(encoder.encode(text))
93+
result = count_tokens(text)
94+
assert result.count == exact_count
95+
7896

7997
class TestPrintTokenSummary:
8098
def test_prints_to_stderr(self):
@@ -139,67 +157,3 @@ def test_different_encodings_cached_separately(self):
139157
r1 = count_tokens("test", encoding="o200k_base")
140158
r2 = count_tokens("test", encoding="cl100k_base")
141159
assert r1.encoding != r2.encoding or r1.encoding == "approximation"
142-
143-
144-
class TestChunkedCounting:
145-
def test_chunked_counting_for_large_text(self):
146-
from treemapper.tokens import CHUNK_THRESHOLD
147-
148-
large_text = "word " * (CHUNK_THRESHOLD // 5 + 1000)
149-
result = count_tokens(large_text)
150-
assert result.count > 0
151-
# Chunked counting is not exact due to BPE context sensitivity
152-
# is_exact=False with real encoding, or approximation fallback
153-
assert result.is_exact is False
154-
155-
def test_chunked_count_close_to_exact(self, monkeypatch):
156-
import treemapper.tokens as tokens_module
157-
from treemapper.tokens import _get_encoder
158-
159-
encoder = _get_encoder("o200k_base")
160-
if encoder is None:
161-
return
162-
163-
text = "word " * 5_000
164-
exact_count = len(encoder.encode(text))
165-
166-
monkeypatch.setattr(tokens_module, "CHUNK_THRESHOLD", 1_000)
167-
chunked_result = count_tokens(text)
168-
169-
assert abs(chunked_result.count - exact_count) / exact_count < 0.05
170-
171-
def test_small_text_not_chunked(self):
172-
small_text = "hello world"
173-
result = count_tokens(small_text)
174-
assert result.count > 0
175-
176-
177-
class TestSampledCounting:
178-
def test_sampling_threshold_is_reasonable(self):
179-
from treemapper.tokens import SAMPLE_CHAR_THRESHOLD
180-
181-
assert SAMPLE_CHAR_THRESHOLD >= 1_000_000
182-
183-
def test_very_large_text_uses_sampling(self, monkeypatch):
184-
import treemapper.tokens as tokens_module
185-
from treemapper.tokens import _count_tokens_sampled, _get_encoder
186-
187-
encoder = _get_encoder("o200k_base")
188-
if encoder is None:
189-
return
190-
191-
monkeypatch.setattr(tokens_module, "SAMPLE_CHAR_THRESHOLD", 10_000)
192-
large_text = "x" * 15_000
193-
result = _count_tokens_sampled(large_text, len(large_text), encoder, "o200k_base")
194-
assert result.is_exact is False
195-
assert result.count > 0
196-
197-
def test_sampled_result_is_approximate(self, monkeypatch):
198-
import treemapper.tokens as tokens_module
199-
200-
monkeypatch.setattr(tokens_module, "SAMPLE_CHAR_THRESHOLD", 10_000)
201-
monkeypatch.setattr(tokens_module, "CHUNK_THRESHOLD", 1_000)
202-
text = "word " * 5_000
203-
result = count_tokens(text)
204-
if result.encoding != "approximation":
205-
assert result.is_exact is False

0 commit comments

Comments
 (0)