Skip to content

Commit 9560a48

Browse files
Update generate_click_command inject logic to not expose unwanted flags to hyp-jumpstart-endpoint (#213)
* Update generate_click_command inject logic to not expose unwanted flags to hyp-jumpstart-endpoint * Update unit tests for bug fix, change --label_selector to --label-selector
1 parent 0fd2bef commit 9560a48

File tree

4 files changed

+32
-40
lines changed

4 files changed

+32
-40
lines changed

src/sagemaker/hyperpod/cli/inference_utils.py

Lines changed: 21 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -40,45 +40,29 @@ def wrapped_func(*args, **kwargs):
4040
domain = flat.to_domain()
4141
return func(namespace, version, domain)
4242

43-
# 2) inject the special JSON‐env flag before everything else
44-
wrapped_func = click.option(
45-
"--env",
46-
callback=_parse_json_flag,
47-
type=str,
48-
default=None,
49-
help=(
50-
"JSON object of environment variables, e.g. "
51-
'\'{"VAR1":"foo","VAR2":"bar"}\''
52-
),
53-
metavar="JSON",
54-
)(wrapped_func)
55-
56-
wrapped_func = click.option(
57-
"--dimensions",
58-
callback=_parse_json_flag,
59-
type=str,
60-
default=None,
61-
help=("JSON object of dimensions, e.g. " '\'{"VAR1":"foo","VAR2":"bar"}\''),
62-
metavar="JSON",
63-
)(wrapped_func)
64-
65-
wrapped_func = click.option(
66-
"--resources-limits",
67-
callback=_parse_json_flag,
68-
help='JSON object of resource limits, e.g. \'{"cpu":"2","memory":"4Gi"}\'',
69-
metavar="JSON",
70-
)(wrapped_func)
71-
72-
wrapped_func = click.option(
73-
"--resources-requests",
74-
callback=_parse_json_flag,
75-
help='JSON object of resource requests, e.g. \'{"cpu":"1","memory":"2Gi"}\'',
76-
metavar="JSON",
77-
)(wrapped_func)
78-
79-
# 3) auto-inject all schema.json fields
43+
# 2) inject JSON flags only if they exist in the schema
8044
schema = load_schema_for_version(version, schema_pkg)
8145
props = schema.get("properties", {})
46+
47+
json_flags = {
48+
"env": ("JSON object of environment variables, e.g. " '\'{"VAR1":"foo","VAR2":"bar"}\''),
49+
"dimensions": ("JSON object of dimensions, e.g. " '\'{"VAR1":"foo","VAR2":"bar"}\''),
50+
"resources_limits": ('JSON object of resource limits, e.g. \'{"cpu":"2","memory":"4Gi"}\''),
51+
"resources_requests": ('JSON object of resource requests, e.g. \'{"cpu":"1","memory":"2Gi"}\''),
52+
}
53+
54+
for flag_name, help_text in json_flags.items():
55+
if flag_name in props:
56+
wrapped_func = click.option(
57+
f"--{flag_name.replace('_', '-')}",
58+
callback=_parse_json_flag,
59+
type=str,
60+
default=None,
61+
help=help_text,
62+
metavar="JSON",
63+
)(wrapped_func)
64+
65+
# 3) auto-inject all schema.json fields
8266
reqs = set(schema.get("required", []))
8367

8468
for name, spec in reversed(list(props.items())):

src/sagemaker/hyperpod/cli/training_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def wrapped_func(*args, **kwargs):
107107
metavar="JSON",
108108
)(wrapped_func)
109109
wrapped_func = click.option(
110-
"--label_selector",
110+
"--label-selector",
111111
callback=_parse_json_flag,
112112
help='JSON object of resource limits, e.g. \'{"cpu":"2","memory":"4Gi"}\'',
113113
metavar="JSON",

test/unit_tests/cli/test_inference_utils.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,15 @@ def cmd(namespace, version, domain):
5959

6060
@patch('sagemaker.hyperpod.cli.inference_utils.load_schema_for_version')
6161
def test_json_flags(self, mock_load_schema):
62-
mock_load_schema.return_value = {'properties': {}, 'required': []}
62+
mock_load_schema.return_value = {
63+
'properties': {
64+
'env': {'type': 'object'},
65+
'dimensions': {'type': 'object'},
66+
'resources_limits': {'type': 'object'},
67+
'resources_requests': {'type': 'object'}
68+
},
69+
'required': []
70+
}
6371
# Domain receives flags as attributes env, dimensions, resources_limits, resources_requests
6472
class DummyFlat:
6573
def __init__(self, **kwargs): self.__dict__.update(kwargs)

test/unit_tests/cli/test_training_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def cmd(version, debug, config):
8080
# Test valid JSON input
8181
result = self.runner.invoke(cmd, [
8282
'--environment', '{"VAR1":"val1"}',
83-
'--label_selector', '{"key":"value"}'
83+
'--label-selector', '{"key":"value"}'
8484
])
8585
assert result.exit_code == 0
8686
output = json.loads(result.output)

0 commit comments

Comments
 (0)