|
13 | 13 | # limitations under the License. |
14 | 14 |
|
15 | 15 |
|
| 16 | +import argparse |
16 | 17 | import unittest |
17 | 18 | from unittest.mock import MagicMock, patch |
18 | 19 |
|
19 | 20 | from nemo_deploy.deploy_ray import DeployRay |
20 | 21 |
|
| 22 | +# Import the functions from the deploy script |
| 23 | +from scripts.deploy.nlp.deploy_ray_inframework import ( |
| 24 | + json_type, |
| 25 | +) |
| 26 | + |
21 | 27 |
|
22 | 28 | class TestDeployRay(unittest.TestCase): |
23 | 29 | def setUp(self): |
@@ -297,5 +303,42 @@ def test_stop_with_errors(self, mock_logger, mock_serve, mock_ray): |
297 | 303 | mock_logger.warning.assert_any_call("Error during ray.shutdown(): Ray shutdown error") |
298 | 304 |
|
299 | 305 |
|
| 306 | +class TestDeployRayInFrameworkScriptJsonType(unittest.TestCase): |
| 307 | + """Test suite for deploy_ray_inframework.py script's json_type function.""" |
| 308 | + |
| 309 | + def test_json_type_valid_json(self): |
| 310 | + """Test json_type with valid JSON strings.""" |
| 311 | + # Test valid dictionary |
| 312 | + result = json_type('{"key": "value", "number": 42}') |
| 313 | + self.assertEqual(result, {"key": "value", "number": 42}) |
| 314 | + |
| 315 | + # Test valid list |
| 316 | + result = json_type("[1, 2, 3]") |
| 317 | + self.assertEqual(result, [1, 2, 3]) |
| 318 | + |
| 319 | + # Test nested structure |
| 320 | + result = json_type('{"pip": ["numpy", "pandas"], "env_vars": {"PATH": "/usr/bin"}}') |
| 321 | + expected = {"pip": ["numpy", "pandas"], "env_vars": {"PATH": "/usr/bin"}} |
| 322 | + self.assertEqual(result, expected) |
| 323 | + |
| 324 | + def test_json_type_invalid_json(self): |
| 325 | + """Test json_type with invalid JSON strings.""" |
| 326 | + with self.assertRaises(argparse.ArgumentTypeError) as context: |
| 327 | + json_type("not a valid json") |
| 328 | + self.assertIn("Invalid JSON", str(context.exception)) |
| 329 | + |
| 330 | + with self.assertRaises(argparse.ArgumentTypeError) as context: |
| 331 | + json_type('{"incomplete": ') |
| 332 | + self.assertIn("Invalid JSON", str(context.exception)) |
| 333 | + |
| 334 | + def test_json_type_empty_json(self): |
| 335 | + """Test json_type with empty JSON objects.""" |
| 336 | + result = json_type("{}") |
| 337 | + self.assertEqual(result, {}) |
| 338 | + |
| 339 | + result = json_type("[]") |
| 340 | + self.assertEqual(result, []) |
| 341 | + |
| 342 | + |
300 | 343 | if __name__ == "__main__": |
301 | 344 | unittest.main() |
0 commit comments