Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions detection_rules/rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -990,12 +990,25 @@ def validates_esql_data(self, data: dict[str, Any], **_: Any) -> None:

# Enforce KEEP command for ESQL rules
# Match | followed by optional whitespace/newlines and then 'keep'
keep_pattern = re.compile(r"\|\s*keep\b", re.IGNORECASE | re.DOTALL)
if not keep_pattern.search(query_lower):
keep_pattern = re.compile(r"\|\s*keep\b\s+([^\|]+)", re.IGNORECASE | re.DOTALL)
keep_match = keep_pattern.search(query_lower)
if not keep_match:
raise EsqlSemanticError(
f"Rule: {data['name']} does not contain a 'keep' command -> Add a 'keep' command to the query."
)

# Ensure that keep clause includes metadata fields on non-aggregate queries
aggregate_pattern = re.compile(r"\bstats\b.*\bby\b", re.IGNORECASE | re.DOTALL)
if not aggregate_pattern.search(query_lower):
keep_fields = keep_match.group(1)
required_metadata = {"_id", "_version", "_index"}
if not required_metadata.issubset(set(map(str.strip, keep_fields.split(",")))):
raise EsqlSemanticError(
f"Rule: {data['name']} contains a keep clause without"
f" metadata fields '_id', '_version', and '_index' ->"
f" Add '_id, _version, _index' to the keep command."
)


@dataclass(frozen=True, kw_only=True)
class ThreatMatchRuleData(QueryRuleData):
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "detection_rules"
version = "1.5.22"
version = "1.5.23"
description = "Detection Rules is the home for rules used by Elastic Security. This repository is used for the development, maintenance, testing, validation, and release of rules for Elastic Security’s Detection Engine."
readme = "README.md"
requires-python = ">=3.12"
Expand Down
26 changes: 13 additions & 13 deletions tests/test_rules_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def test_esql_related_integrations(self):
and aws.cloudtrail.user_identity.arn is not null
and aws.cloudtrail.user_identity.type == "IAMUser"
| keep
aws.cloudtrail.user_identity.type
aws.cloudtrail.user_identity.type, _id, _version, _index
"""
rule = RuleCollection().load_dict(production_rule)
related_integrations = rule.contents.to_api_format()["related_integrations"]
Expand All @@ -61,7 +61,7 @@ def test_esql_non_dataset_package_related_integrations(self):
user.id,
gen_ai.request.model.id,
cloud.account.id,
gen_ai.response.error_code
gen_ai.response.error_code, _id, _version, _index
"""
rule = RuleCollection().load_dict(production_rule)
related_integrations = rule.contents.to_api_format()["related_integrations"]
Expand All @@ -81,7 +81,7 @@ def test_esql_event_dataset_schema_error(self):
and event.dataset in ("aws.billing")
and aws.cloudtrail.user_identity.type == "IAMUser"
| keep
aws.cloudtrail.user_identity.type
aws.cloudtrail.user_identity.type, _id, _version, _index
"""
with pytest.raises(EsqlSchemaError):
_ = RuleCollection().load_dict(production_rule)
Expand All @@ -99,7 +99,7 @@ def test_esql_type_mismatch_error(self):
and event.dataset in ("aws.cloudtrail", "aws.billing")
and aws.cloudtrail.user_identity.type == 5
| keep
aws.cloudtrail.user_identity.type
aws.cloudtrail.user_identity.type, _id, _version, _index
"""
with pytest.raises(EsqlTypeMismatchError):
_ = RuleCollection().load_dict(production_rule)
Expand All @@ -117,7 +117,7 @@ def test_esql_syntax_error(self):
and event.dataset in ("aws.cloudtrail", "aws.billing")
and aws.cloudtrail.user_identity.type = "IAMUser"
| keep
aws.cloudtrail.user_identity.type
aws.cloudtrail.user_identity.type, _id, _version, _index
"""
with pytest.raises(EsqlSyntaxError):
_ = RuleCollection().load_dict(production_rule)
Expand All @@ -134,7 +134,7 @@ def test_esql_filtered_index(self):
| where @timestamp > now() - 30 minutes
and aws.cloudtrail.user_identity.type == "IAMUser"
| keep
aws.*
aws.*, _id, _version, _index
"""
_ = RuleCollection().load_dict(production_rule)

