Skip to content

Commit d2151c2

Browse files
committed
Allowing for sym-links, better refactoring
1 parent 532fe77 commit d2151c2

File tree

3 files changed

+26
-63
lines changed

3 files changed

+26
-63
lines changed

sagemaker-core/src/sagemaker/core/common_utils.py

Lines changed: 21 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,21 @@
7474
WAITING_DOT_NUMBER = 10
7575
MAX_ITEMS = 100
7676
PAGE_SIZE = 10
77+
_MAX_BUFFER_SIZE = 100 * 1024 * 1024 # 100 MB - Maximum buffer size for streaming iterators
78+
79+
_SENSITIVE_SYSTEM_PATHS = [
80+
abspath(os.path.expanduser("~/.aws")),
81+
abspath(os.path.expanduser("~/.ssh")),
82+
abspath(os.path.expanduser("~/.kube")),
83+
abspath(os.path.expanduser("~/.docker")),
84+
abspath(os.path.expanduser("~/.config")),
85+
abspath(os.path.expanduser("~/.credentials")),
86+
"/etc",
87+
"/root",
88+
"/home",
89+
"/var/lib",
90+
"/opt/ml/metadata",
91+
]
7792

7893
logger = logging.getLogger(__name__)
7994

@@ -622,35 +637,17 @@ def _validate_source_directory(source_directory):
622637
# S3 paths and None are safe
623638
return
624639

625-
abs_source = abspath(source_directory)
626-
627-
# Blocklist of sensitive directories that should not be accessible
628-
sensitive_paths = [
629-
abspath(os.path.expanduser("~/.aws")),
630-
abspath(os.path.expanduser("~/.ssh")),
631-
abspath(os.path.expanduser("~/.kube")),
632-
abspath(os.path.expanduser("~/.docker")),
633-
abspath(os.path.expanduser("~/.config")),
634-
abspath(os.path.expanduser("~/.credentials")),
635-
"/etc",
636-
"/root",
637-
"/home",
638-
"/var/lib",
639-
"/opt/ml/metadata",
640-
]
640+
# Resolve symlinks to get the actual path
641+
abs_source = abspath(realpath(source_directory))
641642

642643
# Check if the source path is under any sensitive directory
643-
for sensitive_path in sensitive_paths:
644+
for sensitive_path in _SENSITIVE_SYSTEM_PATHS:
644645
if abs_source.startswith(sensitive_path):
645646
raise ValueError(
646647
f"source_directory cannot access sensitive system paths. "
647648
f"Got: {source_directory} (resolved to {abs_source})"
648649
)
649650

650-
# Check for symlinks to prevent symlink-based escapes
651-
if os.path.islink(abs_source):
652-
raise ValueError(f"source_directory cannot be a symlink: {source_directory}")
653-
654651

