Skip to content

Commit 43b5bce

Browse files
authored
Merge branch 'master' into command-injection
2 parents 9c916ce + a74f9ab commit 43b5bce

File tree

199 files changed

+4287
-46343
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

199 files changed

+4287
-46343
lines changed

requirements/extras/test_requirements.txt

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,11 @@ pytest-xdist
44
mock
55
pydantic==2.11.9
66
pydantic_core==2.33.2
7-
pandas
7+
pandas>=2.3.0
8+
numpy>=2.0.0, <3.0
9+
scikit-learn==1.6.1
810
scipy
911
omegaconf
1012
graphene
11-
typing_extensions>=4.9.0
13+
typing_extensions>=4.9.0
14+
tensorflow>=2.16.2,<=2.19.0

sagemaker-core/src/sagemaker/core/common_utils.py

Lines changed: 116 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
"us-isob-east-1": "sc2s.sgov.gov",
6060
"us-isof-south-1": "csp.hci.ic.gov",
6161
"us-isof-east-1": "csp.hci.ic.gov",
62+
"eu-isoe-west-1": "cloud.adc-e.uk",
6263
}
6364

6465
ECR_URI_PATTERN = r"^(\d+)(\.)dkr(\.)ecr(\.)(.+)(\.)(.*)(/)(.*:.*)$"
@@ -74,6 +75,20 @@
7475
WAITING_DOT_NUMBER = 10
7576
MAX_ITEMS = 100
7677
PAGE_SIZE = 10
78+
_MAX_BUFFER_SIZE = 100 * 1024 * 1024 # 100 MB - Maximum buffer size for streaming iterators
79+
80+
_SENSITIVE_SYSTEM_PATHS = [
81+
abspath(os.path.expanduser("~/.aws")),
82+
abspath(os.path.expanduser("~/.ssh")),
83+
abspath(os.path.expanduser("~/.kube")),
84+
abspath(os.path.expanduser("~/.docker")),
85+
abspath(os.path.expanduser("~/.config")),
86+
abspath(os.path.expanduser("~/.credentials")),
87+
"/etc",
88+
"/root",
89+
"/var/lib",
90+
"/opt/ml/metadata",
91+
]
7792

7893
logger = logging.getLogger(__name__)
7994

@@ -607,11 +622,73 @@ def _save_model(repacked_model_uri, tmp_model_path, sagemaker_session, kms_key):
607622
shutil.move(tmp_model_path, repacked_model_uri.replace("file://", ""))
608623

609624

