Skip to content

Commit a4352a8

Browse files
authored
Merge branch 'master' into fw-and-version-bug
2 parents 30c1a8b + 840f3a1 commit a4352a8

File tree

15 files changed

+217
-51
lines changed

15 files changed

+217
-51
lines changed
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
---
2+
name: Bug report
3+
about: File a report to help us reproduce and fix the problem
4+
title: ''
5+
labels: 'bug'
6+
assignees: ''
7+
8+
---
9+
10+
**PySDK Version**
11+
- [ ] PySDK V2 (2.x)
12+
- [ ] PySDK V3 (3.x)
13+
14+
**Describe the bug**
15+
A clear and concise description of what the bug is.
16+
17+
**To reproduce**
18+
A clear, step-by-step set of instructions to reproduce the bug.
19+
The provided code need to be **complete** and **runnable**, if additional data is needed, please include them in the issue.
20+
21+
**Expected behavior**
22+
A clear and concise description of what you expected to happen.
23+
24+
**Screenshots or logs**
25+
If applicable, add screenshots or logs to help explain your problem.
26+
27+
**System information**
28+
A description of your system. Please provide:
29+
- **SageMaker Python SDK version**:
30+
- **Framework name (eg. PyTorch) or algorithm (eg. KMeans)**:
31+
- **Framework version**:
32+
- **Python version**:
33+
- **CPU or GPU**:
34+
- **Custom Docker image (Y/N)**:
35+
36+
**Additional context**
37+
Add any other context about the problem here.

.github/ISSUE_TEMPLATE/config.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
blank_issues_enabled: false
2+
contact_links:
3+
- name: Ask a question
4+
url: https://github.com/aws/sagemaker-python-sdk/discussions
5+
about: Use GitHub Discussions to ask and answer questions
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
---
2+
name: Documentation request
3+
about: Request improved documentation
4+
title: ''
5+
labels: ''
6+
assignees: ''
7+
8+
---
9+
10+
**What did you find confusing? Please describe.**
11+
A clear and concise description of what you found confusing. Ex. I tried to [...] but I didn't understand how to [...]
12+
13+
**Describe how documentation can be improved**
14+
A clear and concise description of where documentation was lacking and how it can be improved.
15+
16+
**Additional context**
17+
Add any other context or screenshots about the documentation request here.
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
---
2+
name: Feature request
3+
about: Suggest new functionality for this library
4+
title: ''
5+
labels: 'feature request'
6+
assignees: ''
7+
8+
---
9+
10+
**Describe the feature you'd like**
11+
A clear and concise description of the functionality you want.
12+
13+
**How would this feature be used? Please describe.**
14+
A clear and concise description of the use case for this feature. Please provide an example, if possible.
15+
16+
**Describe alternatives you've considered**
17+
A clear and concise description of any alternative solutions or features you've considered.
18+
19+
**Additional context**
20+
Add any other context or screenshots about the feature request here.

