Skip to content

Commit 12bef9c

Browse files
committed
add tests
1 parent c1c23fb commit 12bef9c

File tree

1 file changed

+89
-0
lines changed

1 file changed

+89
-0
lines changed

Lib/test/test_pyexpat.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@
22
# handler, are obscure and unhelpful.
33

44
import os
5+
import re
56
import sys
67
import sysconfig
78
import unittest
9+
import textwrap
810
import traceback
911
from io import BytesIO
1012
from test import support
@@ -821,5 +823,92 @@ def start_element(name, _):
821823
self.assertEqual(started, ['doc'])
822824

823825

826+
class AttackProtectionTest(unittest.TestCase):
827+
828+
def billion_laughs(self, ncols, nrows, text='.'):
829+
"""Create a billion laugh payload.
830+
831+
Be careful: the number of total items is pow(n, k), thereby
832+
requiring at least pow(ncols, nrows) * sizeof(base) memory!
833+
"""
834+
body = textwrap.indent('\n'.join(
835+
f'<!ENTITY row{i + 1} "{f"&row{i};" * ncols}">'
836+
for i in range(nrows)
837+
), ' ')
838+
return f"""\
839+
<?xml version="1.0"?>
840+
<!DOCTYPE doc [
841+
<!ENTITY row0 "{text}">
842+
<!ELEMENT doc (#PCDATA)>
843+
{body}
844+
]>
845+
<doc>&row{nrows};</doc>
846+
"""
847+
848+
def test_set_alloc_tracker_maximum_amplification(self):
849+
payload = self.billion_laughs(10, 4)
850+
851+
p = expat.ParserCreate()
852+
# Unconditionally enable maximum amplification factor.
853+
p.SetAllocTrackerActivationThreshold(0)
854+
# At runtime, the peak amplification factor is 101.71,
855+
# which is above the default threshold (100.0).
856+
msg = re.escape("out of memory: line 3, column 15")
857+
self.assertRaisesRegex(expat.ExpatError, msg, p.Parse, payload)
858+
859+
# # Re-create a parser as the current parser is now in an error state.
860+
p = expat.ParserCreate()
861+
# Unconditionally enable maximum amplification factor.
862+
p.SetAllocTrackerActivationThreshold(0)
863+
# Use a max amplification factor a bit above the actual one.
864+
self.assertIsNone(p.SetAllocTrackerMaximumAmplification(101.72))
865+
self.assertIsNotNone(p.Parse(payload))
866+
867+
def test_set_alloc_tracker_maximum_amplification_invalid_args(self):
868+
parser = expat.ParserCreate()
869+
f = parser.SetAllocTrackerMaximumAmplification
870+
871+
msg = re.escape("'max_factor' must be at least 1.0")
872+
self.assertRaisesRegex(expat.ExpatError, msg, f, float('nan'))
873+
self.assertRaisesRegex(expat.ExpatError, msg, f, 0.99)
874+
875+
subparser = parser.ExternalEntityParserCreate(None)
876+
fsub = subparser.SetAllocTrackerMaximumAmplification
877+
msg = re.escape("parser must be a root parser")
878+
self.assertRaisesRegex(expat.ExpatError, msg, fsub, 1.0)
879+
880+
def test_set_alloc_tracker_activation_threshold(self):
881+
# Run the test with EXPAT_MALLOC_DEBUG=2 to detect those constants.
882+
MAX_ALLOC = 17333
883+
MIN_ALLOC = 1096
884+
885+
payload = self.billion_laughs(10, 4)
886+
887+
p = expat.ParserCreate()
888+
p.SetAllocTrackerActivationThreshold(MAX_ALLOC + 1)
889+
self.assertIsNone(p.SetAllocTrackerMaximumAmplification(1.0))
890+
# Check that we never reach the activation threshold.
891+
self.assertIsNotNone(p.Parse(payload))
892+
893+
p = expat.ParserCreate()
894+
p.SetAllocTrackerActivationThreshold(MIN_ALLOC - 1)
895+
# Check that we always reach the activation threshold.
896+
self.assertIsNone(p.SetAllocTrackerMaximumAmplification(1.0))
897+
msg = re.escape("out of memory: line 3, column 10")
898+
self.assertRaisesRegex(expat.ExpatError, msg, p.Parse, payload)
899+
900+
def test_set_alloc_tracker_activation_threshold_invalid_args(self):
901+
parser = expat.ParserCreate()
902+
f = parser.SetAllocTrackerActivationThreshold
903+
904+
ULONG_LONG_MAX = 2 * sys.maxsize + 1
905+
self.assertRaises(OverflowError, f, ULONG_LONG_MAX + 1)
906+
907+
subparser = parser.ExternalEntityParserCreate(None)
908+
fsub = subparser.SetAllocTrackerActivationThreshold
909+
msg = re.escape("parser must be a root parser")
910+
self.assertRaisesRegex(expat.ExpatError, msg, fsub, 12345)
911+
912+
824913
if __name__ == "__main__":
825914
unittest.main()

0 commit comments

Comments
 (0)