Skip to content

Commit 9d538e4

Browse files
committed
address review
1 parent 220f3e2 commit 9d538e4

File tree

1 file changed

+48
-35
lines changed

1 file changed

+48
-35
lines changed

Lib/test/test_pyexpat.py

Lines changed: 48 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# XXX TypeErrors on calling handlers, or on bad return values from a
22
# handler, are obscure and unhelpful.
33

4+
import abc
45
import functools
56
import os
67
import re
@@ -823,11 +824,13 @@ def start_element(name, _):
823824
self.assertEqual(started, ['doc'])
824825

825826

826-
class AttackProtectionTestCases:
827-
"""Generic interface for testing XML Expat protections.
827+
class AttackProtectionTestBase(abc.ABC):
828+
"""
829+
Base class for testing protections against XML payloads with
830+
disproportionate amplification.
828831
829-
The protections being tested should mitigate attacks based
830-
on Billion Laughs payloads.
832+
The protections being tested should detect and prevent attacks
833+
that leverage disproportionate amplification from small inputs.
831834
"""
832835

833836
@staticmethod
@@ -855,88 +858,98 @@ def assert_root_parser_failure(self, func, /, *args, **kwargs):
855858
msg = "parser must be a root parser"
856859
self.assertRaisesRegex(expat.ExpatError, msg, func, *args, **kwargs)
857860

858-
def assert_active_protection(self, func, /, *args, **kwargs):
859-
"""Assert that func(*args, **kwargs) triggers the attack protection."""
860-
raise NotImplementedError
861+
@abc.abstractmethod
862+
def assert_rejected(self, func, /, *args, **kwargs):
863+
"""Assert that func(*args, **kwargs) triggers the attack protection.
864+
865+
Note: this method must ensure that the attack protection being tested
866+
is the one that is actually triggered at runtime, e.g., by matching
867+
the exact error message.
868+
"""
861869

870+
@abc.abstractmethod
862871
def set_activation_threshold(self, parser, threshold):
863872
"""Set the activation threshold for the tested protection."""
864-
raise NotImplementedError
865873

874+
@abc.abstractmethod
866875
def set_maximum_amplification(self, parser, max_factor):
867876
"""Set the maximum amplification factor for the tested protection."""
868-
raise NotImplementedError
869877

870-
def test_set_attack_protection_threshold_reached(self):
871-
raise NotImplementedError
878+
@abc.abstractmethod
879+
def test_set_activation_threshold__threshold_reached(self):
880+
"""Test when the activation threshold is exceeded."""
872881

873-
def test_set_attack_protection_threshold_ignored(self):
874-
raise NotImplementedError
882+
@abc.abstractmethod
883+
def test_set_activation_threshold__threshold_not_reached(self):
884+
"""Test when the activation threshold is not exceeded."""
875885

876-
def test_set_attack_protection_threshold_arg_invalid_type(self):
886+
def test_set_activation_threshold__invalid_threshold_type(self):
877887
parser = expat.ParserCreate()
878888
setter = functools.partial(self.set_activation_threshold, parser)
879889

880890
self.assertRaises(TypeError, setter, 1.0)
881891
self.assertRaises(TypeError, setter, -1.5)
882892
self.assertRaises(ValueError, setter, -5)
883893

884-
def test_set_attack_protection_threshold_arg_invalid_range(self):
894+
def test_set_activation_threshold__invalid_threshold_range(self):
885895
_testcapi = import_helper.import_module("_testcapi")
886896
parser = expat.ParserCreate()
887897
setter = functools.partial(self.set_activation_threshold, parser)
888898

889899
self.assertRaises(OverflowError, setter, _testcapi.ULLONG_MAX + 1)
890900

891-
def test_set_attack_protection_threshold_fail_for_subparser(self):
901+
def test_set_activation_threshold__fail_for_subparser(self):
892902
parser = expat.ParserCreate()
893903
subparser = parser.ExternalEntityParserCreate(None)
894904
setter = functools.partial(self.set_activation_threshold, subparser)
895905
self.assert_root_parser_failure(setter, 12345)
896906

897-
def test_set_maximum_amplification_reached(self):
898-
raise NotImplementedError
907+
@abc.abstractmethod
908+
def test_set_maximum_amplification__amplification_exceeded(self):
909+
"""Test when the amplification factor is exceeded."""
899910

900-
def test_set_maximum_amplification_ignored(self):
901-
raise NotImplementedError
911+
@abc.abstractmethod
912+
def test_set_maximum_amplification__amplification_not_exceeded(self):
913+
"""Test when the amplification factor is not exceeded."""
902914

903-
def test_set_maximum_amplification_infinity(self):
915+
def test_set_maximum_amplification__infinity(self):
904916
inf = float('inf') # an 'inf' threshold is allowed by Expat
905917
parser = expat.ParserCreate()
906918
self.assertIsNone(self.set_maximum_amplification(parser, inf))
907919

908-
def test_set_maximum_amplification_arg_invalid_type(self):
920+
def test_set_maximum_amplification__invalid_max_factor_type(self):
909921
parser = expat.ParserCreate()
910922
setter = functools.partial(self.set_maximum_amplification, parser)
911923

912924
self.assertRaises(TypeError, setter, None)
913925
self.assertRaises(TypeError, setter, 'abc')
914926

915-
def test_set_maximum_amplification_arg_invalid_range(self):
927+
def test_set_maximum_amplification__invalid_max_factor_range(self):
916928
parser = expat.ParserCreate()
917929
setter = functools.partial(self.set_maximum_amplification, parser)
918930

919931
msg = re.escape("'max_factor' must be at least 1.0")
920932
self.assertRaisesRegex(expat.ExpatError, msg, setter, float('nan'))
921933
self.assertRaisesRegex(expat.ExpatError, msg, setter, 0.99)
922934

923-
def test_set_maximum_amplification_fail_for_subparser(self):
935+
def test_set_maximum_amplification__fail_for_subparser(self):
924936
parser = expat.ParserCreate()
925937
subparser = parser.ExternalEntityParserCreate(None)
926938
setter = functools.partial(self.set_maximum_amplification, subparser)
927939
self.assert_root_parser_failure(setter, 123.45)
928940

929941

930942
@unittest.skipIf(expat.version_info < (2, 7, 2), "requires Expat >= 2.7.2")
931-
class MemoryProtectionTest(AttackProtectionTestCases, unittest.TestCase):
943+
class MemoryProtectionTest(AttackProtectionTestBase, unittest.TestCase):
932944

933945
# With the default Expat configuration, the billion laughs protection may
934946
# hit before the allocation limiter if exponential_expansion_payload() is
935-
# not carefully parametrized. In particular, use the following assert_*()
936-
# methods to check the error message of the active protection.
947+
# not carefully parametrized. As such, the payloads should be chosen so
948+
# that either the allocation limiter is hit before other protections are
949+
# triggered or no protection at all is triggered.
937950

938-
def assert_active_protection(self, func, /, *args, **kwargs):
939-
"""Check that fnuc(*args, **kwargs) hits the allocation limit."""
951+
def assert_rejected(self, func, /, *args, **kwargs):
952+
"""Check that func(*args, **kwargs) hits the allocation limit."""
940953
msg = r"out of memory: line \d+, column \d+"
941954
self.assertRaisesRegex(expat.ExpatError, msg, func, *args, **kwargs)
942955

@@ -946,17 +959,17 @@ def set_activation_threshold(self, parser, threshold):
946959
def set_maximum_amplification(self, parser, max_factor):
947960
return parser.SetAllocTrackerMaximumAmplification(max_factor)
948961

949-
def test_set_attack_protection_threshold_reached(self):
962+
def test_set_activation_threshold__threshold_reached(self):
950963
parser = expat.ParserCreate()
951964
# Choose a threshold expected to be always reached.
952965
self.set_activation_threshold(parser, 3)
953966
# Check that the threshold is reached by choosing a small factor
954967
# and a payload whose peak amplification factor exceeds it.
955968
self.assertIsNone(self.set_maximum_amplification(parser, 1.0))
956969
payload = self.exponential_expansion_payload(10, 4)
957-
self.assert_active_protection(parser.Parse, payload, True)
970+
self.assert_rejected(parser.Parse, payload, True)
958971

959-
def test_set_attack_protection_threshold_ignored(self):
972+
def test_set_activation_threshold__threshold_not_reached(self):
960973
parser = expat.ParserCreate()
961974
# Choose a threshold expected to be never reached.
962975
self.set_activation_threshold(parser, pow(10, 5))
@@ -966,17 +979,17 @@ def test_set_attack_protection_threshold_ignored(self):
966979
payload = self.exponential_expansion_payload(10, 4)
967980
self.assertIsNotNone(parser.Parse(payload, True))
968981

969-
def test_set_maximum_amplification_reached(self):
982+
def test_set_maximum_amplification__amplification_exceeded(self):
970983
parser = expat.ParserCreate()
971984
# Unconditionally enable maximum activation factor.
972985
self.set_activation_threshold(parser, 0)
973986
# Choose a max amplification factor expected to always be exceeded.
974987
self.assertIsNone(self.set_maximum_amplification(parser, 1.0))
975988
# Craft a payload for which the peak amplification factor is > 1.0.
976989
payload = self.exponential_expansion_payload(1, 2)
977-
self.assert_active_protection(parser.Parse, payload, True)
990+
self.assert_rejected(parser.Parse, payload, True)
978991

979-
def test_set_maximum_amplification_ignored(self):
992+
def test_set_maximum_amplification__amplification_not_exceeded(self):
980993
parser = expat.ParserCreate()
981994
# Unconditionally enable maximum activation factor.
982995
self.set_activation_threshold(parser, 0)

0 commit comments

Comments
 (0)