Skip to content

Commit 7480494

Browse files
authored
Change type for --runtime_env in ray in-fw deployment script (#505)
Signed-off-by: Abhishree <[email protected]>
1 parent ccf1eb9 commit 7480494

File tree

2 files changed

+54
-2
lines changed

2 files changed

+54
-2
lines changed

scripts/deploy/nlp/deploy_ray_inframework.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import argparse
16+
import json
1617
import logging
1718
import multiprocessing
1819

@@ -26,6 +27,14 @@ def get_available_cpus():
2627
return multiprocessing.cpu_count()
2728

2829

30+
def json_type(string):
31+
"""Parse JSON string into a dictionary."""
32+
try:
33+
return json.loads(string)
34+
except json.JSONDecodeError as e:
35+
raise argparse.ArgumentTypeError(f"Invalid JSON: {e}")
36+
37+
2938
def parse_args():
3039
"""Parse command-line arguments for the Ray deployment script."""
3140
parser = argparse.ArgumentParser(description="Deploy a Megatron model using Ray")
@@ -187,9 +196,9 @@ def parse_args():
187196
)
188197
parser.add_argument(
189198
"--runtime_env",
190-
type=dict,
199+
type=json_type,
191200
default={},
192-
help="Runtime environment for the deployment",
201+
help="Runtime environment for the deployment (JSON string)",
193202
)
194203
return parser.parse_args()
195204

tests/unit_tests/deploy/test_deploy_ray.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,17 @@
1313
# limitations under the License.
1414

1515

16+
import argparse
1617
import unittest
1718
from unittest.mock import MagicMock, patch
1819

1920
from nemo_deploy.deploy_ray import DeployRay
2021

22+
# Import the functions from the deploy script
23+
from scripts.deploy.nlp.deploy_ray_inframework import (
24+
json_type,
25+
)
26+
2127

2228
class TestDeployRay(unittest.TestCase):
2329
def setUp(self):
@@ -297,5 +303,42 @@ def test_stop_with_errors(self, mock_logger, mock_serve, mock_ray):
297303
mock_logger.warning.assert_any_call("Error during ray.shutdown(): Ray shutdown error")
298304

299305

306+
class TestDeployRayInFrameworkScriptJsonType(unittest.TestCase):
307+
"""Test suite for deploy_ray_inframework.py script's json_type function."""
308+
309+
def test_json_type_valid_json(self):
310+
"""Test json_type with valid JSON strings."""
311+
# Test valid dictionary
312+
result = json_type('{"key": "value", "number": 42}')
313+
self.assertEqual(result, {"key": "value", "number": 42})
314+
315+
# Test valid list
316+
result = json_type("[1, 2, 3]")
317+
self.assertEqual(result, [1, 2, 3])
318+
319+
# Test nested structure
320+
result = json_type('{"pip": ["numpy", "pandas"], "env_vars": {"PATH": "/usr/bin"}}')
321+
expected = {"pip": ["numpy", "pandas"], "env_vars": {"PATH": "/usr/bin"}}
322+
self.assertEqual(result, expected)
323+
324+
def test_json_type_invalid_json(self):
325+
"""Test json_type with invalid JSON strings."""
326+
with self.assertRaises(argparse.ArgumentTypeError) as context:
327+
json_type("not a valid json")
328+
self.assertIn("Invalid JSON", str(context.exception))
329+
330+
with self.assertRaises(argparse.ArgumentTypeError) as context:
331+
json_type('{"incomplete": ')
332+
self.assertIn("Invalid JSON", str(context.exception))
333+
334+
def test_json_type_empty_json(self):
335+
"""Test json_type with empty JSON objects."""
336+
result = json_type("{}")
337+
self.assertEqual(result, {})
338+
339+
result = json_type("[]")
340+
self.assertEqual(result, [])
341+
342+
300343
if __name__ == "__main__":
301344
unittest.main()

0 commit comments

Comments
 (0)