Skip to content

Commit 53b867a

Browse files
committed
typing fixes
1 parent a05ab13 commit 53b867a

File tree

3 files changed

+59
-12
lines changed

3 files changed

+59
-12
lines changed

flag_engine/segments/evaluator.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,12 @@
2020
)
2121
from flag_engine.result.types import EvaluationResult, FlagResult, SegmentResult
2222
from flag_engine.segments import constants
23-
from flag_engine.segments.types import ConditionOperator, ContextValue, is_context_value
23+
from flag_engine.segments.types import (
24+
ConditionOperator,
25+
ContextValue,
26+
MetadataT,
27+
is_context_value,
28+
)
2429
from flag_engine.segments.utils import escape_double_quotes, get_matching_function
2530
from flag_engine.utils.hashing import get_hashed_percentage_for_object_ids
2631
from flag_engine.utils.semver import is_semver
@@ -32,14 +37,16 @@ class FeatureContextWithSegmentName(typing.TypedDict):
3237
segment_name: str
3338

3439

35-
def get_evaluation_result(context: EvaluationContext) -> EvaluationResult:
40+
def get_evaluation_result(
41+
context: EvaluationContext[MetadataT],
42+
) -> EvaluationResult[MetadataT]:
3643
"""
3744
Get the evaluation result for a given context.
3845
3946
:param context: the evaluation context
4047
:return: EvaluationResult containing the context, flags, and segments
4148
"""
42-
segments: list[SegmentResult] = []
49+
segments: list[SegmentResult[MetadataT]] = []
4350
flags: dict[str, FlagResult] = {}
4451

4552
segment_feature_contexts: dict[SupportsStr, FeatureContextWithSegmentName] = {}
@@ -48,7 +55,7 @@ def get_evaluation_result(context: EvaluationContext) -> EvaluationResult:
4855
if not is_context_in_segment(context, segment_context):
4956
continue
5057

51-
segment_result: SegmentResult = {
58+
segment_result: SegmentResult[MetadataT] = {
5259
"key": segment_context["key"],
5360
"name": segment_context["name"],
5461
}
@@ -152,8 +159,8 @@ def get_flag_result_from_feature_context(
152159

153160

154161
def is_context_in_segment(
155-
context: EvaluationContext,
156-
segment_context: SegmentContext,
162+
context: EvaluationContext[MetadataT],
163+
segment_context: SegmentContext[MetadataT],
157164
) -> bool:
158165
return bool(rules := segment_context["rules"]) and all(
159166
context_matches_rule(
@@ -164,7 +171,7 @@ def is_context_in_segment(
164171

165172

166173
def context_matches_rule(
167-
context: EvaluationContext,
174+
context: EvaluationContext[MetadataT],
168175
rule: SegmentRule,
169176
segment_key: SupportsStr,
170177
) -> bool:
@@ -194,7 +201,7 @@ def context_matches_rule(
194201

195202

196203
def context_matches_condition(
197-
context: EvaluationContext,
204+
context: EvaluationContext[MetadataT],
198205
condition: SegmentCondition,
199206
segment_key: SupportsStr,
200207
) -> bool:
@@ -255,7 +262,7 @@ def context_matches_condition(
255262

256263

257264
def get_context_value(
258-
context: EvaluationContext,
265+
context: EvaluationContext[MetadataT],
259266
property: str,
260267
) -> ContextValue:
261268
value = None
@@ -353,7 +360,7 @@ def inner(
353360
@lru_cache
354361
def _get_context_value_getter(
355362
property: str,
356-
) -> typing.Callable[[EvaluationContext], ContextValue]:
363+
) -> typing.Callable[[EvaluationContext[MetadataT]], ContextValue]:
357364
"""
358365
Get a function to retrieve a context value based on property value,
359366
assumed to be either a JSONPath string or a trait key.
@@ -370,7 +377,7 @@ def _get_context_value_getter(
370377
f'$.identity.traits["{escape_double_quotes(property)}"]',
371378
)
372379

373-
def getter(context: EvaluationContext) -> ContextValue:
380+
def getter(context: EvaluationContext[MetadataT]) -> ContextValue:
374381
if typing.TYPE_CHECKING: # pragma: no cover
375382
# Ugly hack to satisfy mypy :(
376383
data = dict(context)

flag_engine/segments/types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from typing_extensions import TypeGuard, TypeVar
66

7-
MetadataT = TypeVar("MetadataT", default=Dict[str, Any])
7+
MetadataT = TypeVar("MetadataT", default=Dict[str, object])
88

99
ConditionOperator = Literal[
1010
"EQUAL",

tests/unit/test_engine.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import json
2+
from typing import TypedDict
23

34
from flag_engine.context.types import EvaluationContext, IdentityContext, SegmentContext
45
from flag_engine.engine import get_evaluation_result
@@ -357,3 +358,42 @@ def test_get_evaluation_result__segment_override__no_priority__returns_expected(
357358
{"key": "3", "name": "another_segment"},
358359
],
359360
}
361+
362+
363+
def test_segment_metadata_generic_type__returns_expected() -> None:
364+
# Given
365+
class CustomMetadata(TypedDict):
366+
foo: str
367+
bar: int
368+
369+
segment_metadata = CustomMetadata(foo="hello", bar=123)
370+
371+
evaluation_context: EvaluationContext[CustomMetadata] = {
372+
"environment": {"key": "api-key", "name": ""},
373+
"segments": {
374+
"1": {
375+
"key": "1",
376+
"name": "my_segment",
377+
"rules": [
378+
{
379+
"type": "ALL",
380+
"conditions": [
381+
{
382+
"property": "$.environment.name",
383+
"operator": "EQUAL",
384+
"value": "",
385+
}
386+
],
387+
"rules": [],
388+
}
389+
],
390+
"metadata": segment_metadata,
391+
},
392+
},
393+
}
394+
395+
# When
396+
result = get_evaluation_result(evaluation_context)
397+
398+
# Then
399+
assert result["segments"][0]["metadata"] is segment_metadata

0 commit comments

Comments
 (0)