Skip to content

Commit a9ee32c

Browse files
Required Types via Annotated and Dataclass
1 parent 4ddb675 commit a9ee32c

File tree

3 files changed

+29
-4
lines changed

3 files changed

+29
-4
lines changed

detection_rules/cli_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from . import ecs
1919
from .attack import build_threat_map_entry, matrix, tactics
2020
from .config import parse_rules_config
21+
from .mixins import enforce_required_fields
2122
from .rule import BYPASS_VERSION_LOCK, TOMLRule, TOMLRuleContents
2223
from .rule_loader import DEFAULT_PREBUILT_BBR_DIRS, DEFAULT_PREBUILT_RULES_DIRS, RuleCollection, dict_filter
2324
from .schemas import definitions
@@ -166,6 +167,8 @@ def rule_prompt( # noqa: PLR0912, PLR0913, PLR0915
166167
)
167168

168169
target_data_subclass = TOMLRuleContents.get_data_subclass(rule_type_val)
170+
171+
enforce_required_fields(target_data_subclass)
169172
schema = target_data_subclass.jsonschema()
170173
props = schema["properties"]
171174
required_fields = schema.get("required", []) + additional_required

detection_rules/mixins.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,14 @@
88
import dataclasses
99
import json
1010
from pathlib import Path
11-
from typing import Any, Literal
11+
from typing import Any, Literal, get_type_hints
1212

1313
import marshmallow
1414
import marshmallow_dataclass
1515
import marshmallow_dataclass.union_field
1616
import marshmallow_jsonschema # type: ignore[reportMissingTypeStubs]
1717
import marshmallow_union # type: ignore[reportMissingTypeStubs]
18+
import typing_inspect # type: ignore[reportMissingTypeStubs]
1819
from marshmallow import Schema, ValidationError, validates_schema
1920
from marshmallow import fields as marshmallow_fields
2021
from semver import Version
@@ -38,6 +39,27 @@ def _strip_none_from_dict(obj: Any) -> Any:
3839
return obj
3940

4041

42+
def enforce_required_fields(cls: Any) -> None:
43+
"""Enforce required fields based on both dataclass and type Annotations."""
44+
hints = get_type_hints(cls, include_extras=True)
45+
marshmallow_schema = marshmallow_dataclass.class_schema(cls)()
46+
for dc_field in dataclasses.fields(cls):
47+
mm_field = marshmallow_schema.fields.get(dc_field.name)
48+
if mm_field is None or mm_field.required:
49+
continue
50+
51+
if dc_field.default is not dataclasses.MISSING:
52+
continue
53+
if getattr(dc_field, "default_factory", dataclasses.MISSING) is not dataclasses.MISSING:
54+
continue
55+
56+
hint = hints.get(dc_field.name)
57+
if hint is not None and typing_inspect.is_optional_type(hint): # type: ignore[reportMissingTypeStubs]
58+
continue
59+
60+
mm_field.required = True # or set to some new list that is required_fields
61+
62+
4163
def patch_jsonschema(obj: Any) -> dict[str, Any]:
4264
"""Patch marshmallow-jsonschema output to look more like JSL."""
4365

detection_rules/schemas/definitions.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -243,8 +243,8 @@ def validator_wrapper(value: Any) -> Any:
243243
list[NonEmptyStr], fields.List(NON_EMPTY_STRING_FIELD, validate=validate.Length(min=1, max=3))
244244
]
245245
PositiveInteger = Annotated[int, fields.Integer(validate=validate.Range(min=1))]
246-
RiskScore = Annotated[int, fields.Integer(validate=validate.Range(min=1, max=100), required=True)]
247-
RuleName = Annotated[str, fields.String(validate=elastic_rule_name_regexp(NAME_PATTERN), required=True)]
246+
RiskScore = Annotated[int, fields.Integer(validate=validate.Range(min=1, max=100))]
247+
RuleName = Annotated[str, fields.String(validate=elastic_rule_name_regexp(NAME_PATTERN))]
248248
SemVer = Annotated[str, fields.String(validate=validate.Regexp(VERSION_PATTERN))]
249249
SemVerMinorOnly = Annotated[str, fields.String(validate=validate.Regexp(MINOR_SEMVER))]
250250
Sha256 = Annotated[str, fields.String(validate=validate.Regexp(SHA256_PATTERN))]
@@ -254,7 +254,7 @@ def validator_wrapper(value: Any) -> Any:
254254
ThresholdValue = Annotated[int, fields.Integer(validate=validate.Range(min=1))]
255255
TimelineTemplateId = Annotated[str, fields.String(validate=elastic_timeline_template_id_validator())]
256256
TimelineTemplateTitle = Annotated[str, fields.String(validate=elastic_timeline_template_title_validator())]
257-
UUIDString = Annotated[str, fields.String(validate=validate.Regexp(UUID_PATTERN), required=True)]
257+
UUIDString = Annotated[str, fields.String(validate=validate.Regexp(UUID_PATTERN))]
258258

259259
# experimental machine learning features and releases
260260
MachineLearningType = Literal[MACHINE_LEARNING_PACKAGES]

0 commit comments

Comments
 (0)