Skip to content

Commit 7b7f4ef

Browse files
committed
fix typing issues
1 parent 3ca356a commit 7b7f4ef

File tree

5 files changed

+30
-22
lines changed

5 files changed

+30
-22
lines changed

src/jinja2/environment.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
from .utils import _PassArg
5252
from .utils import concat
5353
from .utils import consume
54+
from .utils import EscapeFunc
5455
from .utils import get_wrapped_escape_class
5556
from .utils import import_string
5657
from .utils import internalcode
@@ -346,7 +347,7 @@ class Environment:
346347
context_class: t.Type[Context] = Context
347348

348349
template_class: t.Type["Template"]
349-
default_markup_class: t.Type["Markup"]
350+
default_markup_class: t.Type[Markup]
350351

351352
def __init__(
352353
self,
@@ -372,7 +373,7 @@ def __init__(
372373
auto_reload: bool = True,
373374
bytecode_cache: t.Optional["BytecodeCache"] = None,
374375
enable_async: bool = False,
375-
default_escape=html_escape,
376+
default_escape: t.Union[EscapeFunc, t.Type[Markup]] = html_escape,
376377
allow_mixed_escape_extends: bool = False,
377378
):
378379
# !!Important notice!!
@@ -406,6 +407,7 @@ def __init__(
406407
self.finalize = finalize
407408
self.autoescape = autoescape
408409
if isclass(default_escape):
410+
default_escape = t.cast(t.Type[Markup], default_escape)
409411
self.default_markup_class = default_escape
410412
elif default_escape != html_escape:
411413
self.default_markup_class = get_wrapped_escape_class(default_escape)
@@ -434,9 +436,7 @@ def __init__(
434436
self.is_async = enable_async
435437
_environment_config_check(self)
436438

437-
def get_markup_class(
438-
self, template_name: t.Optional[str] = None
439-
) -> t.Type["Markup"]:
439+
def get_markup_class(self, template_name: t.Optional[str] = None) -> t.Type[Markup]:
440440
"""
441441
Get the correct :class:`Markup` for the given template name.
442442
@@ -1193,14 +1193,14 @@ def select_template(
11931193

11941194
for name in names:
11951195
if isinstance(name, Template):
1196-
self._check_multi_template_autoescape(names, parent_template, caller)
1196+
self._check_multi_template_autoescape(name, parent_template, caller)
11971197
return name
11981198
if parent is not None:
11991199
name = self.join_path(name, parent)
12001200
try:
12011201
template = self._load_template(name, globals)
12021202
# Only check autoescape if template can be loaded
1203-
self._check_multi_template_autoescape(names, parent_template, caller)
1203+
self._check_multi_template_autoescape(name, parent_template, caller)
12041204
return template
12051205
except (TemplateNotFound, UndefinedError):
12061206
pass

src/jinja2/ext.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@ def _make_new_gettext(func: t.Callable[[str], str]) -> t.Callable[..., str]:
171171
def gettext(__context: Context, __string: str, **variables: t.Any) -> str:
172172
rv = __context.call(func, __string)
173173
if __context.eval_ctx.autoescape:
174+
rv = t.cast(str, rv)
174175
rv = __context.eval_ctx.mark_safe(rv)
175176
# Always treat as a format string, even if there are no
176177
# variables. This makes translation strings more consistent
@@ -192,6 +193,7 @@ def ngettext(
192193
variables.setdefault("num", __num)
193194
rv = __context.call(func, __singular, __plural, __num)
194195
if __context.eval_ctx.autoescape:
196+
rv = t.cast(str, rv)
195197
rv = __context.eval_ctx.mark_safe(rv)
196198
# Always treat as a format string, see gettext comment above.
197199
return rv % variables # type: ignore
@@ -208,6 +210,7 @@ def pgettext(
208210
rv = __context.call(func, __string_ctx, __string)
209211

210212
if __context.eval_ctx.autoescape:
213+
rv = t.cast(str, rv)
211214
rv = __context.eval_ctx.mark_safe(rv)
212215

213216
# Always treat as a format string, see gettext comment above.
@@ -233,6 +236,7 @@ def npgettext(
233236
rv = __context.call(func, __string_ctx, __singular, __plural, __num)
234237

235238
if __context.eval_ctx.autoescape:
239+
rv = t.cast(str, rv)
236240
rv = __context.eval_ctx.mark_safe(rv)
237241

238242
# Always treat as a format string, see gettext comment above.

src/jinja2/filters.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,7 @@ def do_replace(
287287
):
288288
s = do_escape(eval_ctx, s)
289289
else:
290-
s = markupsafe.soft_str(s) # type: ignore
290+
s = markupsafe.soft_str(s)
291291

292292
# Special case, if user uses Markup class directly to mark
293293
# something as safe but uses custom escape function

src/jinja2/runtime.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from .nodes import EvalContext
1717
from .utils import _PassArg
1818
from .utils import concat
19+
from .utils import EscapeFunc
1920
from .utils import internalcode
2021
from .utils import missing
2122
from .utils import Namespace # noqa: F401
@@ -71,7 +72,7 @@ def identity(x: V) -> V:
7172
return x
7273

7374

74-
def markup_join(seq: t.Iterable[t.Any], escape_func=html_escape) -> str:
75+
def markup_join(seq: t.Iterable[t.Any], escape_func: EscapeFunc = html_escape) -> str:
7576
"""
7677
Concatenation that escapes if necessary and converts to string.
7778
@@ -88,7 +89,7 @@ def markup_join(seq: t.Iterable[t.Any], escape_func=html_escape) -> str:
8889
return concat(buf)
8990

9091

91-
def str_join(seq: t.Iterable[t.Any], escape_func=html_escape):
92+
def str_join(seq: t.Iterable[t.Any], escape_func: EscapeFunc = html_escape) -> str:
9293
"""
9394
Simple args to string conversion and concatenation.
9495
@@ -727,7 +728,7 @@ def __init__(
727728
default_autoescape: t.Optional[bool] = None,
728729
):
729730
self._environment = environment
730-
self._mark_safe = environment.default_markup_class
731+
self._mark_safe: EscapeFunc = environment.get_markup_class()
731732
self._func = func
732733
self._argument_count = len(arguments)
733734
self.name = name
@@ -770,7 +771,7 @@ def __call__(self, *args: t.Any, **kwargs: t.Any) -> str:
770771
# If the eval context is available we use it to determine
771772
# the correct mark safe method
772773
# otherwise mark safe is already set in the __init__
773-
# function from enviromental context
774+
# function from environmental context
774775
self._mark_safe = args[0].mark_safe
775776
args = args[1:]
776777
else:

src/jinja2/utils.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
import typing_extensions as te
2323

2424
F = t.TypeVar("F", bound=t.Callable[..., t.Any])
25+
# Typing definition of the Escape function
26+
EscapeFunc = t.Callable[[t.Any], markupsafe.Markup]
2527

2628
# special singleton representing missing values for the runtime
2729
missing: t.Any = type("MissingType", (), {"__repr__": lambda x: "missing"})()
@@ -660,10 +662,10 @@ def __reversed__(self) -> t.Iterator[t.Any]:
660662
def select_autoescape(
661663
enabled_extensions: t.Collection[str] = ("html", "htm", "xml"),
662664
disabled_extensions: t.Collection[str] = (),
663-
special_extensions: t.Optional[t.Dict[str, t.Callable]] = None,
665+
special_extensions: t.Optional[t.Dict[str, EscapeFunc]] = None,
664666
default_for_string: bool = True,
665667
default: bool = False,
666-
) -> t.Callable[[t.Optional[str]], bool]:
668+
) -> t.Callable[[t.Optional[str]], t.Union[bool, EscapeFunc]]:
667669
"""Intelligently sets the initial value of autoescaping based on the
668670
filename of the template. This is the recommended way to configure
669671
autoescaping if you do not want to write a custom function yourself.
@@ -716,7 +718,7 @@ def select_autoescape(
716718
parameter ``special_extensions`` was added
717719
"""
718720

719-
def extension_str(x):
721+
def extension_str(x: str) -> str:
720722
"""return a lower case extension always starting with point"""
721723
return f".{x.lstrip('.').lower()}"
722724

@@ -725,22 +727,23 @@ def extension_str(x):
725727

726728
if special_extensions is None:
727729
special_extensions = {}
728-
if special_extensions is False:
729-
special_extensions = {}
730730
special_extensions = {
731731
extension_str(key): func for key, func in special_extensions.items()
732732
}
733733

734-
def autoescape(template_name: t.Optional[str]) -> bool:
734+
def autoescape(template_name: t.Optional[str]) -> t.Union[bool, EscapeFunc]:
735735
if template_name is None:
736736
return default_for_string
737737
template_name = template_name.lower()
738738
# Lookup autoescape function using the longest matching suffix
739+
739740
for key, func in sorted(
740-
special_extensions.items(), key=lambda x: len(x[0]), reverse=True
741+
special_extensions.items(), # type: ignore
742+
key=lambda x: len(x[0]),
743+
reverse=True,
741744
):
742745
if template_name.endswith(key):
743-
return func
746+
return t.cast(EscapeFunc, func)
744747
if template_name.endswith(enabled_patterns):
745748
return True
746749
if template_name.endswith(disabled_patterns):
@@ -826,12 +829,12 @@ class MarkupWrapper(markupsafe.Markup):
826829
"""
827830

828831
@classmethod
829-
def get_unwrapped_escape(cls):
832+
def get_unwrapped_escape(cls) -> t.Callable[[Any], str]:
830833
# Needed for test
831834
return custom_escape
832835

833836
@classmethod
834-
def escape(cls, s):
837+
def escape(cls, s: Any) -> markupsafe.Markup:
835838
"""
836839
Make sure the custom escape function does not escape
837840
already escaped strings

0 commit comments

Comments
 (0)