diff --git a/outlines/types/__init__.py b/outlines/types/__init__.py index 7d9145b31..31d7f6397 100644 --- a/outlines/types/__init__.py +++ b/outlines/types/__init__.py @@ -96,37 +96,40 @@ def __getattr__(self, name): integer = Regex(r"[+-]?(0|[1-9][0-9]*)") boolean = Regex("(True|False)") number = Regex(rf"{integer.pattern}(\.[0-9]+)?([eE][+-][0-9]+)?") -date = Regex(r"(\d{4})-(0[1-9]|1[0-2])-([0-2][0-9]|3[0-1])") -time = Regex(r"([0-1][0-9]|2[0-3]):([0-5][0-9]):([0-5][0-9])") -datetime = Regex(rf"({date.pattern})(\s)({time.pattern})") +date = Regex(r"(\d{4})-(0[1-9]|1[0-2])-([0-2][0-9]|3[0-1])", requires_quoting=True) +time = Regex(r"([0-1][0-9]|2[0-3]):([0-5][0-9]):([0-5][0-9])", requires_quoting=True) +datetime = Regex(rf"({date.pattern})(\s)({time.pattern})", requires_quoting=True) # Basic regex types digit = Regex(r"\d") -char = Regex(r"\w") -newline = Regex(r"(\r\n|\r|\n)") # Matched new lines on Linux, Windows & MacOS -whitespace = Regex(r"\s") -hex_str = Regex(r"(0x)?[a-fA-F0-9]+") +char = Regex(r"\w", requires_quoting=True) +newline = Regex(r"(\r\n|\r|\n)", requires_quoting=True) # Matched new lines on Linux, Windows & MacOS +whitespace = Regex(r"\s", requires_quoting=True) +hex_str = Regex(r"(0x)?[a-fA-F0-9]+", requires_quoting=True) uuid4 = Regex( r"[a-fA-F0-9]{8}-" r"[a-fA-F0-9]{4}-" r"4[a-fA-F0-9]{3}-" r"[89abAB][a-fA-F0-9]{3}-" - r"[a-fA-F0-9]{12}" + r"[a-fA-F0-9]{12}", + requires_quoting=True ) ipv4 = Regex( r"((25[0-5]|2[0-4][0-9]|1?[0-9]{1,2})\.){3}" - r"(25[0-5]|2[0-4][0-9]|1?[0-9]{1,2})" + r"(25[0-5]|2[0-4][0-9]|1?[0-9]{1,2})", + requires_quoting=True ) # Document-specific types -sentence = Regex(r"[A-Z].*\s*[.!?]") -paragraph = Regex(rf"{sentence.pattern}(?:\s+{sentence.pattern})*\n+") +sentence = Regex(r"[A-Z].*\s*[.!?]", requires_quoting=True) +paragraph = Regex(rf"{sentence.pattern}(?:\s+{sentence.pattern})*\n+", requires_quoting=True) # The following regex is FRC 5322 compliant and was found at: # https://emailregex.com/ email = Regex( - r"""(?:[a-z0-9!#$%&'*+/=?^_`{|}~-]+(?:\.[a-z0-9!#$%&'*+/=?^_`{|}~-]+)*|"(?:[\x01-\x08\x0b\x0c\x0e-\x1f\x21\x23-\x5b\x5d-\x7f]|\\[\x01-\x09\x0b\x0c\x0e-\x7f])*")@(?:(?:[a-z0-9](?:[a-z0-9-]*[a-z0-9])?\.)+[a-z0-9](?:[a-z0-9-]*[a-z0-9])?|\[(?:(?:(2(5[0-5]|[0-4][0-9])|1[0-9][0-9]|[1-9]?[0-9]))\.){3}(?:(2(5[0-5]|[0-4][0-9])|1[0-9][0-9]|[1-9]?[0-9])|[a-z0-9-]*[a-z0-9]:(?:[\x01-\x08\x0b\x0c\x0e-\x1f\x21-\x5a\x53-\x7f]|\\[\x01-\x09\x0b\x0c\x0e-\x7f])+)\])""" + r"""(?:[a-z0-9!#$%&'*+/=?^_`{|}~-]+(?:\.[a-z0-9!#$%&'*+/=?^_`{|}~-]+)*|"(?:[\x01-\x08\x0b\x0c\x0e-\x1f\x21\x23-\x5b\x5d-\x7f]|\\[\x01-\x09\x0b\x0c\x0e-\x7f])*")@(?:(?:[a-z0-9](?:[a-z0-9-]*[a-z0-9])?\.)+[a-z0-9](?:[a-z0-9-]*[a-z0-9])?|\[(?:(?:(2(5[0-5]|[0-4][0-9])|1[0-9][0-9]|[1-9]?[0-9]))\.){3}(?:(2(5[0-5]|[0-4][0-9])|1[0-9][0-9]|[1-9]?[0-9])|[a-z0-9-]*[a-z0-9]:(?:[\x01-\x08\x0b\x0c\x0e-\x1f\x21-\x5a\x53-\x7f]|\\[\x01-\x09\x0b\x0c\x0e-\x7f])+)\])""", + requires_quoting=True ) # Matches any ISBN number. Note that this is not completely correct as not all @@ -136,5 +139,6 @@ def __getattr__(self, name): # # TODO: The check digit can only be computed by calling a function to compute it dynamically isbn = Regex( - r"(?:ISBN(?:-1[03])?:? )?(?=[0-9X]{10}$|(?=(?:[0-9]+[- ]){3})[- 0-9X]{13}$|97[89][0-9]{10}$|(?=(?:[0-9]+[- ]){4})[- 0-9]{17}$)(?:97[89][- ]?)?[0-9]{1,5}[- ]?[0-9]+[- ]?[0-9]+[- ]?[0-9X]" + r"(?:ISBN(?:-1[03])?:? )?(?=[0-9X]{10}$|(?=(?:[0-9]+[- ]){3})[- 0-9X]{13}$|97[89][0-9]{10}$|(?=(?:[0-9]+[- ]){4})[- 0-9]{17}$)(?:97[89][- ]?)?[0-9]{1,5}[- ]?[0-9]+[- ]?[0-9]+[- ]?[0-9X]", + requires_quoting=True ) diff --git a/outlines/types/dsl.py b/outlines/types/dsl.py index 5a12a9b68..5daf08e58 100644 --- a/outlines/types/dsl.py +++ b/outlines/types/dsl.py @@ -15,7 +15,6 @@ import json import re import sys -import warnings from dataclasses import dataclass from enum import EnumMeta from types import FunctionType @@ -96,6 +95,7 @@ class also handles validation. >>> age: age_type """ + apply_quotation: bool = False def __add__(self: "Term", other: "Term") -> "Sequence": if is_str_instance(other): @@ -174,6 +174,10 @@ def display_ascii_tree(self, indent="", is_last=True) -> str: def _display_node(self): raise NotImplementedError + @property + def requires_quoting(self) -> bool: + raise NotImplementedError + def _display_children(self, indent: str) -> str: """Display the children of this node. Override in subclasses with children.""" return "" @@ -203,10 +207,17 @@ def zero_or_more(self) -> "KleeneStar": return zero_or_more(self) +### Element Terms + + @dataclass class String(Term): + """Class representing a string.""" value: str + def requires_quoting(self) -> bool: + return True + def _display_node(self) -> str: return f"String('{self.value}')" @@ -214,17 +225,17 @@ def __repr__(self): return f"String(value='{self.value}')" -@dataclass class Regex(Term): - """Class representing a regular expression. + """Class representing a regular expression.""" + pattern: str + _requires_quoting: bool = False - Parameters - ---------- - pattern - The regular expression as a string. + def __init__(self, pattern: str, requires_quoting: bool = False): + self.pattern = pattern + self._requires_quoting = requires_quoting - """ - pattern: str + def requires_quoting(self) -> bool: + return self._requires_quoting def _display_node(self) -> str: return f"Regex('{self.pattern}')" @@ -235,16 +246,12 @@ def __repr__(self): @dataclass class CFG(Term): - """Class representing a context-free grammar. - - Parameters - ---------- - definition - The definition of the context-free grammar as a string. - - """ + """Class representing a context-free grammar.""" definition: str + def requires_quoting(self) -> bool: + return True + def _display_node(self) -> str: return f"CFG('{self.definition}')" @@ -283,6 +290,9 @@ class JsonSchema(Term): genSON schema builder. """ + schema: str + whitespace_pattern: OptionalType[str] + def __init__( self, schema: Union[ @@ -331,6 +341,9 @@ def __init__( def __post_init__(self): jsonschema.Draft7Validator.check_schema(json.loads(self.schema)) + def requires_quoting(self) -> bool: + return True + def _display_node(self) -> str: return f"JsonSchema('{self.schema}')" @@ -367,6 +380,9 @@ def from_file(cls, path: str) -> "JsonSchema": return cls(schema) +### Multiple choice terms + + @dataclass class Choice(Term): """Class representing a choice between different items. @@ -378,6 +394,16 @@ class Choice(Term): """ items: List[Any] + _requires_quoting: bool = False + + def requires_quoting(self) -> bool: + return ( + self._requires_quoting + or any( + python_types_to_terms(item).requires_quoting + for item in self.items + ) + ) def _display_node(self) -> str: return f"Choice({repr(self.items)})" @@ -386,10 +412,40 @@ def __repr__(self): return f"Choice(items={repr(self.items)})" +@dataclass +class Alternatives(Term): + terms: List[Term] + _requires_quoting: bool = False + + def requires_quoting(self) -> bool: + return ( + self._requires_quoting + or any(term.requires_quoting for term in self.terms) + ) + + def _display_node(self) -> str: + return "Alternatives(|)" + + def _display_children(self, indent: str) -> str: + return "".join( + term.display_ascii_tree(indent, i == len(self.terms) - 1) + for i, term in enumerate(self.terms) + ) + + def __repr__(self): + return f"Alternatives(terms={repr(self.terms)})" + + +### Quantifier terms + + @dataclass class KleeneStar(Term): term: Term + def requires_quoting(self) -> bool: + return self.term.requires_quoting + def _display_node(self) -> str: return "KleeneStar(*)" @@ -404,6 +460,9 @@ def __repr__(self): class KleenePlus(Term): term: Term + def requires_quoting(self) -> bool: + return self.term.requires_quoting + def _display_node(self) -> str: return "KleenePlus(+)" @@ -418,6 +477,9 @@ def __repr__(self): class Optional(Term): term: Term + def requires_quoting(self) -> bool: + return self.term.requires_quoting + def _display_node(self) -> str: return "Optional(?)" @@ -428,45 +490,14 @@ def __repr__(self): return f"Optional(term={repr(self.term)})" -@dataclass -class Alternatives(Term): - terms: List[Term] - - def _display_node(self) -> str: - return "Alternatives(|)" - - def _display_children(self, indent: str) -> str: - return "".join( - term.display_ascii_tree(indent, i == len(self.terms) - 1) - for i, term in enumerate(self.terms) - ) - - def __repr__(self): - return f"Alternatives(terms={repr(self.terms)})" - - -@dataclass -class Sequence(Term): - terms: List[Term] - - def _display_node(self) -> str: - return "Sequence" - - def _display_children(self, indent: str) -> str: - return "".join( - term.display_ascii_tree(indent, i == len(self.terms) - 1) - for i, term in enumerate(self.terms) - ) - - def __repr__(self): - return f"Sequence(terms={repr(self.terms)})" - - @dataclass class QuantifyExact(Term): term: Term count: int + def requires_quoting(self) -> bool: + return self.term.requires_quoting + def _display_node(self) -> str: return f"Quantify({{{self.count}}})" @@ -482,6 +513,9 @@ class QuantifyMinimum(Term): term: Term min_count: int + def requires_quoting(self) -> bool: + return self.term.requires_quoting + def _display_node(self) -> str: return f"Quantify({{{self.min_count},}})" @@ -499,6 +533,9 @@ class QuantifyMaximum(Term): term: Term max_count: int + def requires_quoting(self) -> bool: + return self.term.requires_quoting + def _display_node(self) -> str: return f"Quantify({{,{self.max_count}}})" @@ -523,6 +560,9 @@ def __post_init__(self): "QuantifyBetween: `max_count` must be greater than `min_count`." ) + def requires_quoting(self) -> bool: + return self.term.requires_quoting + def _display_node(self) -> str: return f"Quantify({{{self.min_count},{self.max_count}}})" @@ -533,6 +573,30 @@ def __repr__(self): return f"QuantifyBetween(term={repr(self.term)}, min_count={repr(self.min_count)}, max_count={repr(self.max_count)})" +### Sequence terms + + +@dataclass +class Sequence(Term): + terms: List[Term] + _requires_quoting: bool = False + + def requires_quoting(self) -> bool: + return self._requires_quoting or any(term.requires_quoting for term in self.terms) + + def _display_node(self) -> str: + return "Sequence" + + def _display_children(self, indent: str) -> str: + return "".join( + term.display_ascii_tree(indent, i == len(self.terms) - 1) + for i, term in enumerate(self.terms) + ) + + def __repr__(self): + return f"Sequence(terms={repr(self.terms)})" + + def regex(pattern: str): return Regex(pattern) @@ -726,6 +790,7 @@ def _handle_list(args: tuple, recursion_depth: int) -> Sequence: f"Only homogeneous lists are supported. Got multiple type arguments {args}." ) item_type = python_types_to_terms(args[0], recursion_depth + 1) + item_type.apply_quotation = True return Sequence( [ String("["), @@ -741,6 +806,7 @@ def _handle_tuple(args: tuple, recursion_depth: int) -> Union[Sequence, String]: return String("()") elif len(args) == 2 and args[1] is Ellipsis: item_term = python_types_to_terms(args[0], recursion_depth + 1) + item_term.apply_quotation = True return Sequence( [ String("("), @@ -754,6 +820,7 @@ def _handle_tuple(args: tuple, recursion_depth: int) -> Union[Sequence, String]: separator = String(", ") elements = [] for i, item in enumerate(items): + item.apply_quotation = True elements.append(item) if i < len(items) - 1: elements.append(separator) @@ -766,6 +833,8 @@ def _handle_dict(args: tuple, recursion_depth: int) -> Sequence: # Add dict support with key:value pairs key_type = python_types_to_terms(args[0], recursion_depth + 1) value_type = python_types_to_terms(args[1], recursion_depth + 1) + key_type.apply_quotation = True + value_type.apply_quotation = True return Sequence( [ String("{"), @@ -786,6 +855,12 @@ def _handle_dict(args: tuple, recursion_depth: int) -> Sequence: ) +def handle_quotation(term: Term, value: str) -> str: + if term.requires_quoting and term.apply_quotation: + return repr(value) + return value + + def to_regex(term: Term) -> str: """Convert a term to a regular expression. @@ -803,35 +878,35 @@ def to_regex(term: Term) -> str: """ if isinstance(term, String): - return re.escape(term.value) + return re.escape(handle_quotation(term, term.value)) elif isinstance(term, Regex): - return f"({term.pattern})" + return f"({handle_quotation(term, term.pattern)})" elif isinstance(term, JsonSchema): regex_str = outlines_core.json_schema.build_regex_from_schema(term.schema, term.whitespace_pattern) - return f"({regex_str})" + return f"({handle_quotation(term, regex_str)})" elif isinstance(term, Choice): - regexes = [to_regex(python_types_to_terms(item)) for item in term.items] - return f"({'|'.join(regexes)})" + regexes = '|'.join([to_regex(python_types_to_terms(item)) for item in term.items]) + return f"({handle_quotation(term, regexes)})" elif isinstance(term, KleeneStar): - return f"({to_regex(term.term)})*" + return f"({handle_quotation(term, to_regex(term.term))})*" elif isinstance(term, KleenePlus): - return f"({to_regex(term.term)})+" + return f"({handle_quotation(term, to_regex(term.term))})+" elif isinstance(term, Optional): - return f"({to_regex(term.term)})?" + return f"({handle_quotation(term, to_regex(term.term))})?" elif isinstance(term, Alternatives): - regexes = [to_regex(subterm) for subterm in term.terms] - return f"({'|'.join(regexes)})" + regexes = '|'.join([to_regex(subterm) for subterm in term.terms]) + return f"({handle_quotation(term, regexes)})" elif isinstance(term, Sequence): - regexes = [to_regex(subterm) for subterm in term.terms] - return f"{''.join(regexes)}" + regexes = ''.join([to_regex(subterm) for subterm in term.terms]) + return f"{handle_quotation(term, regexes)}" elif isinstance(term, QuantifyExact): - return f"({to_regex(term.term)}){{{term.count}}}" + return f"({handle_quotation(term, to_regex(term.term))}){{{term.count}}}" elif isinstance(term, QuantifyMinimum): - return f"({to_regex(term.term)}){{{term.min_count},}}" + return f"({handle_quotation(term, to_regex(term.term))}){{{term.min_count},}}" elif isinstance(term, QuantifyMaximum): - return f"({to_regex(term.term)}){{,{term.max_count}}}" + return f"({handle_quotation(term, to_regex(term.term))}){{,{term.max_count}}}" elif isinstance(term, QuantifyBetween): - return f"({to_regex(term.term)}){{{term.min_count},{term.max_count}}}" + return f"({handle_quotation(term, to_regex(term.term))}){{{term.min_count},{term.max_count}}}" else: raise TypeError( f"Cannot convert object {repr(term)} to a regular expression."