diff --git a/protovalidate/internal/extra_func.py b/protovalidate/internal/extra_func.py index adcfcbf8..eca6e1b5 100644 --- a/protovalidate/internal/extra_func.py +++ b/protovalidate/internal/extra_func.py @@ -931,7 +931,7 @@ def __scheme(self) -> bool: while self.__alpha() or self.__digit() or self.__take("+") or self.__take("-") or self.__take("."): pass - if self._string[self._index] == ":": + if self.__peek(":"): return True self._index = start @@ -997,9 +997,8 @@ def __userinfo(self) -> bool: while self.__unreserved() or self.__pct_encoded() or self.__sub_delims() or self.__take(":"): pass - if self._index < len(self._string): - if self._string[self._index] == "@": - return True + if self.__peek("@"): + return True self._index = start return False @@ -1023,14 +1022,11 @@ def __host(self) -> bool: host = IP-literal / IPv4address / reg-name. """ - if self._index >= len(self._string): - return False - start = self._index self._pct_encoded_found = False # Note: IPv4address is a subset of reg-name - if (self._string[self._index] == "[" and self.__ip_literal()) or self.__reg_name(): + if (self.__peek("[") and self.__ip_literal()) or self.__reg_name(): if self._pct_encoded_found: raw_host = self._string[start : self._index] # RFC 3986: @@ -1188,7 +1184,7 @@ def __reg_name(self) -> bool: # End of authority return True - if self._string[self._index] == ":": + if self.__peek(":"): return True self._index = start @@ -1380,7 +1376,7 @@ def __query(self) -> bool: while self.__pchar() or self.__take("/") or self.__take("?"): pass - if self._index == len(self._string) or self._string[self._index] == "#": + if self._index == len(self._string) or self.__peek("#"): return True self._index = start @@ -1537,6 +1533,9 @@ def __take(self, char: str) -> bool: return False + def __peek(self, char: str) -> bool: + return self._index < len(self._string) and self._string[self._index] == char + def make_extra_funcs(locale: str) -> dict[str, celpy.CELFunction]: # TODO(#257): Fix types and add tests for StringFormat.