2222 import typing_extensions as te
2323
2424F = t .TypeVar ("F" , bound = t .Callable [..., t .Any ])
25+ # Typing definition of the Escape function
26+ EscapeFunc = t .Callable [[t .Any ], markupsafe .Markup ]
2527
2628# special singleton representing missing values for the runtime
2729missing : t .Any = type ("MissingType" , (), {"__repr__" : lambda x : "missing" })()
@@ -660,10 +662,10 @@ def __reversed__(self) -> t.Iterator[t.Any]:
660662def select_autoescape (
661663 enabled_extensions : t .Collection [str ] = ("html" , "htm" , "xml" ),
662664 disabled_extensions : t .Collection [str ] = (),
663- special_extensions : t .Optional [t .Dict [str , t . Callable ]] = None ,
665+ special_extensions : t .Optional [t .Dict [str , EscapeFunc ]] = None ,
664666 default_for_string : bool = True ,
665667 default : bool = False ,
666- ) -> t .Callable [[t .Optional [str ]], bool ]:
668+ ) -> t .Callable [[t .Optional [str ]], t . Union [ bool , EscapeFunc ] ]:
667669 """Intelligently sets the initial value of autoescaping based on the
668670 filename of the template. This is the recommended way to configure
669671 autoescaping if you do not want to write a custom function yourself.
@@ -716,7 +718,7 @@ def select_autoescape(
716718 parameter ``special_extensions`` was added
717719 """
718720
719- def extension_str (x ) :
721+ def extension_str (x : str ) -> str :
720722 """return a lower case extension always starting with point"""
721723 return f".{ x .lstrip ('.' ).lower ()} "
722724
@@ -725,22 +727,23 @@ def extension_str(x):
725727
726728 if special_extensions is None :
727729 special_extensions = {}
728- if special_extensions is False :
729- special_extensions = {}
730730 special_extensions = {
731731 extension_str (key ): func for key , func in special_extensions .items ()
732732 }
733733
734- def autoescape (template_name : t .Optional [str ]) -> bool :
734+ def autoescape (template_name : t .Optional [str ]) -> t . Union [ bool , EscapeFunc ] :
735735 if template_name is None :
736736 return default_for_string
737737 template_name = template_name .lower ()
738738 # Lookup autoescape function using the longest matching suffix
739+
739740 for key , func in sorted (
740- special_extensions .items (), key = lambda x : len (x [0 ]), reverse = True
741+ special_extensions .items (), # type: ignore
742+ key = lambda x : len (x [0 ]),
743+ reverse = True ,
741744 ):
742745 if template_name .endswith (key ):
743- return func
746+ return t . cast ( EscapeFunc , func )
744747 if template_name .endswith (enabled_patterns ):
745748 return True
746749 if template_name .endswith (disabled_patterns ):
@@ -826,12 +829,12 @@ class MarkupWrapper(markupsafe.Markup):
826829 """
827830
828831 @classmethod
829- def get_unwrapped_escape (cls ):
832+ def get_unwrapped_escape (cls ) -> t . Callable [[ Any ], str ] :
830833 # Needed for test
831834 return custom_escape
832835
833836 @classmethod
834- def escape (cls , s ) :
837+ def escape (cls , s : Any ) -> markupsafe . Markup :
835838 """
836839 Make sure the custom escape function does not escape
837840 already escaped strings
0 commit comments