Skip to content

Commit 6923891

Browse files
committed
Annotate types
1 parent e07fa16 commit 6923891

File tree

8 files changed

+259
-167
lines changed

8 files changed

+259
-167
lines changed

posthog/__init__.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from posthog.client import Client
66
from posthog.exception_capture import Integrations # noqa: F401
77
from posthog.version import VERSION
8-
8+
from posthog.types import FlagsAndPayloads
99
__version__ = VERSION
1010

1111
"""Settings."""
@@ -403,7 +403,7 @@ def get_feature_flag(
403403
only_evaluate_locally=False, # type: bool
404404
send_feature_flag_events=True, # type: bool
405405
disable_geoip=None, # type: Optional[bool]
406-
):
406+
) -> str | bool | None:
407407
"""
408408
Get feature flag variant for users. Used with experiments.
409409
Example:
@@ -446,7 +446,7 @@ def get_all_flags(
446446
group_properties={}, # type: dict
447447
only_evaluate_locally=False, # type: bool
448448
disable_geoip=None, # type: Optional[bool]
449-
):
449+
) -> dict[str, str | bool] | None:
450450
"""
451451
Get all flags for a given user.
452452
Example:
@@ -477,7 +477,7 @@ def get_feature_flag_payload(
477477
only_evaluate_locally=False,
478478
send_feature_flag_events=True,
479479
disable_geoip=None, # type: Optional[bool]
480-
):
480+
) -> str:
481481
return _proxy(
482482
"get_feature_flag_payload",
483483
key=key,
@@ -519,7 +519,7 @@ def get_all_flags_and_payloads(
519519
group_properties={},
520520
only_evaluate_locally=False,
521521
disable_geoip=None, # type: Optional[bool]
522-
):
522+
) -> FlagsAndPayloads:
523523
return _proxy(
524524
"get_all_flags_and_payloads",
525525
distinct_id=distinct_id,

posthog/client.py

Lines changed: 34 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,10 @@
1818
from posthog.exception_utils import exc_info_from_error, exceptions_from_error_tuple, handle_in_app
1919
from posthog.feature_flags import InconclusiveMatchError, match_feature_flag_properties
2020
from posthog.poller import Poller
21-
from posthog.request import DEFAULT_HOST, APIError, batch_post, decide, determine_server_host, get, remote_config, normalize_decide_response, DecideResponse
21+
from posthog.request import DEFAULT_HOST, APIError, batch_post, decide, determine_server_host, get, remote_config, DecideResponse
2222
from posthog.utils import SizeLimitedDict, clean, guess_timezone, remove_trailing_slash
2323
from posthog.version import VERSION
24-
24+
from posthog.types import FlagsAndPayloads, FlagValue, to_values, to_payloads, to_flags_and_payloads, normalize_decide_response
2525
try:
2626
import queue
2727
except ImportError:
@@ -239,26 +239,23 @@ def identify(self, distinct_id=None, properties=None, context=None, timestamp=No
239239

240240
def get_feature_variants(
241241
self, distinct_id, groups=None, person_properties=None, group_properties=None, disable_geoip=None
242-
):
242+
) -> dict[str, str | bool]:
243243
resp_data = self.get_decide(distinct_id, groups, person_properties, group_properties, disable_geoip)
244-
return self.to_variants(resp_data)
244+
return to_values(resp_data) or {}
245245

246246
def get_feature_payloads(
247247
self, distinct_id, groups=None, person_properties=None, group_properties=None, disable_geoip=None
248-
):
248+
) -> dict[str, str]:
249249
resp_data = self.get_decide(distinct_id, groups, person_properties, group_properties, disable_geoip)
250-
return self.to_payloads(resp_data)
250+
return to_payloads(resp_data) or {}
251251

252252
def get_feature_flags_and_payloads(
253253
self, distinct_id, groups=None, person_properties=None, group_properties=None, disable_geoip=None
254-
):
255-
resp_data = self.get_decide(distinct_id, groups, person_properties, group_properties, disable_geoip)
256-
return {
257-
"featureFlags": self.to_variants(resp_data),
258-
"featureFlagPayloads": self.to_payloads(resp_data),
259-
}
254+
) -> FlagsAndPayloads:
255+
resp = self.get_decide(distinct_id, groups, person_properties, group_properties, disable_geoip)
256+
return to_flags_and_payloads(resp)
260257

