Skip to content

Commit b6dbff1

Browse files
authored
feat: Add FlagDefinitionCacheProvider interface (#387)
* feat: Add FlagDefinitionCacheProvider interface * feat: Add a Redis example for FlagDefinitionCacheProvider * style: ruff format * style: ruff format * refactor: clean up mypy errors * chore: mypy-baseline sync * fix: Type Redis as Redis[str] * fix: Adhere to strict typing The defined types don't leave room for missing or optional keys. We'll use the types as they're defined.
1 parent 9f8faf7 commit b6dbff1

File tree

7 files changed

+990
-25
lines changed

7 files changed

+990
-25
lines changed

examples/redis_flag_cache.py

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
"""
2+
Redis-based distributed cache for PostHog feature flag definitions.
3+
4+
This example demonstrates how to implement a FlagDefinitionCacheProvider
5+
using Redis for multi-instance deployments (leader election pattern).
6+
7+
Usage:
8+
import redis
9+
from posthog import Posthog
10+
11+
redis_client = redis.Redis(host='localhost', port=6379, decode_responses=True)
12+
cache = RedisFlagCache(redis_client, service_key="my-service")
13+
14+
posthog = Posthog(
15+
"<project_api_key>",
16+
personal_api_key="<personal_api_key>",
17+
flag_definition_cache_provider=cache,
18+
)
19+
20+
Requirements:
21+
pip install redis
22+
"""
23+
24+
import json
25+
import uuid
26+
27+
from posthog import FlagDefinitionCacheData, FlagDefinitionCacheProvider
28+
from redis import Redis
29+
from typing import Optional
30+
31+
32+
class RedisFlagCache(FlagDefinitionCacheProvider):
33+
"""
34+
A distributed cache for PostHog feature flag definitions using Redis.
35+
36+
In a multi-instance deployment (e.g., multiple serverless functions or containers),
37+
we want only ONE instance to poll PostHog for flag updates, while all instances
38+
share the cached results. This prevents N instances from making N redundant API calls.
39+
40+
The implementation uses leader election:
41+
- One instance "wins" and becomes responsible for fetching
42+
- Other instances read from the shared cache
43+
- If the leader dies, the lock expires (TTL) and another instance takes over
44+
45+
Uses Lua scripts for atomic operations, following Redis distributed lock best practices:
46+
https://redis.io/docs/latest/develop/clients/patterns/distributed-locks/
47+
"""
48+
49+
LOCK_TTL_MS = 60 * 1000 # 60 seconds, should be longer than the flags poll interval
50+
CACHE_TTL_SECONDS = 60 * 60 * 24 # 24 hours
51+
52+
# Lua script: acquire lock if free, or extend if we own it
53+
_LUA_TRY_LEAD = """
54+
local current = redis.call('GET', KEYS[1])
55+
if current == false then
56+
redis.call('SET', KEYS[1], ARGV[1], 'PX', ARGV[2])
57+
return 1
58+
elseif current == ARGV[1] then
59+
redis.call('PEXPIRE', KEYS[1], ARGV[2])
60+
return 1
61+
end
62+
return 0
63+
"""
64+
65+
# Lua script: release lock only if we own it
66+
_LUA_STOP_LEAD = """
67+
if redis.call('GET', KEYS[1]) == ARGV[1] then
68+
return redis.call('DEL', KEYS[1])
69+
end
70+
return 0
71+
"""
72+
73+
def __init__(self, redis: Redis[str], service_key: str):
74+
"""
75+
Initialize the Redis flag cache.
76+
77+
Args:
78+
redis: A redis-py client instance. Must be configured with
79+
decode_responses=True for correct string handling.
80+
service_key: A unique identifier for this service/environment.
81+
Used to scope Redis keys, allowing multiple services
82+
or environments to share the same Redis instance.
83+
Examples: "my-api-prod", "checkout-service", "staging".
84+
85+
Redis Keys Created:
86+
- posthog:flags:{service_key} - Cached flag definitions (JSON)
87+
- posthog:flags:{service_key}:lock - Leader election lock
88+
89+
Example:
90+
redis_client = redis.Redis(
91+
host='localhost',
92+
port=6379,
93+
decode_responses=True
94+
)
95+
cache = RedisFlagCache(redis_client, service_key="my-api-prod")
96+
"""
97+
self._redis = redis
98+
self._cache_key = f"posthog:flags:{service_key}"
99+
self._lock_key = f"posthog:flags:{service_key}:lock"
100+
self._instance_id = str(uuid.uuid4())
101+
self._try_lead = self._redis.register_script(self._LUA_TRY_LEAD)
102+
self._stop_lead = self._redis.register_script(self._LUA_STOP_LEAD)
103+
104+
def get_flag_definitions(self) -> Optional[FlagDefinitionCacheData]:
105+
"""
106+
Retrieve cached flag definitions from Redis.
107+
108+
Returns:
109+
Cached flag definitions if available, None otherwise.
110+
"""
111+
cached = self._redis.get(self._cache_key)
112+
return json.loads(cached) if cached else None
113+
114+
def should_fetch_flag_definitions(self) -> bool:
115+
"""
116+
Determines if this instance should fetch flag definitions from PostHog.
117+
118+
Atomically either:
119+
- Acquires the lock if no one holds it, OR
120+
- Extends the lock TTL if we already hold it
121+
122+
Returns:
123+
True if this instance is the leader and should fetch, False otherwise.
124+
"""
125+
result = self._try_lead(
126+
keys=[self._lock_key],
127+
args=[self._instance_id, self.LOCK_TTL_MS],
128+
)
129+
return result == 1
130+
131+
def on_flag_definitions_received(self, data: FlagDefinitionCacheData) -> None:
132+
"""
133+
Store fetched flag definitions in Redis.
134+
135+
Args:
136+
data: The flag definitions to cache.
137+
"""
138+
self._redis.set(self._cache_key, json.dumps(data), ex=self.CACHE_TTL_SECONDS)
139+
140+
def shutdown(self) -> None:
141+
"""
142+
Release leadership if we hold it. Safe to call even if not the leader.
143+
"""
144+
self._stop_lead(keys=[self._lock_key], args=[self._instance_id])

mypy-baseline.txt

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,9 @@ posthog/client.py:0: error: Incompatible types in assignment (expression has typ
2626
posthog/client.py:0: error: Incompatible types in assignment (expression has type "dict[Any, Any]", variable has type "None") [assignment]
2727
posthog/client.py:0: error: "None" has no attribute "__iter__" (not iterable) [attr-defined]
2828
posthog/client.py:0: error: Statement is unreachable [unreachable]
29-
posthog/client.py:0: error: Incompatible types in assignment (expression has type "Any | dict[Any, Any]", variable has type "None") [assignment]
30-
posthog/client.py:0: error: Incompatible types in assignment (expression has type "Any | dict[Any, Any]", variable has type "None") [assignment]
31-
posthog/client.py:0: error: Incompatible types in assignment (expression has type "dict[Never, Never]", variable has type "None") [assignment]
32-
posthog/client.py:0: error: Incompatible types in assignment (expression has type "dict[Never, Never]", variable has type "None") [assignment]
3329
posthog/client.py:0: error: Right operand of "and" is never evaluated [unreachable]
3430
posthog/client.py:0: error: Incompatible types in assignment (expression has type "Poller", variable has type "None") [assignment]
3531
posthog/client.py:0: error: "None" has no attribute "start" [attr-defined]
36-
posthog/client.py:0: error: "None" has no attribute "get" [attr-defined]
3732
posthog/client.py:0: error: Statement is unreachable [unreachable]
3833
posthog/client.py:0: error: Statement is unreachable [unreachable]
3934
posthog/client.py:0: error: Name "urlparse" already defined (possibly by an import) [no-redef]

posthog/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@
2222
InconclusiveMatchError as InconclusiveMatchError,
2323
RequiresServerEvaluation as RequiresServerEvaluation,
2424
)
25+
from posthog.flag_definition_cache import (
26+
FlagDefinitionCacheData as FlagDefinitionCacheData,
27+
FlagDefinitionCacheProvider as FlagDefinitionCacheProvider,
28+
)
2529
from posthog.request import (
2630
disable_connection_reuse as disable_connection_reuse,
2731
enable_keep_alive as enable_keep_alive,

posthog/client.py

Lines changed: 103 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@
2828
RequiresServerEvaluation,
2929
match_feature_flag_properties,
3030
)
31+
from posthog.flag_definition_cache import (
32+
FlagDefinitionCacheData,
33+
FlagDefinitionCacheProvider,
34+
)
3135
from posthog.poller import Poller
3236
from posthog.request import (
3337
DEFAULT_HOST,
@@ -184,6 +188,7 @@ def __init__(
184188
before_send=None,
185189
flag_fallback_cache_url=None,
186190
enable_local_evaluation=True,
191+
flag_definition_cache_provider: Optional[FlagDefinitionCacheProvider] = None,
187192
capture_exception_code_variables=False,
188193
code_variables_mask_patterns=None,
189194
code_variables_ignore_patterns=None,
@@ -222,8 +227,8 @@ def __init__(
222227
self.timeout = timeout
223228
self._feature_flags = None # private variable to store flags
224229
self.feature_flags_by_key = None
225-
self.group_type_mapping = None
226-
self.cohorts = None
230+
self.group_type_mapping: Optional[dict[str, str]] = None
231+
self.cohorts: Optional[dict[str, Any]] = None
227232
self.poll_interval = poll_interval
228233
self.feature_flags_request_timeout_seconds = (
229234
feature_flags_request_timeout_seconds
@@ -233,6 +238,7 @@ def __init__(
233238
self.flag_cache = self._initialize_flag_cache(flag_fallback_cache_url)
234239
self.flag_definition_version = 0
235240
self._flags_etag: Optional[str] = None
241+
self._flag_definition_cache_provider = flag_definition_cache_provider
236242
self.disabled = disabled
237243
self.disable_geoip = disable_geoip
238244
self.historical_migration = historical_migration
@@ -1165,17 +1171,25 @@ def join(self):
11651171
posthog.join()
11661172
```
11671173
"""
1168-
for consumer in self.consumers:
1169-
consumer.pause()
1170-
try:
1171-
consumer.join()
1172-
except RuntimeError:
1173-
# consumer thread has not started
1174-
pass
1174+
if self.consumers:
1175+
for consumer in self.consumers:
1176+
consumer.pause()
1177+
try:
1178+
consumer.join()
1179+
except RuntimeError:
1180+
# consumer thread has not started
1181+
pass
11751182

11761183
if self.poller:
11771184
self.poller.stop()
11781185

1186+
# Shutdown the cache provider (release locks, cleanup)
1187+
if self._flag_definition_cache_provider:
1188+
try:
1189+
self._flag_definition_cache_provider.shutdown()
1190+
except Exception as e:
1191+
self.log.error(f"[FEATURE FLAGS] Cache provider shutdown error: {e}")
1192+
11791193
def shutdown(self):
11801194
"""
11811195
Flush all messages and cleanly shutdown the client. Call this before the process ends in serverless environments to avoid data loss.
@@ -1191,7 +1205,71 @@ def shutdown(self):
11911205
if self.exception_capture:
11921206
self.exception_capture.close()
11931207

1208+
def _update_flag_state(
1209+
self, data: FlagDefinitionCacheData, old_flags_by_key: Optional[dict] = None
1210+
) -> None:
1211+
"""Update internal flag state from cache data and invalidate evaluation cache if changed."""
1212+
self.feature_flags = data["flags"]
1213+
self.group_type_mapping = data["group_type_mapping"]
1214+
self.cohorts = data["cohorts"]
1215+
1216+
# Invalidate evaluation cache if flag definitions changed
1217+
if (
1218+
self.flag_cache
1219+
and old_flags_by_key is not None
1220+
and old_flags_by_key != (self.feature_flags_by_key or {})
1221+
):
1222+
old_version = self.flag_definition_version
1223+
self.flag_definition_version += 1
1224+
self.flag_cache.invalidate_version(old_version)
1225+
11941226
def _load_feature_flags(self):
1227+
should_fetch = True
1228+
if self._flag_definition_cache_provider:
1229+
try:
1230+
should_fetch = (
1231+
self._flag_definition_cache_provider.should_fetch_flag_definitions()
1232+
)
1233+
except Exception as e:
1234+
self.log.error(
1235+
f"[FEATURE FLAGS] Cache provider should_fetch error: {e}"
1236+
)
1237+
# Fail-safe: fetch from API if cache provider errors
1238+
should_fetch = True
1239+
1240+
# If not fetching, try to get from cache
1241+
if not should_fetch and self._flag_definition_cache_provider:
1242+
try:
1243+
cached_data = (
1244+
self._flag_definition_cache_provider.get_flag_definitions()
1245+
)
1246+
if cached_data:
1247+
self.log.debug(
1248+
"[FEATURE FLAGS] Using cached flag definitions from external cache"
1249+
)
1250+
self._update_flag_state(
1251+
cached_data, old_flags_by_key=self.feature_flags_by_key or {}
1252+
)
1253+
self._last_feature_flag_poll = datetime.now(tz=tzutc())
1254+
return
1255+
else:
1256+
# Emergency fallback: if cache is empty and we have no flags, fetch anyway.
1257+
# There's really no other way of recovering in this case.
1258+
if not self.feature_flags:
1259+
self.log.debug(
1260+
"[FEATURE FLAGS] Cache empty and no flags loaded, falling back to API fetch"
1261+
)
1262+
should_fetch = True
1263+
except Exception as e:
1264+
self.log.error(f"[FEATURE FLAGS] Cache provider get error: {e}")
1265+
# Fail-safe: fetch from API if cache provider errors
1266+
should_fetch = True
1267+
1268+
if should_fetch:
1269+
self._fetch_feature_flags_from_api()
1270+
1271+
def _fetch_feature_flags_from_api(self):
1272+
"""Fetch feature flags from the PostHog API."""
11951273
try:
11961274
# Store old flags to detect changes
11971275
old_flags_by_key: dict[str, dict] = self.feature_flags_by_key or {}
@@ -1221,17 +1299,21 @@ def _load_feature_flags(self):
12211299
)
12221300
return
12231301

1224-
self.feature_flags = response.data["flags"] or []
1225-
self.group_type_mapping = response.data["group_type_mapping"] or {}
1226-
self.cohorts = response.data["cohorts"] or {}
1302+
self._update_flag_state(response.data, old_flags_by_key=old_flags_by_key)
12271303

1228-
# Check if flag definitions changed and update version
1229-
if self.flag_cache and old_flags_by_key != (
1230-
self.feature_flags_by_key or {}
1231-
):
1232-
old_version = self.flag_definition_version
1233-
self.flag_definition_version += 1
1234-
self.flag_cache.invalidate_version(old_version)
1304+
# Store in external cache if provider is configured
1305+
if self._flag_definition_cache_provider:
1306+
try:
1307+
self._flag_definition_cache_provider.on_flag_definitions_received(
1308+
{
1309+
"flags": self.feature_flags or [],
1310+
"group_type_mapping": self.group_type_mapping or {},
1311+
"cohorts": self.cohorts or {},
1312+
}
1313+
)
1314+
except Exception as e:
1315+
self.log.error(f"[FEATURE FLAGS] Cache provider store error: {e}")
1316+
# Flags are already in memory, so continue normally
12351317

12361318
except APIError as e:
12371319
if e.status == 401:
@@ -1331,7 +1413,8 @@ def _compute_flag_locally(
13311413
flag_filters = feature_flag.get("filters") or {}
13321414
aggregation_group_type_index = flag_filters.get("aggregation_group_type_index")
13331415
if aggregation_group_type_index is not None:
1334-
group_name = self.group_type_mapping.get(str(aggregation_group_type_index))
1416+
group_type_mapping = self.group_type_mapping or {}
1417+
group_name = group_type_mapping.get(str(aggregation_group_type_index))
13351418

13361419
if not group_name:
13371420
self.log.warning(

0 commit comments

Comments
 (0)