Skip to content

Commit 1305937

Browse files
committed
Add API for TOML spec version detection
This adds new functions `loads_with_info()` and `load_with_info()` that return a `ParseResult` containing: - `data`: The parsed TOML data (same as before) - `spec_version`: Minimum TOML spec version required ("1.0" or "1.1") - `features`: Set of TOML 1.1 features used in the document Currently detects the `\e` escape sequence (merged in PR hukkin#201) as a TOML 1.1 feature. TODO comments mark where detection should be added for pending TOML 1.1 features: - `\xHH` hex escape (PR hukkin#202) - Newlines/trailing commas in inline tables (PR hukkin#200) - Optional seconds in datetime/time (PR hukkin#203) The existing `loads()` and `load()` functions remain unchanged for backward compatibility. New public exports: - `loads_with_info()`, `load_with_info()`: Parse with version info - `ParseResult`: Frozen dataclass with parsing results - `TOMLFeature`: Constants for TOML 1.1 feature identifiers Addresses: hukkin#273
1 parent 0921abf commit 1305937

File tree

3 files changed

+361
-29
lines changed

3 files changed

+361
-29
lines changed

src/tomli/__init__.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,23 @@
22
# SPDX-FileCopyrightText: 2021 Taneli Hukkinen
33
# Licensed to PSF under a Contributor Agreement.
44

5-
__all__ = ("loads", "load", "TOMLDecodeError")
5+
__all__ = (
6+
"loads",
7+
"load",
8+
"loads_with_info",
9+
"load_with_info",
10+
"TOMLDecodeError",
11+
"ParseResult",
12+
"TOMLFeature",
13+
)
614
__version__ = "2.3.0" # DO NOT EDIT THIS LINE MANUALLY. LET bump2version UTILITY DO IT
715

8-
from ._parser import TOMLDecodeError, load, loads
16+
from ._parser import (
17+
TOMLDecodeError,
18+
load,
19+
load_with_info,
20+
loads,
21+
loads_with_info,
22+
ParseResult,
23+
TOMLFeature,
24+
)

src/tomli/_parser.py

Lines changed: 166 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from __future__ import annotations
66

7+
from dataclasses import dataclass
78
import sys
89
from types import MappingProxyType
910

@@ -23,6 +24,57 @@
2324

2425
from ._types import Key, ParseFloat, Pos
2526

27+
28+
class TOMLFeature:
29+
"""Constants for TOML 1.1 features.
30+
31+
These identify specific features that require TOML 1.1 spec compliance.
32+
"""
33+
34+
ESCAPE_CHAR = "escape_char" # \e escape sequence
35+
HEX_ESCAPE = "hex_escape" # \xHH escape sequence
36+
INLINE_TABLE_NEWLINE = "inline_table_newline" # Newlines in inline tables
37+
INLINE_TABLE_TRAILING_COMMA = "inline_table_trailing_comma" # Trailing comma in inline tables
38+
OPTIONAL_SECONDS = "optional_seconds" # Date-time/time without seconds
39+
40+
41+
@dataclass(frozen=True)
42+
class ParseResult:
43+
"""Result of parsing a TOML document with version information.
44+
45+
Attributes:
46+
data: The parsed TOML data as a dictionary.
47+
spec_version: The minimum TOML spec version required ("1.0" or "1.1").
48+
features: Set of TOML 1.1 feature constants that were used.
49+
"""
50+
51+
data: dict[str, Any]
52+
spec_version: str
53+
features: frozenset[str]
54+
55+
56+
class VersionContext:
57+
"""Tracks TOML 1.1 features used during parsing."""
58+
59+
__slots__ = ("_features",)
60+
61+
def __init__(self) -> None:
62+
self._features: set[str] = set()
63+
64+
def mark(self, feature: str) -> None:
65+
"""Mark a 1.1 feature as used."""
66+
self._features.add(feature)
67+
68+
@property
69+
def spec_version(self) -> str:
70+
"""Return the minimum spec version required."""
71+
return "1.1" if self._features else "1.0"
72+
73+
@property
74+
def features(self) -> frozenset[str]:
75+
"""Return the set of 1.1 features used."""
76+
return frozenset(self._features)
77+
2678
# Inline tables/arrays are implemented using recursion. Pathologically
2779
# nested documents cause pure Python to raise RecursionError (which is OK),
2880
# but mypyc binary wheels will crash unrecoverably (not OK). According to
@@ -146,8 +198,42 @@ def load(__fp: IO[bytes], *, parse_float: ParseFloat = float) -> dict[str, Any]:
146198
return loads(s, parse_float=parse_float)
147199

148200

201+
def load_with_info(__fp: IO[bytes], *, parse_float: ParseFloat = float) -> ParseResult:
202+
"""Parse TOML from a binary file object with version information.
203+
204+
Returns a ParseResult containing the parsed data, the minimum
205+
TOML spec version required, and the set of TOML 1.1 features used.
206+
"""
207+
b = __fp.read()
208+
try:
209+
s = b.decode()
210+
except AttributeError:
211+
raise TypeError(
212+
"File must be opened in binary mode, e.g. use `open('foo.toml', 'rb')`"
213+
) from None
214+
return loads_with_info(s, parse_float=parse_float)
215+
216+
149217
def loads(__s: str, *, parse_float: ParseFloat = float) -> dict[str, Any]:
150218
"""Parse TOML from a string."""
219+
return loads_with_info(__s, parse_float=parse_float).data
220+
221+
222+
def loads_with_info(__s: str, *, parse_float: ParseFloat = float) -> ParseResult:
223+
"""Parse TOML from a string with version information.
224+
225+
Returns a ParseResult containing the parsed data, the minimum
226+
TOML spec version required, and the set of TOML 1.1 features used.
227+
228+
Example:
229+
>>> result = loads_with_info('key = "value with \\e escape"')
230+
>>> result.data
231+
{'key': 'value with \\x1b escape'}
232+
>>> result.spec_version
233+
'1.1'
234+
>>> result.features
235+
frozenset({'escape_char'})
236+
"""
151237

152238
# The spec allows converting "\r\n" to "\n", even in string
153239
# literals. Let's do so to simplify parsing.
@@ -161,6 +247,7 @@ def loads(__s: str, *, parse_float: ParseFloat = float) -> dict[str, Any]:
161247
out = Output()
162248
header: Key = ()
163249
parse_float = make_safe_parse_float(parse_float)
250+
version_ctx = VersionContext()
164251

165252
# Parse one statement at a time
166253
# (typically means one line in TOML source)
@@ -184,7 +271,7 @@ def loads(__s: str, *, parse_float: ParseFloat = float) -> dict[str, Any]:
184271
pos += 1
185272
continue
186273
if char in KEY_INITIAL_CHARS:
187-
pos = key_value_rule(src, pos, out, header, parse_float)
274+
pos = key_value_rule(src, pos, out, header, parse_float, version_ctx)
188275
pos = skip_chars(src, pos, TOML_WS)
189276
elif char == "[":
190277
try:
@@ -214,7 +301,11 @@ def loads(__s: str, *, parse_float: ParseFloat = float) -> dict[str, Any]:
214301
)
215302
pos += 1
216303

