Skip to content

Commit ca2da09

Browse files
committed
Removing home path and adding additional validaiton
1 parent 849b169 commit ca2da09

File tree

2 files changed

+299
-1
lines changed

2 files changed

+299
-1
lines changed

sagemaker-core/src/sagemaker/core/common_utils.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,6 @@
8585
abspath(os.path.expanduser("~/.credentials")),
8686
"/etc",
8787
"/root",
88-
"/home",
8988
"/var/lib",
9089
"/opt/ml/metadata",
9190
]
@@ -680,6 +679,15 @@ def _create_or_update_code_dir(
680679
):
681680
"""Placeholder docstring"""
682681
code_dir = os.path.join(model_dir, "code")
682+
resolved_code_dir = _get_resolved_path(code_dir)
683+
684+
# Validate that code_dir does not resolve to a sensitive system path
685+
for sensitive_path in _SENSITIVE_SYSTEM_PATHS:
686+
if resolved_code_dir.startswith(sensitive_path):
687+
raise ValueError(
688+
f"Invalid code_dir path: {code_dir} resolves to sensitive system path {resolved_code_dir}"
689+
)
690+
683691
if source_directory and source_directory.lower().startswith("s3://"):
684692
local_code_path = os.path.join(tmp, "local_code.tar.gz")
685693
download_file_from_url(source_directory, local_code_path, sagemaker_session)

sagemaker-core/tests/unit/test_common_utils.py

