33
44This module provides a HyperCache that stores feature flags for the feature-flags service.
55Unlike the local_evaluation.py cache which provides rich data for SDKs (including cohorts
6- and group type mappings), this cache provides just the raw flag data.
6+ and group type mappings), this cache provides flag data plus the cohort definitions that
7+ those flags actually reference.
78
89The cache is automatically invalidated when:
910- FeatureFlag models are created, updated, or deleted
1011- Team models are created or deleted (to ensure flag caches are cleaned up)
1112- FeatureFlagEvaluationTag models are created or deleted
1213- Tag models are updated (since tag names are cached in evaluation_tags)
14+ - Cohort definitions are created, updated, or deleted (not recalculation-only saves)
1315- Hourly refresh job detects expiring entries (TTL < 24h)
1416
1517Cache Key Pattern:
4446
4547from posthog .caching .flags_redis_cache import FLAGS_DEDICATED_CACHE_ALIAS
4648from posthog .metrics import TOMBSTONE_COUNTER
49+ from posthog .models .cohort .cohort import Cohort
50+ from posthog .models .cohort .dependencies import extract_cohort_dependencies
4751from posthog .models .feature_flag import FeatureFlag
4852from posthog .models .feature_flag .feature_flag import (
4953 FeatureFlagEvaluationTag ,
@@ -92,6 +96,138 @@ def _extract_direct_dependency_ids(flag_data: dict[str, Any]) -> set[int]:
9296 return dep_ids
9397
9498
99+ # Cohort model fields that change only during recalculation, not definition edits.
100+ # Used by the post_save signal to avoid rebuilding the flags cache on every
101+ # static cohort recalculation or count update.
102+ # NOTE: cohort_type is included because calculate_people_ch() always saves it in
103+ # update_fields even when unchanged. The rare actual cohort_type change (realtime
104+ # exceeding person limit) uses Cohort.objects.filter().update() which bypasses signals.
105+ _COHORT_RECALCULATION_FIELDS = frozenset (
106+ [
107+ "count" ,
108+ "version" ,
109+ "pending_version" ,
110+ "is_calculating" ,
111+ "last_calculation" ,
112+ "last_calculation_duration_ms" ,
113+ "errors_calculating" ,
114+ "last_error_at" ,
115+ # NOTE: `groups` is the legacy cohort-condition field (deprecated in favour of
116+ # `filters`). calculate_people_ch() always saves it in update_fields even when
117+ # unchanged (see cohort.py:347). Real definition changes go through a full save
118+ # (update_fields=None), so they still trigger invalidation.
119+ "groups" ,
120+ "cohort_type" ,
121+ ]
122+ )
123+
124+
125+ def _extract_cohort_ids_from_flag_filters (flags_data : list [dict [str , Any ]]) -> set [int ]:
126+ """Extract cohort IDs directly referenced in active flag filters.
127+
128+ Only scans ``groups`` — the other filter sections cannot contain cohort
129+ properties:
130+ - ``super_groups`` are early-access enrollment gates that only use person
131+ properties (``$feature_enrollment/*``).
132+ - ``holdout`` uses a different schema for configuring experiment holdouts
133+ with no property filters at all.
134+ """
135+ cohort_ids : set [int ] = set ()
136+ for flag in flags_data :
137+ if not flag .get ("active" , True ) or flag .get ("deleted" , False ):
138+ continue
139+ for group in flag .get ("filters" , {}).get ("groups" ) or []:
140+ for prop in group .get ("properties" ) or []:
141+ if prop .get ("type" ) == "cohort" :
142+ try :
143+ cohort_ids .add (int (prop ["value" ]))
144+ except (ValueError , KeyError , TypeError ):
145+ continue
146+ return cohort_ids
147+
148+
149+ _MAX_COHORT_DEPENDENCY_DEPTH = 20
150+
151+
152+ def _load_cohorts_with_deps (seed_ids : set [int ], ** team_filter : Any ) -> dict [int , Cohort ]:
153+ """BFS-load cohorts by seed IDs, resolving transitive cohort-on-cohort deps.
154+
155+ Args:
156+ seed_ids: Initial cohort IDs to load.
157+ **team_filter: Passed to Cohort.objects.filter() for team scoping,
158+ e.g. team_id=5 or team_id__in={5, 6}.
159+
160+ Returns:
161+ Dict mapping cohort PK to loaded Cohort instance.
162+ """
163+ if not seed_ids :
164+ return {}
165+
166+ all_ids = set (seed_ids )
167+ ids_to_load = set (seed_ids )
168+ loaded : dict [int , Cohort ] = {}
169+ depth = 0
170+
171+ while ids_to_load :
172+ if depth >= _MAX_COHORT_DEPENDENCY_DEPTH :
173+ logger .warning (
174+ "Cohort dependency depth limit reached" ,
175+ depth = depth ,
176+ remaining_ids = ids_to_load ,
177+ )
178+ break
179+ depth += 1
180+ newly_loaded : list [Cohort ] = []
181+ for cohort in Cohort .objects .filter (pk__in = ids_to_load , deleted = False , ** team_filter ):
182+ loaded [cohort .pk ] = cohort
183+ newly_loaded .append (cohort )
184+
185+ ids_to_load_next : set [int ] = set ()
186+ for cohort in newly_loaded :
187+ for dep_id in extract_cohort_dependencies (cohort ):
188+ if dep_id not in all_ids :
189+ all_ids .add (dep_id )
190+ ids_to_load_next .add (dep_id )
191+ ids_to_load = ids_to_load_next
192+
193+ return loaded
194+
195+
196+ def _get_referenced_cohorts (team_id : int , flags_data : list [dict [str , Any ]]) -> list [dict [str , Any ]]:
197+ """Fetch cohort definitions referenced by flags, including transitive cohort-on-cohort deps."""
198+ direct_ids = _extract_cohort_ids_from_flag_filters (flags_data )
199+ loaded = _load_cohorts_with_deps (direct_ids , team_id = team_id )
200+ return [_serialize_cohort (c ) for c in loaded .values ()]
201+
202+
203+ def _serialize_cohort (cohort : Cohort ) -> dict [str , Any ]:
204+ """Serialize a Cohort to a dict matching the Rust Cohort struct field names.
205+
206+ Keep in sync with rust/feature-flags/src/cohorts/cohort_models.rs::Cohort.
207+ The Rust struct has no serde(default) on required fields (deleted, is_calculating,
208+ is_static, errors_calculating, groups), so omitting any of these causes a
209+ deserialization failure.
210+ """
211+ return {
212+ "id" : cohort .id ,
213+ "name" : cohort .name ,
214+ "description" : cohort .description ,
215+ "team_id" : cohort .team_id ,
216+ "deleted" : cohort .deleted ,
217+ "filters" : cohort .filters ,
218+ "query" : cohort .query ,
219+ "version" : cohort .version ,
220+ "pending_version" : cohort .pending_version ,
221+ "count" : cohort .count ,
222+ "is_calculating" : cohort .is_calculating ,
223+ "is_static" : cohort .is_static ,
224+ "errors_calculating" : cohort .errors_calculating ,
225+ "groups" : cohort .groups ,
226+ "created_by_id" : cohort .created_by_id ,
227+ "cohort_type" : cohort .cohort_type ,
228+ }
229+
230+
95231def _compute_flag_dependencies (flags_data : list [dict [str , Any ]]) -> dict [str , Any ]:
96232 """
97233 Compute flag dependency metadata and return evaluation metadata.
@@ -185,24 +321,28 @@ def _get_feature_flags_for_service(team: Team) -> dict[str, Any]:
185321 in /flags would return unusable encrypted ciphertext.
186322
187323 Returns:
188- dict: {"flags": [...], "evaluation_metadata": {...}} where flags is a list
189- of flag dictionaries and evaluation_metadata contains pre-computed dependency
190- metadata (stages, missing deps, transitive deps).
324+ dict: {"flags": [...], "evaluation_metadata": {...}, "cohorts": [...]} where
325+ flags is a list of flag dictionaries, evaluation_metadata contains pre-computed
326+ dependency metadata (stages, missing deps, transitive deps), and cohorts contains
327+ serialized cohort definitions referenced by the flags (including transitive deps).
191328 """
192329 # Exclude encrypted remote config flags at DB level for efficiency
193330 flags = get_feature_flags (team = team , exclude_encrypted_remote_config = True )
194331 flags_data = serialize_feature_flags (flags )
195332 evaluation_metadata = _compute_flag_dependencies (flags_data )
196333
334+ cohorts = _get_referenced_cohorts (team .id , flags_data )
335+
197336 logger .info (
198337 "Loaded feature flags for service cache" ,
199338 team_id = team .id ,
200339 project_id = team .project_id ,
201340 flag_count = len (flags_data ),
341+ cohort_count = len (cohorts ),
202342 )
203343
204344 # Wrap in dict for HyperCache compatibility
205- return {"flags" : flags_data , "evaluation_metadata" : evaluation_metadata }
345+ return {"flags" : flags_data , "evaluation_metadata" : evaluation_metadata , "cohorts" : cohorts }
206346
207347
208348def _get_feature_flags_for_teams_batch (teams : list [Team ]) -> dict [int , dict [str , Any ]]:
@@ -220,7 +360,7 @@ def _get_feature_flags_for_teams_batch(teams: list[Team]) -> dict[int, dict[str,
220360 teams: List of Team objects to load flags for
221361
222362 Returns:
223- Dict mapping team_id to {"flags": [...], "evaluation_metadata": {...}} for each team
363+ Dict mapping team_id to {"flags": [...], "evaluation_metadata": {...}, "cohorts": [...] } for each team
224364 """
225365 if not teams :
226366 return {}
@@ -255,21 +395,44 @@ def _get_feature_flags_for_teams_batch(teams: list[Team]) -> dict[int, dict[str,
255395 for flag in all_flags :
256396 flags_by_team_id [flag .team_id ].append (flag )
257397
258- # Serialize flags for each team
259- result : dict [int , dict [str , Any ]] = {}
398+ # Serialize flags for each team and collect all referenced cohort IDs
399+ flags_data_by_team : dict [int , list [dict [str , Any ]]] = {}
400+ all_cohort_ids : set [int ] = set ()
260401 for team in teams :
261402 team_flags = flags_by_team_id .get (team .id , [])
262403 flags_data = serialize_feature_flags (team_flags )
404+ flags_data_by_team [team .id ] = flags_data
405+ all_cohort_ids .update (_extract_cohort_ids_from_flag_filters (flags_data ))
406+
407+ # Batch-load all referenced cohorts across all teams (including transitive deps)
408+ team_ids = {t .id for t in teams }
409+ loaded_cohorts = _load_cohorts_with_deps (all_cohort_ids , team_id__in = team_ids )
410+
411+ # Group loaded cohorts by team_id
412+ cohorts_by_team : dict [int , list [dict [str , Any ]]] = defaultdict (list )
413+ for cohort in loaded_cohorts .values ():
414+ cohorts_by_team [cohort .team_id ].append (_serialize_cohort (cohort ))
415+
416+ # Build result for each team
417+ result : dict [int , dict [str , Any ]] = {}
418+ for team in teams :
419+ flags_data = flags_data_by_team [team .id ]
263420 evaluation_metadata = _compute_flag_dependencies (flags_data )
264421
422+ team_cohorts = cohorts_by_team .get (team .id , [])
265423 logger .info (
266424 "Loaded feature flags for service cache (batch)" ,
267425 team_id = team .id ,
268426 project_id = team .project_id ,
269427 flag_count = len (flags_data ),
428+ cohort_count = len (team_cohorts ),
270429 )
271430
272- result [team .id ] = {"flags" : flags_data , "evaluation_metadata" : evaluation_metadata }
431+ result [team .id ] = {
432+ "flags" : flags_data ,
433+ "evaluation_metadata" : evaluation_metadata ,
434+ "cohorts" : team_cohorts ,
435+ }
273436
274437 return result
275438
@@ -746,3 +909,25 @@ def tag_changed_flags_cache(sender, instance: "Tag", created: bool, **kwargs):
746909 for team_id in FeatureFlagEvaluationTag .get_team_ids_using_tag (instance ):
747910 # Capture team_id in closure to avoid late binding issues
748911 transaction .on_commit (lambda tid = team_id : update_team_service_flags_cache .delay (tid )) # type: ignore[misc]
912+
913+
914+ @receiver (post_save , sender = Cohort )
915+ @receiver (post_delete , sender = Cohort )
916+ def cohort_changed_flags_cache (sender , instance : "Cohort" , ** kwargs ):
917+ """
918+ Invalidate flags cache when a cohort definition changes.
919+
920+ Skips recalculation-only saves (count, version, is_calculating, etc.) to avoid
921+ rebuilding the flags cache on every static cohort recalculation.
922+ Only operates when FLAGS_REDIS_URL is configured.
923+ """
924+ if not settings .FLAGS_REDIS_URL :
925+ return
926+
927+ update_fields = kwargs .get ("update_fields" )
928+ if update_fields is not None and frozenset (update_fields ) <= _COHORT_RECALCULATION_FIELDS :
929+ return
930+
931+ from posthog .tasks .feature_flags import update_team_service_flags_cache
932+
933+ transaction .on_commit (lambda : update_team_service_flags_cache .delay (instance .team_id ))
0 commit comments