|
4 | 4 |
|
5 | 5 | import pytest |
6 | 6 |
|
| 7 | +from azure.ai.ml._internal._schema.component import NodeType as InternalNodeType |
7 | 8 | from azure.ai.ml._utils.utils import ( |
8 | 9 | _get_mfe_base_url_from_batch_endpoint, |
9 | 10 | dict_eq, |
10 | 11 | get_all_data_binding_expressions, |
| 12 | + get_valid_dot_keys_with_wildcard, |
11 | 13 | is_data_binding_expression, |
12 | 14 | map_single_brackets_and_warn, |
13 | 15 | write_to_shared_file, |
14 | | - get_valid_dot_keys_with_wildcard, |
15 | 16 | ) |
| 17 | +from azure.ai.ml.constants._component import NodeType |
16 | 18 | from azure.ai.ml.entities import BatchEndpoint |
17 | | -from azure.ai.ml.entities._util import convert_ordered_dict_to_dict |
| 19 | +from azure.ai.ml.entities._util import convert_ordered_dict_to_dict, get_type_from_spec |
| 20 | +from azure.ai.ml.exceptions import ValidationException |
18 | 21 |
|
19 | 22 |
|
20 | 23 | @pytest.mark.unittest |
@@ -120,3 +123,17 @@ def test_get_valid_dot_keys_with_wildcard(self): |
120 | 123 | "deep.*.*", |
121 | 124 | validate_func=lambda _root, _parts: _parts[1] == "l1_2", |
122 | 125 | ) == ["deep.l1_2.l2"] |
| 126 | + |
| 127 | + def test_get_type_from_spec_case_insensitive(self): |
| 128 | + """Test that get_type_from_spec normalizes type to lowercase for case-insensitive validation.""" |
| 129 | + valid_keys = [NodeType.COMMAND, InternalNodeType.COMMAND] |
| 130 | + |
| 131 | + test_cases = [ |
| 132 | + ({"type": "command"}, "command"), # lowercase |
| 133 | + ({"type": "Command"}, "command"), # uppercase - should normalize to lowercase |
| 134 | + ({"type": "CommandComponent"}, "CommandComponent"), # remains unchanged as it's not in NodeType |
| 135 | + ] |
| 136 | + |
| 137 | + for data, expected in test_cases: |
| 138 | + result = get_type_from_spec(data, valid_keys=valid_keys) |
| 139 | + assert result == expected |
0 commit comments