1616import click
1717import eql # type: ignore[reportMissingTypeStubs]
1818import kql # type: ignore[reportMissingTypeStubs]
19+ from elastic_transport import ObjectApiResponse
1920from elasticsearch import Elasticsearch # type: ignore[reportMissingTypeStubs]
2021from eql import ast # type: ignore[reportMissingTypeStubs]
2122from eql .parser import KvTree , LarkToEQL , NodeInfo , TypeHint # type: ignore[reportMissingTypeStubs]
@@ -618,18 +619,29 @@ def validate_rule_type_configurations(self, data: EQLRuleData, meta: RuleMeta) -
618619class ESQLValidator (QueryValidator ):
619620 """Validate specific fields for ESQL query event types."""
620621
621- esql_unique_fields : list [str ]
622+ def __init__ (self , query : str ) -> None :
623+ """Initialize the ESQLValidator with the given query."""
624+ super ().__init__ (query )
625+ self .esql_unique_fields : list [dict [str , str ]] = []
622626
623627 @cached_property
624628 def ast (self ) -> None : # type: ignore[reportIncompatibleMethodOverride]
629+ """There is no AST for ESQL until we have an ESQL parser."""
625630 return None
626631
627632 @cached_property
628633 def unique_fields (self ) -> list [str ]: # type: ignore[reportIncompatibleMethodOverride]
629634 """Return a list of unique fields in the query. Requires remote validation to have occurred."""
630- if not self .esql_unique_fields :
631- return []
632- return self .esql_unique_fields
635+ if self .esql_unique_fields :
636+ return [field ["name" ] for field in self .esql_unique_fields ]
637+ return []
638+
639+ def get_unique_field_type (self , field_name : str ) -> str | None : # type: ignore[reportIncompatibleMethodOverride]
640+ """Get the type of the unique field. Requires remote validation to have occurred."""
641+ for field in self .esql_unique_fields :
642+ if field ["name" ] == field_name :
643+ return field ["type" ]
644+ return None
633645
634646 def validate (self , rule_data : "QueryRuleData" , rule_meta : RuleMeta ) -> None : # type: ignore[reportIncompatibleMethodOverride]
635647 """Validate an ESQL query while checking TOMLRule."""
@@ -648,7 +660,7 @@ def validate(self, rule_data: "QueryRuleData", rule_meta: RuleMeta) -> None: #
648660 elasticsearch_url = misc .getdefault ("elasticsearch_url" )(),
649661 ignore_ssl_errors = misc .getdefault ("ignore_ssl_errors" )(),
650662 )
651- self .remote_validate_rule (
663+ _ = self .remote_validate_rule (
652664 kibana_client ,
653665 elastic_client ,
654666 rule_data .query ,
@@ -774,7 +786,7 @@ def execute_query_against_indices(
774786 test_index_str : str ,
775787 log : Callable [[str ], None ],
776788 delete_indices : bool = True ,
777- ) -> list [Any ]:
789+ ) -> tuple [ list [Any ], ObjectApiResponse [ Any ] ]:
778790 """Execute the ESQL query against the test indices on a remote Stack and return the columns."""
779791 try :
780792 log (f"Executing a query against `{ test_index_str } `" )
@@ -789,7 +801,7 @@ def execute_query_against_indices(
789801
790802 query_column_names = [c ["name" ] for c in query_columns ]
791803 log (f"Got query columns: { ', ' .join (query_column_names )} " )
792- return query_columns
804+ return query_columns , response
793805
794806 def find_nested_multifields (self , mapping : dict [str , Any ], path : str = "" ) -> list [Any ]:
795807 """Recursively search for nested multi-fields in Elasticsearch mappings."""
@@ -886,9 +898,9 @@ def prepare_mappings(
886898
887899 def remote_validate_rule_contents (
888900 self , kibana_client : Kibana , elastic_client : Elasticsearch , contents : TOMLRuleContents , verbosity : int = 0
889- ) -> None :
901+ ) -> ObjectApiResponse [ Any ] :
890902 """Remote validate a rule's ES|QL query using an Elastic Stack."""
891- self .remote_validate_rule (
903+ return self .remote_validate_rule (
892904 kibana_client = kibana_client ,
893905 elastic_client = elastic_client ,
894906 query = contents .data .query , # type: ignore[reportUnknownVariableType]
@@ -905,7 +917,7 @@ def remote_validate_rule( # noqa: PLR0913
905917 metadata : RuleMeta ,
906918 rule_id : str = "" ,
907919 verbosity : int = 0 ,
908- ) -> None :
920+ ) -> ObjectApiResponse [ Any ] :
909921 """Uses remote validation from an Elastic Stack to validate ES|QL a given rule"""
910922
911923 def log (val : str ) -> None :
@@ -939,7 +951,7 @@ def log(val: str) -> None:
939951 # Replace all sources with the test indices
940952 query = query .replace (indices_str , full_index_str ) # type: ignore[reportUnknownVariableType]
941953
942- query_columns = self .execute_query_against_indices (elastic_client , query , full_index_str , log ) # type: ignore[reportUnknownVariableType]
954+ query_columns , response = self .execute_query_against_indices (elastic_client , query , full_index_str , log ) # type: ignore[reportUnknownVariableType]
943955 self .esql_unique_fields = query_columns
944956
945957 # Validate that all fields (columns) are either dynamic fields or correctly mapped
@@ -949,6 +961,8 @@ def log(val: str) -> None:
949961 else :
950962 log ("Dynamic column(s) have improper formatting." )
951963
964+ return response
965+
952966
953967def extract_error_field (source : str , exc : eql .EqlParseError | kql .KqlParseError ) -> str | None :
954968 """Extract the field name from an EQL or KQL parse error."""
0 commit comments