Expand All @@ -150,7 +150,7 @@ def test_esql_filtered_index_error(self):
| where @timestamp > now() - 30 minutes
and aws.cloudtrail.user_identity.type == "IAMUser"
| keep
aws.cloudtrail.user_identity.type
aws.cloudtrail.user_identity.type, _id, _version, _index
"""
with pytest.raises(EsqlSchemaError):
_ = RuleCollection().load_dict(production_rule)
Expand All @@ -167,7 +167,7 @@ def test_new_line_split_index(self):
| where @timestamp > now() - 30 minutes
and aws.cloudtrail.user_identity.type == "IAMUser"
| keep
aws.*
aws.*, _id, _version, _index
"""
_ = RuleCollection().load_dict(production_rule)

Expand All @@ -179,7 +179,7 @@ def test_esql_endpoint_alerts_index(self):
production_rule["rule"]["query"] = """
from logs-endpoint.alerts-*
| where event.code in ("malicious_file", "memory_signature", "shellcode_thread") and rule.name is not null
| keep host.id, rule.name, event.code
| keep host.id, rule.name, event.code, _id, _version, _index
| stats Esql.host_id_count_distinct = count_distinct(host.id) by rule.name, event.code
| where Esql.host_id_count_distinct >= 3
"""
Expand All @@ -193,7 +193,7 @@ def test_esql_endpoint_unknown_index(self):
production_rule["rule"]["query"] = """
from logs-endpoint.fake-*
| where event.code in ("malicious_file", "memory_signature", "shellcode_thread") and rule.name is not null
| keep host.id, rule.name, event.code
| keep host.id, rule.name, event.code, _id, _version, _index
| stats Esql.host_id_count_distinct = count_distinct(host.id) by rule.name, event.code
| where Esql.host_id_count_distinct >= 3
"""
Expand All @@ -209,7 +209,7 @@ def test_esql_endpoint_alerts_index_endpoint_fields(self):
production_rule["rule"]["query"] = """
from logs-endpoint.alerts-*
| where event.code in ("malicious_file", "memory_signature", "shellcode_thread") and rule.name is not null and file.Ext.entry_modified > 0
| keep host.id, rule.name, event.code, file.Ext.entry_modified
| keep host.id, rule.name, event.code, file.Ext.entry_modified, _id, _version, _index
| stats Esql.host_id_count_distinct = count_distinct(host.id) by rule.name, event.code, file.Ext.entry_modified
| where Esql.host_id_count_distinct >= 3
"""
Expand All @@ -228,7 +228,7 @@ def test_esql_filtered_keep(self):
production_rule["rule"]["query"] = """
from logs-aws.billing* metadata _id, _version, _index
| where @timestamp > now() - 30 minutes and aws.cloudtrail.user_identity.type == "IAMUser"
| keep host.id, rule.name, event.code
| keep host.id, rule.name, event.code, _id, _version, _index
| stats Esql.host_id_count_distinct = count_distinct(host.id) by rule.name, event.code
| where Esql.host_id_count_distinct >= 3
"""
Expand All @@ -248,6 +248,6 @@ def test_esql_non_ecs_schema_conflict_resolution(self):
and event.outcome == "success"
and azure.signinlogs.properties.user_id is not null
| keep
event.outcome
event.outcome, _id, _version, _index
"""
_ = RuleCollection().load_dict(production_rule)
6 changes: 3 additions & 3 deletions tests/test_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ def test_esql_data_validation(self):
query = """
FROM logs-windows.powershell_operational* METADATA _id, _version, _index
| WHERE event.code == "4104"
| KEEP event.code
| KEEP event.code, _id, _version, _index
"""
rule_dict["rule"]["query"] = query
_ = RuleCollection().load_dict(rule_dict, path=rule_path)
Expand All @@ -334,7 +334,7 @@ def test_esql_data_validation(self):
query = """
FROM logs-windows.powershell_operational* METADATA _id, _index, _version
| WHERE event.code == "4104"
| KEEP event.code
| KEEP event.code, _id, _version, _index
"""
rule_dict["rule"]["query"] = query
_ = RuleCollection().load_dict(rule_dict, path=rule_path)
Expand All @@ -344,7 +344,7 @@ def test_esql_data_validation(self):
query = """
FROM logs-windows.powershell_operational* METADATA _foo, _index
| WHERE event.code == "4104"
| KEEP event.code
| KEEP event.code, _id, _version, _index
"""
rule_dict["rule"]["query"] = query
_ = RuleCollection().load_dict(rule_dict, path=rule_path)
Expand Down
Loading