261-
def get_decide(self, distinct_id, groups=None, person_properties=None, group_properties=None, disable_geoip=None):
258+
def get_decide(self, distinct_id, groups=None, person_properties=None, group_properties=None, disable_geoip=None) -> DecideResponse:
262259
require("distinct_id", distinct_id, ID_TYPES)
263260

264261
if disable_geoip is None:
@@ -317,8 +314,8 @@ def capture(
317314
require("groups", groups, dict)
318315
msg["properties"]["$groups"] = groups
319316

320-
extra_properties = {}
321-
feature_variants = {}
317+
extra_properties: dict[str, Any] = {}
318+
feature_variants: dict[str, bool | str] | None = {}
322319
if send_feature_flags:
323320
try:
324321
feature_variants = self.get_feature_variants(distinct_id, groups, disable_geoip=disable_geoip)
@@ -331,10 +328,10 @@ def capture(
331328
distinct_id, groups=(groups or {}), disable_geoip=disable_geoip, only_evaluate_locally=True
332329
)
333330

334-
for feature, variant in feature_variants.items():
331+
for feature, variant in (feature_variants or {}).items():
335332
extra_properties[f"$feature/{feature}"] = variant
336333

337-
active_feature_flags = [key for (key, value) in feature_variants.items() if value is not False]
334+
active_feature_flags = [key for (key, value) in (feature_variants or {}).items() if value is not False]
338335
if active_feature_flags:
339336
extra_properties["$active_feature_flags"] = active_feature_flags
340337

@@ -711,7 +708,7 @@ def _compute_flag_locally(
711708
person_properties={},
712709
group_properties={},
713710
warn_on_unknown_groups=True,
714-
):
711+
) -> FlagValue:
715712
if feature_flag.get("ensure_experience_continuity", False):
716713
raise InconclusiveMatchError("Flag has experience continuity enabled")
717714

