Skip to content

Commit 206c07e

Browse files
committed
Allowing for sym-links, better refactoring
1 parent db68076 commit 206c07e

File tree

3 files changed

+26
-63
lines changed

3 files changed

+26
-63
lines changed

src/sagemaker/iterators.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import io
1818

1919
from sagemaker.exceptions import ModelStreamError, InternalStreamFailure
20+
from sagemaker.utils import _MAX_BUFFER_SIZE
2021

2122

2223
def handle_stream_errors(chunk):
@@ -114,9 +115,6 @@ def __next__(self):
114115
class LineIterator(BaseIterator):
115116
"""A helper class for parsing the byte Event Stream input to provide Line iteration."""
116117

117-
# Maximum buffer size to prevent unbounded memory consumption (10 MB)
118-
MAX_BUFFER_SIZE = 10 * 1024 * 1024
119-
120118
def __init__(self, event_stream):
121119
"""Initialises a LineIterator Iterator object
122120
@@ -189,9 +187,9 @@ def __next__(self):
189187
# Check buffer size before writing to prevent unbounded memory consumption
190188
chunk_size = len(chunk["PayloadPart"]["Bytes"])
191189
current_size = self.buffer.getbuffer().nbytes
192-
if current_size + chunk_size > self.MAX_BUFFER_SIZE:
190+
if current_size + chunk_size > _MAX_BUFFER_SIZE:
193191
raise RuntimeError(
194-
f"Line buffer exceeded maximum size of {self.MAX_BUFFER_SIZE} bytes. "
192+
f"Line buffer exceeded maximum size of {_MAX_BUFFER_SIZE} bytes. "
195193
f"No newline found in stream."
196194
)
197195

src/sagemaker/local/data.py

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import sagemaker.amazon.common
2727
import sagemaker.local.utils
2828
import sagemaker.utils
29+
from sagemaker.utils import _SENSITIVE_SYSTEM_PATHS
2930

3031

3132
def get_data_source_instance(data_source, sagemaker_session):
@@ -118,28 +119,13 @@ def get_root_dir(self):
118119
class LocalFileDataSource(DataSource):
119120
"""Represents a data source within the local filesystem."""
120121

121-
# Blocklist of sensitive directories that should not be accessible
122-
RESTRICTED_PATHS = [
123-
os.path.abspath(os.path.expanduser("~/.aws")),
124-
os.path.abspath(os.path.expanduser("~/.ssh")),
125-
os.path.abspath(os.path.expanduser("~/.kube")),
126-
os.path.abspath(os.path.expanduser("~/.docker")),
127-
os.path.abspath(os.path.expanduser("~/.config")),
128-
os.path.abspath(os.path.expanduser("~/.credentials")),
129-
"/etc",
130-
"/root",
131-
"/home",
132-
"/var/lib",
133-
"/opt/ml/metadata",
134-
]
135-
136122
def __init__(self, root_path):
137123
super(LocalFileDataSource, self).__init__()
138124

139125
self.root_path = os.path.abspath(root_path)
140126

141127
# Validate that the path is not in restricted locations
142-
for restricted_path in self.RESTRICTED_PATHS:
128+
for restricted_path in _SENSITIVE_SYSTEM_PATHS:
143129
if self.root_path.startswith(restricted_path):
144130
raise ValueError(
145131
f"Local Mode does not support mounting from restricted system paths. "

src/sagemaker/utils.py

Lines changed: 21 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,21 @@
7676
WAITING_DOT_NUMBER = 10
7777
MAX_ITEMS = 100
7878
PAGE_SIZE = 10
79+
_MAX_BUFFER_SIZE = 100 * 1024 * 1024 # 100 MB - Maximum buffer size for streaming iterators
80+
81+
_SENSITIVE_SYSTEM_PATHS = [
82+
abspath(os.path.expanduser("~/.aws")),
83+
abspath(os.path.expanduser("~/.ssh")),
84+
abspath(os.path.expanduser("~/.kube")),
85+
abspath(os.path.expanduser("~/.docker")),
86+
abspath(os.path.expanduser("~/.config")),
87+
abspath(os.path.expanduser("~/.credentials")),
88+
"/etc",
89+
"/root",
90+
"/home",
91+
"/var/lib",
92+
"/opt/ml/metadata",
93+
]
7994

8095
logger = logging.getLogger(__name__)
8196

@@ -616,35 +631,17 @@ def _validate_source_directory(source_directory):
616631
# S3 paths and None are safe
617632
return
618633

619-
abs_source = abspath(source_directory)
620-
621-
# Blocklist of sensitive directories that should not be accessible
622-
sensitive_paths = [
623-
abspath(os.path.expanduser("~/.aws")),
624-
abspath(os.path.expanduser("~/.ssh")),
625-
abspath(os.path.expanduser("~/.kube")),
626-
abspath(os.path.expanduser("~/.docker")),
627-
abspath(os.path.expanduser("~/.config")),
628-
abspath(os.path.expanduser("~/.credentials")),
629-
"/etc",
630-
"/root",
631-
"/home",
632-
"/var/lib",
633-
"/opt/ml/metadata",
634-
]
634+
# Resolve symlinks to get the actual path
635+
abs_source = abspath(realpath(source_directory))
635636

636637
# Check if the source path is under any sensitive directory
637-
for sensitive_path in sensitive_paths:
638+
for sensitive_path in _SENSITIVE_SYSTEM_PATHS:
638639
if abs_source.startswith(sensitive_path):
639640
raise ValueError(
640641
f"source_directory cannot access sensitive system paths. "
641642
f"Got: {source_directory} (resolved to {abs_source})"
642643
)
643644

644-
# Check for symlinks to prevent symlink-based escapes
645-
if os.path.islink(abs_source):
646-
raise ValueError(f"source_directory cannot be a symlink: {source_directory}")
647-
648645

649646
def _validate_dependency_path(dependency):
650647
"""Validate that a dependency path is safe to use.
@@ -660,35 +657,17 @@ def _validate_dependency_path(dependency):
660657
if not dependency:
661658
return
662659

663-
abs_dependency = abspath(dependency)
664-
665-
# Blocklist of sensitive directories that should not be accessible
666-
sensitive_paths = [
667-
abspath(os.path.expanduser("~/.aws")),
668-
abspath(os.path.expanduser("~/.ssh")),
669-
abspath(os.path.expanduser("~/.kube")),
670-
abspath(os.path.expanduser("~/.docker")),
671-
abspath(os.path.expanduser("~/.config")),
672-
abspath(os.path.expanduser("~/.credentials")),
673-
"/etc",
674-
"/root",
675-
"/home",
676-
"/var/lib",
677-
"/opt/ml/metadata",
678-
]
660+
# Resolve symlinks to get the actual path
661+
abs_dependency = abspath(realpath(dependency))
679662

680663
# Check if the dependency path is under any sensitive directory
681-
for sensitive_path in sensitive_paths:
664+
for sensitive_path in _SENSITIVE_SYSTEM_PATHS:
682665
if abs_dependency.startswith(sensitive_path):
683666
raise ValueError(
684667
f"dependency path cannot access sensitive system paths. "
685668
f"Got: {dependency} (resolved to {abs_dependency})"
686669
)
687670

688-
# Check for symlinks to prevent symlink-based escapes
689-
if os.path.islink(abs_dependency):
690-
raise ValueError(f"dependency path cannot be a symlink: {dependency}")
691-
692671

693672
def _create_or_update_code_dir(
694673
model_dir, inference_script, source_directory, dependencies, sagemaker_session, tmp

0 commit comments

Comments
 (0)