Skip to content

Commit c2c72c6

Browse files
committed
Bug fixes for unit and integ tests
1 parent 02f51c3 commit c2c72c6

File tree

9 files changed

+45
-11
lines changed

9 files changed

+45
-11
lines changed

sagemaker-core/pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ dependencies = [
3434
"omegaconf>=2.1.0",
3535
"torch>=1.9.0",
3636
"scipy>=1.5.0",
37+
# Remote function dependencies
38+
"cloudpickle>=2.0.0",
39+
"paramiko>=2.11.0",
3740
]
3841
requires-python = ">=3.9"
3942
classifiers = [

sagemaker-core/src/sagemaker/core/training/configs.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -257,15 +257,16 @@ class InputData(BaseConfig):
257257
Parameters:
258258
channel_name (StrPipeVar):
259259
The name of the input data source channel.
260-
data_source (Union[str, S3DataSource, FileSystemDataSource, DatasetSource]):
260+
data_source (Union[StrPipeVar, S3DataSource, FileSystemDataSource, DatasetSource]):
261261
The data source for the channel. Can be an S3 URI string, local file path string,
262-
S3DataSource object, or FileSystemDataSource object.
262+
S3DataSource object, FileSystemDataSource object, DatasetSource object, or a
263+
pipeline variable (Properties) from a previous step.
263264
content_type (StrPipeVar):
264265
The MIME type of the data.
265266
"""
266267

267268
channel_name: StrPipeVar = None
268-
data_source: Union[str, FileSystemDataSource, S3DataSource, DatasetSource] = None
269+
data_source: Union[StrPipeVar, FileSystemDataSource, S3DataSource, DatasetSource] = None
269270
content_type: StrPipeVar = None
270271

271272

sagemaker-core/tests/integ/jumpstart/test_search_integ.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from sagemaker.core.resources import HubContent
2020

2121

22+
@pytest.mark.serial
2223
@pytest.mark.integ
2324
def test_search_public_hub_models_default_args():
2425
# Only query, uses default hub name and session
@@ -30,6 +31,7 @@ def test_search_public_hub_models_default_args():
3031
assert len(results) > 0, "Expected at least one matching model from the public hub"
3132

3233

34+
@pytest.mark.serial
3335
@pytest.mark.integ
3436
def test_search_public_hub_models_custom_session():
3537
# Provide a custom SageMaker session
@@ -41,6 +43,7 @@ def test_search_public_hub_models_custom_session():
4143
assert all(isinstance(m, HubContent) for m in results)
4244

4345

46+
@pytest.mark.serial
4447
@pytest.mark.integ
4548
def test_search_public_hub_models_custom_hub_name():
4649
# Using the default public hub but provided explicitly
@@ -51,6 +54,7 @@ def test_search_public_hub_models_custom_hub_name():
5154
assert all(isinstance(m, HubContent) for m in results)
5255

5356

57+
@pytest.mark.serial
5458
@pytest.mark.integ
5559
def test_search_public_hub_models_all_args():
5660
# Provide both hub_name and session explicitly

sagemaker-core/tests/unit/telemetry/test_telemetry_logging.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,18 @@
3030
PYTHON_VERSION,
3131
)
3232
from sagemaker.core.user_agent import SDK_VERSION, process_studio_metadata_file
33-
from sagemaker.serve.utils.exceptions import ModelBuilderException, LocalModelOutOfMemoryException
33+
34+
# Try to import sagemaker-serve exceptions, skip tests if not available
35+
try:
36+
from sagemaker.serve.utils.exceptions import ModelBuilderException, LocalModelOutOfMemoryException
37+
SAGEMAKER_SERVE_AVAILABLE = True
38+
except ImportError:
39+
SAGEMAKER_SERVE_AVAILABLE = False
40+
# Create mock exceptions for type hints
41+
class ModelBuilderException(Exception):
42+
pass
43+
class LocalModelOutOfMemoryException(Exception):
44+
pass
3445

3546
MOCK_SESSION = Mock()
3647
MOCK_EXCEPTION = LocalModelOutOfMemoryException("mock raise ex")
@@ -147,6 +158,10 @@ def test_telemetry_emitter_decorator_success(
147158
1, [1, 2], MOCK_SESSION, None, None, expected_extra_str
148159
)
149160

161+
@pytest.mark.skipif(
162+
not SAGEMAKER_SERVE_AVAILABLE,
163+
reason="Requires sagemaker-serve package"
164+
)
150165
@patch("sagemaker.core.telemetry.telemetry_logging._send_telemetry_request")
151166
@patch("sagemaker.core.telemetry.telemetry_logging.resolve_value_from_config")
152167
def test_telemetry_emitter_decorator_handle_exception_success(

sagemaker-core/tox.ini

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ markers =
6363
release
6464
image_uris_unit_test
6565
timeout: mark a test as a timeout.
66+
serial: marks tests that must run serially (not in parallel)
6667

6768
[testenv]
6869
setenv =

sagemaker-mlops/tests/integ/test_pipeline_train_registry.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,19 @@ def test_pipeline_with_train_and_registry(sagemaker_session, pipeline_session, r
215215
assert execution_status == "Succeeded"
216216
break
217217
elif execution_status in ["Failed", "Stopped"]:
218-
pytest.fail(f"Pipeline execution {execution_status}")
218+
# Get detailed failure information
219+
steps = sagemaker_session.sagemaker_client.list_pipeline_execution_steps(
220+
PipelineExecutionArn=execution_desc["PipelineExecutionArn"]
221+
)["PipelineExecutionSteps"]
222+
223+
failed_steps = []
224+
for step in steps:
225+
if step.get("StepStatus") == "Failed":
226+
failure_reason = step.get("FailureReason", "Unknown reason")
227+
failed_steps.append(f"{step['StepName']}: {failure_reason}")
228+
229+
failure_details = "\n".join(failed_steps) if failed_steps else "No detailed failure information available"
230+
pytest.fail(f"Pipeline execution {execution_status}. Failed steps:\n{failure_details}")
219231

220232
time.sleep(60)
221233
else:

sagemaker-mlops/tox.ini

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,7 @@ allowlist_externals =
8787
commands =
8888
python -c "import os; os.system('install-custom-pkgs --install-boto-wheels')"
8989
pip install 'apache-airflow==2.10.4' --constraint "https://raw.githubusercontent.com/apache/airflow/constraints-2.10.4/constraints-3.9.txt"
90-
pip install 'torch==2.3.1+cpu' -f 'https://download.pytorch.org/whl/torch_stable.html'
91-
pip install 'torchvision==0.18.1+cpu' -f 'https://download.pytorch.org/whl/torch_stable.html'
90+
pip install 'torch==2.8.0' 'torchvision==0.23.0'
9291
pip install 'dill>=0.3.9'
9392

9493
pytest {posargs}

sagemaker-train/src/sagemaker/train/evaluate/benchmark_evaluator.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -323,12 +323,10 @@ class BenchMarkEvaluator(BaseEvaluator):
323323
"""
324324

325325
benchmark: _Benchmark
326+
dataset: Union[str, Any] # Required field, must come before optional fields
326327
subtasks: Optional[Union[str, List[str]]] = None
327-
_hyperparameters: Optional[Any] = None
328-
329-
# Template-required fields
330-
dataset: Union[str, Any]
331328
evaluate_base_model: bool = True
329+
_hyperparameters: Optional[Any] = None
332330

333331
@validator('dataset', pre=True)
334332
def _resolve_dataset(cls, v):

sagemaker-train/tests/integ/ai_registry/test_dataset.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from sagemaker.ai_registry.air_constants import HubContentStatus
2222

2323

24+
@pytest.mark.serial
2425
class TestDataSetIntegration:
2526
"""Integration tests for DataSet operations."""
2627

0 commit comments

Comments
 (0)