2828 RequiresServerEvaluation ,
2929 match_feature_flag_properties ,
3030)
31+ from posthog .flag_definition_cache import (
32+ FlagDefinitionCacheData ,
33+ FlagDefinitionCacheProvider ,
34+ )
3135from posthog .poller import Poller
3236from posthog .request import (
3337 DEFAULT_HOST ,
@@ -184,6 +188,7 @@ def __init__(
184188 before_send = None ,
185189 flag_fallback_cache_url = None ,
186190 enable_local_evaluation = True ,
191+ flag_definition_cache_provider : Optional [FlagDefinitionCacheProvider ] = None ,
187192 capture_exception_code_variables = False ,
188193 code_variables_mask_patterns = None ,
189194 code_variables_ignore_patterns = None ,
@@ -222,8 +227,8 @@ def __init__(
222227 self .timeout = timeout
223228 self ._feature_flags = None # private variable to store flags
224229 self .feature_flags_by_key = None
225- self .group_type_mapping = None
226- self .cohorts = None
230+ self .group_type_mapping : Optional [ dict [ str , str ]] = None
231+ self .cohorts : Optional [ dict [ str , Any ]] = None
227232 self .poll_interval = poll_interval
228233 self .feature_flags_request_timeout_seconds = (
229234 feature_flags_request_timeout_seconds
@@ -233,6 +238,7 @@ def __init__(
233238 self .flag_cache = self ._initialize_flag_cache (flag_fallback_cache_url )
234239 self .flag_definition_version = 0
235240 self ._flags_etag : Optional [str ] = None
241+ self ._flag_definition_cache_provider = flag_definition_cache_provider
236242 self .disabled = disabled
237243 self .disable_geoip = disable_geoip
238244 self .historical_migration = historical_migration
@@ -1165,17 +1171,25 @@ def join(self):
11651171 posthog.join()
11661172 ```
11671173 """
1168- for consumer in self .consumers :
1169- consumer .pause ()
1170- try :
1171- consumer .join ()
1172- except RuntimeError :
1173- # consumer thread has not started
1174- pass
1174+ if self .consumers :
1175+ for consumer in self .consumers :
1176+ consumer .pause ()
1177+ try :
1178+ consumer .join ()
1179+ except RuntimeError :
1180+ # consumer thread has not started
1181+ pass
11751182
11761183 if self .poller :
11771184 self .poller .stop ()
11781185
1186+ # Shutdown the cache provider (release locks, cleanup)
1187+ if self ._flag_definition_cache_provider :
1188+ try :
1189+ self ._flag_definition_cache_provider .shutdown ()
1190+ except Exception as e :
1191+ self .log .error (f"[FEATURE FLAGS] Cache provider shutdown error: { e } " )
1192+
11791193 def shutdown (self ):
11801194 """
11811195 Flush all messages and cleanly shutdown the client. Call this before the process ends in serverless environments to avoid data loss.
@@ -1191,7 +1205,71 @@ def shutdown(self):
11911205 if self .exception_capture :
11921206 self .exception_capture .close ()
11931207
1208+ def _update_flag_state (
1209+ self , data : FlagDefinitionCacheData , old_flags_by_key : Optional [dict ] = None
1210+ ) -> None :
1211+ """Update internal flag state from cache data and invalidate evaluation cache if changed."""
1212+ self .feature_flags = data ["flags" ]
1213+ self .group_type_mapping = data ["group_type_mapping" ]
1214+ self .cohorts = data ["cohorts" ]
1215+
1216+ # Invalidate evaluation cache if flag definitions changed
1217+ if (
1218+ self .flag_cache
1219+ and old_flags_by_key is not None
1220+ and old_flags_by_key != (self .feature_flags_by_key or {})
1221+ ):
1222+ old_version = self .flag_definition_version
1223+ self .flag_definition_version += 1
1224+ self .flag_cache .invalidate_version (old_version )
1225+
11941226 def _load_feature_flags (self ):
1227+ should_fetch = True
1228+ if self ._flag_definition_cache_provider :
1229+ try :
1230+ should_fetch = (
1231+ self ._flag_definition_cache_provider .should_fetch_flag_definitions ()
1232+ )
1233+ except Exception as e :
1234+ self .log .error (
1235+ f"[FEATURE FLAGS] Cache provider should_fetch error: { e } "
1236+ )
1237+ # Fail-safe: fetch from API if cache provider errors
1238+ should_fetch = True
1239+
1240+ # If not fetching, try to get from cache
1241+ if not should_fetch and self ._flag_definition_cache_provider :
1242+ try :
1243+ cached_data = (
1244+ self ._flag_definition_cache_provider .get_flag_definitions ()
1245+ )
1246+ if cached_data :
1247+ self .log .debug (
1248+ "[FEATURE FLAGS] Using cached flag definitions from external cache"
1249+ )
1250+ self ._update_flag_state (
1251+ cached_data , old_flags_by_key = self .feature_flags_by_key or {}
1252+ )
1253+ self ._last_feature_flag_poll = datetime .now (tz = tzutc ())
1254+ return
1255+ else :
1256+ # Emergency fallback: if cache is empty and we have no flags, fetch anyway.
1257+ # There's really no other way of recovering in this case.
1258+ if not self .feature_flags :
1259+ self .log .debug (
1260+ "[FEATURE FLAGS] Cache empty and no flags loaded, falling back to API fetch"
1261+ )
1262+ should_fetch = True
1263+ except Exception as e :
1264+ self .log .error (f"[FEATURE FLAGS] Cache provider get error: { e } " )
1265+ # Fail-safe: fetch from API if cache provider errors
1266+ should_fetch = True
1267+
1268+ if should_fetch :
1269+ self ._fetch_feature_flags_from_api ()
1270+
1271+ def _fetch_feature_flags_from_api (self ):
1272+ """Fetch feature flags from the PostHog API."""
11951273 try :
11961274 # Store old flags to detect changes
11971275 old_flags_by_key : dict [str , dict ] = self .feature_flags_by_key or {}
@@ -1221,17 +1299,21 @@ def _load_feature_flags(self):
12211299 )
12221300 return
12231301
1224- self .feature_flags = response .data ["flags" ] or []
1225- self .group_type_mapping = response .data ["group_type_mapping" ] or {}
1226- self .cohorts = response .data ["cohorts" ] or {}
1302+ self ._update_flag_state (response .data , old_flags_by_key = old_flags_by_key )
12271303
1228- # Check if flag definitions changed and update version
1229- if self .flag_cache and old_flags_by_key != (
1230- self .feature_flags_by_key or {}
1231- ):
1232- old_version = self .flag_definition_version
1233- self .flag_definition_version += 1
1234- self .flag_cache .invalidate_version (old_version )
1304+ # Store in external cache if provider is configured
1305+ if self ._flag_definition_cache_provider :
1306+ try :
1307+ self ._flag_definition_cache_provider .on_flag_definitions_received (
1308+ {
1309+ "flags" : self .feature_flags or [],
1310+ "group_type_mapping" : self .group_type_mapping or {},
1311+ "cohorts" : self .cohorts or {},
1312+ }
1313+ )
1314+ except Exception as e :
1315+ self .log .error (f"[FEATURE FLAGS] Cache provider store error: { e } " )
1316+ # Flags are already in memory, so continue normally
12351317
12361318 except APIError as e :
12371319 if e .status == 401 :
@@ -1331,7 +1413,8 @@ def _compute_flag_locally(
13311413 flag_filters = feature_flag .get ("filters" ) or {}
13321414 aggregation_group_type_index = flag_filters .get ("aggregation_group_type_index" )
13331415 if aggregation_group_type_index is not None :
1334- group_name = self .group_type_mapping .get (str (aggregation_group_type_index ))
1416+ group_type_mapping = self .group_type_mapping or {}
1417+ group_name = group_type_mapping .get (str (aggregation_group_type_index ))
13351418
13361419 if not group_name :
13371420 self .log .warning (
0 commit comments