Skip to content

Commit 78d88af

Browse files
committed
Add input validation and resource management improvements
1 parent 708c7b2 commit 78d88af

File tree

5 files changed

+167
-8
lines changed

5 files changed

+167
-8
lines changed

src/sagemaker/iterators.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,9 @@ def __next__(self):
114114
class LineIterator(BaseIterator):
115115
"""A helper class for parsing the byte Event Stream input to provide Line iteration."""
116116

117+
# Maximum buffer size to prevent unbounded memory consumption (10 MB)
118+
MAX_BUFFER_SIZE = 10 * 1024 * 1024
119+
117120
def __init__(self, event_stream):
118121
"""Initialises a LineIterator Iterator object
119122
@@ -182,5 +185,15 @@ def __next__(self):
182185
# print and move on to next response byte
183186
print("Unknown event type:" + chunk)
184187
continue
188+
189+
# Check buffer size before writing to prevent unbounded memory consumption
190+
chunk_size = len(chunk["PayloadPart"]["Bytes"])
191+
current_size = self.buffer.getbuffer().nbytes
192+
if current_size + chunk_size > self.MAX_BUFFER_SIZE:
193+
raise RuntimeError(
194+
f"Line buffer exceeded maximum size of {self.MAX_BUFFER_SIZE} bytes. "
195+
f"No newline found in stream."
196+
)
197+
185198
self.buffer.seek(0, io.SEEK_END)
186199
self.buffer.write(chunk["PayloadPart"]["Bytes"])

src/sagemaker/local/data.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,10 +118,34 @@ def get_root_dir(self):
118118
class LocalFileDataSource(DataSource):
119119
"""Represents a data source within the local filesystem."""
120120

121+
# Blocklist of sensitive directories that should not be accessible
122+
RESTRICTED_PATHS = [
123+
os.path.abspath(os.path.expanduser("~/.aws")),
124+
os.path.abspath(os.path.expanduser("~/.ssh")),
125+
os.path.abspath(os.path.expanduser("~/.kube")),
126+
os.path.abspath(os.path.expanduser("~/.docker")),
127+
os.path.abspath(os.path.expanduser("~/.config")),
128+
os.path.abspath(os.path.expanduser("~/.credentials")),
129+
"/etc",
130+
"/root",
131+
"/home",
132+
"/var/lib",
133+
"/opt/ml/metadata",
134+
]
135+
121136
def __init__(self, root_path):
122137
super(LocalFileDataSource, self).__init__()
123138

124139
self.root_path = os.path.abspath(root_path)
140+
141+
# Validate that the path is not in restricted locations
142+
for restricted_path in self.RESTRICTED_PATHS:
143+
if self.root_path.startswith(restricted_path):
144+
raise ValueError(
145+
f"Local Mode does not support mounting from restricted system paths. "
146+
f"Got: {root_path}"
147+
)
148+
125149
if not os.path.exists(self.root_path):
126150
raise RuntimeError("Invalid data source: %s does not exist." % self.root_path)
127151

src/sagemaker/local/utils.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,7 @@ def copy_directory_structure(destination_directory, relative_path):
4848
destination_directory
4949
"""
5050
full_path = os.path.join(destination_directory, relative_path)
51-
if os.path.exists(full_path):
52-
return
53-
54-
os.makedirs(destination_directory, relative_path)
51+
os.makedirs(full_path, exist_ok=True)
5552

5653

5754
def move_to_destination(source, destination, job_name, sagemaker_session, prefix=""):

src/sagemaker/utils.py

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -601,6 +601,95 @@ def _save_model(repacked_model_uri, tmp_model_path, sagemaker_session, kms_key):
601601
shutil.move(tmp_model_path, repacked_model_uri.replace("file://", ""))
602602

603603

