Skip to content

Commit 898defc

Browse files
committed
[FR] Support Multi-Dataset Sequence Validation
1 parent b4db783 commit 898defc

File tree

2 files changed

+312
-114
lines changed

2 files changed

+312
-114
lines changed

detection_rules/rule_validators.py

Lines changed: 195 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,10 @@
2222
from semver import Version
2323

2424
from . import ecs, endgame
25+
from .beats import get_datasets_and_modules
2526
from .config import CUSTOM_RULES_DIR, load_current_package_version, parse_rules_config
2627
from .custom_schemas import update_auto_generated_schema
27-
from .integrations import get_integration_schema_data, load_integrations_manifests
28+
from .integrations import get_integration_schema_data, load_integrations_manifests, parse_datasets
2829
from .rule import EQLRuleData, QueryRuleData, QueryValidator, RuleMeta, TOMLRuleContents, set_eql_config
2930
from .schemas import get_stack_schemas
3031

@@ -445,89 +446,202 @@ def validate_stack_combos(self, data: QueryRuleData, meta: RuleMeta) -> EQL_ERRO
445446
raise exc
446447
return None
447448

448-
def validate_integration( # noqa: PLR0912
449+
def validate_integration( # noqa: PLR0912, PLR0915, PLR0911
449450
self,
450451
data: QueryRuleData,
451452
meta: RuleMeta,
452453
package_integrations: list[dict[str, Any]],
453454
) -> EQL_ERROR_TYPES | None | ValueError:
454-
"""Validate an EQL query while checking TOMLRule against integration schemas."""
455+
"""Validate an EQL query while checking TOMLRule against integration schemas.
456+
457+
If the EQL query is a sequence, validate each subquery against the schema of the dataset's
458+
integration.package referenced within that subquery. This avoids cross-integration field
459+
mismatches when multiple datasets from the same integration are used in different subqueries.
460+
"""
455461
if meta.query_schema_validation is False or meta.maturity == "deprecated":
456462
# syntax only, which is done via self.ast
457463
return None
458464

459-
error_fields = {}
465+
error_fields: dict[str, dict[str, Any]] = {}
460466
package_schemas: dict[str, Any] = {}
461467

462-
# Initialize package_schemas with a nested structure
463-
for integration_data in package_integrations:
464-
package = integration_data["package"]
465-
integration = integration_data["integration"]
466-
if integration:
467-
package_schemas.setdefault(package, {}).setdefault(integration, {})
468-
else:
469-
package_schemas.setdefault(package, {})
470-
471-
# Process each integration schema
472-
for integration_schema_data in get_integration_schema_data(data, meta, package_integrations):
473-
ecs_version = integration_schema_data["ecs_version"]
474-
package, integration = (
475-
integration_schema_data["package"],
476-
integration_schema_data["integration"],
477-
)
478-
package_version = integration_schema_data["package_version"]
479-
integration_schema = integration_schema_data["schema"]
480-
stack_version = integration_schema_data["stack_version"]
481-
482-
# add non-ecs-schema fields for edge cases not added to the integration
468+
# Function to extract the field name from an error message
469+
def _prepare_integration_schema(schema_dict: dict[str, Any], stack_version: str) -> dict[str, Any]:
470+
"""Add index/custom/endpoint fields to the base integration schema."""
483471
if data.index_or_dataview:
484472
for index_name in data.index_or_dataview:
485-
integration_schema.update(**ecs.flatten(ecs.get_index_schema(index_name)))
473+
schema_dict.update(**ecs.flatten(ecs.get_index_schema(index_name)))
486474

487-
# Add custom schema fields for appropriate stack version
488475
if data.index_or_dataview and CUSTOM_RULES_DIR:
489476
for index_name in data.index_or_dataview:
490-
integration_schema.update(**ecs.flatten(ecs.get_custom_index_schema(index_name, stack_version)))
477+
schema_dict.update(**ecs.flatten(ecs.get_custom_index_schema(index_name, stack_version)))
478+
479+
schema_dict.update(**ecs.flatten(ecs.get_endpoint_schemas()))
480+
return schema_dict
481+
482+
# Function to validate against a list of packaged integrations
483+
def _validate_against_packaged_integrations(
484+
query_text: str,
485+
packaged: list[dict[str, Any]],
486+
trailer_builder: Callable[[str, str | None, str, str, str], str],
487+
join_values: list[Any] | None = None,
488+
*,
489+
accumulate_schemas: bool = True,
490+
) -> EQL_ERROR_TYPES | ValueError | None:
491+
"""Validate a query text against a set of packaged integrations, collect field errors.
492+
493+
- query_text: EQL snippet to validate (full query or a subquery's event query).
494+
- packaged: list of {package, integration} dicts to build schemas for.
495+
- trailer_builder: function to build error trailer text for context.
496+
- join_values: optional join/group-by fields to validate exist in the schema.
497+
"""
498+
for integration_schema_data in get_integration_schema_data(data, meta, packaged):
499+
ecs_version = integration_schema_data["ecs_version"]
500+
package, integration = (
501+
integration_schema_data["package"],
502+
integration_schema_data["integration"],
503+
)
504+
package_version = integration_schema_data["package_version"]
505+
integration_schema = integration_schema_data["schema"]
506+
stack_version = integration_schema_data["stack_version"]
507+
508+
# Prepare schema with index/custom/endpoint additions
509+
integration_schema = _prepare_integration_schema(integration_schema, stack_version)
510+
if accumulate_schemas:
511+
package_schemas.setdefault(package, {}).update(**integration_schema)
512+
513+
# Build trailer and validate the query text
514+
err_trailer = trailer_builder(package, integration, package_version, stack_version, ecs_version)
515+
exc = self.validate_query_text_with_schema(
516+
query_text,
517+
ecs.KqlSchema2Eql(integration_schema),
518+
err_trailer=err_trailer,
519+
min_stack_version=meta.min_stack_version, # type: ignore[reportArgumentType]
520+
)
521+
if isinstance(exc, eql.EqlParseError):
522+
field = extract_error_field(query_text, exc)
523+
error_fields[field or "<unknown field>"] = {
524+
"error": exc,
525+
"trailer": exc.trailer if hasattr(exc, "trailer") else err_trailer, # type: ignore[reportUnknownArgumentType]
526+
"package": package,
527+
"integration": integration,
528+
}
529+
if data.get("notify", False):
530+
print(
531+
f"\nWarning: `{field}` in `{data.name}` not found in schema. "
532+
f"{error_fields[field or '<unknown field>']['trailer']}"
533+
)
534+
elif exc is not None:
535+
return exc
491536

492-
# add endpoint schema fields for multi-line fields
493-
integration_schema.update(**ecs.flatten(ecs.get_endpoint_schemas()))
494-
package_schemas[package].update(**integration_schema)
537+
# Validate join/group-by fields exist in this integration schema (if provided)
538+
for jf in join_values or []:
539+
jf_str = str(jf)
540+
if jf_str not in integration_schema:
541+
trailer = (
542+
f"\n\tJoin field not found in schema.\n\t"
543+
f"package: {package}, integration: {integration}, package_version: {package_version}, "
544+
f"stack: {stack_version}, ecs: {ecs_version}"
545+
)
546+
error_fields[jf_str] = {
547+
"error": ValueError(f"Unknown field: {jf_str}"),
548+
"trailer": trailer,
549+
"package": package,
550+
"integration": integration,
551+
}
495552

496-
eql_schema = ecs.KqlSchema2Eql(integration_schema)
497-
err_trailer = (
498-
f"stack: {stack_version}, integration: {integration},"
499-
f"ecs: {ecs_version}, package: {package}, package_version: {package_version}"
553+
return None
554+
555+
# Function to extract the field name from an error message
556+
def _subquery_trailer_builder(pkg: str, integ: str | None, pkg_ver: str, stk_ver: str, ecs_ver: str) -> str:
557+
return (
558+
f"Subquery schema mismatch. "
559+
f"package: {pkg}, integration: {integ}, package_version: {pkg_ver}, "
560+
f"stack: {stk_ver}, ecs: {ecs_ver}"
500561
)
501562

502-
# Validate the query against the schema
503-
exc = self.validate_query_with_schema(
504-
data=data,
505-
schema=eql_schema,
506-
err_trailer=err_trailer,
507-
min_stack_version=meta.min_stack_version, # type: ignore[reportArgumentType]
563+
# Non-sequence: validate full query against each integration schema
564+
def _full_query_trailer_builder(pkg: str, integ: str | None, pkg_ver: str, stk_ver: str, ecs_ver: str) -> str:
565+
return (
566+
f"Try adding event.module or event.dataset to specify integration module\n\t"
567+
f"Will check against integrations {meta.integration} combined.\n\t"
568+
f"package: {pkg}, integration: {integ}, package_version: {pkg_ver}, stack: {stk_ver}, ecs: {ecs_ver}"
508569
)
509570

510-
if isinstance(exc, eql.EqlParseError):
511-
message = exc.error_msg # type: ignore[reportUnknownVariableType]
512-
if message == "Unknown field" or "Field not recognized" in message:
513-
field = extract_error_field(self.query, exc)
514-
trailer = (
515-
f"\n\tTry adding event.module or data_stream.dataset to specify integration module\n\t"
516-
f"Will check against integrations {meta.integration} combined.\n\t"
517-
f"{package=}, {integration=}, {package_version=}, "
518-
f"{stack_version=}, {ecs_version=}"
519-
)
520-
error_fields[field] = {
521-
"error": exc,
522-
"trailer": trailer,
523-
"package": package,
524-
"integration": integration,
525-
}
526-
if data.get("notify", False):
527-
print(f"\nWarning: `{field}` in `{data.name}` not found in schema. {trailer}")
528-
else:
571+
# Determine if this is a sequence query via rule data flag
572+
if data.is_sequence: # type: ignore[reportAttributeAccessIssue]
573+
sequence: ast.Sequence = self.ast.first # type: ignore[reportAttributeAccessIssue]
574+
575+
did_subquery_validation = False
576+
packages_manifest = load_integrations_manifests()
577+
# Validate each subquery against the corresponding integration.package schema
578+
for subquery in sequence.queries: # type: ignore[reportUnknownVariableType]
579+
# Get datasets used in this subquery only
580+
subquery_datasets, _ = get_datasets_and_modules(subquery) # type: ignore[reportUnknownVariableType]
581+
if not subquery_datasets:
582+
# If dataset isn't specified in the subquery, skip to avoid false positives here
583+
# The stack schema validation will provide generic guidance
584+
continue
585+
586+
# Build subquery-specific package_integrations
587+
subquery_pkg_ints = parse_datasets(list(subquery_datasets), packages_manifest)
588+
589+
# Validate the subquery's event query (without the "by" fields)
590+
subquery_query_str = subquery.query.render() # type: ignore[reportUnknownVariableType]
591+
592+
# Only mark as validated if there are subquery-specific integrations to check
593+
if subquery_pkg_ints:
594+
did_subquery_validation = True
595+
596+
exc = _validate_against_packaged_integrations(
597+
subquery_query_str, # type: ignore[reportUnknownVariableType]
598+
subquery_pkg_ints,
599+
_subquery_trailer_builder,
600+
join_values=list(getattr(subquery, "join_values", []) or []), # type: ignore[reportUnknownVariableType]
601+
accumulate_schemas=False,
602+
)
603+
if exc is not None:
529604
return exc
530605

606+
# If no subquery specified a dataset/module (nothing validated),
607+
# fall back to validating the full query against provided integrations
608+
if not did_subquery_validation:
609+
exc = _validate_against_packaged_integrations(
610+
self.query,
611+
package_integrations,
612+
_full_query_trailer_builder,
613+
join_values=None,
614+
)
615+
if exc is not None:
616+
return exc
617+
618+
# Raise the first error across subqueries
619+
if error_fields:
620+
_, data_ = next(iter(error_fields.items()))
621+
err = data_["error"]
622+
# If it's an EQL error, wrap with trailer for better context
623+
if isinstance(err, eql.EqlParseError):
624+
return err.__class__(
625+
err.error_msg, # type: ignore[reportUnknownArgumentType]
626+
err.line, # type: ignore[reportUnknownArgumentType]
627+
err.column, # type: ignore[reportUnknownArgumentType]
628+
err.source, # type: ignore[reportUnknownArgumentType]
629+
len(err.caret.lstrip()),
630+
trailer=data_["trailer"], # type: ignore[reportUnknownArgumentType]
631+
)
632+
return err
633+
634+
return None
635+
636+
exc = _validate_against_packaged_integrations(
637+
self.query,
638+
package_integrations,
639+
_full_query_trailer_builder,
640+
join_values=None,
641+
)
642+
if exc is not None:
643+
return exc
644+
531645
# Check error fields against schemas of different packages or different integrations
532646
for field, error_data in list(error_fields.items()): # type: ignore[reportUnknownArgumentType]
533647
error_package, error_integration = ( # type: ignore[reportUnknownVariableType]
@@ -562,17 +676,34 @@ def validate_query_with_schema(
562676
min_stack_version: str,
563677
beat_types: list[str] | None = None,
564678
) -> EQL_ERROR_TYPES | ValueError | None:
565-
"""Validate the query against the schema."""
679+
"""Validate the query against the schema (delegates to validate_query_text_with_schema)."""
680+
return self.validate_query_text_with_schema(
681+
self.query,
682+
schema,
683+
err_trailer=err_trailer,
684+
min_stack_version=min_stack_version,
685+
beat_types=beat_types,
686+
)
687+
688+
def validate_query_text_with_schema(
689+
self,
690+
query_text: str,
691+
schema: ecs.KqlSchema2Eql | endgame.EndgameSchema,
692+
err_trailer: str,
693+
min_stack_version: str,
694+
beat_types: list[str] | None = None,
695+
) -> EQL_ERROR_TYPES | ValueError | None:
696+
"""Validate the provided EQL query text against the schema (variant of validate_query_with_schema)."""
566697
try:
567698
config = set_eql_config(min_stack_version)
568699
with config, schema, eql.parser.elasticsearch_syntax, eql.parser.ignore_missing_functions:
569-
_ = eql.parse_query(self.query) # type: ignore[reportUnknownMemberType]
700+
_ = eql.parse_query(query_text) # type: ignore[reportUnknownMemberType]
570701
except eql.EqlParseError as exc:
571702
message = exc.error_msg
572703
trailer = err_trailer
573704
if "Unknown field" in message and beat_types:
574-
trailer = f"\nTry adding event.module or data_stream.dataset to specify beats module\n\n{trailer}"
575-
elif "Field not recognized" in message:
705+
trailer = f"\nTry adding event.module or event.dataset to specify beats module\n\n{trailer}"
706+
elif "Field not recognized" in message and isinstance(schema, ecs.KqlSchema2Eql):
576707
text_fields = self.text_fields(schema)
577708
if text_fields:
578709
fields_str = ", ".join(text_fields)
@@ -586,7 +717,6 @@ def validate_query_with_schema(
586717
len(exc.caret.lstrip()),
587718
trailer=trailer,
588719
)
589-
590720
except Exception as exc: # noqa: BLE001
591721
print(err_trailer)
592722
return exc # type: ignore[reportReturnType]

0 commit comments

Comments
 (0)