Skip to content

Commit de8bc1e

Browse files
committed
Adding additional validation and removing home as sensitive path
1 parent 206c07e commit de8bc1e

File tree

2 files changed

+172
-7
lines changed

2 files changed

+172
-7
lines changed

src/sagemaker/utils.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -85,11 +85,10 @@
8585
abspath(os.path.expanduser("~/.docker")),
8686
abspath(os.path.expanduser("~/.config")),
8787
abspath(os.path.expanduser("~/.credentials")),
88-
"/etc",
89-
"/root",
90-
"/home",
91-
"/var/lib",
92-
"/opt/ml/metadata",
88+
abspath(realpath("/etc")),
89+
abspath(realpath("/root")),
90+
abspath(realpath("/var/lib")),
91+
abspath(realpath("/opt/ml/metadata")),
9392
]
9493

9594
logger = logging.getLogger(__name__)
@@ -636,7 +635,7 @@ def _validate_source_directory(source_directory):
636635

637636
# Check if the source path is under any sensitive directory
638637
for sensitive_path in _SENSITIVE_SYSTEM_PATHS:
639-
if abs_source.startswith(sensitive_path):
638+
if abs_source != "/" and abs_source.startswith(sensitive_path):
640639
raise ValueError(
641640
f"source_directory cannot access sensitive system paths. "
642641
f"Got: {source_directory} (resolved to {abs_source})"
@@ -662,7 +661,7 @@ def _validate_dependency_path(dependency):
662661

663662
# Check if the dependency path is under any sensitive directory
664663
for sensitive_path in _SENSITIVE_SYSTEM_PATHS:
665-
if abs_dependency.startswith(sensitive_path):
664+
if abs_dependency != "/" and abs_dependency.startswith(sensitive_path):
666665
raise ValueError(
667666
f"dependency path cannot access sensitive system paths. "
668667
f"Got: {dependency} (resolved to {abs_dependency})"
@@ -674,6 +673,15 @@ def _create_or_update_code_dir(
674673
):
675674
"""Placeholder docstring"""
676675
code_dir = os.path.join(model_dir, "code")
676+
resolved_code_dir = _get_resolved_path(code_dir)
677+
678+
# Validate that code_dir does not resolve to a sensitive system path
679+
for sensitive_path in _SENSITIVE_SYSTEM_PATHS:
680+
if resolved_code_dir.startswith(sensitive_path):
681+
raise ValueError(
682+
f"Invalid code_dir path: {code_dir} resolves to sensitive system path {resolved_code_dir}"
683+
)
684+
677685
if source_directory and source_directory.lower().startswith("s3://"):
678686
local_code_path = os.path.join(tmp, "local_code.tar.gz")
679687
download_file_from_url(source_directory, local_code_path, sagemaker_session)

tests/unit/test_utils.py

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2245,3 +2245,160 @@ def test_get_domain_for_region(self):
22452245
self.assertEqual(get_domain_for_region("us-iso-east-1"), "c2s.ic.gov")
22462246
self.assertEqual(get_domain_for_region("us-isob-east-1"), "sc2s.sgov.gov")
22472247
self.assertEqual(get_domain_for_region("invalid-region"), "amazonaws.com")
2248+
2249+
2250+
2251+
class TestValidateSourceDirectory(TestCase):
2252+
"""Tests for _validate_source_directory function"""
2253+
2254+
def test_validate_source_directory_with_s3_path(self):
2255+
"""S3 paths should be allowed"""
2256+
from sagemaker.utils import _validate_source_directory
2257+
# Should not raise any exception
2258+
_validate_source_directory("s3://my-bucket/my-prefix")
2259+
2260+
def test_validate_source_directory_with_none(self):
2261+
"""None should be allowed"""
2262+
from sagemaker.utils import _validate_source_directory
2263+
# Should not raise any exception
2264+
_validate_source_directory(None)
2265+
2266+
def test_validate_source_directory_with_safe_local_path(self):
2267+
"""Safe local paths should be allowed"""
2268+
from sagemaker.utils import _validate_source_directory
2269+
# Should not raise any exception
2270+
_validate_source_directory("/tmp/my_code")
2271+
_validate_source_directory("./my_code")
2272+
_validate_source_directory("../my_code")
2273+
2274+
def test_validate_source_directory_with_sensitive_path_aws(self):
2275+
"""Paths under ~/.aws should be rejected"""
2276+
from sagemaker.utils import _validate_source_directory
2277+
with pytest.raises(ValueError, match="cannot access sensitive system paths"):
2278+
_validate_source_directory(os.path.expanduser("~/.aws/credentials"))
2279+
2280+
def test_validate_source_directory_with_sensitive_path_ssh(self):
2281+
"""Paths under ~/.ssh should be rejected"""
2282+
from sagemaker.utils import _validate_source_directory
2283+
with pytest.raises(ValueError, match="cannot access sensitive system paths"):
2284+
_validate_source_directory(os.path.expanduser("~/.ssh/id_rsa"))
2285+
2286+
def test_validate_source_directory_with_root_directory(self):
2287+
"""Root directory itself should be allowed (not rejected)"""
2288+
from sagemaker.utils import _validate_source_directory
2289+
# Should not raise any exception - root directory is explicitly allowed
2290+
_validate_source_directory("/")
2291+
2292+
2293+
class TestValidateDependencyPath(TestCase):
2294+
"""Tests for _validate_dependency_path function"""
2295+
2296+
def test_validate_dependency_path_with_none(self):
2297+
"""None should be allowed"""
2298+
from sagemaker.utils import _validate_dependency_path
2299+
# Should not raise any exception
2300+
_validate_dependency_path(None)
2301+
2302+
def test_validate_dependency_path_with_safe_local_path(self):
2303+
"""Safe local paths should be allowed"""
2304+
from sagemaker.utils import _validate_dependency_path
2305+
# Should not raise any exception
2306+
_validate_dependency_path("/tmp/my_lib")
2307+
_validate_dependency_path("./my_lib")
2308+
_validate_dependency_path("../my_lib")
2309+
2310+
def test_validate_dependency_path_with_sensitive_path_aws(self):
2311+
"""Paths under ~/.aws should be rejected"""
2312+
from sagemaker.utils import _validate_dependency_path
2313+
with pytest.raises(ValueError, match="cannot access sensitive system paths"):
2314+
_validate_dependency_path(os.path.expanduser("~/.aws"))
2315+
2316+
def test_validate_dependency_path_with_sensitive_path_docker(self):
2317+
"""Paths under ~/.docker should be rejected"""
2318+
from sagemaker.utils import _validate_dependency_path
2319+
with pytest.raises(ValueError, match="cannot access sensitive system paths"):
2320+
_validate_dependency_path(os.path.expanduser("~/.docker/config.json"))
2321+
2322+
def test_validate_dependency_path_with_root_directory(self):
2323+
"""Root directory itself should be allowed (not rejected)"""
2324+
from sagemaker.utils import _validate_dependency_path
2325+
# Should not raise any exception - root directory is explicitly allowed
2326+
_validate_dependency_path("/")
2327+
2328+
2329+
class TestCreateOrUpdateCodeDir(TestCase):
2330+
"""Tests for _create_or_update_code_dir function"""
2331+
2332+
@patch("sagemaker.utils._validate_source_directory")
2333+
@patch("sagemaker.utils._validate_dependency_path")
2334+
@patch("sagemaker.utils.os.path.exists")
2335+
@patch("sagemaker.utils.os.mkdir")
2336+
@patch("sagemaker.utils.shutil.copy2")
2337+
def test_create_or_update_code_dir_with_inference_script(
2338+
self, mock_copy, mock_mkdir, mock_exists, mock_validate_dep, mock_validate_src
2339+
):
2340+
"""Test creating code dir with inference script"""
2341+
from sagemaker.utils import _create_or_update_code_dir
2342+
2343+
mock_exists.return_value = False
2344+
2345+
with patch("sagemaker.utils._get_resolved_path") as mock_get_resolved:
2346+
mock_get_resolved.return_value = "/tmp/model/code"
2347+
2348+
_create_or_update_code_dir(
2349+
model_dir="/tmp/model",
2350+
inference_script="inference.py",
2351+
source_directory=None,
2352+
dependencies=[],
2353+
sagemaker_session=None,
2354+
tmp="/tmp"
2355+
)
2356+
2357+
mock_mkdir.assert_called()
2358+
mock_copy.assert_called_once()
2359+
2360+
@patch("sagemaker.utils._validate_source_directory")
2361+
@patch("sagemaker.utils.os.path.exists")
2362+
@patch("sagemaker.utils.shutil.rmtree")
2363+
@patch("sagemaker.utils.shutil.copytree")
2364+
def test_create_or_update_code_dir_with_source_directory(
2365+
self, mock_copytree, mock_rmtree, mock_exists, mock_validate_src
2366+
):
2367+
"""Test creating code dir with source directory"""
2368+
from sagemaker.utils import _create_or_update_code_dir
2369+
2370+
mock_exists.return_value = True
2371+
2372+
with patch("sagemaker.utils._get_resolved_path") as mock_get_resolved:
2373+
mock_get_resolved.return_value = "/tmp/model/code"
2374+
2375+
_create_or_update_code_dir(
2376+
model_dir="/tmp/model",
2377+
inference_script=None,
2378+
source_directory="/tmp/my_code",
2379+
dependencies=[],
2380+
sagemaker_session=None,
2381+
tmp="/tmp"
2382+
)
2383+
2384+
mock_validate_src.assert_called_once_with("/tmp/my_code")
2385+
mock_rmtree.assert_called_once()
2386+
mock_copytree.assert_called_once()
2387+
2388+
def test_create_or_update_code_dir_with_sensitive_code_dir(self):
2389+
"""Test that code_dir resolving to sensitive path is rejected"""
2390+
from sagemaker.utils import _create_or_update_code_dir
2391+
2392+
with patch("sagemaker.utils._get_resolved_path") as mock_get_resolved:
2393+
# Simulate code_dir resolving to a sensitive path
2394+
mock_get_resolved.return_value = os.path.abspath(os.path.expanduser("~/.aws"))
2395+
2396+
with pytest.raises(ValueError, match="Invalid code_dir path"):
2397+
_create_or_update_code_dir(
2398+
model_dir="/tmp/model",
2399+
inference_script="inference.py",
2400+
source_directory=None,
2401+
dependencies=[],
2402+
sagemaker_session=None,
2403+
tmp="/tmp"
2404+
)

0 commit comments

Comments
 (0)