Skip to content

Commit 1670855

Browse files
[Bug] Add Required to the Annotation (#5159)
* Add Required to the Annotation * Additional required fields * remove nonempty sting validation * Required Types via Annotated and Dataclass * remove space * Remove inline comment * Switch to getting a list * Fix typo and sort --------- Co-authored-by: Mika Ayenson, PhD <[email protected]> (cherry picked from commit 42be8bc)
1 parent 6692cba commit 1670855

File tree

3 files changed

+28
-4
lines changed

3 files changed

+28
-4
lines changed

detection_rules/cli_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from . import ecs
1919
from .attack import build_threat_map_entry, matrix, tactics
2020
from .config import parse_rules_config
21+
from .mixins import get_dataclass_required_fields
2122
from .rule import BYPASS_VERSION_LOCK, TOMLRule, TOMLRuleContents
2223
from .rule_loader import DEFAULT_PREBUILT_BBR_DIRS, DEFAULT_PREBUILT_RULES_DIRS, RuleCollection, dict_filter
2324
from .schemas import definitions
@@ -166,9 +167,10 @@ def rule_prompt( # noqa: PLR0912, PLR0913, PLR0915
166167
)
167168

168169
target_data_subclass = TOMLRuleContents.get_data_subclass(rule_type_val)
170+
required_fields = get_dataclass_required_fields(target_data_subclass)
169171
schema = target_data_subclass.jsonschema()
170172
props = schema["properties"]
171-
required_fields = schema.get("required", []) + additional_required
173+
required_fields = sorted(required_fields + additional_required)
172174
contents: dict[str, Any] = {}
173175
skipped: list[str] = []
174176

detection_rules/mixins.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,14 @@
88
import dataclasses
99
import json
1010
from pathlib import Path
11-
from typing import Any, Literal
11+
from typing import Any, Literal, get_type_hints
1212

1313
import marshmallow
1414
import marshmallow_dataclass
1515
import marshmallow_dataclass.union_field
1616
import marshmallow_jsonschema # type: ignore[reportMissingTypeStubs]
1717
import marshmallow_union # type: ignore[reportMissingTypeStubs]
18+
import typing_inspect # type: ignore[reportMissingTypeStubs]
1819
from marshmallow import Schema, ValidationError, validates_schema
1920
from marshmallow import fields as marshmallow_fields
2021
from semver import Version
@@ -38,6 +39,28 @@ def _strip_none_from_dict(obj: Any) -> Any:
3839
return obj
3940

4041

42+
def get_dataclass_required_fields(cls: Any) -> list[str]:
43+
"""Get required fields based on both dataclass and type Annotations."""
44+
required_fields: list[str] = []
45+
hints = get_type_hints(cls, include_extras=True)
46+
marshmallow_schema = marshmallow_dataclass.class_schema(cls)()
47+
for dc_field in dataclasses.fields(cls):
48+
hint = hints.get(dc_field.name)
49+
if not hint:
50+
continue
51+
52+
mm_field = marshmallow_schema.fields.get(dc_field.name)
53+
if mm_field is None:
54+
continue
55+
if dc_field.default is not dataclasses.MISSING:
56+
continue
57+
if getattr(dc_field, "default_factory", dataclasses.MISSING) is not dataclasses.MISSING:
58+
continue
59+
if not typing_inspect.is_optional_type(hint) or mm_field.required is True: # type: ignore[reportUnknownVariableType]
60+
required_fields.append(dc_field.name)
61+
return required_fields
62+
63+
4164
def patch_jsonschema(obj: Any) -> dict[str, Any]:
4265
"""Patch marshmallow-jsonschema output to look more like JSL."""
4366

@@ -264,5 +287,4 @@ def _get_schema_for_field(self, obj: Any, field: Any) -> Any:
264287
default=field.default, # type: ignore[reportUnknownMemberType]
265288
allow_none=field.allow_none,
266289
)
267-
268290
return super()._get_schema_for_field(obj, field) # type: ignore[reportUnknownMemberType]

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "detection_rules"
3-
version = "1.4.6"
3+
version = "1.4.7"
44
description = "Detection Rules is the home for rules used by Elastic Security. This repository is used for the development, maintenance, testing, validation, and release of rules for Elastic Security’s Detection Engine."
55
readme = "README.md"
66
requires-python = ">=3.12"

0 commit comments

Comments
 (0)