655652
def _validate_dependency_path(dependency):
656653
"""Validate that a dependency path is safe to use.
@@ -666,35 +663,17 @@ def _validate_dependency_path(dependency):
666663
if not dependency:
667664
return
668665

669-
abs_dependency = abspath(dependency)
670-
671-
# Blocklist of sensitive directories that should not be accessible
672-
sensitive_paths = [
673-
abspath(os.path.expanduser("~/.aws")),
674-
abspath(os.path.expanduser("~/.ssh")),
675-
abspath(os.path.expanduser("~/.kube")),
676-
abspath(os.path.expanduser("~/.docker")),
677-
abspath(os.path.expanduser("~/.config")),
678-
abspath(os.path.expanduser("~/.credentials")),
679-
"/etc",
680-
"/root",
681-
"/home",
682-
"/var/lib",
683-
"/opt/ml/metadata",
684-
]
666+
# Resolve symlinks to get the actual path
667+
abs_dependency = abspath(realpath(dependency))
685668

686669
# Check if the dependency path is under any sensitive directory
687-
for sensitive_path in sensitive_paths:
670+
for sensitive_path in _SENSITIVE_SYSTEM_PATHS:
688671
if abs_dependency.startswith(sensitive_path):
689672
raise ValueError(
690673
f"dependency path cannot access sensitive system paths. "
691674
f"Got: {dependency} (resolved to {abs_dependency})"
692675
)
693676

694-
# Check for symlinks to prevent symlink-based escapes
695-
if os.path.islink(abs_dependency):
696-
raise ValueError(f"dependency path cannot be a symlink: {dependency}")
697-
698677

699678
def _create_or_update_code_dir(
700679
model_dir, inference_script, source_directory, dependencies, sagemaker_session, tmp

sagemaker-core/src/sagemaker/core/iterators.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import io
1818

1919
from sagemaker.core.exceptions import ModelStreamError, InternalStreamFailure
20+
from sagemaker.core.common_utils import _MAX_BUFFER_SIZE
2021

2122

2223
def handle_stream_errors(chunk):
@@ -114,9 +115,6 @@ def __next__(self):
114115
class LineIterator(BaseIterator):
115116
"""A helper class for parsing the byte Event Stream input to provide Line iteration."""
116117

117-
# Maximum buffer size to prevent unbounded memory consumption (10 MB)
118-
MAX_BUFFER_SIZE = 10 * 1024 * 1024
119-
120118
def __init__(self, event_stream):
121119
"""Initialises a LineIterator Iterator object
122120
@@ -189,9 +187,9 @@ def __next__(self):
189187
# Check buffer size before writing to prevent unbounded memory consumption
190188
chunk_size = len(chunk["PayloadPart"]["Bytes"])
191189
current_size = self.buffer.getbuffer().nbytes
192-
if current_size + chunk_size > self.MAX_BUFFER_SIZE:
190+
if current_size + chunk_size > _MAX_BUFFER_SIZE:
193191
raise RuntimeError(
194-
f"Line buffer exceeded maximum size of {self.MAX_BUFFER_SIZE} bytes. "
192+
f"Line buffer exceeded maximum size of {_MAX_BUFFER_SIZE} bytes. "
195193
f"No newline found in stream."
196194
)
197195

sagemaker-core/src/sagemaker/core/local/data.py

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from six.moves.urllib.parse import urlparse
2525

2626
import sagemaker.core
27+
from sagemaker.core.common_utils import _SENSITIVE_SYSTEM_PATHS
2728

2829

2930
def get_data_source_instance(data_source, sagemaker_session):
@@ -116,28 +117,13 @@ def get_root_dir(self):
116117
class LocalFileDataSource(DataSource):
117118
"""Represents a data source within the local filesystem."""
118119

119-
# Blocklist of sensitive directories that should not be accessible
120-
RESTRICTED_PATHS = [
121-
os.path.abspath(os.path.expanduser("~/.aws")),
122-
os.path.abspath(os.path.expanduser("~/.ssh")),
123-
os.path.abspath(os.path.expanduser("~/.kube")),
124-
os.path.abspath(os.path.expanduser("~/.docker")),
125-
os.path.abspath(os.path.expanduser("~/.config")),
126-
os.path.abspath(os.path.expanduser("~/.credentials")),
127-
"/etc",
128-
"/root",
129-
"/home",
130-
"/var/lib",
131-
"/opt/ml/metadata",
132-
]
133-
134120
def __init__(self, root_path):
135121
super(LocalFileDataSource, self).__init__()
136122

137123
self.root_path = os.path.abspath(root_path)
138124

139125
# Validate that the path is not in restricted locations
140-
for restricted_path in self.RESTRICTED_PATHS:
126+
for restricted_path in _SENSITIVE_SYSTEM_PATHS:
141127
if self.root_path.startswith(restricted_path):
142128
raise ValueError(
143129
f"Local Mode does not support mounting from restricted system paths. "

0 commit comments

Comments
 (0)