diff --git a/detection_rules/rule.py b/detection_rules/rule.py index 474d49902a7..34a8ff6e6d3 100644 --- a/detection_rules/rule.py +++ b/detection_rules/rule.py @@ -990,12 +990,25 @@ def validates_esql_data(self, data: dict[str, Any], **_: Any) -> None: # Enforce KEEP command for ESQL rules # Match | followed by optional whitespace/newlines and then 'keep' - keep_pattern = re.compile(r"\|\s*keep\b", re.IGNORECASE | re.DOTALL) - if not keep_pattern.search(query_lower): + keep_pattern = re.compile(r"\|\s*keep\b\s+([^\|]+)", re.IGNORECASE | re.DOTALL) + keep_match = keep_pattern.search(query_lower) + if not keep_match: raise EsqlSemanticError( f"Rule: {data['name']} does not contain a 'keep' command -> Add a 'keep' command to the query." ) + # Ensure that keep clause includes metadata fields on non-aggregate queries + aggregate_pattern = re.compile(r"\bstats\b.*\bby\b", re.IGNORECASE | re.DOTALL) + if not aggregate_pattern.search(query_lower): + keep_fields = keep_match.group(1) + required_metadata = {"_id", "_version", "_index"} + if not required_metadata.issubset(set(map(str.strip, keep_fields.split(",")))): + raise EsqlSemanticError( + f"Rule: {data['name']} contains a keep clause without" + f" metadata fields '_id', '_version', and '_index' ->" + f" Add '_id, _version, _index' to the keep command." + ) + @dataclass(frozen=True, kw_only=True) class ThreatMatchRuleData(QueryRuleData): diff --git a/pyproject.toml b/pyproject.toml index b6eb932824e..350f74e650e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "detection_rules" -version = "1.5.22" +version = "1.5.23" description = "Detection Rules is the home for rules used by Elastic Security. This repository is used for the development, maintenance, testing, validation, and release of rules for Elastic Security’s Detection Engine." readme = "README.md" requires-python = ">=3.12" diff --git a/tests/test_rules_remote.py b/tests/test_rules_remote.py index 507654411fb..be6b027f62c 100644 --- a/tests/test_rules_remote.py +++ b/tests/test_rules_remote.py @@ -39,7 +39,7 @@ def test_esql_related_integrations(self): and aws.cloudtrail.user_identity.arn is not null and aws.cloudtrail.user_identity.type == "IAMUser" | keep - aws.cloudtrail.user_identity.type + aws.cloudtrail.user_identity.type, _id, _version, _index """ rule = RuleCollection().load_dict(production_rule) related_integrations = rule.contents.to_api_format()["related_integrations"] @@ -61,7 +61,7 @@ def test_esql_non_dataset_package_related_integrations(self): user.id, gen_ai.request.model.id, cloud.account.id, - gen_ai.response.error_code + gen_ai.response.error_code, _id, _version, _index """ rule = RuleCollection().load_dict(production_rule) related_integrations = rule.contents.to_api_format()["related_integrations"] @@ -81,7 +81,7 @@ def test_esql_event_dataset_schema_error(self): and event.dataset in ("aws.billing") and aws.cloudtrail.user_identity.type == "IAMUser" | keep - aws.cloudtrail.user_identity.type + aws.cloudtrail.user_identity.type, _id, _version, _index """ with pytest.raises(EsqlSchemaError): _ = RuleCollection().load_dict(production_rule) @@ -99,7 +99,7 @@ def test_esql_type_mismatch_error(self): and event.dataset in ("aws.cloudtrail", "aws.billing") and aws.cloudtrail.user_identity.type == 5 | keep - aws.cloudtrail.user_identity.type + aws.cloudtrail.user_identity.type, _id, _version, _index """ with pytest.raises(EsqlTypeMismatchError): _ = RuleCollection().load_dict(production_rule) @@ -117,7 +117,7 @@ def test_esql_syntax_error(self): and event.dataset in ("aws.cloudtrail", "aws.billing") and aws.cloudtrail.user_identity.type = "IAMUser" | keep - aws.cloudtrail.user_identity.type + aws.cloudtrail.user_identity.type, _id, _version, _index """ with pytest.raises(EsqlSyntaxError): _ = RuleCollection().load_dict(production_rule) @@ -134,7 +134,7 @@ def test_esql_filtered_index(self): | where @timestamp > now() - 30 minutes and aws.cloudtrail.user_identity.type == "IAMUser" | keep - aws.* + aws.*, _id, _version, _index """ _ = RuleCollection().load_dict(production_rule) @@ -150,7 +150,7 @@ def test_esql_filtered_index_error(self): | where @timestamp > now() - 30 minutes and aws.cloudtrail.user_identity.type == "IAMUser" | keep - aws.cloudtrail.user_identity.type + aws.cloudtrail.user_identity.type, _id, _version, _index """ with pytest.raises(EsqlSchemaError): _ = RuleCollection().load_dict(production_rule) @@ -167,7 +167,7 @@ def test_new_line_split_index(self): | where @timestamp > now() - 30 minutes and aws.cloudtrail.user_identity.type == "IAMUser" | keep - aws.* + aws.*, _id, _version, _index """ _ = RuleCollection().load_dict(production_rule) @@ -179,7 +179,7 @@ def test_esql_endpoint_alerts_index(self): production_rule["rule"]["query"] = """ from logs-endpoint.alerts-* | where event.code in ("malicious_file", "memory_signature", "shellcode_thread") and rule.name is not null - | keep host.id, rule.name, event.code + | keep host.id, rule.name, event.code, _id, _version, _index | stats Esql.host_id_count_distinct = count_distinct(host.id) by rule.name, event.code | where Esql.host_id_count_distinct >= 3 """ @@ -193,7 +193,7 @@ def test_esql_endpoint_unknown_index(self): production_rule["rule"]["query"] = """ from logs-endpoint.fake-* | where event.code in ("malicious_file", "memory_signature", "shellcode_thread") and rule.name is not null - | keep host.id, rule.name, event.code + | keep host.id, rule.name, event.code, _id, _version, _index | stats Esql.host_id_count_distinct = count_distinct(host.id) by rule.name, event.code | where Esql.host_id_count_distinct >= 3 """ @@ -209,7 +209,7 @@ def test_esql_endpoint_alerts_index_endpoint_fields(self): production_rule["rule"]["query"] = """ from logs-endpoint.alerts-* | where event.code in ("malicious_file", "memory_signature", "shellcode_thread") and rule.name is not null and file.Ext.entry_modified > 0 - | keep host.id, rule.name, event.code, file.Ext.entry_modified + | keep host.id, rule.name, event.code, file.Ext.entry_modified, _id, _version, _index | stats Esql.host_id_count_distinct = count_distinct(host.id) by rule.name, event.code, file.Ext.entry_modified | where Esql.host_id_count_distinct >= 3 """ @@ -228,7 +228,7 @@ def test_esql_filtered_keep(self): production_rule["rule"]["query"] = """ from logs-aws.billing* metadata _id, _version, _index | where @timestamp > now() - 30 minutes and aws.cloudtrail.user_identity.type == "IAMUser" - | keep host.id, rule.name, event.code + | keep host.id, rule.name, event.code, _id, _version, _index | stats Esql.host_id_count_distinct = count_distinct(host.id) by rule.name, event.code | where Esql.host_id_count_distinct >= 3 """ @@ -248,6 +248,6 @@ def test_esql_non_ecs_schema_conflict_resolution(self): and event.outcome == "success" and azure.signinlogs.properties.user_id is not null | keep - event.outcome + event.outcome, _id, _version, _index """ _ = RuleCollection().load_dict(production_rule) diff --git a/tests/test_schemas.py b/tests/test_schemas.py index 1215e4b0fbc..5f5e609eb94 100644 --- a/tests/test_schemas.py +++ b/tests/test_schemas.py @@ -324,7 +324,7 @@ def test_esql_data_validation(self): query = """ FROM logs-windows.powershell_operational* METADATA _id, _version, _index | WHERE event.code == "4104" - | KEEP event.code + | KEEP event.code, _id, _version, _index """ rule_dict["rule"]["query"] = query _ = RuleCollection().load_dict(rule_dict, path=rule_path) @@ -334,7 +334,7 @@ def test_esql_data_validation(self): query = """ FROM logs-windows.powershell_operational* METADATA _id, _index, _version | WHERE event.code == "4104" - | KEEP event.code + | KEEP event.code, _id, _version, _index """ rule_dict["rule"]["query"] = query _ = RuleCollection().load_dict(rule_dict, path=rule_path) @@ -344,7 +344,7 @@ def test_esql_data_validation(self): query = """ FROM logs-windows.powershell_operational* METADATA _foo, _index | WHERE event.code == "4104" - | KEEP event.code + | KEEP event.code, _id, _version, _index """ rule_dict["rule"]["query"] = query _ = RuleCollection().load_dict(rule_dict, path=rule_path)