Skip to content

Commit f6f5451

Browse files
authored
fix(sdk): add type inference for CLI run parameters. Fixes #11607 (#12767)
Signed-off-by: Pavan More <[email protected]>
1 parent 8d938b9 commit f6f5451

File tree

3 files changed

+130
-3
lines changed

3 files changed

+130
-3
lines changed

sdk/python/kfp/cli/run.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,15 @@ def create(ctx: click.Context, experiment_name: str, run_name: str,
126126
err=True)
127127
sys.exit(1)
128128

129-
arg_dict = dict(arg.split('=', maxsplit=1) for arg in args)
129+
arg_dict = {}
130+
for arg in args:
131+
if '=' not in arg:
132+
click.echo(
133+
f"Invalid argument format: '{arg}'. Expected 'key=value'.",
134+
err=True)
135+
sys.exit(1)
136+
k, v = arg.split('=', maxsplit=1)
137+
arg_dict[k] = parsing.parse_parameter_value(v)
130138

131139
experiment = client_obj.create_experiment(experiment_name)
132140
run = client_obj.run_pipeline(

sdk/python/kfp/cli/utils/parsing.py

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,47 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
14+
import json
1515
import re
16-
from typing import Callable
16+
from typing import Any, Callable
17+
18+
19+
def parse_parameter_value(value: str) -> Any:
20+
"""Parse a CLI string value into the appropriate Python type.
21+
22+
Attempts JSON parsing for complex types (lists, dicts, quoted strings),
23+
then tries numeric and boolean conversion, falling back to string.
24+
25+
Args:
26+
value: The string value from CLI argument.
27+
28+
Returns:
29+
The parsed value with inferred type (int, float, bool, list, dict,
30+
or str).
31+
"""
32+
try:
33+
parsed = json.loads(value)
34+
if isinstance(parsed, (list, dict, int, float, bool, str, type(None))):
35+
return parsed
36+
except (json.JSONDecodeError, ValueError):
37+
pass # Not valid JSON, try other conversions
38+
39+
if value.lower() == 'true':
40+
return True
41+
if value.lower() == 'false':
42+
return False
43+
44+
try:
45+
return int(value)
46+
except ValueError:
47+
pass # Not an integer, try float
48+
49+
try:
50+
return float(value)
51+
except ValueError:
52+
pass # Not a number, return as string
53+
54+
return value
1755

1856

1957
def get_param_descr(fn: Callable, param_name: str) -> str:

sdk/python/kfp/cli/utils/parsing_test.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,5 +132,86 @@ def test_multiline(self):
132132
)
133133

134134

135+
class TestParseParameterValue(unittest.TestCase):
136+
"""Tests for parse_parameter_value function."""
137+
138+
def test_integer_positive(self):
139+
self.assertEqual(parsing.parse_parameter_value('123'), 123)
140+
141+
def test_integer_negative(self):
142+
self.assertEqual(parsing.parse_parameter_value('-456'), -456)
143+
144+
def test_integer_zero(self):
145+
self.assertEqual(parsing.parse_parameter_value('0'), 0)
146+
147+
def test_float_positive(self):
148+
self.assertEqual(parsing.parse_parameter_value('12.5'), 12.5)
149+
150+
def test_float_negative(self):
151+
self.assertEqual(parsing.parse_parameter_value('-3.14'), -3.14)
152+
153+
def test_float_scientific(self):
154+
self.assertEqual(parsing.parse_parameter_value('1e10'), 1e10)
155+
156+
def test_boolean_true_lowercase(self):
157+
self.assertEqual(parsing.parse_parameter_value('true'), True)
158+
159+
def test_boolean_false_lowercase(self):
160+
self.assertEqual(parsing.parse_parameter_value('false'), False)
161+
162+
def test_boolean_true_capitalized(self):
163+
self.assertEqual(parsing.parse_parameter_value('True'), True)
164+
165+
def test_boolean_false_capitalized(self):
166+
self.assertEqual(parsing.parse_parameter_value('False'), False)
167+
168+
def test_list_integers(self):
169+
self.assertEqual(parsing.parse_parameter_value('[1, 2, 3]'), [1, 2, 3])
170+
171+
def test_list_strings(self):
172+
self.assertEqual(
173+
parsing.parse_parameter_value('["a", "b", "c"]'), ['a', 'b', 'c'])
174+
175+
def test_list_empty(self):
176+
self.assertEqual(parsing.parse_parameter_value('[]'), [])
177+
178+
def test_dict_simple(self):
179+
self.assertEqual(
180+
parsing.parse_parameter_value('{"key": "value"}'), {'key': 'value'})
181+
182+
def test_dict_nested(self):
183+
self.assertEqual(
184+
parsing.parse_parameter_value('{"a": {"b": 1}}'), {'a': {
185+
'b': 1
186+
}})
187+
188+
def test_dict_empty(self):
189+
self.assertEqual(parsing.parse_parameter_value('{}'), {})
190+
191+
def test_null_value(self):
192+
self.assertIsNone(parsing.parse_parameter_value('null'))
193+
194+
def test_string_simple(self):
195+
self.assertEqual(parsing.parse_parameter_value('hello'), 'hello')
196+
197+
def test_string_with_spaces(self):
198+
self.assertEqual(
199+
parsing.parse_parameter_value('hello world'), 'hello world')
200+
201+
def test_json_quoted_string_preserves_value(self):
202+
self.assertEqual(parsing.parse_parameter_value('"007"'), '007')
203+
self.assertEqual(parsing.parse_parameter_value('"hello"'), 'hello')
204+
205+
def test_string_that_looks_like_number(self):
206+
self.assertEqual(parsing.parse_parameter_value('007'), 7)
207+
self.assertEqual(
208+
parsing.parse_parameter_value('+1-555-1234'), '+1-555-1234')
209+
210+
def test_string_preserves_value(self):
211+
self.assertEqual(
212+
parsing.parse_parameter_value('some-pipeline-name'),
213+
'some-pipeline-name')
214+
215+
135216
if __name__ == '__main__':
136217
unittest.main()

0 commit comments

Comments
 (0)