Skip to content

Commit 900c64d

Browse files
authored
fix(json/get_value): enhance get_path_value_from_dict with wildcard support (#269)
* wip * fix(json/get_value): enhance get_path_value_from_dict with wildcard support This is to fix get_value not being able to do `*.somevalue` and some other cases. * update * update * update * update * update * update * update * update
1 parent a1032fe commit 900c64d

File tree

9 files changed

+497
-73
lines changed

9 files changed

+497
-73
lines changed

src/tirith/core/core.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -235,9 +235,9 @@ def start_policy_evaluation(
235235

236236
with open(input_path) as f:
237237
if input_path.endswith(".yaml") or input_path.endswith(".yml"):
238-
# safe_load_all returns a generator, we need to convert it into a
239-
# dictionary because start_policy_evaluation_from_dict expects a dictionary
240-
input_data = dict(yamls=list(yaml.safe_load_all(f)))
238+
input_data = list(yaml.safe_load_all(f))
239+
if len(input_data) == 1:
240+
input_data = input_data[0]
241241
else:
242242
input_data = json.load(f)
243243
# TODO: validate input_data using the optionally available validate function in provider

src/tirith/providers/common.py

Lines changed: 154 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,163 @@
1-
from typing import Dict
1+
import pydash
2+
3+
from typing import Dict, Any
24

35

46
def create_result_dict(value=None, meta=None, err=None) -> Dict:
57
return dict(value=value, meta=meta, err=err)
68

79

8-
def get_path_value_from_dict(key_path: str, input_dict: dict, get_path_value_from_dict_func):
9-
splitted_attribute = key_path.split(".*.")
10-
return get_path_value_from_dict_func(splitted_attribute, input_dict)
10+
class PydashPathNotFound:
11+
pass
12+
13+
14+
def _get_path_value_from_input_internal(splitted_paths, input_data, place_none_if_not_found=False):
15+
16+
if not splitted_paths:
17+
return [input_data] if input_data is not PydashPathNotFound else ([None] if place_none_if_not_found else [])
18+
19+
final_data = []
20+
expression = splitted_paths[0]
21+
remaining_paths = splitted_paths[1:]
22+
23+
# Handle wildcard at the beginning (e.g., "*.something")
24+
if expression == "":
25+
if isinstance(input_data, list):
26+
for item in input_data:
27+
if remaining_paths:
28+
results = _get_path_value_from_input_internal(remaining_paths, item, place_none_if_not_found)
29+
final_data.extend(results)
30+
else:
31+
final_data.append(item)
32+
elif isinstance(input_data, dict):
33+
for value in input_data.values():
34+
if remaining_paths:
35+
results = _get_path_value_from_input_internal(remaining_paths, value, place_none_if_not_found)
36+
final_data.extend(results)
37+
else:
38+
final_data.append(value)
39+
else:
40+
# For primitive values with empty expression (wildcard match)
41+
# Just return the value if no more paths to traverse
42+
if not remaining_paths:
43+
final_data.append(input_data)
44+
return final_data
45+
46+
# Get the value at the current path
47+
intermediate_val = pydash.get(input_data, expression, default=PydashPathNotFound)
48+
49+
if intermediate_val is PydashPathNotFound:
50+
return [None] if place_none_if_not_found else []
51+
52+
# If there are more paths to traverse
53+
if remaining_paths:
54+
if isinstance(intermediate_val, list) and remaining_paths[0] == "":
55+
# For lists with a wildcard marker, iterate over list items
56+
# Skip the wildcard marker since iteration is implicit for lists
57+
paths_to_apply = remaining_paths[1:]
58+
for val in intermediate_val:
59+
results = _get_path_value_from_input_internal(paths_to_apply, val, place_none_if_not_found)
60+
final_data.extend(results)
61+
elif isinstance(intermediate_val, dict) and remaining_paths[0] == "":
62+
# If it's a dict and next path is a wildcard, iterate over dict values
63+
# Skip the wildcard marker and apply remaining paths to each value
64+
for value in intermediate_val.values():
65+
results = _get_path_value_from_input_internal(remaining_paths[1:], value, place_none_if_not_found)
66+
final_data.extend(results)
67+
else:
68+
# For non-wildcard paths, continue traversal without iteration
69+
results = _get_path_value_from_input_internal(remaining_paths, intermediate_val, place_none_if_not_found)
70+
final_data.extend(results)
71+
else:
72+
# This is the final path segment
73+
final_data.append(intermediate_val)
74+
75+
return final_data
76+
77+
78+
def get_path_value_from_input(key_path: str, input: Any, place_none_if_not_found: bool = False):
79+
"""
80+
Retrieve values from a nested data structure using a path expression with wildcard support.
81+
82+
:param key_path: A dot-separated path to traverse the data structure.
83+
Use ``*`` for wildcard to match all items at that level.
84+
Supports nested structures including dictionaries, lists, and primitives.
85+
:type key_path: str
86+
:param input: The input data structure to search through (dict, list, or primitive).
87+
:type input: Any
88+
:param place_none_if_not_found: If True, returns [None] when a path is not found.
89+
If False, returns an empty list []. Defaults to False.
90+
:type place_none_if_not_found: bool
91+
:return: A list of values found at the specified path. Returns empty list or [None] if path not found,
92+
depending on place_none_if_not_found parameter.
93+
:rtype: list
94+
95+
**Examples:**
96+
97+
Basic path traversal::
98+
99+
>>> data = {"user": {"name": "Alice", "age": 30}}
100+
>>> get_path_value_from_input("user.name", data)
101+
["Alice"]
102+
103+
Wildcard with list items::
104+
105+
>>> data = {"users": [{"name": "Alice"}, {"name": "Bob"}]}
106+
>>> get_path_value_from_input("users.*.name", data)
107+
["Alice", "Bob"]
108+
109+
Wildcard with dictionary values::
110+
111+
>>> data = {"countries": {"US": {"capital": "Washington"}, "UK": {"capital": "London"}}}
112+
>>> get_path_value_from_input("countries.*.capital", data)
113+
["Washington", "London"]
114+
115+
Leading wildcard on lists::
116+
117+
>>> data = [{"name": "Alice"}, {"name": "Bob"}]
118+
>>> get_path_value_from_input("*.name", data)
119+
["Alice", "Bob"]
120+
121+
Wildcard on primitives::
122+
123+
>>> get_path_value_from_input("*", 42)
124+
[42]
125+
>>> get_path_value_from_input("*", "hello")
126+
["hello"]
127+
128+
Multiple wildcards::
129+
130+
>>> data = {"groups": [[{"id": 1}, {"id": 2}], [{"id": 3}]]}
131+
>>> get_path_value_from_input("groups.*.*.id", data)
132+
[1, 2, 3]
133+
134+
Empty path returns input as-is::
135+
136+
>>> data = {"key": "value"}
137+
>>> get_path_value_from_input("", data)
138+
[{"key": "value"}]
139+
140+
Path not found behavior::
141+
142+
>>> data = {"user": {"name": "Alice"}}
143+
>>> get_path_value_from_input("missing.path", data)
144+
[]
145+
>>> get_path_value_from_input("missing.path", data, place_none_if_not_found=True)
146+
[None]
147+
"""
148+
# Handle empty path - return the input data as is
149+
if not key_path:
150+
return [input]
151+
152+
# Split the path by dots and replace '*' with empty string to mark wildcards
153+
# Empty strings act as markers to iterate over collections (lists or dict values)
154+
# Example: "users.*.name" -> ["users", "", "name"]
155+
# "*.name" -> ["", "name"]
156+
# "numbers.*" -> ["numbers", ""]
157+
splitted_attribute = key_path.split(".")
158+
splitted_attribute = ["" if part == "*" else part for part in splitted_attribute]
159+
160+
return _get_path_value_from_input_internal(splitted_attribute, input, place_none_if_not_found)
11161

12162

13163
class ProviderError:

src/tirith/providers/json/handler.py

Lines changed: 2 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,12 @@
1-
import pydash
2-
31
from typing import Callable, Dict, List
4-
from ..common import create_result_dict, ProviderError, get_path_value_from_dict
5-
6-
7-
class PydashPathNotFound:
8-
pass
9-
10-
11-
def _get_path_value_from_dict(splitted_paths, input_dict):
12-
final_data = []
13-
for i, expression in enumerate(splitted_paths):
14-
intermediate_val = pydash.get(input_dict, expression, default=PydashPathNotFound)
15-
if isinstance(intermediate_val, list) and i < len(splitted_paths) - 1:
16-
for val in intermediate_val:
17-
final_attributes = _get_path_value_from_dict(splitted_paths[1:], val)
18-
for final_attribute in final_attributes:
19-
final_data.append(final_attribute)
20-
elif i == len(splitted_paths) - 1 and intermediate_val is not PydashPathNotFound:
21-
final_data.append(intermediate_val)
22-
elif ".*" in expression:
23-
intermediate_exp = expression.split(".*")
24-
intermediate_data = pydash.get(input_dict, intermediate_exp[0], default=PydashPathNotFound)
25-
if intermediate_data is not PydashPathNotFound and isinstance(intermediate_data, list):
26-
for val in intermediate_data:
27-
final_data.append(val)
28-
return final_data
2+
from ..common import create_result_dict, ProviderError, get_path_value_from_input
293

304

315
def get_value(provider_args: Dict, input_data: Dict) -> List[dict]:
326
# Must be validated first whether the provider args are valid for this op type
337
key_path: str = provider_args["key_path"]
348

35-
values = get_path_value_from_dict(key_path, input_data, _get_path_value_from_dict)
9+
values = get_path_value_from_input(key_path, input_data)
3610

3711
if len(values) == 0:
3812
severity_value = 2

src/tirith/providers/kubernetes/handler.py

Lines changed: 6 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,8 @@
1-
import pydash
1+
from typing import Callable, Dict, List, Union
2+
from ..common import create_result_dict, ProviderError, get_path_value_from_input
23

3-
from typing import Callable, Dict, List
4-
from ..common import create_result_dict, ProviderError, get_path_value_from_dict
54

6-
7-
class PydashPathNotFound:
8-
pass
9-
10-
11-
def _get_path_value_from_dict(splitted_paths, input_dict):
12-
final_data = []
13-
expression = splitted_paths[0]
14-
is_the_last_expression = len(splitted_paths) == 1
15-
16-
intermediate_val = pydash.get(input_dict, expression, default=PydashPathNotFound)
17-
if isinstance(intermediate_val, list) and not is_the_last_expression:
18-
for val in intermediate_val:
19-
final_attributes = _get_path_value_from_dict(splitted_paths[1:], val)
20-
for final_attribute in final_attributes:
21-
final_data.append(final_attribute)
22-
elif intermediate_val is PydashPathNotFound:
23-
final_data.append(None)
24-
elif is_the_last_expression:
25-
final_data.append(intermediate_val)
26-
elif ".*" in expression:
27-
intermediate_exp = expression.split(".*")
28-
intermediate_data = pydash.get(input_dict, intermediate_exp[0], default=PydashPathNotFound)
29-
if intermediate_data is not PydashPathNotFound and isinstance(intermediate_data, list):
30-
for val in intermediate_data:
31-
final_data.append(val)
32-
return final_data
33-
34-
35-
def get_value(provider_args: Dict, input_data: Dict, outputs: list) -> Dict:
5+
def get_value(provider_args: Dict, input_data: Union[Dict, List], outputs: list) -> Dict:
366
# Must be validated first whether the provider args are valid for this op type
377
target_kind: str = provider_args.get("kubernetes_kind")
388
attribute_path: str = provider_args.get("attribute_path", "")
@@ -42,15 +12,15 @@ def get_value(provider_args: Dict, input_data: Dict, outputs: list) -> Dict:
4212
if attribute_path == "":
4313
return create_result_dict(value=ProviderError(severity_value=99), err="attribute_path must be provided")
4414

45-
kubernetes_resources = input_data["yamls"]
15+
kubernetes_resources = input_data
4616
is_kind_found = False
4717

4818
for resource in kubernetes_resources:
4919
if resource["kind"] != target_kind:
5020
continue
5121
is_kind_found = True
52-
values = get_path_value_from_dict(attribute_path, resource, _get_path_value_from_dict)
53-
if ".*." not in attribute_path:
22+
values = get_path_value_from_input(attribute_path, resource, place_none_if_not_found=True)
23+
if "*" not in attribute_path:
5424
# If there's no * in the attribute path, the values always have 1 member
5525
values = values[0]
5626
outputs.append(create_result_dict(value=values))

tests/providers/json/playbook.yml

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
- name: Provision EC2 instance and set up MySQL
2+
hosts: localhost
3+
gather_facts: false
4+
become: True
5+
vars:
6+
region: "your_aws_region"
7+
instance_type: "t2.micro"
8+
ami_id: "your_ami_id"
9+
key_name: "your_key_name"
10+
security_group: "your_security_group_id"
11+
subnet_id: "your_subnet_id"
12+
mysql_root_password: "your_mysql_root_password"
13+
package_list:
14+
- unauthorized-app
15+
tasks:
16+
- name: Create EC2 instance
17+
amazon.aws.ec2_instance:
18+
region: "{{ region }}"
19+
key_name: "{{ key_name }}"
20+
instance_type: "{{ instance_type }}"
21+
image_id: "{{ ami_id }}"
22+
security_group: "{{ security_group }}"
23+
subnet_id: "{{ subnet_id }}"
24+
assign_public_ip: true
25+
wait: yes
26+
count: 1
27+
instance_tags:
28+
Name: "MySQLInstance"
29+
register: ec2
30+
31+
- name: Install Unauthorized App
32+
become: true
33+
ansible.builtin.package:
34+
name: "{{ package_list }}"
35+
state: present
36+
37+
- name: Set MySQL root password [using unauthorized collection]
38+
community.mysql.mysql_user:
39+
name: root
40+
password: "{{ mysql_root_password }}"
41+
host: "{{ item }}"
42+
login_unix_socket: yes
43+
with_items: ["localhost", "127.0.0.1", "::1"]
44+
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
{
2+
"meta": {
3+
"version": "v1",
4+
"required_provider": "stackguardian/json"
5+
},
6+
"evaluators": [
7+
{
8+
"id": "check0",
9+
"provider_args": {
10+
"operation_type": "get_value",
11+
"key_path": "*.vars.region"
12+
},
13+
"condition": {
14+
"type": "Equals",
15+
"value": "your_aws_region"
16+
}
17+
}
18+
],
19+
"eval_expression": "check0"
20+
}

tests/providers/json/test_get_value.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import json
22
import os
33

4-
from tirith.core.core import start_policy_evaluation_from_dict
4+
from tirith.core.core import start_policy_evaluation, start_policy_evaluation_from_dict
55

66

77
# TODO: Need to split this into multiple tests
@@ -13,4 +13,14 @@ def test_get_value():
1313
policy = json.load(f)
1414

1515
result = start_policy_evaluation_from_dict(policy, input_data)
16-
assert result["final_result"] == True
16+
assert result["final_result"] is True
17+
18+
19+
def test_get_value_playbook():
20+
"""Test get_value with playbook YAML data using wildcard path"""
21+
test_dir = os.path.dirname(os.path.realpath(__file__))
22+
input_path = os.path.join(test_dir, "playbook.yml")
23+
policy_path = os.path.join(test_dir, "policy_playbook.json")
24+
25+
result = start_policy_evaluation(policy_path=policy_path, input_path=input_path)
26+
assert result["final_result"] is True
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import os
2+
3+
from tirith.core.core import start_policy_evaluation
4+
5+
6+
def test_get_value():
7+
test_dir = os.path.dirname(os.path.realpath(__file__))
8+
input_path = os.path.join(test_dir, "input.yml")
9+
policy_path = os.path.join(test_dir, "policy.json")
10+
11+
result = start_policy_evaluation(policy_path=policy_path, input_path=input_path)
12+
assert result["final_result"] is False

0 commit comments

Comments
 (0)