Lines changed: 290 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2209,3 +2209,293 @@ def test_nested_set_dict_multiple_keys(self):
22092209
d = {}
22102210
nested_set_dict(d, ["a", "b", "c"], "value")
22112211
assert d["a"]["b"]["c"] == "value"
2212+
2213+
2214+
2215+
class TestValidateSourceDirectory:
2216+
"""Test _validate_source_directory function."""
2217+
2218+
def test_validate_source_directory_none(self):
2219+
"""Test with None source directory."""
2220+
from sagemaker.core.common_utils import _validate_source_directory
2221+
2222+
# Should not raise
2223+
_validate_source_directory(None)
2224+
2225+
def test_validate_source_directory_s3_path(self):
2226+
"""Test with S3 path."""
2227+
from sagemaker.core.common_utils import _validate_source_directory
2228+
2229+
# Should not raise for S3 paths
2230+
_validate_source_directory("s3://my-bucket/my-code")
2231+
2232+
def test_validate_source_directory_valid_local_path(self):
2233+
"""Test with valid local path."""
2234+
from sagemaker.core.common_utils import _validate_source_directory
2235+
2236+
with tempfile.TemporaryDirectory() as tmpdir:
2237+
# Should not raise for valid local paths
2238+
_validate_source_directory(tmpdir)
2239+
2240+
def test_validate_source_directory_sensitive_path_aws(self):
2241+
"""Test rejection of ~/.aws path."""
2242+
from sagemaker.core.common_utils import _validate_source_directory
2243+
2244+
aws_dir = os.path.expanduser("~/.aws")
2245+
if os.path.exists(aws_dir):
2246+
with pytest.raises(ValueError, match="cannot access sensitive system paths"):
2247+
_validate_source_directory(aws_dir)
2248+
2249+
def test_validate_source_directory_sensitive_path_ssh(self):
2250+
"""Test rejection of ~/.ssh path."""
2251+
from sagemaker.core.common_utils import _validate_source_directory
2252+
2253+
ssh_dir = os.path.expanduser("~/.ssh")
2254+
if os.path.exists(ssh_dir):
2255+
with pytest.raises(ValueError, match="cannot access sensitive system paths"):
2256+
_validate_source_directory(ssh_dir)
2257+
2258+
def test_validate_source_directory_sensitive_path_root(self):
2259+
"""Test rejection of /root path."""
2260+
from sagemaker.core.common_utils import _validate_source_directory
2261+
2262+
# Test with /root which is a sensitive path
2263+
if os.path.exists("/root") and os.access("/root", os.R_OK):
2264+
with pytest.raises(ValueError, match="cannot access sensitive system paths"):
2265+
_validate_source_directory("/root")
2266+
2267+
def test_validate_source_directory_symlink_resolution(self):
2268+
"""Test that symlinks are resolved correctly."""
2269+
from sagemaker.core.common_utils import _validate_source_directory
2270+
2271+
with tempfile.TemporaryDirectory() as tmpdir:
2272+
# Create a real directory
2273+
real_dir = os.path.join(tmpdir, "real_code")
2274+
os.makedirs(real_dir)
2275+
2276+
# Create a symlink to it
2277+
symlink_path = os.path.join(tmpdir, "link_to_code")
2278+
os.symlink(real_dir, symlink_path)
2279+
2280+
# Should not raise - symlink should be resolved and validated
2281+
_validate_source_directory(symlink_path)
2282+
2283+
2284+
class TestValidateDependencyPath:
2285+
"""Test _validate_dependency_path function."""
2286+
2287+
def test_validate_dependency_path_none(self):
2288+
"""Test with None dependency."""
2289+
from sagemaker.core.common_utils import _validate_dependency_path
2290+
2291+
# Should not raise
2292+
_validate_dependency_path(None)
2293+
2294+
def test_validate_dependency_path_valid_local_path(self):
2295+
"""Test with valid local path."""
2296+
from sagemaker.core.common_utils import _validate_dependency_path
2297+
2298+
with tempfile.TemporaryDirectory() as tmpdir:
2299+
# Should not raise for valid local paths
2300+
_validate_dependency_path(tmpdir)
2301+
2302+
def test_validate_dependency_path_sensitive_path_aws(self):
2303+
"""Test rejection of ~/.aws path."""
2304+
from sagemaker.core.common_utils import _validate_dependency_path
2305+
2306+
aws_dir = os.path.expanduser("~/.aws")
2307+
if os.path.exists(aws_dir):
2308+
with pytest.raises(ValueError, match="cannot access sensitive system paths"):
2309+
_validate_dependency_path(aws_dir)
2310+
2311+
def test_validate_dependency_path_sensitive_path_credentials(self):
2312+
"""Test rejection of ~/.credentials path."""
2313+
from sagemaker.core.common_utils import _validate_dependency_path
2314+
2315+
creds_dir = os.path.expanduser("~/.credentials")
2316+
if os.path.exists(creds_dir):
2317+
with pytest.raises(ValueError, match="cannot access sensitive system paths"):
2318+
_validate_dependency_path(creds_dir)
2319+
2320+
def test_validate_dependency_path_symlink_resolution(self):
2321+
"""Test that symlinks are resolved correctly."""
2322+
from sagemaker.core.common_utils import _validate_dependency_path
2323+
2324+
with tempfile.TemporaryDirectory() as tmpdir:
2325+
# Create a real directory
2326+
real_dir = os.path.join(tmpdir, "real_lib")
2327+
os.makedirs(real_dir)
2328+
2329+
# Create a symlink to it
2330+
symlink_path = os.path.join(tmpdir, "link_to_lib")
2331+
os.symlink(real_dir, symlink_path)
2332+
2333+
# Should not raise - symlink should be resolved and validated
2334+
_validate_dependency_path(symlink_path)
2335+
2336+
2337+
class TestCreateOrUpdateCodeDir:
2338+
"""Test _create_or_update_code_dir function."""
2339+
2340+
def test_create_or_update_code_dir_basic(self):
2341+
"""Test basic code directory creation."""
2342+
from sagemaker.core.common_utils import _create_or_update_code_dir
2343+
2344+
with tempfile.TemporaryDirectory() as tmpdir:
2345+
model_dir = os.path.join(tmpdir, "model")
2346+
os.makedirs(model_dir)
2347+
2348+
inference_script = os.path.join(tmpdir, "inference.py")
2349+
with open(inference_script, "w") as f:
2350+
f.write("# inference code")
2351+
2352+
# Should create code directory and copy inference script
2353+
_create_or_update_code_dir(
2354+
model_dir,
2355+
inference_script,
2356+
None,
2357+
[],
2358+
None,
2359+
tmpdir,
2360+
)
2361+
2362+
code_dir = os.path.join(model_dir, "code")
2363+
assert os.path.exists(code_dir)
2364+
assert os.path.exists(os.path.join(code_dir, "inference.py"))
2365+
2366+
def test_create_or_update_code_dir_with_source_directory(self):
2367+
"""Test code directory creation with source directory."""
2368+
from sagemaker.core.common_utils import _create_or_update_code_dir
2369+
2370+
with tempfile.TemporaryDirectory() as tmpdir:
2371+
model_dir = os.path.join(tmpdir, "model")
2372+
os.makedirs(model_dir)
2373+
2374+
source_dir = os.path.join(tmpdir, "source")
2375+
os.makedirs(source_dir)
2376+
with open(os.path.join(source_dir, "app.py"), "w") as f:
2377+
f.write("# app code")
2378+
2379+
# Should copy source directory to code directory
2380+
_create_or_update_code_dir(
2381+
model_dir,
2382+
"inference.py",
2383+
source_dir,
2384+
[],
2385+
None,
2386+
tmpdir,
2387+
)
2388+
2389+
code_dir = os.path.join(model_dir, "code")
2390+
assert os.path.exists(code_dir)
2391+
assert os.path.exists(os.path.join(code_dir, "app.py"))
2392+
2393+
def test_create_or_update_code_dir_with_dependencies(self):
2394+
"""Test code directory creation with dependencies."""
2395+
from sagemaker.core.common_utils import _create_or_update_code_dir
2396+
2397+
with tempfile.TemporaryDirectory() as tmpdir:
2398+
model_dir = os.path.join(tmpdir, "model")
2399+
os.makedirs(model_dir)
2400+
2401+
inference_script = os.path.join(tmpdir, "inference.py")
2402+
with open(inference_script, "w") as f:
2403+
f.write("# inference code")
2404+
2405+
dep_dir = os.path.join(tmpdir, "my_lib")
2406+
os.makedirs(dep_dir)
2407+
with open(os.path.join(dep_dir, "helper.py"), "w") as f:
2408+
f.write("# helper code")
2409+
2410+
# Should create code directory with dependencies
2411+
_create_or_update_code_dir(
2412+
model_dir,
2413+
inference_script,
2414+
None,
2415+
[dep_dir],
2416+
None,
2417+
tmpdir,
2418+
)
2419+
2420+
code_dir = os.path.join(model_dir, "code")
2421+
lib_dir = os.path.join(code_dir, "lib")
2422+
assert os.path.exists(lib_dir)
2423+
assert os.path.exists(os.path.join(lib_dir, "my_lib", "helper.py"))
2424+
2425+
def test_create_or_update_code_dir_rejects_sensitive_paths(self):
2426+
"""Test that code_dir validation rejects sensitive system paths."""
2427+
from sagemaker.core.common_utils import _create_or_update_code_dir
2428+
2429+
with tempfile.TemporaryDirectory() as tmpdir:
2430+
# Create a model_dir that would resolve to a sensitive path
2431+
# This is tricky to test without mocking, so we'll mock _get_resolved_path
2432+
with patch("sagemaker.core.common_utils._get_resolved_path") as mock_resolve:
2433+
mock_resolve.return_value = "/etc"
2434+
2435+
inference_script = os.path.join(tmpdir, "inference.py")
2436+
with open(inference_script, "w") as f:
2437+
f.write("# inference code")
2438+
2439+
model_dir = os.path.join(tmpdir, "model")
2440+
os.makedirs(model_dir)
2441+
2442+
# Should raise ValueError for sensitive path
2443+
with pytest.raises(ValueError, match="Invalid code_dir path"):
2444+
_create_or_update_code_dir(
2445+
model_dir,
2446+
"inference.py",
2447+
None,
2448+
[],
2449+
None,
2450+
tmpdir,
2451+
)
2452+
2453+
def test_create_or_update_code_dir_validates_source_directory(self):
2454+
"""Test that source_directory is validated."""
2455+
from sagemaker.core.common_utils import _create_or_update_code_dir
2456+
2457+
with tempfile.TemporaryDirectory() as tmpdir:
2458+
model_dir = os.path.join(tmpdir, "model")
2459+
os.makedirs(model_dir)
2460+
2461+
inference_script = os.path.join(tmpdir, "inference.py")
2462+
with open(inference_script, "w") as f:
2463+
f.write("# inference code")
2464+
2465+
# Try to use a sensitive path as source_directory
2466+
aws_dir = os.path.expanduser("~/.aws")
2467+
if os.path.exists(aws_dir):
2468+
with pytest.raises(ValueError, match="cannot access sensitive system paths"):
2469+
_create_or_update_code_dir(
2470+
model_dir,
2471+
"inference.py",
2472+
aws_dir,
2473+
[],
2474+
None,
2475+
tmpdir,
2476+
)
2477+
2478+
def test_create_or_update_code_dir_validates_dependencies(self):
2479+
"""Test that dependencies are validated."""
2480+
from sagemaker.core.common_utils import _create_or_update_code_dir
2481+
2482+
with tempfile.TemporaryDirectory() as tmpdir:
2483+
model_dir = os.path.join(tmpdir, "model")
2484+
os.makedirs(model_dir)
2485+
2486+
inference_script = os.path.join(tmpdir, "inference.py")
2487+
with open(inference_script, "w") as f:
2488+
f.write("# inference code")
2489+
2490+
# Try to use a sensitive path as dependency
2491+
aws_dir = os.path.expanduser("~/.aws")
2492+
if os.path.exists(aws_dir):
2493+
with pytest.raises(ValueError, match="cannot access sensitive system paths"):
2494+
_create_or_update_code_dir(
2495+
model_dir,
2496+
inference_script,
2497+
None,
2498+
[aws_dir],
2499+
None,
2500+
tmpdir,
2501+
)

0 commit comments

Comments
 (0)