|
| 1 | +""" |
| 2 | +Test suite for zipfile validation features. |
| 3 | +""" |
| 4 | + |
| 5 | +import io |
| 6 | +import os |
| 7 | +import struct |
| 8 | +import tempfile |
| 9 | +import unittest |
| 10 | +import zipfile |
| 11 | +from zipfile import ( |
| 12 | + ZipFile, ZipValidationLevel, ZipStructuralError, ZipValidationError, |
| 13 | + BadZipFile, sizeEndCentDir, stringEndArchive, structEndArchive, |
| 14 | + sizeCentralDir, stringCentralDir, structCentralDir, |
| 15 | + sizeFileHeader, stringFileHeader, structFileHeader, |
| 16 | + _ECD_ENTRIES_TOTAL, _ECD_SIZE, _ECD_OFFSET, _ECD_COMMENT_SIZE |
| 17 | +) |
| 18 | +from test.support.os_helper import TESTFN, unlink |
| 19 | + |
| 20 | + |
| 21 | +class TestZipValidation(unittest.TestCase): |
| 22 | + """Test zipfile validation functionality.""" |
| 23 | + |
| 24 | + def setUp(self): |
| 25 | + """Set up test fixtures.""" |
| 26 | + self.temp_files = [] |
| 27 | + |
| 28 | + def tearDown(self): |
| 29 | + """Clean up test fixtures.""" |
| 30 | + for temp_file in self.temp_files: |
| 31 | + try: |
| 32 | + unlink(temp_file) |
| 33 | + except OSError: |
| 34 | + pass |
| 35 | + |
| 36 | + def create_temp_file(self, content=b''): |
| 37 | + """Create a temporary file with given content.""" |
| 38 | + fd, path = tempfile.mkstemp() |
| 39 | + self.temp_files.append(path) |
| 40 | + with os.fdopen(fd, 'wb') as f: |
| 41 | + f.write(content) |
| 42 | + return path |
| 43 | + |
| 44 | + def test_basic_validation_backward_compatibility(self): |
| 45 | + """Test that basic validation mode maintains backward compatibility.""" |
| 46 | + # Create a valid ZIP file |
| 47 | + temp_path = self.create_temp_file() |
| 48 | + with ZipFile(temp_path, 'w') as zf: |
| 49 | + zf.writestr('test.txt', 'Hello, World!') |
| 50 | + |
| 51 | + # Test default behavior (should be BASIC validation) |
| 52 | + with ZipFile(temp_path, 'r') as zf: |
| 53 | + self.assertEqual(zf._strict_validation, ZipValidationLevel.BASIC) |
| 54 | + self.assertEqual(zf.read('test.txt'), b'Hello, World!') |
| 55 | + |
| 56 | + # Test explicit BASIC validation |
| 57 | + with ZipFile(temp_path, 'r', strict_validation=ZipValidationLevel.BASIC) as zf: |
| 58 | + self.assertEqual(zf._strict_validation, ZipValidationLevel.BASIC) |
| 59 | + self.assertEqual(zf.read('test.txt'), b'Hello, World!') |
| 60 | + |
| 61 | + def test_validation_level_enum(self): |
| 62 | + """Test validation level enum values.""" |
| 63 | + self.assertEqual(ZipValidationLevel.BASIC, 0) |
| 64 | + self.assertEqual(ZipValidationLevel.STRUCTURAL, 1) |
| 65 | + self.assertEqual(ZipValidationLevel.STRICT, 2) |
| 66 | + |
| 67 | + # Test enum conversion |
| 68 | + self.assertEqual(ZipValidationLevel(0), ZipValidationLevel.BASIC) |
| 69 | + self.assertEqual(ZipValidationLevel(1), ZipValidationLevel.STRUCTURAL) |
| 70 | + self.assertEqual(ZipValidationLevel(2), ZipValidationLevel.STRICT) |
| 71 | + |
| 72 | + def test_structural_validation_valid_file(self): |
| 73 | + """Test structural validation with a valid ZIP file.""" |
| 74 | + temp_path = self.create_temp_file() |
| 75 | + with ZipFile(temp_path, 'w') as zf: |
| 76 | + zf.writestr('test.txt', 'Hello, World!') |
| 77 | + zf.writestr('dir/nested.txt', 'Nested content') |
| 78 | + |
| 79 | + # Should pass structural validation |
| 80 | + with ZipFile(temp_path, 'r', strict_validation=ZipValidationLevel.STRUCTURAL) as zf: |
| 81 | + self.assertEqual(len(zf.filelist), 2) |
| 82 | + self.assertEqual(zf.read('test.txt'), b'Hello, World!') |
| 83 | + self.assertEqual(zf.read('dir/nested.txt'), b'Nested content') |
| 84 | + |
| 85 | + def test_strict_validation_valid_file(self): |
| 86 | + """Test strict validation with a valid ZIP file.""" |
| 87 | + temp_path = self.create_temp_file() |
| 88 | + with ZipFile(temp_path, 'w') as zf: |
| 89 | + zf.writestr('test.txt', 'Hello, World!') |
| 90 | + |
| 91 | + # Should pass strict validation |
| 92 | + with ZipFile(temp_path, 'r', strict_validation=ZipValidationLevel.STRICT) as zf: |
| 93 | + self.assertEqual(zf.read('test.txt'), b'Hello, World!') |
| 94 | + |
| 95 | + def test_malformed_eocd_too_many_entries(self): |
| 96 | + """Test detection of EOCD with too many entries.""" |
| 97 | + # Create a basic ZIP file first |
| 98 | + temp_path = self.create_temp_file() |
| 99 | + with ZipFile(temp_path, 'w') as zf: |
| 100 | + zf.writestr('test.txt', 'Hello') |
| 101 | + |
| 102 | + # Read the file and modify the EOCD to claim too many entries |
| 103 | + with open(temp_path, 'rb') as f: |
| 104 | + data = bytearray(f.read()) |
| 105 | + |
| 106 | + # Find EOCD signature and modify entry count |
| 107 | + eocd_pos = data.rfind(stringEndArchive) |
| 108 | + if eocd_pos >= 0: |
| 109 | + # Modify total entries field to exceed limit (65535 is max for H format) |
| 110 | + struct.pack_into('<H', data, eocd_pos + 10, 65535) # _ECD_ENTRIES_TOTAL offset is 10 |
| 111 | + |
| 112 | + malformed_path = self.create_temp_file(data) |
| 113 | + |
| 114 | + # Should fail with structural validation - will catch entry count mismatch first |
| 115 | + with self.assertRaises(ZipStructuralError) as cm: |
| 116 | + with ZipFile(malformed_path, 'r', strict_validation=ZipValidationLevel.STRUCTURAL): |
| 117 | + pass |
| 118 | + # Could be either "Too many entries" or "Entry count mismatch" depending on which check runs first |
| 119 | + error_msg = str(cm.exception) |
| 120 | + self.assertTrue("Too many entries" in error_msg or "Entry count mismatch" in error_msg) |
| 121 | + |
| 122 | + # Should pass with basic validation (backward compatibility) |
| 123 | + with ZipFile(malformed_path, 'r', strict_validation=ZipValidationLevel.BASIC): |
| 124 | + pass |
| 125 | + |
| 126 | + def test_exception_hierarchy(self): |
| 127 | + """Test that new exceptions are subclasses of BadZipFile.""" |
| 128 | + self.assertTrue(issubclass(ZipStructuralError, BadZipFile)) |
| 129 | + self.assertTrue(issubclass(ZipValidationError, BadZipFile)) |
| 130 | + |
| 131 | + # Test exception creation |
| 132 | + exc1 = ZipStructuralError("Structure error") |
| 133 | + exc2 = ZipValidationError("Validation error") |
| 134 | + |
| 135 | + self.assertIsInstance(exc1, BadZipFile) |
| 136 | + self.assertIsInstance(exc2, BadZipFile) |
| 137 | + |
| 138 | + def test_compression_ratio_detection(self): |
| 139 | + """Test detection of suspicious compression ratios.""" |
| 140 | + # This is a simplified test - creating an actual zip bomb would be complex |
| 141 | + # Instead we'll test the validation logic directly |
| 142 | + from zipfile import _validate_zipinfo_fields, ZipInfo |
| 143 | + |
| 144 | + zinfo = ZipInfo('test.txt') |
| 145 | + zinfo.compress_size = 1 # 1 byte compressed |
| 146 | + zinfo.file_size = 2000 # 2000 bytes uncompressed (ratio = 2000) |
| 147 | + zinfo.header_offset = 0 |
| 148 | + zinfo.compress_type = zipfile.ZIP_DEFLATED |
| 149 | + |
| 150 | + # Should trigger zip bomb detection with ratio > 1000 |
| 151 | + with self.assertRaises(ZipStructuralError) as cm: |
| 152 | + _validate_zipinfo_fields(zinfo, ZipValidationLevel.STRUCTURAL) |
| 153 | + self.assertIn("Suspicious compression ratio", str(cm.exception)) |
| 154 | + |
| 155 | + def test_constructor_parameter_validation(self): |
| 156 | + """Test validation of constructor parameters.""" |
| 157 | + temp_path = self.create_temp_file() |
| 158 | + with ZipFile(temp_path, 'w') as zf: |
| 159 | + zf.writestr('test.txt', 'Hello') |
| 160 | + |
| 161 | + # Test invalid validation level |
| 162 | + with self.assertRaises(ValueError): |
| 163 | + ZipFile(temp_path, 'r', strict_validation=99) |
| 164 | + |
| 165 | + # Test valid validation levels |
| 166 | + for level in [ZipValidationLevel.BASIC, ZipValidationLevel.STRUCTURAL, ZipValidationLevel.STRICT]: |
| 167 | + with ZipFile(temp_path, 'r', strict_validation=level) as zf: |
| 168 | + self.assertEqual(zf._strict_validation, level) |
| 169 | + |
| 170 | + |
| 171 | +class TestValidationIntegration(unittest.TestCase): |
| 172 | + """Test integration of validation with existing zipfile functionality.""" |
| 173 | + |
| 174 | + def setUp(self): |
| 175 | + self.temp_files = [] |
| 176 | + |
| 177 | + def tearDown(self): |
| 178 | + for temp_file in self.temp_files: |
| 179 | + try: |
| 180 | + unlink(temp_file) |
| 181 | + except OSError: |
| 182 | + pass |
| 183 | + |
| 184 | + def create_temp_file(self, content=b''): |
| 185 | + fd, path = tempfile.mkstemp() |
| 186 | + self.temp_files.append(path) |
| 187 | + with os.fdopen(fd, 'wb') as f: |
| 188 | + f.write(content) |
| 189 | + return path |
| 190 | + |
| 191 | + def test_existing_methods_work_with_validation(self): |
| 192 | + """Test that existing ZipFile methods work with validation enabled.""" |
| 193 | + temp_path = self.create_temp_file() |
| 194 | + with ZipFile(temp_path, 'w') as zf: |
| 195 | + zf.writestr('test1.txt', 'Content 1') |
| 196 | + zf.writestr('test2.txt', 'Content 2') |
| 197 | + |
| 198 | + with ZipFile(temp_path, 'r', strict_validation=ZipValidationLevel.STRUCTURAL) as zf: |
| 199 | + # Test namelist |
| 200 | + names = zf.namelist() |
| 201 | + self.assertEqual(set(names), {'test1.txt', 'test2.txt'}) |
| 202 | + |
| 203 | + # Test infolist |
| 204 | + infos = zf.infolist() |
| 205 | + self.assertEqual(len(infos), 2) |
| 206 | + |
| 207 | + # Test getinfo |
| 208 | + info = zf.getinfo('test1.txt') |
| 209 | + self.assertEqual(info.filename, 'test1.txt') |
| 210 | + |
| 211 | + # Test read |
| 212 | + content = zf.read('test1.txt') |
| 213 | + self.assertEqual(content, b'Content 1') |
| 214 | + |
| 215 | + # Test testzip |
| 216 | + result = zf.testzip() |
| 217 | + self.assertIsNone(result) # No errors |
| 218 | + |
| 219 | + def test_validation_with_different_compression_methods(self): |
| 220 | + """Test validation works with different compression methods.""" |
| 221 | + temp_path = self.create_temp_file() |
| 222 | + with ZipFile(temp_path, 'w') as zf: |
| 223 | + # Test different compression methods |
| 224 | + zf.writestr('stored.txt', 'Stored content', compress_type=zipfile.ZIP_STORED) |
| 225 | + try: |
| 226 | + import zlib |
| 227 | + zf.writestr('deflated.txt', 'Deflated content', compress_type=zipfile.ZIP_DEFLATED) |
| 228 | + has_zlib = True |
| 229 | + except ImportError: |
| 230 | + has_zlib = False |
| 231 | + |
| 232 | + # Should work with structural validation |
| 233 | + with ZipFile(temp_path, 'r', strict_validation=ZipValidationLevel.STRUCTURAL) as zf: |
| 234 | + self.assertEqual(zf.read('stored.txt'), b'Stored content') |
| 235 | + if has_zlib: |
| 236 | + self.assertEqual(zf.read('deflated.txt'), b'Deflated content') |
| 237 | + |
| 238 | + |
| 239 | +if __name__ == '__main__': |
| 240 | + unittest.main() |
0 commit comments