Skip to content

Commit a5c100a

Browse files
[Bug] Add unit tests and fix Alert Suppression schema validation for ThresholdQueryRuleData (#5196)
* Add schema validation for AlertSuppressionMapping * Add support for indicator match alert suppression * Add unit tests * Update order and remove validates_schema method * Add comments * Add test for query rule duration only
1 parent ebb7bb5 commit a5c100a

File tree

3 files changed

+231
-8
lines changed

3 files changed

+231
-8
lines changed

detection_rules/rule.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1074,10 +1074,17 @@ def validate(self, meta: RuleMeta) -> None: # noqa: ARG002
10741074

10751075
# All of the possible rule types
10761076
# Sort inverse of any inheritance - see comment in TOMLRuleContents.to_dict
1077+
# ThresholdQueryRuleData needs to be first in this union to handle cases where there is ambiguity between
1078+
# ThresholdAlertSuppression and AlertSuppressionMapping. Since AlertSuppressionMapping has duration as an
1079+
# optional field, ThresholdAlertSuppression objects can be mistakenly loaded as an AlertSuppressionMapping
1080+
# object with group_by and missing_fields_strategy as missing parameters, resulting in an error.
1081+
# Checking the type against ThresholdQueryRuleData first in the union prevent this from occurring.
1082+
# Please also keep issue 1141 in mind when handling union schemas.
1083+
10771084
AnyRuleData = (
1078-
EQLRuleData
1085+
ThresholdQueryRuleData
1086+
| EQLRuleData
10791087
| ESQLRuleData
1080-
| ThresholdQueryRuleData
10811088
| ThreatMatchRuleData
10821089
| MachineLearningRuleData
10831090
| QueryRuleData

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "detection_rules"
3-
version = "1.4.11"
3+
version = "1.4.12"
44
description = "Detection Rules is the home for rules used by Elastic Security. This repository is used for the development, maintenance, testing, validation, and release of rules for Elastic Security’s Detection Engine."
55
readme = "README.md"
66
requires-python = ">=3.12"

tests/test_python_library.py

Lines changed: 221 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,10 @@
33
# 2.0; you may not use this file except in compliance with the Elastic License
44
# 2.0.
55

6+
from typing import Any
7+
68
import eql
9+
from marshmallow import ValidationError
710

811
from detection_rules.rule_loader import RuleCollection
912

@@ -22,26 +25,46 @@ def mk_metadata(integrations: list[str], comments: str = "Test metadata") -> dic
2225
}
2326

2427

25-
def mk_rule(
28+
def mk_rule( # noqa: PLR0913
2629
*,
2730
name: str,
2831
rule_id: str,
2932
description: str,
3033
risk_score: int,
3134
query: str,
32-
) -> dict:
35+
language: str = "eql",
36+
query_type: str = "eql",
37+
threshold: dict[str, Any] | None = None,
38+
alert_suppression: dict[str, Any] | None = None,
39+
index: list[str] | None = None,
40+
threat_language: str | None = None,
41+
threat_index: list[str] | None = None,
42+
threat_indicator_path: str | None = None,
43+
threat_mapping: list[Any] | None = None,
44+
) -> dict[str, Any]:
3345
"""Create rule dictionary."""
34-
return {
46+
rule = {
3547
"author": ["Elastic"],
3648
"description": description,
37-
"language": "eql",
49+
"language": language,
3850
"name": name,
3951
"risk_score": risk_score,
4052
"rule_id": rule_id,
4153
"severity": "low",
42-
"type": "eql",
54+
"type": query_type,
4355
"query": query,
56+
"alert_suppression": alert_suppression,
4457
}
58+
if threshold is not None:
59+
rule["threshold"] = threshold
60+
if query_type == "threat_match":
61+
rule["index"] = index
62+
rule["threat_language"] = threat_language
63+
rule["threat_index"] = threat_index
64+
rule["threat_indicator_path"] = threat_indicator_path
65+
rule["threat_mapping"] = threat_mapping
66+
67+
return rule
4568

4669

4770
class TestEQLInSet(BaseRuleTest):
@@ -283,3 +306,196 @@ def test_sequence_datasetless_subquery_with_metadata_integration_valid(self) ->
283306
),
284307
}
285308
rc.load_dict(rule)
309+
310+
311+
class TestAlertSuppressionValidation(BaseRuleTest):
312+
"""Tests for alert_suppression field validation in rules."""
313+
314+
def test_threshold_rule_duration(self) -> None:
315+
"""Test that a threshold rule with alert_suppression with just duration validates correctly."""
316+
rc = RuleCollection()
317+
query = """
318+
process.name: \"test\"
319+
"""
320+
rule_dict: dict[str, Any] = {
321+
"metadata": mk_metadata(
322+
["endpoint", "windows"], comments="New fields added: required_fields, related_integrations, setup"
323+
),
324+
"rule": mk_rule(
325+
name="Fake Test Rule",
326+
rule_id="4fffae5d-8b7d-4e48-88b1-979ed42fd9a3",
327+
description="Test Rule.",
328+
risk_score=47,
329+
query=query,
330+
language="kuery",
331+
query_type="threshold",
332+
threshold={"field": [], "value": 200, "cardinality": []},
333+
alert_suppression={"duration": {"value": 5, "unit": "h"}},
334+
),
335+
}
336+
_ = rc.load_dict(rule_dict)
337+
338+
def test_query_rule_duration(self) -> None:
339+
"""Test that a query rule with alert_suppression with group_by and missing_fields_strategy validates correctly."""
340+
rc = RuleCollection()
341+
query = """
342+
process.name: \"test\"
343+
"""
344+
rule_dict: dict[str, Any] = {
345+
"metadata": mk_metadata(
346+
["endpoint", "windows"], comments="New fields added: required_fields, related_integrations, setup"
347+
),
348+
"rule": mk_rule(
349+
name="Fake Test Rule",
350+
rule_id="4fffae5d-8b7d-4e48-88b1-979ed42fd9a3",
351+
description="Test Rule.",
352+
risk_score=47,
353+
query=query,
354+
language="kuery",
355+
query_type="query",
356+
threshold=None,
357+
alert_suppression={"duration": {"value": 5, "unit": "h"}},
358+
),
359+
}
360+
with self.assertRaises((ValidationError, TypeError)):
361+
_ = rc.load_dict(rule_dict)
362+
363+
def test_query_rule_group_by_missing_fields(self) -> None:
364+
"""Test that a query rule with alert_suppression with group_by and missing_fields_strategy validates correctly."""
365+
rc = RuleCollection()
366+
query = """
367+
process.name: \"test\"
368+
"""
369+
rule_dict: dict[str, Any] = {
370+
"metadata": mk_metadata(
371+
["endpoint", "windows"], comments="New fields added: required_fields, related_integrations, setup"
372+
),
373+
"rule": mk_rule(
374+
name="Fake Test Rule",
375+
rule_id="4fffae5d-8b7d-4e48-88b1-979ed42fd9a3",
376+
description="Test Rule.",
377+
risk_score=47,
378+
query=query,
379+
language="kuery",
380+
query_type="query",
381+
threshold=None,
382+
alert_suppression={"group_by": ["process.id"], "missing_fields_strategy": "suppress"},
383+
),
384+
}
385+
_ = rc.load_dict(rule_dict)
386+
387+
def test_query_rule_group_by(self) -> None:
388+
"""Test that a query rule with alert_suppression with just group_by is not valid."""
389+
rc = RuleCollection()
390+
query = """
391+
process.name: \"test\"
392+
"""
393+
rule_dict: dict[str, Any] = {
394+
"metadata": mk_metadata(
395+
["endpoint", "windows"], comments="New fields added: required_fields, related_integrations, setup"
396+
),
397+
"rule": mk_rule(
398+
name="Fake Test Rule",
399+
rule_id="4fffae5d-8b7d-4e48-88b1-979ed42fd9a3",
400+
description="Test Rule.",
401+
risk_score=47,
402+
query=query,
403+
language="kuery",
404+
query_type="query",
405+
threshold=None,
406+
alert_suppression={"group_by": ["process.id"]},
407+
),
408+
}
409+
with self.assertRaises((ValidationError, TypeError)):
410+
_ = rc.load_dict(rule_dict)
411+
412+
def test_query_rule_missing_fields_strategy(self) -> None:
413+
"""Test that a query rule with alert_suppression with just missing_fields_strategy is not valid."""
414+
rc = RuleCollection()
415+
query = """
416+
process.name: \"test\"
417+
"""
418+
rule_dict: dict[str, Any] = {
419+
"metadata": mk_metadata(
420+
["endpoint", "windows"], comments="New fields added: required_fields, related_integrations, setup"
421+
),
422+
"rule": mk_rule(
423+
name="Fake Test Rule",
424+
rule_id="4fffae5d-8b7d-4e48-88b1-979ed42fd9a3",
425+
description="Test Rule.",
426+
risk_score=47,
427+
query=query,
428+
language="kuery",
429+
query_type="query",
430+
threshold=None,
431+
alert_suppression={"missing_fields_strategy": "suppress"},
432+
),
433+
}
434+
with self.assertRaises((ValidationError, TypeError)):
435+
_ = rc.load_dict(rule_dict)
436+
437+
def test_threat_match_rule(self) -> None:
438+
"""Test that a threat_match rule with alert_suppression with all fields set is valid."""
439+
rc = RuleCollection()
440+
query = """
441+
process.name: \"test\"
442+
"""
443+
rule_dict: dict[str, Any] = {
444+
"metadata": mk_metadata(
445+
["endpoint", "windows"], comments="New fields added: required_fields, related_integrations, setup"
446+
),
447+
"rule": mk_rule(
448+
name="Fake Test Rule",
449+
rule_id="4fffae5d-8b7d-4e48-88b1-979ed42fd9a3",
450+
description="Test Rule.",
451+
risk_score=47,
452+
query=query,
453+
language="kuery",
454+
query_type="threat_match",
455+
threshold=None,
456+
alert_suppression={
457+
"group_by": ["client.ip"],
458+
"duration": {"value": 12, "unit": "h"},
459+
"missing_fields_strategy": "suppress",
460+
},
461+
index=["logs-*"],
462+
threat_language="kuery",
463+
threat_index=["logs-*"],
464+
threat_indicator_path="threat.indicator",
465+
threat_mapping=[{"entries": [{"field": "client.ip", "type": "mapping", "value": "client.ip"}]}],
466+
),
467+
}
468+
_ = rc.load_dict(rule_dict)
469+
470+
def test_threat_match_rule_missing_fields_duration(self) -> None:
471+
"""Test that a threat_match rule with alert_suppression with missing_fields_strategy and duration is not valid."""
472+
rc = RuleCollection()
473+
query = """
474+
process.name: \"test\"
475+
"""
476+
rule_dict: dict[str, Any] = {
477+
"metadata": mk_metadata(
478+
["endpoint", "windows"], comments="New fields added: required_fields, related_integrations, setup"
479+
),
480+
"rule": mk_rule(
481+
name="Fake Test Rule",
482+
rule_id="4fffae5d-8b7d-4e48-88b1-979ed42fd9a3",
483+
description="Test Rule.",
484+
risk_score=47,
485+
query=query,
486+
language="kuery",
487+
query_type="threat_match",
488+
threshold=None,
489+
alert_suppression={
490+
"duration": {"value": 12, "unit": "h"},
491+
"missing_fields_strategy": "suppress",
492+
},
493+
index=["logs-*"],
494+
threat_language="kuery",
495+
threat_index=["logs-*"],
496+
threat_indicator_path="threat.indicator",
497+
threat_mapping=[{"entries": [{"field": "client.ip", "type": "mapping", "value": "client.ip"}]}],
498+
),
499+
}
500+
with self.assertRaises((ValidationError, TypeError)):
501+
_ = rc.load_dict(rule_dict)

0 commit comments

Comments
 (0)