Skip to content

Commit b5d2a75

Browse files
authored
feat(api): filter Attack Paths query results by provider_id (#10118)
1 parent c12f274 commit b5d2a75

File tree

6 files changed

+90
-21
lines changed

6 files changed

+90
-21
lines changed

api/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ All notable changes to the **Prowler API** are documented in this file.
2424
- Attack Paths: Add `graph_data_ready` field to decouple query availability from scan state [(#10089)](https://github.com/prowler-cloud/prowler/pull/10089)
2525
- AI agent guidelines with TDD and testing skills references [(#9925)](https://github.com/prowler-cloud/prowler/pull/9925)
2626
- Attack Paths: Upgrade Cartography from fork 0.126.1 to upstream 0.129.0 and Neo4j driver from 5.x to 6.x [(#10110)](https://github.com/prowler-cloud/prowler/pull/10110)
27+
- Attack Paths: Query results now filtered by provider, preventing future cross-tenant and cross-provider data leakage [(#10118)](https://github.com/prowler-cloud/prowler/pull/10118)
2728

2829
### 🐞 Fixed
2930

api/src/backend/api/attack_paths/queries/aws.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
description="Detect EC2 instances with SSH exposed to the internet that can assume higher-privileged roles to read tagged sensitive S3 buckets despite bucket-level public access blocks.",
1717
provider="aws",
1818
cypher=f"""
19-
CALL apoc.create.vNode(['Internet'], {{id: 'Internet', name: 'Internet'}})
19+
CALL apoc.create.vNode(['Internet'], {{id: 'Internet', name: 'Internet', provider_id: $provider_id}})
2020
YIELD node AS internet
2121
2222
MATCH path_s3 = (aws:AWSAccount {{id: $provider_uid}})--(s3:S3Bucket)--(t:AWSTag)
@@ -32,7 +32,7 @@
3232
3333
MATCH path_assume_role = (ec2)-[p:STS_ASSUMEROLE_ALLOW*1..9]-(r:AWSRole)
3434
35-
CALL apoc.create.vRelationship(internet, 'CAN_ACCESS', {{}}, ec2)
35+
CALL apoc.create.vRelationship(internet, 'CAN_ACCESS', {{provider_id: $provider_id}}, ec2)
3636
YIELD rel AS can_access
3737
3838
UNWIND nodes(path_s3) + nodes(path_ec2) + nodes(path_role) + nodes(path_assume_role) as n
@@ -181,13 +181,13 @@
181181
description="Find EC2 instances flagged as exposed to the internet within the selected account.",
182182
provider="aws",
183183
cypher=f"""
184-
CALL apoc.create.vNode(['Internet'], {{id: 'Internet', name: 'Internet'}})
184+
CALL apoc.create.vNode(['Internet'], {{id: 'Internet', name: 'Internet', provider_id: $provider_id}})
185185
YIELD node AS internet
186186
187187
MATCH path = (aws:AWSAccount {{id: $provider_uid}})--(ec2:EC2Instance)
188188
WHERE ec2.exposed_internet = true
189189
190-
CALL apoc.create.vRelationship(internet, 'CAN_ACCESS', {{}}, ec2)
190+
CALL apoc.create.vRelationship(internet, 'CAN_ACCESS', {{provider_id: $provider_id}}, ec2)
191191
YIELD rel AS can_access
192192
193193
UNWIND nodes(path) as n
@@ -205,15 +205,15 @@
205205
description="Find internet-facing resources associated with security groups that allow inbound access from '0.0.0.0/0'.",
206206
provider="aws",
207207
cypher=f"""
208-
CALL apoc.create.vNode(['Internet'], {{id: 'Internet', name: 'Internet'}})
208+
CALL apoc.create.vNode(['Internet'], {{id: 'Internet', name: 'Internet', provider_id: $provider_id}})
209209
YIELD node AS internet
210210
211211
// Match EC2 instances that are internet-exposed with open security groups (0.0.0.0/0)
212212
MATCH path_ec2 = (aws:AWSAccount {{id: $provider_uid}})--(ec2:EC2Instance)--(sg:EC2SecurityGroup)--(ipi:IpPermissionInbound)--(ir:IpRange)
213213
WHERE ec2.exposed_internet = true
214214
AND ir.range = "0.0.0.0/0"
215215
216-
CALL apoc.create.vRelationship(internet, 'CAN_ACCESS', {{}}, ec2)
216+
CALL apoc.create.vRelationship(internet, 'CAN_ACCESS', {{provider_id: $provider_id}}, ec2)
217217
YIELD rel AS can_access
218218
219219
UNWIND nodes(path_ec2) as n
@@ -231,13 +231,13 @@
231231
description="Find Classic Load Balancers exposed to the internet along with their listeners.",
232232
provider="aws",
233233
cypher=f"""
234-
CALL apoc.create.vNode(['Internet'], {{id: 'Internet', name: 'Internet'}})
234+
CALL apoc.create.vNode(['Internet'], {{id: 'Internet', name: 'Internet', provider_id: $provider_id}})
235235
YIELD node AS internet
236236
237237
MATCH path = (aws:AWSAccount {{id: $provider_uid}})--(elb:LoadBalancer)--(listener:ELBListener)
238238
WHERE elb.exposed_internet = true
239239
240-
CALL apoc.create.vRelationship(internet, 'CAN_ACCESS', {{}}, elb)
240+
CALL apoc.create.vRelationship(internet, 'CAN_ACCESS', {{provider_id: $provider_id}}, elb)
241241
YIELD rel AS can_access
242242
243243
UNWIND nodes(path) as n
@@ -255,13 +255,13 @@
255255
description="Find ELBv2 load balancers exposed to the internet along with their listeners.",
256256
provider="aws",
257257
cypher=f"""
258-
CALL apoc.create.vNode(['Internet'], {{id: 'Internet', name: 'Internet'}})
258+
CALL apoc.create.vNode(['Internet'], {{id: 'Internet', name: 'Internet', provider_id: $provider_id}})
259259
YIELD node AS internet
260260
261261
MATCH path = (aws:AWSAccount {{id: $provider_uid}})--(elbv2:LoadBalancerV2)--(listener:ELBV2Listener)
262262
WHERE elbv2.exposed_internet = true
263263
264-
CALL apoc.create.vRelationship(internet, 'CAN_ACCESS', {{}}, elbv2)
264+
CALL apoc.create.vRelationship(internet, 'CAN_ACCESS', {{provider_id: $provider_id}}, elbv2)
265265
YIELD rel AS can_access
266266
267267
UNWIND nodes(path) as n
@@ -279,7 +279,7 @@
279279
description="Given a public IP address, find the related AWS resource and its adjacent node within the selected account.",
280280
provider="aws",
281281
cypher=f"""
282-
CALL apoc.create.vNode(['Internet'], {{id: 'Internet', name: 'Internet'}})
282+
CALL apoc.create.vNode(['Internet'], {{id: 'Internet', name: 'Internet', provider_id: $provider_id}})
283283
YIELD node AS internet
284284
285285
CALL () {{
@@ -302,7 +302,7 @@
302302
303303
WITH path, x, internet
304304
305-
CALL apoc.create.vRelationship(internet, 'CAN_ACCESS', {{}}, x)
305+
CALL apoc.create.vRelationship(internet, 'CAN_ACCESS', {{provider_id: $provider_id}}, x)
306306
YIELD rel AS can_access
307307
308308
UNWIND nodes(path) as n

api/src/backend/api/attack_paths/views_helpers.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def prepare_query_parameters(
3535
definition: AttackPathsQueryDefinition,
3636
provided_parameters: dict[str, Any],
3737
provider_uid: str,
38+
provider_id: str,
3839
) -> dict[str, Any]:
3940
parameters = dict(provided_parameters or {})
4041
expected_names = {parameter.name for parameter in definition.parameters}
@@ -56,6 +57,7 @@ def prepare_query_parameters(
5657

5758
clean_parameters = {
5859
"provider_uid": str(provider_uid),
60+
"provider_id": str(provider_id),
5961
}
6062

6163
for definition_parameter in definition.parameters:
@@ -82,11 +84,12 @@ def execute_attack_paths_query(
8284
database_name: str,
8385
definition: AttackPathsQueryDefinition,
8486
parameters: dict[str, Any],
87+
provider_id: str,
8588
) -> dict[str, Any]:
8689
try:
8790
with graph_database.get_session(database_name) as session:
8891
result = session.run(definition.cypher, parameters)
89-
return _serialize_graph(result.graph())
92+
return _serialize_graph(result.graph(), provider_id)
9093

9194
except graph_database.GraphDatabaseQueryException as exc:
9295
logger.error(f"Query failed for Attack Paths query `{definition.id}`: {exc}")
@@ -95,9 +98,14 @@ def execute_attack_paths_query(
9598
)
9699

97100

98-
def _serialize_graph(graph):
101+
def _serialize_graph(graph, provider_id: str):
99102
nodes = []
103+
kept_node_ids = set()
100104
for node in graph.nodes:
105+
if node._properties.get("provider_id") != provider_id:
106+
continue
107+
108+
kept_node_ids.add(node.element_id)
101109
nodes.append(
102110
{
103111
"id": node.element_id,
@@ -108,6 +116,15 @@ def _serialize_graph(graph):
108116

109117
relationships = []
110118
for relationship in graph.relationships:
119+
if relationship._properties.get("provider_id") != provider_id:
120+
continue
121+
122+
if (
123+
relationship.start_node.element_id not in kept_node_ids
124+
or relationship.end_node.element_id not in kept_node_ids
125+
):
126+
continue
127+
111128
relationships.append(
112129
{
113130
"id": relationship.element_id,

api/src/backend/api/tests/test_attack_paths.py

Lines changed: 49 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,11 @@ def test_prepare_query_parameters_includes_provider_and_casts(
3838
definition,
3939
{"limit": "5"},
4040
provider_uid="123456789012",
41+
provider_id="test-provider-id",
4142
)
4243

4344
assert result["provider_uid"] == "123456789012"
45+
assert result["provider_id"] == "test-provider-id"
4446
assert result["limit"] == 5
4547

4648

@@ -57,7 +59,9 @@ def test_prepare_query_parameters_validates_names(
5759
definition = attack_paths_query_definition_factory()
5860

5961
with pytest.raises(ValidationError) as exc:
60-
views_helpers.prepare_query_parameters(definition, provided, provider_uid="1")
62+
views_helpers.prepare_query_parameters(
63+
definition, provided, provider_uid="1", provider_id="p1"
64+
)
6165

6266
assert expected_message in str(exc.value)
6367

@@ -72,6 +76,7 @@ def test_prepare_query_parameters_validates_cast(
7276
definition,
7377
{"limit": "not-an-int"},
7478
provider_uid="1",
79+
provider_id="p1",
7580
)
7681

7782
assert "Invalid value" in str(exc.value)
@@ -90,11 +95,13 @@ def test_execute_attack_paths_query_serializes_graph(
9095
)
9196
parameters = {"provider_uid": "123"}
9297

98+
provider_id = "test-provider-123"
9399
node = attack_paths_graph_stub_classes.Node(
94100
element_id="node-1",
95101
labels=["AWSAccount"],
96102
properties={
97103
"name": "account",
104+
"provider_id": provider_id,
98105
"complex": {
99106
"items": [
100107
attack_paths_graph_stub_classes.NativeValue("value"),
@@ -103,14 +110,17 @@ def test_execute_attack_paths_query_serializes_graph(
103110
},
104111
},
105112
)
113+
node_2 = attack_paths_graph_stub_classes.Node(
114+
"node-2", ["RDSInstance"], {"provider_id": provider_id}
115+
)
106116
relationship = attack_paths_graph_stub_classes.Relationship(
107117
element_id="rel-1",
108118
rel_type="OWNS",
109119
start_node=node,
110-
end_node=attack_paths_graph_stub_classes.Node("node-2", ["RDSInstance"], {}),
111-
properties={"weight": 1},
120+
end_node=node_2,
121+
properties={"weight": 1, "provider_id": provider_id},
112122
)
113-
graph = SimpleNamespace(nodes=[node], relationships=[relationship])
123+
graph = SimpleNamespace(nodes=[node, node_2], relationships=[relationship])
114124

115125
run_result = MagicMock()
116126
run_result.graph.return_value = graph
@@ -129,7 +139,7 @@ def test_execute_attack_paths_query_serializes_graph(
129139
return_value=session_ctx,
130140
) as mock_get_session:
131141
result = views_helpers.execute_attack_paths_query(
132-
database_name, definition, parameters
142+
database_name, definition, parameters, provider_id=provider_id
133143
)
134144

135145
mock_get_session.assert_called_once_with(database_name)
@@ -169,7 +179,40 @@ def __exit__(self, exc_type, exc, tb):
169179
):
170180
with pytest.raises(APIException):
171181
views_helpers.execute_attack_paths_query(
172-
database_name, definition, parameters
182+
database_name, definition, parameters, provider_id="test-provider-123"
173183
)
174184

175185
mock_logger.error.assert_called_once()
186+
187+
188+
def test_serialize_graph_filters_by_provider_id(attack_paths_graph_stub_classes):
189+
provider_id = "provider-keep"
190+
191+
node_keep = attack_paths_graph_stub_classes.Node(
192+
"n1", ["AWSAccount"], {"provider_id": provider_id}
193+
)
194+
node_drop = attack_paths_graph_stub_classes.Node(
195+
"n2", ["AWSAccount"], {"provider_id": "provider-other"}
196+
)
197+
198+
rel_keep = attack_paths_graph_stub_classes.Relationship(
199+
"r1", "OWNS", node_keep, node_keep, {"provider_id": provider_id}
200+
)
201+
rel_drop_by_provider = attack_paths_graph_stub_classes.Relationship(
202+
"r2", "OWNS", node_keep, node_drop, {"provider_id": "provider-other"}
203+
)
204+
rel_drop_orphaned = attack_paths_graph_stub_classes.Relationship(
205+
"r3", "OWNS", node_keep, node_drop, {"provider_id": provider_id}
206+
)
207+
208+
graph = SimpleNamespace(
209+
nodes=[node_keep, node_drop],
210+
relationships=[rel_keep, rel_drop_by_provider, rel_drop_orphaned],
211+
)
212+
213+
result = views_helpers._serialize_graph(graph, provider_id)
214+
215+
assert len(result["nodes"]) == 1
216+
assert result["nodes"][0]["id"] == "n1"
217+
assert len(result["relationships"]) == 1
218+
assert result["relationships"][0]["id"] == "r1"

api/src/backend/api/tests/test_views.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3995,15 +3995,18 @@ def test_run_attack_paths_query_returns_graph(
39953995
assert response.status_code == status.HTTP_200_OK
39963996
mock_get_query.assert_called_once_with("aws-rds")
39973997
mock_get_db_name.assert_called_once_with(attack_paths_scan.provider.tenant_id)
3998+
provider_id = str(attack_paths_scan.provider_id)
39983999
mock_prepare.assert_called_once_with(
39994000
query_definition,
40004001
{},
40014002
attack_paths_scan.provider.uid,
4003+
provider_id,
40024004
)
40034005
mock_execute.assert_called_once_with(
40044006
expected_db_name,
40054007
query_definition,
40064008
prepared_parameters,
4009+
provider_id,
40074010
)
40084011
mock_clear_cache.assert_called_once_with(expected_db_name)
40094012
result = response.json()["data"]

api/src/backend/api/v1/views.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2505,14 +2505,19 @@ def run_attack_paths_query(self, request, pk=None):
25052505
database_name = graph_database.get_database_name(
25062506
attack_paths_scan.provider.tenant_id
25072507
)
2508+
provider_id = str(attack_paths_scan.provider_id)
25082509
parameters = attack_paths_views_helpers.prepare_query_parameters(
25092510
query_definition,
25102511
serializer.validated_data.get("parameters", {}),
25112512
attack_paths_scan.provider.uid,
2513+
provider_id,
25122514
)
25132515

25142516
graph = attack_paths_views_helpers.execute_attack_paths_query(
2515-
database_name, query_definition, parameters
2517+
database_name,
2518+
query_definition,
2519+
parameters,
2520+
provider_id,
25162521
)
25172522
graph_database.clear_cache(database_name)
25182523

0 commit comments

Comments
 (0)