@@ -901,8 +898,13 @@ def get_feature_flag_payload(
901898
responses_and_payloads = self.get_feature_flags_and_payloads(
902899
distinct_id, groups, person_properties, group_properties, disable_geoip
903900
)
904-
response = responses_and_payloads["featureFlags"].get(key, None)
905-
payload = responses_and_payloads["featureFlagPayloads"].get(str(key), None)
901+
featureFlags = responses_and_payloads["featureFlags"]
902+
if featureFlags is not None:
903+
response = featureFlags.get(key, None)
904+
905+
featureFlagPayloads = responses_and_payloads["featureFlagPayloads"]
906+
if featureFlagPayloads is not None:
907+
payload = featureFlagPayloads.get(str(key), None)
906908
except Exception as e:
907909
self.log.exception(f"[FEATURE FLAGS] Unable to get feature flags and payloads: {e}")
908910

@@ -949,7 +951,7 @@ def get_remote_config_payload(self, key: str):
949951
except Exception as e:
950952
self.log.exception(f"[FEATURE FLAGS] Unable to get decrypted feature flag payload: {e}")
951953

952-
def _compute_payload_locally(self, key, match_value):
954+
def _compute_payload_locally(self, key: str, match_value: FlagValue) -> str | None:
953955
payload = None
954956

955957
if self.feature_flags_by_key is None:
@@ -974,7 +976,7 @@ def get_all_flags(
974976
group_properties={},
975977
only_evaluate_locally=False,
976978
disable_geoip=None,
977-
) -> dict[str, bool | str]:
979+
) -> dict[str, bool | str] | None:
978980
response = self.get_all_flags_and_payloads(
979981
distinct_id,
980982
groups=groups,
@@ -984,7 +986,7 @@ def get_all_flags(
984986
disable_geoip=disable_geoip,
985987
)
986988

987-
return self.to_variants(response)
989+
return response["featureFlags"]
988990

989991
def get_all_flags_and_payloads(
990992
self,
@@ -995,46 +997,45 @@ def get_all_flags_and_payloads(
995997
group_properties={},
996998
only_evaluate_locally=False,
997999
disable_geoip=None,
998-
) -> DecideResponse:
1000+
) -> FlagsAndPayloads:
9991001
if self.disabled:
10001002
return {"featureFlags": None, "featureFlagPayloads": None}
10011003

10021004
person_properties, group_properties = self._add_local_person_and_group_properties(
10031005
distinct_id, groups, person_properties, group_properties
10041006
)
10051007

1006-
flags, payloads, fallback_to_decide = self._get_all_flags_and_payloads_locally(
1008+
response, fallback_to_decide = self._get_all_flags_and_payloads_locally(
10071009
distinct_id, groups=groups, person_properties=person_properties, group_properties=group_properties
10081010
)
1009-
1010-
response = normalize_decide_response({"featureFlags": flags, "featureFlagPayloads": payloads})
10111011

10121012
if fallback_to_decide and not only_evaluate_locally:
10131013
try:
1014-
flags_and_payloads = self.get_decide(
1014+
decide_response = self.get_decide(
10151015
distinct_id,
10161016
groups=groups,
10171017
person_properties=person_properties,
10181018
group_properties=group_properties,
10191019
disable_geoip=disable_geoip,
10201020
)
1021-
response = flags_and_payloads
1021+
return to_flags_and_payloads(decide_response)
10221022
except Exception as e:
10231023
self.log.exception(f"[FEATURE FLAGS] Unable to get feature flags and payloads: {e}")
10241024

10251025
return response
10261026

1027+
10271028
def _get_all_flags_and_payloads_locally(
10281029
self, distinct_id, *, groups={}, person_properties={}, group_properties={}, warn_on_unknown_groups=False
1029-
):
1030+
) -> tuple[FlagsAndPayloads, bool]:
10301031
require("distinct_id", distinct_id, ID_TYPES)
10311032
require("groups", groups, dict)
10321033

10331034
if self.feature_flags is None and self.personal_api_key:
10341035
self.load_feature_flags()
10351036

1036-
flags = {}
1037-
payloads = {}
1037+
flags: dict[str, FlagValue] = {}
1038+
payloads: dict[str, str] = {}
10381039
fallback_to_decide = False
10391040
# If loading in previous line failed
10401041
if self.feature_flags:
@@ -1060,7 +1061,7 @@ def _get_all_flags_and_payloads_locally(
10601061
else:
10611062
fallback_to_decide = True
10621063

1063-
return flags, payloads, fallback_to_decide
1064+
return {"featureFlags": flags, "featureFlagPayloads": payloads}, fallback_to_decide
10641065

10651066
def feature_flag_definitions(self):
10661067
return self.feature_flags
@@ -1078,18 +1079,6 @@ def _add_local_person_and_group_properties(self, distinct_id, groups, person_pro
10781079

10791080
return all_person_properties, all_group_properties
10801081

1081-
def to_variants(self, response: any) -> dict[str, bool | str]:
1082-
if "flags" not in response:
1083-
return None
1084-
1085-
return {key: value.get_value() for key, value in response.get("flags", {}).items()}
1086-
1087-
def to_payloads(self, response: any) -> dict[str, str]:
1088-
if "flags" not in response:
1089-
return None
1090-
1091-
return {key: value.metadata.payload for key, value in response.get("flags", {}).items() if value.enabled}
1092-
10931082

10941083
def require(name, field, data_type):
10951084
"""Require that the named `field` has the right `data_type`"""

posthog/feature_flags.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from dateutil.relativedelta import relativedelta
99

1010
from posthog import utils
11+
from posthog.types import FlagValue
1112
from posthog.utils import convert_to_datetime_aware, is_valid_regex
1213

1314
__LONG_SCALE__ = float(0xFFFFFFFFFFFFFFF)
@@ -25,7 +26,7 @@ class InconclusiveMatchError(Exception):
2526
# Given the same distinct_id and key, it'll always return the same float. These floats are
2627
# uniformly distributed between 0 and 1, so if we want to show this feature to 20% of traffic
2728
# we can do _hash(key, distinct_id) < 0.2
28-
def _hash(key, distinct_id, salt=""):
29+
def _hash(key: str, distinct_id: str, salt: str = "") -> float:
2930
hash_key = f"{key}.{distinct_id}{salt}"
3031
hash_val = int(hashlib.sha1(hash_key.encode("utf-8")).hexdigest()[:15], 16)
3132
return hash_val / __LONG_SCALE__
@@ -50,7 +51,7 @@ def variant_lookup_table(feature_flag):
5051
return lookup_table
5152

5253

53-
def match_feature_flag_properties(flag, distinct_id, properties, cohort_properties=None):
54+
def match_feature_flag_properties(flag, distinct_id, properties, cohort_properties=None) -> FlagValue:
5455
flag_conditions = (flag.get("filters") or {}).get("groups") or []
5556
is_inconclusive = False
5657
cohort_properties = cohort_properties or {}
@@ -87,7 +88,7 @@ def match_feature_flag_properties(flag, distinct_id, properties, cohort_properti
8788
return False
8889

8990

90-
def is_condition_match(feature_flag, distinct_id, condition, properties, cohort_properties):
91+
def is_condition_match(feature_flag, distinct_id, condition, properties, cohort_properties) -> bool:
9192
rollout_percentage = condition.get("rollout_percentage")
9293
if len(condition.get("properties") or []) > 0:
9394
for prop in condition.get("properties"):

posthog/request.py

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import requests
99
from dateutil.tz import tzutc
1010

11-
from posthog.types import DecideResponse, FeatureFlag
11+
from posthog.types import DecideResponse, FeatureFlag, FlagValue
1212
from posthog.utils import remove_trailing_slash
1313
from posthog.version import VERSION
1414

@@ -97,20 +97,6 @@ def decide(api_key: str, host: Optional[str] = None, gzip: bool = False, timeout
9797
res = post(api_key, host, "/decide/?v=3", gzip, timeout, **kwargs)
9898
return _process_response(res, success_message="Feature flags decided successfully")
9999

100-
def normalize_decide_response(resp: any) -> DecideResponse:
101-
if "requestId" not in resp:
102-
resp["requestId"] = None
103-
if "flags" not in resp:
104-
featureFlags = resp.get("featureFlags", {})
105-
featureFlagPayloads = resp.get("featureFlagPayloads", {})
106-
resp.pop("featureFlags", None)
107-
resp.pop("featureFlagPayloads", None)
108-
# look at each key in featureFlags and create a FeatureFlag object
109-
flags = {}
110-
for key, value in featureFlags.items():
111-
flags[key] = FeatureFlag.from_value_and_payload(key, value, featureFlagPayloads.get(key, None))
112-
resp["flags"] = flags
113-
return cast(DecideResponse, resp)
114100

115101
def remote_config(personal_api_key: str, host: Optional[str] = None, key: str = "", timeout: int = 15) -> Any:
116102
"""Get remote config flag value from remote_config API endpoint"""

posthog/test/test_feature_flags.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -642,9 +642,8 @@ def test_get_all_flags_and_payloads_with_fallback(self, patch_decide, patch_capt
642642
},
643643
]
644644
# beta-feature value overridden by /decide
645-
flags_and_payloads = client.get_all_flags_and_payloads("distinct_id")
646645
self.assertEqual(
647-
client.to_payloads(flags_and_payloads),
646+
client.get_all_flags_and_payloads("distinct_id")["featureFlagPayloads"],
648647
{
649648
"beta-feature": 100,
650649
"beta-feature2": 300,
@@ -676,9 +675,8 @@ def test_get_all_flags_and_payloads_with_fallback_empty_local_flags(self, patch_
676675
client = self.client
677676
client.feature_flags = []
678677
# beta-feature value overridden by /decide
679-
flags_and_payloads = client.get_all_flags_and_payloads("distinct_id")
680678
self.assertEqual(
681-
client.to_payloads(flags_and_payloads),
679+
client.get_all_flags_and_payloads("distinct_id")["featureFlagPayloads"],
682680
{"beta-feature": 100, "beta-feature2": 300},
683681
)
684682
self.assertEqual(patch_decide.call_count, 1)
@@ -768,9 +766,9 @@ def test_get_all_flags_and_payloads_with_no_fallback(self, patch_decide, patch_c
768766
basic_flag,
769767
disabled_flag,
770768
]
771-
all_flags_and_payloads = client.get_all_flags_and_payloads("distinct_id")
772769
self.assertEqual(
773-
client.to_payloads(all_flags_and_payloads), {"beta-feature": "new"}
770+
client.get_all_flags_and_payloads("distinct_id")["featureFlagPayloads"],
771+
{"beta-feature": "new"}
774772
)
775773
# decide not called because this can be evaluated locally
776774
self.assertEqual(patch_decide.call_count, 0)
@@ -900,9 +898,8 @@ def test_get_all_flags_and_payloads_with_fallback_but_only_local_evaluation_set(
900898
flag_3,
901899
]
902900
# beta-feature2 has no value
903-
flags_and_payloads = client.get_all_flags_and_payloads("distinct_id", only_evaluate_locally=True)
904901
self.assertEqual(
905-
client.to_payloads(flags_and_payloads),
902+
client.get_all_flags_and_payloads("distinct_id", only_evaluate_locally=True)["featureFlagPayloads"],
906903
{"beta-feature": "some-payload"},
907904
)
908905
self.assertEqual(patch_decide.call_count, 0)

0 commit comments

Comments
 (0)