604+
def _validate_source_directory(source_directory):
605+
"""Validate that source_directory is safe to use.
606+
607+
Ensures the source directory path does not access restricted system locations.
608+
609+
Args:
610+
source_directory (str): The source directory path to validate.
611+
612+
Raises:
613+
ValueError: If the path is not allowed.
614+
"""
615+
if not source_directory or source_directory.lower().startswith("s3://"):
616+
# S3 paths and None are safe
617+
return
618+
619+
abs_source = abspath(source_directory)
620+
621+
# Blocklist of sensitive directories that should not be accessible
622+
sensitive_paths = [
623+
abspath(os.path.expanduser("~/.aws")),
624+
abspath(os.path.expanduser("~/.ssh")),
625+
abspath(os.path.expanduser("~/.kube")),
626+
abspath(os.path.expanduser("~/.docker")),
627+
abspath(os.path.expanduser("~/.config")),
628+
abspath(os.path.expanduser("~/.credentials")),
629+
"/etc",
630+
"/root",
631+
"/home",
632+
"/var/lib",
633+
"/opt/ml/metadata",
634+
]
635+
636+
# Check if the source path is under any sensitive directory
637+
for sensitive_path in sensitive_paths:
638+
if abs_source.startswith(sensitive_path):
639+
raise ValueError(
640+
f"source_directory cannot access sensitive system paths. "
641+
f"Got: {source_directory} (resolved to {abs_source})"
642+
)
643+
644+
# Check for symlinks to prevent symlink-based escapes
645+
if os.path.islink(abs_source):
646+
raise ValueError(f"source_directory cannot be a symlink: {source_directory}")
647+
648+
649+
def _validate_dependency_path(dependency):
650+
"""Validate that a dependency path is safe to use.
651+
652+
Ensures the dependency path does not access restricted system locations.
653+
654+
Args:
655+
dependency (str): The dependency path to validate.
656+
657+
Raises:
658+
ValueError: If the path is not allowed.
659+
"""
660+
if not dependency:
661+
return
662+
663+
abs_dependency = abspath(dependency)
664+
665+
# Blocklist of sensitive directories that should not be accessible
666+
sensitive_paths = [
667+
abspath(os.path.expanduser("~/.aws")),
668+
abspath(os.path.expanduser("~/.ssh")),
669+
abspath(os.path.expanduser("~/.kube")),
670+
abspath(os.path.expanduser("~/.docker")),
671+
abspath(os.path.expanduser("~/.config")),
672+
abspath(os.path.expanduser("~/.credentials")),
673+
"/etc",
674+
"/root",
675+
"/home",
676+
"/var/lib",
677+
"/opt/ml/metadata",
678+
]
679+
680+
# Check if the dependency path is under any sensitive directory
681+
for sensitive_path in sensitive_paths:
682+
if abs_dependency.startswith(sensitive_path):
683+
raise ValueError(
684+
f"dependency path cannot access sensitive system paths. "
685+
f"Got: {dependency} (resolved to {abs_dependency})"
686+
)
687+
688+
# Check for symlinks to prevent symlink-based escapes
689+
if os.path.islink(abs_dependency):
690+
raise ValueError(f"dependency path cannot be a symlink: {dependency}")
691+
692+
604693
def _create_or_update_code_dir(
605694
model_dir, inference_script, source_directory, dependencies, sagemaker_session, tmp
606695
):
@@ -614,6 +703,8 @@ def _create_or_update_code_dir(
614703
custom_extractall_tarfile(t, code_dir)
615704

616705
elif source_directory:
706+
# Validate source_directory for security
707+
_validate_source_directory(source_directory)
617708
if os.path.exists(code_dir):
618709
shutil.rmtree(code_dir)
619710
shutil.copytree(source_directory, code_dir)
@@ -646,6 +737,8 @@ def _create_or_update_code_dir(
646737
)
647738

648739
for dependency in dependencies:
740+
# Validate dependency path for security
741+
_validate_dependency_path(dependency)
649742
lib_dir = os.path.join(code_dir, "lib")
650743
if os.path.isdir(dependency):
651744
shutil.copytree(dependency, os.path.join(lib_dir, os.path.basename(dependency)))
@@ -1620,6 +1713,38 @@ def _get_safe_members(members):
16201713
yield file_info
16211714

16221715

1716+
def _validate_extracted_paths(extract_path):
1717+
"""Validate that extracted paths remain within the expected directory.
1718+
1719+
Performs post-extraction validation to ensure all extracted files and directories
1720+
are within the intended extraction path.
1721+
1722+
Args:
1723+
extract_path (str): The path where files were extracted.
1724+
1725+
Raises:
1726+
ValueError: If any extracted file is outside the expected extraction path.
1727+
"""
1728+
base = _get_resolved_path(extract_path)
1729+
1730+
for root, dirs, files in os.walk(extract_path):
1731+
# Check directories
1732+
for dir_name in dirs:
1733+
dir_path = os.path.join(root, dir_name)
1734+
resolved = _get_resolved_path(dir_path)
1735+
if not resolved.startswith(base):
1736+
logger.error("Extracted directory escaped extraction path: %s", dir_path)
1737+
raise ValueError(f"Extracted path outside expected directory: {dir_path}")
1738+
1739+
# Check files
1740+
for file_name in files:
1741+
file_path = os.path.join(root, file_name)
1742+
resolved = _get_resolved_path(file_path)
1743+
if not resolved.startswith(base):
1744+
logger.error("Extracted file escaped extraction path: %s", file_path)
1745+
raise ValueError(f"Extracted path outside expected directory: {file_path}")
1746+
1747+
16231748
def custom_extractall_tarfile(tar, extract_path):
16241749
"""Extract a tarfile, optionally using data_filter if available.
16251750
@@ -1640,6 +1765,8 @@ def custom_extractall_tarfile(tar, extract_path):
16401765
tar.extractall(path=extract_path, filter="data")
16411766
else:
16421767
tar.extractall(path=extract_path, members=_get_safe_members(tar))
1768+
# Re-validate extracted paths to catch symlink race conditions
1769+
_validate_extracted_paths(extract_path)
16431770

16441771

16451772
def can_model_package_source_uri_autopopulate(source_uri: str):

tests/unit/sagemaker/local/test_local_utils.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,10 @@
2222
from sagemaker.session_settings import SessionSettings
2323

2424

25-
@patch("sagemaker.local.utils.os.path")
2625
@patch("sagemaker.local.utils.os")
27-
def test_copy_directory_structure(m_os, m_os_path):
28-
m_os_path.exists.return_value = False
26+
def test_copy_directory_structure(m_os):
2927
sagemaker.local.utils.copy_directory_structure("/tmp/", "code/")
30-
m_os.makedirs.assert_called_with("/tmp/", "code/")
28+
m_os.makedirs.assert_called_with("/tmp/code/", exist_ok=True)
3129

3230

3331
@patch("shutil.rmtree", Mock())

0 commit comments

Comments
 (0)