Skip to content

Commit 90793e4

Browse files
Add unique field support
1 parent 8f58df6 commit 90793e4

File tree

3 files changed

+41
-21
lines changed

3 files changed

+41
-21
lines changed

detection_rules/remote_validation.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from .config import load_current_package_version
2020
from .misc import ClientError, get_elasticsearch_client, get_kibana_client, getdefault
2121
from .rule import TOMLRule, TOMLRuleContents
22+
from .rule_validators import ESQLValidator
2223
from .schemas import definitions
2324

2425

@@ -180,18 +181,14 @@ def request(c: TOMLRuleContents) -> None:
180181
def validate_esql(self, contents: TOMLRuleContents) -> dict[str, Any]:
181182
query = contents.data.query # type: ignore[reportAttributeAccessIssue]
182183
rule_id = contents.data.rule_id
183-
headers = {"accept": "application/json", "content-type": "application/json"}
184-
body = {"query": f"{query} | LIMIT 0"}
185184
if not self.es_client:
186185
raise ValueError("No ES client found")
186+
187+
if not self.kibana_client:
188+
raise ValueError("No Kibana client found")
187189
try:
188-
response = self.es_client.perform_request(
189-
"POST",
190-
"/_query",
191-
headers=headers,
192-
params={"pretty": True},
193-
body=body,
194-
)
190+
validator = ESQLValidator(contents.data.query) # type: ignore[reportIncompatibleMethodOverride]
191+
response = validator.remote_validate_rule_contents(self.kibana_client, self.es_client, contents)
195192
except Exception as exc:
196193
if isinstance(exc, elasticsearch.BadRequestError):
197194
raise ValidationError(f"ES|QL query failed: {exc} for rule: {rule_id}, query: \n{query}") from exc

detection_rules/rule.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -647,6 +647,10 @@ def unique_fields(self) -> Any:
647647
def validate(self, _: "QueryRuleData", __: RuleMeta) -> None:
648648
raise NotImplementedError
649649

650+
def get_unique_field_type(self, __: str) -> None:
651+
"""Used to get unique field types when schema is not used"""
652+
raise NotImplementedError
653+
650654
@cached
651655
def get_required_fields(self, index: str) -> list[dict[str, Any]]:
652656
"""Retrieves fields needed for the query along with type information from the schema."""
@@ -663,7 +667,9 @@ def get_required_fields(self, index: str) -> list[dict[str, Any]]:
663667
# construct integration schemas
664668
packages_manifest = load_integrations_manifests()
665669
integrations_schemas = load_integrations_schemas()
666-
datasets, _ = beats.get_datasets_and_modules(self.ast)
670+
datasets: set[str] = set()
671+
if self.ast:
672+
datasets, _ = beats.get_datasets_and_modules(self.ast)
667673
package_integrations = parse_datasets(list(datasets), packages_manifest)
668674
int_schema: dict[str, Any] = {}
669675
data = {"notify": False}
@@ -691,6 +697,9 @@ def get_required_fields(self, index: str) -> list[dict[str, Any]]:
691697
elif endgame_schema:
692698
field_type = endgame_schema.endgame_schema.get(fld, None)
693699

700+
if not field_type and isinstance(self, ESQLValidator):
701+
field_type = self.get_unique_field_type(fld)
702+
694703
required.append({"name": fld, "type": field_type or "unknown", "ecs": is_ecs})
695704

696705
return sorted(required, key=lambda f: f["name"])

detection_rules/rule_validators.py

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import click
1717
import eql # type: ignore[reportMissingTypeStubs]
1818
import kql # type: ignore[reportMissingTypeStubs]
19+
from elastic_transport import ObjectApiResponse
1920
from elasticsearch import Elasticsearch # type: ignore[reportMissingTypeStubs]
2021
from eql import ast # type: ignore[reportMissingTypeStubs]
2122
from eql.parser import KvTree, LarkToEQL, NodeInfo, TypeHint # type: ignore[reportMissingTypeStubs]
@@ -618,18 +619,29 @@ def validate_rule_type_configurations(self, data: EQLRuleData, meta: RuleMeta) -
618619
class ESQLValidator(QueryValidator):
619620
"""Validate specific fields for ESQL query event types."""
620621

621-
esql_unique_fields: list[str]
622+
def __init__(self, query: str) -> None:
623+
"""Initialize the ESQLValidator with the given query."""
624+
super().__init__(query)
625+
self.esql_unique_fields: list[dict[str, str]] = []
622626