217-
return out.data.dict
304+
return ParseResult(
305+
data=out.data.dict,
306+
spec_version=version_ctx.spec_version,
307+
features=version_ctx.features,
308+
)
218309

219310

220311
class Flags:
@@ -411,9 +502,14 @@ def create_list_rule(src: str, pos: Pos, out: Output) -> tuple[Pos, Key]:
411502

412503

413504
def key_value_rule(
414-
src: str, pos: Pos, out: Output, header: Key, parse_float: ParseFloat
505+
src: str,
506+
pos: Pos,
507+
out: Output,
508+
header: Key,
509+
parse_float: ParseFloat,
510+
version_ctx: VersionContext | None = None,
415511
) -> Pos:
416-
pos, key, value = parse_key_value_pair(src, pos, parse_float, nest_lvl=0)
512+
pos, key, value = parse_key_value_pair(src, pos, parse_float, nest_lvl=0, version_ctx=version_ctx)
417513
key_parent, key_stem = key[:-1], key[-1]
418514
abs_key_parent = header + key_parent
419515

@@ -445,7 +541,11 @@ def key_value_rule(
445541

446542

447543
def parse_key_value_pair(
448-
src: str, pos: Pos, parse_float: ParseFloat, nest_lvl: int
544+
src: str,
545+
pos: Pos,
546+
parse_float: ParseFloat,
547+
nest_lvl: int,
548+
version_ctx: VersionContext | None = None,
449549
) -> tuple[Pos, Key, Any]:
450550
pos, key = parse_key(src, pos)
451551
try:
@@ -456,7 +556,7 @@ def parse_key_value_pair(
456556
raise TOMLDecodeError("Expected '=' after a key in a key/value pair", src, pos)
457557
pos += 1
458558
pos = skip_chars(src, pos, TOML_WS)
459-
pos, value = parse_value(src, pos, parse_float, nest_lvl)
559+
pos, value = parse_value(src, pos, parse_float, nest_lvl, version_ctx)
460560
return pos, key, value
461561

462562

@@ -494,13 +594,19 @@ def parse_key_part(src: str, pos: Pos) -> tuple[Pos, str]:
494594
raise TOMLDecodeError("Invalid initial character for a key part", src, pos)
495595

496596

497-
def parse_one_line_basic_str(src: str, pos: Pos) -> tuple[Pos, str]:
597+
def parse_one_line_basic_str(
598+
src: str, pos: Pos, version_ctx: VersionContext | None = None
599+
) -> tuple[Pos, str]:
498600
pos += 1
499-
return parse_basic_str(src, pos, multiline=False)
601+
return parse_basic_str(src, pos, multiline=False, version_ctx=version_ctx)
500602

501603

502604
def parse_array(
503-
src: str, pos: Pos, parse_float: ParseFloat, nest_lvl: int
605+
src: str,
606+
pos: Pos,
607+
parse_float: ParseFloat,
608+
nest_lvl: int,
609+
version_ctx: VersionContext | None = None,
504610
) -> tuple[Pos, list[Any]]:
505611
pos += 1
506612
array: list[Any] = []
@@ -509,7 +615,7 @@ def parse_array(
509615
if src.startswith("]", pos):
510616
return pos + 1, array
511617
while True:
512-
pos, val = parse_value(src, pos, parse_float, nest_lvl)
618+
pos, val = parse_value(src, pos, parse_float, nest_lvl, version_ctx)
513619
array.append(val)
514620
pos = skip_comments_and_array_ws(src, pos)
515621

@@ -526,8 +632,13 @@ def parse_array(
526632

527633

528634
def parse_inline_table(
529-
src: str, pos: Pos, parse_float: ParseFloat, nest_lvl: int
635+
src: str,
636+
pos: Pos,
637+
parse_float: ParseFloat,
638+
nest_lvl: int,
639+
version_ctx: VersionContext | None = None,
530640
) -> tuple[Pos, dict[str, Any]]:
641+
# TODO: Add newlines and trailing comma detection for TOML 1.1 (PR #200)
531642
pos += 1
532643
nested_dict = NestedDict()
533644
flags = Flags()
@@ -536,7 +647,7 @@ def parse_inline_table(
536647
if src.startswith("}", pos):
537648
return pos + 1, nested_dict.dict
538649
while True:
539-
pos, key, value = parse_key_value_pair(src, pos, parse_float, nest_lvl)
650+
pos, key, value = parse_key_value_pair(src, pos, parse_float, nest_lvl, version_ctx)
540651
key_parent, key_stem = key[:-1], key[-1]
541652
if flags.is_(key, Flags.FROZEN):
542653
raise TOMLDecodeError(f"Cannot mutate immutable namespace {key}", src, pos)
@@ -560,7 +671,11 @@ def parse_inline_table(
560671

561672

562673
def parse_basic_str_escape(
563-
src: str, pos: Pos, *, multiline: bool = False
674+
src: str,
675+
pos: Pos,
676+
*,
677+
multiline: bool = False,
678+
version_ctx: VersionContext | None = None,
564679
) -> tuple[Pos, str]:
565680
escape_id = src[pos : pos + 2]
566681
pos += 2
@@ -582,14 +697,21 @@ def parse_basic_str_escape(
582697
return parse_hex_char(src, pos, 4)
583698
if escape_id == "\\U":
584699
return parse_hex_char(src, pos, 8)
700+
# TODO: Add \xHH escape detection for TOML 1.1 (PR #202)
585701
try:
586-
return pos, BASIC_STR_ESCAPE_REPLACEMENTS[escape_id]
702+
replacement = BASIC_STR_ESCAPE_REPLACEMENTS[escape_id]
587703
except KeyError:
588704
raise TOMLDecodeError("Unescaped '\\' in a string", src, pos) from None
705+
# Detect TOML 1.1 escape sequences
706+
if escape_id == "\\e" and version_ctx is not None:
707+
version_ctx.mark(TOMLFeature.ESCAPE_CHAR)
708+
return pos, replacement
589709

590710

591-
def parse_basic_str_escape_multiline(src: str, pos: Pos) -> tuple[Pos, str]:
592-
return parse_basic_str_escape(src, pos, multiline=True)
711+
def parse_basic_str_escape_multiline(
712+
src: str, pos: Pos, version_ctx: VersionContext | None = None
713+
) -> tuple[Pos, str]:
714+
return parse_basic_str_escape(src, pos, multiline=True, version_ctx=version_ctx)
593715

594716

595717
def parse_hex_char(src: str, pos: Pos, hex_len: int) -> tuple[Pos, str]:
@@ -614,7 +736,13 @@ def parse_literal_str(src: str, pos: Pos) -> tuple[Pos, str]:
614736
return pos + 1, src[start_pos:pos] # Skip ending apostrophe
615737

616738

617-
def parse_multiline_str(src: str, pos: Pos, *, literal: bool) -> tuple[Pos, str]:
739+
def parse_multiline_str(
740+
src: str,
741+
pos: Pos,
742+
*,
743+
literal: bool,
744+
version_ctx: VersionContext | None = None,
745+
) -> tuple[Pos, str]:
618746
pos += 3
619747
if src.startswith("\n", pos):
620748
pos += 1
@@ -632,7 +760,7 @@ def parse_multiline_str(src: str, pos: Pos, *, literal: bool) -> tuple[Pos, str]
632760
pos = end_pos + 3
633761
else:
634762
delim = '"'
635-
pos, result = parse_basic_str(src, pos, multiline=True)
763+
pos, result = parse_basic_str(src, pos, multiline=True, version_ctx=version_ctx)
636764

637765
# Add at maximum two extra apostrophes/quotes if the end sequence
638766
# is 4 or 5 chars long instead of just 3.
@@ -645,13 +773,17 @@ def parse_multiline_str(src: str, pos: Pos, *, literal: bool) -> tuple[Pos, str]
645773
return pos, result + (delim * 2)
646774

647775

648-
def parse_basic_str(src: str, pos: Pos, *, multiline: bool) -> tuple[Pos, str]:
776+
def parse_basic_str(
777+
src: str,
778+
pos: Pos,
779+
*,
780+
multiline: bool,
781+
version_ctx: VersionContext | None = None,
782+
) -> tuple[Pos, str]:
649783
if multiline:
650784
error_on = ILLEGAL_MULTILINE_BASIC_STR_CHARS
651-
parse_escapes = parse_basic_str_escape_multiline
652785
else:
653786
error_on = ILLEGAL_BASIC_STR_CHARS
654-
parse_escapes = parse_basic_str_escape
655787
result = ""
656788
start_pos = pos
657789
while True:
@@ -668,7 +800,9 @@ def parse_basic_str(src: str, pos: Pos, *, multiline: bool) -> tuple[Pos, str]:
668800
continue
669801
if char == "\\":
670802
result += src[start_pos:pos]
671-
pos, parsed_escape = parse_escapes(src, pos)
803+
pos, parsed_escape = parse_basic_str_escape(
804+
src, pos, multiline=multiline, version_ctx=version_ctx
805+
)
672806
result += parsed_escape
673807
start_pos = pos
674808
continue
@@ -678,7 +812,11 @@ def parse_basic_str(src: str, pos: Pos, *, multiline: bool) -> tuple[Pos, str]:
678812

679813

680814
def parse_value(
681-
src: str, pos: Pos, parse_float: ParseFloat, nest_lvl: int
815+
src: str,
816+
pos: Pos,
817+
parse_float: ParseFloat,
818+
nest_lvl: int,
819+
version_ctx: VersionContext | None = None,
682820
) -> tuple[Pos, Any]:
683821
if nest_lvl > MAX_INLINE_NESTING:
684822
# Pure Python should have raised RecursionError already.
@@ -698,8 +836,8 @@ def parse_value(
698836
# Basic strings
699837
if char == '"':
700838
if src.startswith('"""', pos):
701-
return parse_multiline_str(src, pos, literal=False)
702-
return parse_one_line_basic_str(src, pos)
839+
return parse_multiline_str(src, pos, literal=False, version_ctx=version_ctx)
840+
return parse_one_line_basic_str(src, pos, version_ctx=version_ctx)
703841

704842
# Literal strings
705843
if char == "'":
@@ -717,13 +855,14 @@ def parse_value(
717855

718856
# Arrays
719857
if char == "[":
720-
return parse_array(src, pos, parse_float, nest_lvl + 1)
858+
return parse_array(src, pos, parse_float, nest_lvl + 1, version_ctx)
721859

722860
# Inline tables
723861
if char == "{":
724-
return parse_inline_table(src, pos, parse_float, nest_lvl + 1)
862+
return parse_inline_table(src, pos, parse_float, nest_lvl + 1, version_ctx)
725863

726864
# Dates and times
865+
# TODO: Add optional seconds detection for TOML 1.1 (PR #203)
727866
datetime_match = RE_DATETIME.match(src, pos)
728867
if datetime_match:
729868
try:

0 commit comments

Comments
 (0)