Skip to content

Commit 60982da

Browse files
authored
Small bug fix to print debug messages for inference logger (PySDK) (#246)
* Draft of inference logger bug fix * Draft fix of inference logger for SDK * Revert adding --debug flag * Add debug parameter to failing unit tests * Fix create_from_dict to not have hardcoded debug flag
1 parent 58cff10 commit 60982da

File tree

5 files changed

+26
-5
lines changed

5 files changed

+26
-5
lines changed

src/sagemaker/hyperpod/inference/hp_endpoint.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def create(
5454
kind=INFERENCE_ENDPOINT_CONFIG_KIND,
5555
namespace=namespace,
5656
spec=spec,
57+
debug=debug,
5758
)
5859

5960
self.metadata = Metadata(
@@ -71,9 +72,10 @@ def create_from_dict(
7172
input: Dict,
7273
name: str = None,
7374
namespace: str = None,
75+
debug=False
7476
) -> None:
7577
logger = self.get_logger()
76-
logger = setup_logging(logger)
78+
logger = setup_logging(logger, debug)
7779

7880
spec = _HPEndpoint.model_validate(input, by_name=True)
7981

@@ -93,6 +95,7 @@ def create_from_dict(
9395
kind=INFERENCE_ENDPOINT_CONFIG_KIND,
9496
namespace=namespace,
9597
spec=spec,
98+
debug=debug,
9699
)
97100

98101
self.metadata = Metadata(

src/sagemaker/hyperpod/inference/hp_endpoint_base.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ def call_create_api(
6363
kind: str,
6464
namespace: str,
6565
spec: Union[_HPJumpStartEndpoint, _HPEndpoint],
66+
debug: bool = False,
6667
):
6768
"""Create an inference endpoint using Kubernetes API.
6869
@@ -104,7 +105,7 @@ def call_create_api(
104105
cls.verify_kube_config()
105106

106107
logger = cls.get_logger()
107-
logger = setup_logging(logger)
108+
logger = setup_logging(logger, debug)
108109

109110
custom_api = client.CustomObjectsApi()
110111

src/sagemaker/hyperpod/inference/hp_jumpstart_endpoint.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ def create(
5959
kind=JUMPSTART_MODEL_KIND,
6060
namespace=namespace,
6161
spec=spec,
62+
debug=debug,
6263
)
6364

6465
self.metadata = Metadata(
@@ -76,9 +77,10 @@ def create_from_dict(
7677
input: Dict,
7778
name: str = None,
7879
namespace: str = None,
80+
debug = False
7981
) -> None:
8082
logger = self.get_logger()
81-
logger = setup_logging(logger)
83+
logger = setup_logging(logger, debug)
8284

8385
spec = _HPJumpStartEndpoint.model_validate(input, by_name=True)
8486

@@ -102,6 +104,7 @@ def create_from_dict(
102104
kind=JUMPSTART_MODEL_KIND,
103105
namespace=namespace,
104106
spec=spec,
107+
debug=debug,
105108
)
106109

107110
self.metadata = Metadata(

test/unit_tests/inference/test_hp_endpoint.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ def test_create(self, mock_create_api, mock_validate_instance_type):
104104
kind=INFERENCE_ENDPOINT_CONFIG_KIND,
105105
namespace="test-ns",
106106
spec=unittest.mock.ANY,
107+
debug=False,
107108
)
108109
self.assertEqual(self.endpoint.metadata.name, "test-name")
109110

@@ -115,7 +116,13 @@ def test_create_from_dict(self, mock_create_api, mock_validate_instance_type):
115116

116117
self.endpoint.create_from_dict(input_dict, namespace="test-ns")
117118

118-
mock_create_api.assert_called_once()
119+
mock_create_api.assert_called_once_with(
120+
name=unittest.mock.ANY,
121+
kind=INFERENCE_ENDPOINT_CONFIG_KIND,
122+
namespace="test-ns",
123+
spec=unittest.mock.ANY,
124+
debug=False,
125+
)
119126

120127
@patch.object(HPEndpoint, "call_get_api")
121128
def test_refresh(self, mock_get_api):

test/unit_tests/inference/test_hp_jumpstart_endpoint.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def test_create(self, mock_create_api, mock_validate_instance_type):
4444
kind=JUMPSTART_MODEL_KIND,
4545
namespace="test-ns",
4646
spec=unittest.mock.ANY,
47+
debug=False,
4748
)
4849
self.assertEqual(self.endpoint.metadata.name, "test-name")
4950

@@ -60,7 +61,13 @@ def test_create_from_dict(self, mock_create_api, mock_validate_instance_type):
6061
input_dict, name="test-name", namespace="test-ns"
6162
)
6263

63-
mock_create_api.assert_called_once()
64+
mock_create_api.assert_called_once_with(
65+
name="test-name",
66+
kind=JUMPSTART_MODEL_KIND,
67+
namespace="test-ns",
68+
spec=unittest.mock.ANY,
69+
debug=False,
70+
)
6471

6572
@patch.object(HPJumpStartEndpoint, "call_get_api")
6673
def test_refresh(self, mock_get_api):

0 commit comments

Comments
 (0)