623627
@cached_property
624628
def ast(self) -> None: # type: ignore[reportIncompatibleMethodOverride]
629+
"""There is no AST for ESQL until we have an ESQL parser."""
625630
return None
626631

627632
@cached_property
628633
def unique_fields(self) -> list[str]: # type: ignore[reportIncompatibleMethodOverride]
629634
"""Return a list of unique fields in the query. Requires remote validation to have occurred."""
630-
if not self.esql_unique_fields:
631-
return []
632-
return self.esql_unique_fields
635+
if self.esql_unique_fields:
636+
return [field["name"] for field in self.esql_unique_fields]
637+
return []
638+
639+
def get_unique_field_type(self, field_name: str) -> str | None: # type: ignore[reportIncompatibleMethodOverride]
640+
"""Get the type of the unique field. Requires remote validation to have occurred."""
641+
for field in self.esql_unique_fields:
642+
if field["name"] == field_name:
643+
return field["type"]
644+
return None
633645

634646
def validate(self, rule_data: "QueryRuleData", rule_meta: RuleMeta) -> None: # type: ignore[reportIncompatibleMethodOverride]
635647
"""Validate an ESQL query while checking TOMLRule."""
@@ -648,7 +660,7 @@ def validate(self, rule_data: "QueryRuleData", rule_meta: RuleMeta) -> None: #
648660
elasticsearch_url=misc.getdefault("elasticsearch_url")(),
649661
ignore_ssl_errors=misc.getdefault("ignore_ssl_errors")(),
650662
)
651-
self.remote_validate_rule(
663+
_ = self.remote_validate_rule(
652664
kibana_client,
653665
elastic_client,
654666
rule_data.query,
@@ -774,7 +786,7 @@ def execute_query_against_indices(
774786
test_index_str: str,
775787
log: Callable[[str], None],
776788
delete_indices: bool = True,
777-
) -> list[Any]:
789+
) -> tuple[list[Any], ObjectApiResponse[Any]]:
778790
"""Execute the ESQL query against the test indices on a remote Stack and return the columns."""
779791
try:
780792
log(f"Executing a query against `{test_index_str}`")
@@ -789,7 +801,7 @@ def execute_query_against_indices(
789801

790802
query_column_names = [c["name"] for c in query_columns]
791803
log(f"Got query columns: {', '.join(query_column_names)}")
792-
return query_columns
804+
return query_columns, response
793805

794806
def find_nested_multifields(self, mapping: dict[str, Any], path: str = "") -> list[Any]:
795807
"""Recursively search for nested multi-fields in Elasticsearch mappings."""
@@ -886,9 +898,9 @@ def prepare_mappings(
886898

887899
def remote_validate_rule_contents(
888900
self, kibana_client: Kibana, elastic_client: Elasticsearch, contents: TOMLRuleContents, verbosity: int = 0
889-
) -> None:
901+
) -> ObjectApiResponse[Any]:
890902
"""Remote validate a rule's ES|QL query using an Elastic Stack."""
891-
self.remote_validate_rule(
903+
return self.remote_validate_rule(
892904
kibana_client=kibana_client,
893905
elastic_client=elastic_client,
894906
query=contents.data.query, # type: ignore[reportUnknownVariableType]
@@ -905,7 +917,7 @@ def remote_validate_rule( # noqa: PLR0913
905917
metadata: RuleMeta,
906918
rule_id: str = "",
907919
verbosity: int = 0,
908-
) -> None:
920+
) -> ObjectApiResponse[Any]:
909921
"""Uses remote validation from an Elastic Stack to validate ES|QL a given rule"""
910922

911923
def log(val: str) -> None:
@@ -939,7 +951,7 @@ def log(val: str) -> None:
939951
# Replace all sources with the test indices
940952
query = query.replace(indices_str, full_index_str) # type: ignore[reportUnknownVariableType]
941953

942-
query_columns = self.execute_query_against_indices(elastic_client, query, full_index_str, log) # type: ignore[reportUnknownVariableType]
954+
query_columns, response = self.execute_query_against_indices(elastic_client, query, full_index_str, log) # type: ignore[reportUnknownVariableType]
943955
self.esql_unique_fields = query_columns
944956

945957
# Validate that all fields (columns) are either dynamic fields or correctly mapped
@@ -949,6 +961,8 @@ def log(val: str) -> None:
949961
else:
950962
log("Dynamic column(s) have improper formatting.")
951963

964+
return response
965+
952966

953967
def extract_error_field(source: str, exc: eql.EqlParseError | kql.KqlParseError) -> str | None:
954968
"""Extract the field name from an EQL or KQL parse error."""

0 commit comments

Comments
 (0)