|
1 | 1 | # XXX TypeErrors on calling handlers, or on bad return values from a
|
2 | 2 | # handler, are obscure and unhelpful.
|
3 | 3 |
|
| 4 | +import abc |
| 5 | +import functools |
4 | 6 | import os
|
| 7 | +import re |
5 | 8 | import sys
|
6 | 9 | import sysconfig
|
| 10 | +import textwrap |
7 | 11 | import unittest
|
8 | 12 | import traceback
|
9 | 13 | from io import BytesIO
|
10 | 14 | from test import support
|
11 |
| -from test.support import os_helper |
| 15 | +from test.support import import_helper, os_helper |
12 | 16 |
|
13 | 17 | from xml.parsers import expat
|
14 | 18 | from xml.parsers.expat import errors
|
@@ -809,5 +813,199 @@ def start_element(name, _):
|
809 | 813 | self.assertEqual(started, ['doc'])
|
810 | 814 |
|
811 | 815 |
|
| 816 | +class AttackProtectionTestBase(abc.ABC): |
| 817 | + """ |
| 818 | + Base class for testing protections against XML payloads with |
| 819 | + disproportionate amplification. |
| 820 | +
|
| 821 | + The protections being tested should detect and prevent attacks |
| 822 | + that leverage disproportionate amplification from small inputs. |
| 823 | + """ |
| 824 | + |
| 825 | + @staticmethod |
| 826 | + def exponential_expansion_payload(*, nrows, ncols, text='.'): |
| 827 | + """Create a billion laughs attack payload. |
| 828 | +
|
| 829 | + Be careful: the number of total items is pow(n, k), thereby |
| 830 | + requiring at least pow(ncols, nrows) * sizeof(text) memory! |
| 831 | + """ |
| 832 | + template = textwrap.dedent(f"""\ |
| 833 | + <?xml version="1.0"?> |
| 834 | + <!DOCTYPE doc [ |
| 835 | + <!ENTITY row0 "{text}"> |
| 836 | + <!ELEMENT doc (#PCDATA)> |
| 837 | + {{body}} |
| 838 | + ]> |
| 839 | + <doc>&row{nrows};</doc> |
| 840 | + """).rstrip() |
| 841 | + |
| 842 | + body = '\n'.join( |
| 843 | + f'<!ENTITY row{i + 1} "{f"&row{i};" * ncols}">' |
| 844 | + for i in range(nrows) |
| 845 | + ) |
| 846 | + body = textwrap.indent(body, ' ' * 4) |
| 847 | + return template.format(body=body) |
| 848 | + |
| 849 | + def test_payload_generation(self): |
| 850 | + # self-test for exponential_expansion_payload() |
| 851 | + payload = self.exponential_expansion_payload(nrows=2, ncols=3) |
| 852 | + self.assertEqual(payload, textwrap.dedent("""\ |
| 853 | + <?xml version="1.0"?> |
| 854 | + <!DOCTYPE doc [ |
| 855 | + <!ENTITY row0 "."> |
| 856 | + <!ELEMENT doc (#PCDATA)> |
| 857 | + <!ENTITY row1 "&row0;&row0;&row0;"> |
| 858 | + <!ENTITY row2 "&row1;&row1;&row1;"> |
| 859 | + ]> |
| 860 | + <doc>&row2;</doc> |
| 861 | + """).rstrip()) |
| 862 | + |
| 863 | + def assert_root_parser_failure(self, func, /, *args, **kwargs): |
| 864 | + """Check that func(*args, **kwargs) is invalid for a sub-parser.""" |
| 865 | + msg = "parser must be a root parser" |
| 866 | + self.assertRaisesRegex(expat.ExpatError, msg, func, *args, **kwargs) |
| 867 | + |
| 868 | + @abc.abstractmethod |
| 869 | + def assert_rejected(self, func, /, *args, **kwargs): |
| 870 | + """Assert that func(*args, **kwargs) triggers the attack protection. |
| 871 | +
|
| 872 | + Note: this method must ensure that the attack protection being tested |
| 873 | + is the one that is actually triggered at runtime, e.g., by matching |
| 874 | + the exact error message. |
| 875 | + """ |
| 876 | + |
| 877 | + @abc.abstractmethod |
| 878 | + def set_activation_threshold(self, parser, threshold): |
| 879 | + """Set the activation threshold for the tested protection.""" |
| 880 | + |
| 881 | + @abc.abstractmethod |
| 882 | + def set_maximum_amplification(self, parser, max_factor): |
| 883 | + """Set the maximum amplification factor for the tested protection.""" |
| 884 | + |
| 885 | + @abc.abstractmethod |
| 886 | + def test_set_activation_threshold__threshold_reached(self): |
| 887 | + """Test when the activation threshold is exceeded.""" |
| 888 | + |
| 889 | + @abc.abstractmethod |
| 890 | + def test_set_activation_threshold__threshold_not_reached(self): |
| 891 | + """Test when the activation threshold is not exceeded.""" |
| 892 | + |
| 893 | + def test_set_activation_threshold__invalid_threshold_type(self): |
| 894 | + parser = expat.ParserCreate() |
| 895 | + setter = functools.partial(self.set_activation_threshold, parser) |
| 896 | + |
| 897 | + self.assertRaises(TypeError, setter, 1.0) |
| 898 | + self.assertRaises(TypeError, setter, -1.5) |
| 899 | + self.assertRaises(ValueError, setter, -5) |
| 900 | + |
| 901 | + def test_set_activation_threshold__invalid_threshold_range(self): |
| 902 | + _testcapi = import_helper.import_module("_testcapi") |
| 903 | + parser = expat.ParserCreate() |
| 904 | + setter = functools.partial(self.set_activation_threshold, parser) |
| 905 | + |
| 906 | + self.assertRaises(OverflowError, setter, _testcapi.ULLONG_MAX + 1) |
| 907 | + |
| 908 | + def test_set_activation_threshold__fail_for_subparser(self): |
| 909 | + parser = expat.ParserCreate() |
| 910 | + subparser = parser.ExternalEntityParserCreate(None) |
| 911 | + setter = functools.partial(self.set_activation_threshold, subparser) |
| 912 | + self.assert_root_parser_failure(setter, 12345) |
| 913 | + |
| 914 | + @abc.abstractmethod |
| 915 | + def test_set_maximum_amplification__amplification_exceeded(self): |
| 916 | + """Test when the amplification factor is exceeded.""" |
| 917 | + |
| 918 | + @abc.abstractmethod |
| 919 | + def test_set_maximum_amplification__amplification_not_exceeded(self): |
| 920 | + """Test when the amplification factor is not exceeded.""" |
| 921 | + |
| 922 | + def test_set_maximum_amplification__infinity(self): |
| 923 | + inf = float('inf') # an 'inf' threshold is allowed by Expat |
| 924 | + parser = expat.ParserCreate() |
| 925 | + self.assertIsNone(self.set_maximum_amplification(parser, inf)) |
| 926 | + |
| 927 | + def test_set_maximum_amplification__invalid_max_factor_type(self): |
| 928 | + parser = expat.ParserCreate() |
| 929 | + setter = functools.partial(self.set_maximum_amplification, parser) |
| 930 | + |
| 931 | + self.assertRaises(TypeError, setter, None) |
| 932 | + self.assertRaises(TypeError, setter, 'abc') |
| 933 | + |
| 934 | + def test_set_maximum_amplification__invalid_max_factor_range(self): |
| 935 | + parser = expat.ParserCreate() |
| 936 | + setter = functools.partial(self.set_maximum_amplification, parser) |
| 937 | + |
| 938 | + msg = re.escape("'max_factor' must be at least 1.0") |
| 939 | + self.assertRaisesRegex(expat.ExpatError, msg, setter, float('nan')) |
| 940 | + self.assertRaisesRegex(expat.ExpatError, msg, setter, 0.99) |
| 941 | + |
| 942 | + def test_set_maximum_amplification__fail_for_subparser(self): |
| 943 | + parser = expat.ParserCreate() |
| 944 | + subparser = parser.ExternalEntityParserCreate(None) |
| 945 | + setter = functools.partial(self.set_maximum_amplification, subparser) |
| 946 | + self.assert_root_parser_failure(setter, 123.45) |
| 947 | + |
| 948 | + |
| 949 | +@unittest.skipIf(expat.version_info < (2, 7, 2), "requires Expat >= 2.7.2") |
| 950 | +class MemoryProtectionTest(AttackProtectionTestBase, unittest.TestCase): |
| 951 | + |
| 952 | + # NOTE: with the default Expat configuration, the billion laughs protection |
| 953 | + # may hit before the allocation limiter if exponential_expansion_payload() |
| 954 | + # is not carefully parametrized. As such, the payloads should be chosen so |
| 955 | + # that either the allocation limiter is hit before other protections are |
| 956 | + # triggered or no protection at all is triggered. |
| 957 | + |
| 958 | + def assert_rejected(self, func, /, *args, **kwargs): |
| 959 | + """Check that func(*args, **kwargs) hits the allocation limit.""" |
| 960 | + msg = r"out of memory: line \d+, column \d+" |
| 961 | + self.assertRaisesRegex(expat.ExpatError, msg, func, *args, **kwargs) |
| 962 | + |
| 963 | + def set_activation_threshold(self, parser, threshold): |
| 964 | + return parser.SetAllocTrackerActivationThreshold(threshold) |
| 965 | + |
| 966 | + def set_maximum_amplification(self, parser, max_factor): |
| 967 | + return parser.SetAllocTrackerMaximumAmplification(max_factor) |
| 968 | + |
| 969 | + def test_set_activation_threshold__threshold_reached(self): |
| 970 | + parser = expat.ParserCreate() |
| 971 | + # Choose a threshold expected to be always reached. |
| 972 | + self.set_activation_threshold(parser, 3) |
| 973 | + # Check that the threshold is reached by choosing a small factor |
| 974 | + # and a payload whose peak amplification factor exceeds it. |
| 975 | + self.assertIsNone(self.set_maximum_amplification(parser, 1.0)) |
| 976 | + payload = self.exponential_expansion_payload(ncols=10, nrows=4) |
| 977 | + self.assert_rejected(parser.Parse, payload, True) |
| 978 | + |
| 979 | + def test_set_activation_threshold__threshold_not_reached(self): |
| 980 | + parser = expat.ParserCreate() |
| 981 | + # Choose a threshold expected to be never reached. |
| 982 | + self.set_activation_threshold(parser, pow(10, 5)) |
| 983 | + # Check that the threshold is reached by choosing a small factor |
| 984 | + # and a payload whose peak amplification factor exceeds it. |
| 985 | + self.assertIsNone(self.set_maximum_amplification(parser, 1.0)) |
| 986 | + payload = self.exponential_expansion_payload(ncols=10, nrows=4) |
| 987 | + self.assertIsNotNone(parser.Parse(payload, True)) |
| 988 | + |
| 989 | + def test_set_maximum_amplification__amplification_exceeded(self): |
| 990 | + parser = expat.ParserCreate() |
| 991 | + # Unconditionally enable maximum activation factor. |
| 992 | + self.set_activation_threshold(parser, 0) |
| 993 | + # Choose a max amplification factor expected to always be exceeded. |
| 994 | + self.assertIsNone(self.set_maximum_amplification(parser, 1.0)) |
| 995 | + # Craft a payload for which the peak amplification factor is > 1.0. |
| 996 | + payload = self.exponential_expansion_payload(ncols=1, nrows=2) |
| 997 | + self.assert_rejected(parser.Parse, payload, True) |
| 998 | + |
| 999 | + def test_set_maximum_amplification__amplification_not_exceeded(self): |
| 1000 | + parser = expat.ParserCreate() |
| 1001 | + # Unconditionally enable maximum activation factor. |
| 1002 | + self.set_activation_threshold(parser, 0) |
| 1003 | + # Choose a max amplification factor expected to never be exceeded. |
| 1004 | + self.assertIsNone(self.set_maximum_amplification(parser, 1e4)) |
| 1005 | + # Craft a payload for which the peak amplification factor is < 1e4. |
| 1006 | + payload = self.exponential_expansion_payload(ncols=1, nrows=2) |
| 1007 | + self.assertIsNotNone(parser.Parse(payload, True)) |
| 1008 | + |
| 1009 | + |
812 | 1010 | if __name__ == "__main__":
|
813 | 1011 | unittest.main()
|
0 commit comments