625+
def _validate_source_directory(source_directory):
626+
"""Validate that source_directory is safe to use.
627+
628+
Ensures the source directory path does not access restricted system locations.
629+
630+
Args:
631+
source_directory (str): The source directory path to validate.
632+
633+
Raises:
634+
ValueError: If the path is not allowed.
635+
"""
636+
if not source_directory or source_directory.lower().startswith("s3://"):
637+
# S3 paths and None are safe
638+
return
639+
640+
# Resolve symlinks to get the actual path
641+
abs_source = abspath(realpath(source_directory))
642+
643+
# Check if the source path is under any sensitive directory
644+
for sensitive_path in _SENSITIVE_SYSTEM_PATHS:
645+
if abs_source != "/" and abs_source.startswith(sensitive_path):
646+
raise ValueError(
647+
f"source_directory cannot access sensitive system paths. "
648+
f"Got: {source_directory} (resolved to {abs_source})"
649+
)
650+
651+
652+
def _validate_dependency_path(dependency):
653+
"""Validate that a dependency path is safe to use.
654+
655+
Ensures the dependency path does not access restricted system locations.
656+
657+
Args:
658+
dependency (str): The dependency path to validate.
659+
660+
Raises:
661+
ValueError: If the path is not allowed.
662+
"""
663+
if not dependency:
664+
return
665+
666+
# Resolve symlinks to get the actual path
667+
abs_dependency = abspath(realpath(dependency))
668+
669+
# Check if the dependency path is under any sensitive directory
670+
for sensitive_path in _SENSITIVE_SYSTEM_PATHS:
671+
if abs_dependency != "/" and abs_dependency.startswith(sensitive_path):
672+
raise ValueError(
673+
f"dependency path cannot access sensitive system paths. "
674+
f"Got: {dependency} (resolved to {abs_dependency})"
675+
)
676+
677+
610678
def _create_or_update_code_dir(
611679
model_dir, inference_script, source_directory, dependencies, sagemaker_session, tmp
612680
):
613681
"""Placeholder docstring"""
614682
code_dir = os.path.join(model_dir, "code")
683+
resolved_code_dir = _get_resolved_path(code_dir)
684+
685+
# Validate that code_dir does not resolve to a sensitive system path
686+
for sensitive_path in _SENSITIVE_SYSTEM_PATHS:
687+
if resolved_code_dir != "/" and resolved_code_dir.startswith(sensitive_path):
688+
raise ValueError(
689+
f"Invalid code_dir path: {code_dir} resolves to sensitive system path {resolved_code_dir}"
690+
)
691+
615692
if source_directory and source_directory.lower().startswith("s3://"):
616693
local_code_path = os.path.join(tmp, "local_code.tar.gz")
617694
download_file_from_url(source_directory, local_code_path, sagemaker_session)
@@ -620,6 +697,8 @@ def _create_or_update_code_dir(
620697
custom_extractall_tarfile(t, code_dir)
621698

622699
elif source_directory:
700+
# Validate source_directory for security
701+
_validate_source_directory(source_directory)
623702
if os.path.exists(code_dir):
624703
shutil.rmtree(code_dir)
625704
shutil.copytree(source_directory, code_dir)
@@ -635,6 +714,8 @@ def _create_or_update_code_dir(
635714
raise
636715

637716
for dependency in dependencies:
717+
# Validate dependency path for security
718+
_validate_dependency_path(dependency)
638719
lib_dir = os.path.join(code_dir, "lib")
639720
if os.path.isdir(dependency):
640721
shutil.copytree(dependency, os.path.join(lib_dir, os.path.basename(dependency)))
@@ -1555,7 +1636,7 @@ def get_instance_type_family(instance_type: str) -> str:
15551636
"""
15561637
instance_type_family = ""
15571638
if isinstance(instance_type, str):
1558-
match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type)
1639+
match = re.match(r"^ml[\._]([a-z\d\-]+)\.?\w*$", instance_type)
15591640
if match is not None:
15601641
instance_type_family = match[1]
15611642
return instance_type_family
@@ -1646,6 +1727,38 @@ def _get_safe_members(members):
16461727
yield file_info
16471728

16481729

1730+
def _validate_extracted_paths(extract_path):
1731+
"""Validate that extracted paths remain within the expected directory.
1732+
1733+
Performs post-extraction validation to ensure all extracted files and directories
1734+
are within the intended extraction path.
1735+
1736+
Args:
1737+
extract_path (str): The path where files were extracted.
1738+
1739+
Raises:
1740+
ValueError: If any extracted file is outside the expected extraction path.
1741+
"""
1742+
base = _get_resolved_path(extract_path)
1743+
1744+
for root, dirs, files in os.walk(extract_path):
1745+
# Check directories
1746+
for dir_name in dirs:
1747+
dir_path = os.path.join(root, dir_name)
1748+
resolved = _get_resolved_path(dir_path)
1749+
if not resolved.startswith(base):
1750+
logger.error("Extracted directory escaped extraction path: %s", dir_path)
1751+
raise ValueError(f"Extracted path outside expected directory: {dir_path}")
1752+
1753+
# Check files
1754+
for file_name in files:
1755+
file_path = os.path.join(root, file_name)
1756+
resolved = _get_resolved_path(file_path)
1757+
if not resolved.startswith(base):
1758+
logger.error("Extracted file escaped extraction path: %s", file_path)
1759+
raise ValueError(f"Extracted path outside expected directory: {file_path}")
1760+
1761+
16491762
def custom_extractall_tarfile(tar, extract_path):
16501763
"""Extract a tarfile, optionally using data_filter if available.
16511764
@@ -1666,6 +1779,8 @@ def custom_extractall_tarfile(tar, extract_path):
16661779
tar.extractall(path=extract_path, filter="data")
16671780
else:
16681781
tar.extractall(path=extract_path, members=_get_safe_members(tar))
1782+
# Re-validate extracted paths to catch symlink race conditions
1783+
_validate_extracted_paths(extract_path)
16691784

16701785

16711786
def can_model_package_source_uri_autopopulate(source_uri: str):

sagemaker-core/src/sagemaker/core/fw_utils.py

Lines changed: 56 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,13 @@
2525

2626
from packaging import version
2727

28-
import sagemaker.core.common_utils as sagemaker_utils
29-
from sagemaker.core.deprecations import deprecation_warn_base, renamed_kwargs
28+
import sagemaker.core.common_utils as utils
29+
from sagemaker.core.deprecations import deprecation_warn_base, renamed_kwargs, renamed_warning
3030
from sagemaker.core.instance_group import InstanceGroup
31-
from sagemaker.core.s3 import s3_path_join
31+
from sagemaker.core.s3.utils import s3_path_join
3232
from sagemaker.core.session_settings import SessionSettings
3333
from sagemaker.core.workflow import is_pipeline_variable
34-
from sagemaker.core.helper.pipeline_variable import PipelineVariable
34+
from sagemaker.core.workflow.entities import PipelineVariable
3535

3636
logger = logging.getLogger(__name__)
3737

@@ -155,6 +155,9 @@
155155
"2.3.1",
156156
"2.4.1",
157157
"2.5.1",
158+
"2.6.0",
159+
"2.7.1",
160+
"2.8.0",
158161
]
159162

160163
TRAINIUM_SUPPORTED_DISTRIBUTION_STRATEGIES = ["torch_distributed"]
@@ -455,7 +458,7 @@ def tar_and_upload_dir(
455458

456459
try:
457460
source_files = _list_files_to_compress(script, directory) + dependencies
458-
tar_file = sagemaker_utils.create_tar_file(
461+
tar_file = utils.create_tar_file(
459462
source_files, os.path.join(tmp, _TAR_SOURCE_FILENAME)
460463
)
461464

@@ -516,7 +519,7 @@ def framework_name_from_image(image_uri):
516519
- str: The image tag
517520
- str: If the TensorFlow image is script mode
518521
"""
519-
sagemaker_pattern = re.compile(sagemaker_utils.ECR_URI_PATTERN)
522+
sagemaker_pattern = re.compile(utils.ECR_URI_PATTERN)
520523
sagemaker_match = sagemaker_pattern.match(image_uri)
521524
if sagemaker_match is None:
522525
return None, None, None, None
@@ -595,7 +598,7 @@ def model_code_key_prefix(code_location_key_prefix, model_name, image):
595598
"""
596599
name_from_image = f"/model_code/{int(time.time())}"
597600
if not is_pipeline_variable(image):
598-
name_from_image = sagemaker_utils.name_from_image(image)
601+
name_from_image = utils.name_from_image(image)
599602
return s3_path_join(code_location_key_prefix, model_name or name_from_image)
600603

601604

@@ -961,7 +964,7 @@ def validate_distribution_for_instance_type(instance_type, distribution):
961964
"""
962965
err_msg = ""
963966
if isinstance(instance_type, str):
964-
match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type)
967+
match = re.match(r"^ml[\._]([a-z\d\-]+)\.?\w*$", instance_type)
965968
if match and match[1].startswith("trn"):
966969
keys = list(distribution.keys())
967970
if len(keys) == 0:
@@ -1062,7 +1065,7 @@ def validate_torch_distributed_distribution(
10621065
)
10631066

10641067
# Check entry point type
1065-
if not entry_point.endswith(".py"):
1068+
if entry_point is not None and not entry_point.endswith(".py"):
10661069
err_msg += (
10671070
"Unsupported entry point type for the distribution torch_distributed.\n"
10681071
"Only python programs (*.py) are supported."
@@ -1082,7 +1085,7 @@ def _is_gpu_instance(instance_type):
10821085
bool: Whether or not the instance_type supports GPU
10831086
"""
10841087
if isinstance(instance_type, str):
1085-
match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type)
1088+
match = re.match(r"^ml[\._]([a-z\d\-]+)\.?\w*$", instance_type)
10861089
if match:
10871090
if match[1].startswith("p") or match[1].startswith("g"):
10881091
return True
@@ -1101,7 +1104,7 @@ def _is_trainium_instance(instance_type):
11011104
bool: Whether or not the instance_type is a Trainium instance
11021105
"""
11031106
if isinstance(instance_type, str):
1104-
match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type)
1107+
match = re.match(r"^ml[\._]([a-z\d\-]+)\.?\w*$", instance_type)
11051108
if match and match[1].startswith("trn"):
11061109
return True
11071110
return False
@@ -1148,7 +1151,7 @@ def _instance_type_supports_profiler(instance_type):
11481151
bool: Whether or not the region supports Amazon SageMaker Debugger profiling feature.
11491152
"""
11501153
if isinstance(instance_type, str):
1151-
match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type)
1154+
match = re.match(r"^ml[\._]([a-z\d\-]+)\.?\w*$", instance_type)
11521155
if match and match[1].startswith("trn"):
11531156
return True
11541157
return False
@@ -1174,3 +1177,44 @@ def validate_version_or_image_args(framework_version, py_version, image_uri):
11741177
"framework_version or py_version was None, yet image_uri was also None. "
11751178
"Either specify both framework_version and py_version, or specify image_uri."
11761179
)
1180+
1181+
1182+
def create_image_uri(
1183+
region,
1184+
framework,
1185+
instance_type,
1186+
framework_version,
1187+
py_version=None,
1188+
account=None, # pylint: disable=W0613
1189+
accelerator_type=None,
1190+
optimized_families=None, # pylint: disable=W0613
1191+
):
1192+
"""Deprecated method. Please use sagemaker.image_uris.retrieve().
1193+
1194+
Args:
1195+
region (str): AWS region where the image is uploaded.
1196+
framework (str): framework used by the image.
1197+
instance_type (str): SageMaker instance type. Used to determine device
1198+
type (cpu/gpu/family-specific optimized).
1199+
framework_version (str): The version of the framework.
1200+
py_version (str): Optional. Python version Ex: `py38, py39, py310, py311`.
1201+
If not specified, image uri will not include a python component.
1202+
account (str): AWS account that contains the image. (default:
1203+
'520713654638')
1204+
accelerator_type (str): SageMaker Elastic Inference accelerator type.
1205+
optimized_families (str): Deprecated. A no-op argument.
1206+
1207+
Returns:
1208+
the image uri
1209+
"""
1210+
from sagemaker.core import image_uris
1211+
1212+
renamed_warning("The method create_image_uri")
1213+
return image_uris.retrieve(
1214+
framework=framework,
1215+
region=region,
1216+
version=framework_version,
1217+
py_version=py_version,
1218+
instance_type=instance_type,
1219+
accelerator_type=accelerator_type,
1220+
)

0 commit comments

Comments
 (0)