Skip to content

Commit b9c9da1

Browse files
committed
Fix create_from_dict to not have hardcoded debug flag
1 parent 1078d86 commit b9c9da1

File tree

4 files changed

+20
-6
lines changed

4 files changed

+20
-6
lines changed

src/sagemaker/hyperpod/inference/hp_endpoint.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,10 @@ def create_from_dict(
7272
input: Dict,
7373
name: str = None,
7474
namespace: str = None,
75+
debug=False
7576
) -> None:
7677
logger = self.get_logger()
77-
logger = setup_logging(logger)
78+
logger = setup_logging(logger, debug)
7879

7980
spec = _HPEndpoint.model_validate(input, by_name=True)
8081

@@ -94,7 +95,7 @@ def create_from_dict(
9495
kind=INFERENCE_ENDPOINT_CONFIG_KIND,
9596
namespace=namespace,
9697
spec=spec,
97-
debug=False,
98+
debug=debug,
9899
)
99100

100101
self.metadata = Metadata(

src/sagemaker/hyperpod/inference/hp_jumpstart_endpoint.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,10 @@ def create_from_dict(
7777
input: Dict,
7878
name: str = None,
7979
namespace: str = None,
80+
debug = False
8081
) -> None:
8182
logger = self.get_logger()
82-
logger = setup_logging(logger)
83+
logger = setup_logging(logger, debug)
8384

8485
spec = _HPJumpStartEndpoint.model_validate(input, by_name=True)
8586

@@ -103,7 +104,7 @@ def create_from_dict(
103104
kind=JUMPSTART_MODEL_KIND,
104105
namespace=namespace,
105106
spec=spec,
106-
debug=False,
107+
debug=debug,
107108
)
108109

109110
self.metadata = Metadata(

test/unit_tests/inference/test_hp_endpoint.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,13 @@ def test_create_from_dict(self, mock_create_api, mock_validate_instance_type):
116116

117117
self.endpoint.create_from_dict(input_dict, namespace="test-ns")
118118

119-
mock_create_api.assert_called_once()
119+
mock_create_api.assert_called_once_with(
120+
name=unittest.mock.ANY,
121+
kind=INFERENCE_ENDPOINT_CONFIG_KIND,
122+
namespace="test-ns",
123+
spec=unittest.mock.ANY,
124+
debug=False,
125+
)
120126

121127
@patch.object(HPEndpoint, "call_get_api")
122128
def test_refresh(self, mock_get_api):

test/unit_tests/inference/test_hp_jumpstart_endpoint.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,13 @@ def test_create_from_dict(self, mock_create_api, mock_validate_instance_type):
6161
input_dict, name="test-name", namespace="test-ns"
6262
)
6363

64-
mock_create_api.assert_called_once()
64+
mock_create_api.assert_called_once_with(
65+
name="test-name",
66+
kind=JUMPSTART_MODEL_KIND,
67+
namespace="test-ns",
68+
spec=unittest.mock.ANY,
69+
debug=False,
70+
)
6571

6672
@patch.object(HPJumpStartEndpoint, "call_get_api")
6773
def test_refresh(self, mock_get_api):

0 commit comments

Comments
 (0)