|
| 1 | +From 5b0b4f9f8db6f0de51011ef40e662b940ff44c07 Mon Sep 17 00:00:00 2001 |
| 2 | +From: Fabien Hertschuh < [email protected]> |
| 3 | +Date: Mon, 20 Oct 2025 12:31:04 -0700 |
| 4 | +Subject: [PATCH] Use `filter="data"` option of `TarFile.extractall`. |
| 5 | + |
| 6 | +For Python versions between 3.12 (inclusive) and 3.14 (exclusive). |
| 7 | + |
| 8 | +The "data" filter performs a number of additional checks on links and paths. The `filter` option was added in Python 3.12. The `filter="data"` option became the default in Python 3.14. |
| 9 | + |
| 10 | +Also: |
| 11 | +- added similar path filtering when extracting zip archives |
| 12 | +- shared the extraction code between `file_utils` and `saving_lib` |
| 13 | + |
| 14 | +Signed-off-by: Azure Linux Security Servicing Account < [email protected]> |
| 15 | +Upstream-reference: https://patch-diff.githubusercontent.com/raw/keras-team/keras/pull/21760.patch |
| 16 | +--- |
| 17 | + keras/src/saving/saving_lib.py | 2 +- |
| 18 | + keras/src/utils/file_utils.py | 60 ++++++++++++++++++++++++------ |
| 19 | + keras/src/utils/file_utils_test.py | 8 ++-- |
| 20 | + 3 files changed, 54 insertions(+), 16 deletions(-) |
| 21 | + |
| 22 | +diff --git a/keras/src/saving/saving_lib.py b/keras/src/saving/saving_lib.py |
| 23 | +index 94b9561..1668489 100644 |
| 24 | +--- a/keras/src/saving/saving_lib.py |
| 25 | ++++ b/keras/src/saving/saving_lib.py |
| 26 | +@@ -556,7 +556,7 @@ class DiskIOStore: |
| 27 | + if self.archive: |
| 28 | + self.tmp_dir = get_temp_dir() |
| 29 | + if self.mode == "r": |
| 30 | +- self.archive.extractall(path=self.tmp_dir) |
| 31 | ++ file_utils.extract_open_archive(self.archive, self.tmp_dir) |
| 32 | + self.working_dir = file_utils.join( |
| 33 | + self.tmp_dir, self.root_path |
| 34 | + ).replace("\\", "/") |
| 35 | +diff --git a/keras/src/utils/file_utils.py b/keras/src/utils/file_utils.py |
| 36 | +index e625a9f..c52c352 100644 |
| 37 | +--- a/keras/src/utils/file_utils.py |
| 38 | ++++ b/keras/src/utils/file_utils.py |
| 39 | +@@ -3,6 +3,7 @@ import os |
| 40 | + import pathlib |
| 41 | + import re |
| 42 | + import shutil |
| 43 | ++import sys |
| 44 | + import tarfile |
| 45 | + import urllib |
| 46 | + import warnings |
| 47 | +@@ -50,17 +51,32 @@ def is_link_in_dir(info, base): |
| 48 | + return is_path_in_dir(info.linkname, base_dir=tip) |
| 49 | + |
| 50 | + |
| 51 | +-def filter_safe_paths(members): |
| 52 | ++def filter_safe_zipinfos(members): |
| 53 | + base_dir = resolve_path(".") |
| 54 | + for finfo in members: |
| 55 | + valid_path = False |
| 56 | +- if is_path_in_dir(finfo.name, base_dir): |
| 57 | ++ if is_path_in_dir(finfo.filename, base_dir): |
| 58 | + valid_path = True |
| 59 | + yield finfo |
| 60 | +- elif finfo.issym() or finfo.islnk(): |
| 61 | ++ if not valid_path: |
| 62 | ++ warnings.warn( |
| 63 | ++ "Skipping invalid path during archive extraction: " |
| 64 | ++ f"'{finfo.name}'.", |
| 65 | ++ stacklevel=2, |
| 66 | ++ ) |
| 67 | ++ |
| 68 | ++ |
| 69 | ++def filter_safe_tarinfos(members): |
| 70 | ++ base_dir = resolve_path(".") |
| 71 | ++ for finfo in members: |
| 72 | ++ valid_path = False |
| 73 | ++ if finfo.issym() or finfo.islnk(): |
| 74 | + if is_link_in_dir(finfo, base_dir): |
| 75 | + valid_path = True |
| 76 | + yield finfo |
| 77 | ++ elif is_path_in_dir(finfo.name, base_dir): |
| 78 | ++ valid_path = True |
| 79 | ++ yield finfo |
| 80 | + if not valid_path: |
| 81 | + warnings.warn( |
| 82 | + "Skipping invalid path during archive extraction: " |
| 83 | +@@ -69,6 +85,35 @@ def filter_safe_paths(members): |
| 84 | + ) |
| 85 | + |
| 86 | + |
| 87 | ++def extract_open_archive(archive, path="."): |
| 88 | ++ """Extracts an open tar or zip archive to the provided directory. |
| 89 | ++ |
| 90 | ++ This function filters unsafe paths during extraction. |
| 91 | ++ |
| 92 | ++ Args: |
| 93 | ++ archive: The archive object, either a `TarFile` or a `ZipFile`. |
| 94 | ++ path: Where to extract the archive file. |
| 95 | ++ """ |
| 96 | ++ if isinstance(archive, zipfile.ZipFile): |
| 97 | ++ # Zip archive. |
| 98 | ++ archive.extractall( |
| 99 | ++ path, members=filter_safe_zipinfos(archive.infolist()) |
| 100 | ++ ) |
| 101 | ++ else: |
| 102 | ++ # Tar archive. |
| 103 | ++ extractall_kwargs = {} |
| 104 | ++ # The `filter="data"` option was added in Python 3.12. It became the |
| 105 | ++ # default starting from Python 3.14. So we only specify it between |
| 106 | ++ # those two versions. |
| 107 | ++ if sys.version_info >= (3, 12) and sys.version_info < (3, 14): |
| 108 | ++ extractall_kwargs = {"filter": "data"} |
| 109 | ++ archive.extractall( |
| 110 | ++ path, |
| 111 | ++ members=filter_safe_tarinfos(archive), |
| 112 | ++ **extractall_kwargs, |
| 113 | ++ ) |
| 114 | ++ |
| 115 | ++ |
| 116 | + def extract_archive(file_path, path=".", archive_format="auto"): |
| 117 | + """Extracts an archive if it matches a support format. |
| 118 | + |
| 119 | +@@ -108,14 +153,7 @@ def extract_archive(file_path, path=".", archive_format="auto"): |
| 120 | + if is_match_fn(file_path): |
| 121 | + with open_fn(file_path) as archive: |
| 122 | + try: |
| 123 | +- if zipfile.is_zipfile(file_path): |
| 124 | +- # Zip archive. |
| 125 | +- archive.extractall(path) |
| 126 | +- else: |
| 127 | +- # Tar archive, perhaps unsafe. Filter paths. |
| 128 | +- archive.extractall( |
| 129 | +- path, members=filter_safe_paths(archive) |
| 130 | +- ) |
| 131 | ++ extract_open_archive(archive, path) |
| 132 | + except (tarfile.TarError, RuntimeError, KeyboardInterrupt): |
| 133 | + if os.path.exists(path): |
| 134 | + if os.path.isfile(path): |
| 135 | +diff --git a/keras/src/utils/file_utils_test.py b/keras/src/utils/file_utils_test.py |
| 136 | +index c09f47a..c39314e 100644 |
| 137 | +--- a/keras/src/utils/file_utils_test.py |
| 138 | ++++ b/keras/src/utils/file_utils_test.py |
| 139 | +@@ -139,7 +139,7 @@ class FilterSafePathsTest(test_case.TestCase): |
| 140 | + with tarfile.open(self.tar_path, "w") as tar: |
| 141 | + tar.add(__file__, arcname="safe_path.txt") |
| 142 | + with tarfile.open(self.tar_path, "r") as tar: |
| 143 | +- members = list(file_utils.filter_safe_paths(tar.getmembers())) |
| 144 | ++ members = list(file_utils.filter_safe_tarinfos(tar.getmembers())) |
| 145 | + self.assertEqual(len(members), 1) |
| 146 | + self.assertEqual(members[0].name, "safe_path.txt") |
| 147 | + |
| 148 | +@@ -153,7 +153,7 @@ class FilterSafePathsTest(test_case.TestCase): |
| 149 | + with tarfile.open(self.tar_path, "w") as tar: |
| 150 | + tar.add(symlink_path, arcname="symlink.txt") |
| 151 | + with tarfile.open(self.tar_path, "r") as tar: |
| 152 | +- members = list(file_utils.filter_safe_paths(tar.getmembers())) |
| 153 | ++ members = list(file_utils.filter_safe_tarinfos(tar.getmembers())) |
| 154 | + self.assertEqual(len(members), 1) |
| 155 | + self.assertEqual(members[0].name, "symlink.txt") |
| 156 | + os.remove(symlink_path) |
| 157 | +@@ -170,7 +170,7 @@ class FilterSafePathsTest(test_case.TestCase): |
| 158 | + ) # Path intended to be outside of base dir |
| 159 | + with tarfile.open(self.tar_path, "r") as tar: |
| 160 | + with patch("warnings.warn") as mock_warn: |
| 161 | +- _ = list(file_utils.filter_safe_paths(tar.getmembers())) |
| 162 | ++ _ = list(file_utils.filter_safe_tarinfos(tar.getmembers())) |
| 163 | + warning_msg = ( |
| 164 | + "Skipping invalid path during archive extraction: " |
| 165 | + "'../../invalid.txt'." |
| 166 | +@@ -193,7 +193,7 @@ class FilterSafePathsTest(test_case.TestCase): |
| 167 | + tar.add(symlink_path, arcname="symlink.txt") |
| 168 | + |
| 169 | + with tarfile.open(self.tar_path, "r") as tar: |
| 170 | +- members = list(file_utils.filter_safe_paths(tar.getmembers())) |
| 171 | ++ members = list(file_utils.filter_safe_tarinfos(tar.getmembers())) |
| 172 | + self.assertEqual(len(members), 1) |
| 173 | + self.assertEqual(members[0].name, "symlink.txt") |
| 174 | + self.assertTrue( |
| 175 | +-- |
| 176 | +2.45.4 |
| 177 | + |
0 commit comments