sagemaker-core/src/sagemaker/core/local/utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,11 @@ def get_child_process_ids(pid):
137137
Returns:
138138
(List[int]): Child process ids
139139
"""
140-
cmd = f"pgrep -P {pid}".split()
140+
if not str(pid).isdigit():
141+
raise ValueError("Invalid PID")
142+
143+
cmd = ["pgrep", "-P", str(pid)]
144+
141145
process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
142146
output, err = process.communicate()
143147
if err:

sagemaker-core/tests/unit/local/test_local_utils.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -103,21 +103,24 @@ def test_recursive_copy(copy_tree, m_os_path):
103103
@patch("sagemaker.core.local.utils.os")
104104
@patch("sagemaker.core.local.utils.get_child_process_ids")
105105
def test_kill_child_processes(m_get_child_process_ids, m_os):
106-
m_get_child_process_ids.return_value = ["child_pids"]
107-
kill_child_processes("pid")
108-
m_os.kill.assert_called_with("child_pids", 15)
106+
m_get_child_process_ids.return_value = ["345"]
107+
kill_child_processes("123")
108+
m_os.kill.assert_called_with("345", 15)
109109

110110

111111
@patch("sagemaker.core.local.utils.subprocess")
112112
def test_get_child_process_ids(m_subprocess):
113-
cmd = "pgrep -P pid".split()
113+
cmd = "pgrep -P 123".split()
114114
process_mock = Mock()
115115
attrs = {"communicate.return_value": (b"\n", False), "returncode": 0}
116116
process_mock.configure_mock(**attrs)
117117
m_subprocess.Popen.return_value = process_mock
118-
get_child_process_ids("pid")
118+
get_child_process_ids("123")
119119
m_subprocess.Popen.assert_called_with(cmd, stdout=m_subprocess.PIPE, stderr=m_subprocess.PIPE)
120120

121+
def test_get_child_process_ids_exception():
122+
with pytest.raises(ValueError, match="Invalid PID"):
123+
get_child_process_ids("abc")
121124

122125
@patch("sagemaker.core.local.utils.subprocess")
123126
def test_get_docker_host(m_subprocess):

sagemaker-train/src/sagemaker/train/common_utils/model_resolution.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
from dataclasses import dataclass
1414
from enum import Enum
1515
import re
16+
from sagemaker.train.base_trainer import BaseTrainer
17+
from sagemaker.core.utils.utils import Unassigned
1618

1719

1820
class _ModelType(Enum):
@@ -65,14 +67,14 @@ def __init__(self, sagemaker_session=None):
6567

6668
def resolve_model_info(
6769
self,
68-
base_model: Union[str, 'ModelPackage'],
70+
base_model: Union[str, BaseTrainer, 'ModelPackage'],
6971
hub_name: Optional[str] = None
7072
) -> _ModelInfo:
7173
"""
7274
Resolve model information from various input types.
7375
7476
Args:
75-
base_model: Either a JumpStart model ID (str) or ModelPackage object/ARN
77+
base_model: Either a JumpStart model ID (str) or ModelPackage object/ARN or BaseTrainer object with a completed job
7678
hub_name: Optional hub name for JumpStart models (defaults to SageMakerPublicHub)
7779
7880
Returns:
@@ -88,6 +90,17 @@ def resolve_model_info(
8890
return self._resolve_model_package_arn(base_model)
8991
else:
9092
return self._resolve_jumpstart_model(base_model, hub_name or self.DEFAULT_HUB_NAME)
93+
# Handle BaseTrainer type
94+
elif isinstance(base_model, BaseTrainer):
95+
if hasattr(base_model, '_latest_training_job') and hasattr(base_model._latest_training_job,
96+
'output_model_package_arn'):
97+
arn = base_model._latest_training_job.output_model_package_arn
98+
if not isinstance(arn, Unassigned):
99+
return self._resolve_model_package_arn(arn)
100+
else:
101+
raise ValueError("BaseTrainer must have completed training job to be used for evaluation")
102+
else:
103+
raise ValueError("BaseTrainer must have completed training job to be used for evaluation")
91104
else:
92105
# Not a string, so assume it's a ModelPackage object
93106
# Check if it has the expected attributes of a ModelPackage

sagemaker-train/src/sagemaker/train/evaluate/base_evaluator.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,13 @@
1313

1414
from pydantic import BaseModel, validator
1515

16-
from sagemaker.core.resources import ModelPackageGroup
16+
from sagemaker.core.resources import ModelPackageGroup, ModelPackage
1717
from sagemaker.core.shapes import VpcConfig
1818

1919
if TYPE_CHECKING:
2020
from sagemaker.core.helper.session_helper import Session
2121

22+
from sagemaker.train.base_trainer import BaseTrainer
2223
# Module-level logger
2324
_logger = logging.getLogger(__name__)
2425

@@ -53,6 +54,7 @@ class BaseEvaluator(BaseModel):
5354
- JumpStart model ID (str): e.g., 'llama3-2-1b-instruct'
5455
- ModelPackage object: A fine-tuned model package
5556
- ModelPackage ARN (str): e.g., 'arn:aws:sagemaker:region:account:model-package/name/version'
57+
- BaseTrainer object: A completed training job (i.e., it must have _latest_training_job with output_model_package_arn populated)
5658
base_eval_name (Optional[str]): Optional base name for evaluation jobs. This name is used
5759
as the PipelineExecutionDisplayName when creating the SageMaker pipeline execution.
5860
The actual display name will be "{base_eval_name}-{timestamp}". This parameter can
@@ -86,7 +88,7 @@ class BaseEvaluator(BaseModel):
8688

8789
region: Optional[str] = None
8890
sagemaker_session: Optional[Any] = None
89-
model: Union[str, Any]
91+
model: Union[str, BaseTrainer, ModelPackage]
9092
base_eval_name: Optional[str] = None
9193
s3_output_path: str
9294
mlflow_resource_arn: Optional[str] = None
@@ -278,7 +280,7 @@ def _validate_mlflow_arn_format(cls, v: Optional[str]) -> Optional[str]:
278280
return v
279281

280282
@validator('model')
281-
def _resolve_model_info(cls, v: Union[str, Any], values: dict) -> Union[str, Any]:
283+
def _resolve_model_info(cls, v: Union[str, BaseTrainer, ModelPackage], values: dict) -> Union[str, Any]:
282284
"""Resolve model information from various input types.
283285
284286
This validator uses the common model resolution utility to extract:
@@ -289,7 +291,7 @@ def _resolve_model_info(cls, v: Union[str, Any], values: dict) -> Union[str, Any
289291
The resolved information is stored in private attributes for use by subclasses.
290292
291293
Args:
292-
v (Union[str, Any]): Model identifier (JumpStart ID, ModelPackage, or ARN).
294+
v (Union[str, BaseTrainer, ModelPackage]): Model identifier (JumpStart ID, ModelPackage, ARN, or BaseTrainer).
293295
values (dict): Dictionary of already-validated fields.
294296
295297
Returns:

sagemaker-train/src/sagemaker/train/evaluate/benchmark_evaluator.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -300,18 +300,10 @@ class BenchMarkEvaluator(BaseEvaluator):
300300
"""
301301

302302
benchmark: _Benchmark
303-
dataset: Union[str, Any] # Required field, must come before optional fields
304303
subtasks: Optional[Union[str, List[str]]] = None
305304
evaluate_base_model: bool = True
306305
_hyperparameters: Optional[Any] = None
307-
308-
@validator('dataset', pre=True)
309-
def _resolve_dataset(cls, v):
310-
"""Resolve dataset to string (S3 URI or ARN) and validate format.
311-
312-
Uses BaseEvaluator's common validation logic to avoid code duplication.
313-
"""
314-
return BaseEvaluator._validate_and_resolve_dataset(v)
306+
315307

316308
@validator('benchmark')
317309
def _validate_benchmark_model_compatibility(cls, v, values):

sagemaker-train/src/sagemaker/train/rlaif_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,7 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati
286286
except TimeoutExceededError as e:
287287
logger.error("Error: %s", e)
288288

289-
self.latest_training_job = training_job
289+
self._latest_training_job = training_job
290290
return training_job
291291

292292
def _process_hyperparameters(self):

0 commit comments

Comments
 (0)