Skip to content

Commit 62acd57

Browse files
authored
feat(flags): preload referenced cohorts in flags hypercache (#52023)
1 parent 3cb6f2c commit 62acd57

File tree

13 files changed

+799
-32
lines changed

13 files changed

+799
-32
lines changed

posthog/models/feature_flag/flags_cache.py

Lines changed: 194 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,15 @@
33
44
This module provides a HyperCache that stores feature flags for the feature-flags service.
55
Unlike the local_evaluation.py cache which provides rich data for SDKs (including cohorts
6-
and group type mappings), this cache provides just the raw flag data.
6+
and group type mappings), this cache provides flag data plus the cohort definitions that
7+
those flags actually reference.
78
89
The cache is automatically invalidated when:
910
- FeatureFlag models are created, updated, or deleted
1011
- Team models are created or deleted (to ensure flag caches are cleaned up)
1112
- FeatureFlagEvaluationTag models are created or deleted
1213
- Tag models are updated (since tag names are cached in evaluation_tags)
14+
- Cohort definitions are created, updated, or deleted (not recalculation-only saves)
1315
- Hourly refresh job detects expiring entries (TTL < 24h)
1416
1517
Cache Key Pattern:
@@ -44,6 +46,8 @@
4446

4547
from posthog.caching.flags_redis_cache import FLAGS_DEDICATED_CACHE_ALIAS
4648
from posthog.metrics import TOMBSTONE_COUNTER
49+
from posthog.models.cohort.cohort import Cohort
50+
from posthog.models.cohort.dependencies import extract_cohort_dependencies
4751
from posthog.models.feature_flag import FeatureFlag
4852
from posthog.models.feature_flag.feature_flag import (
4953
FeatureFlagEvaluationTag,
@@ -92,6 +96,138 @@ def _extract_direct_dependency_ids(flag_data: dict[str, Any]) -> set[int]:
9296
return dep_ids
9397

9498

99+
# Cohort model fields that change only during recalculation, not definition edits.
100+
# Used by the post_save signal to avoid rebuilding the flags cache on every
101+
# static cohort recalculation or count update.
102+
# NOTE: cohort_type is included because calculate_people_ch() always saves it in
103+
# update_fields even when unchanged. The rare actual cohort_type change (realtime
104+
# exceeding person limit) uses Cohort.objects.filter().update() which bypasses signals.
105+
_COHORT_RECALCULATION_FIELDS = frozenset(
106+
[
107+
"count",
108+
"version",
109+
"pending_version",
110+
"is_calculating",
111+
"last_calculation",
112+
"last_calculation_duration_ms",
113+
"errors_calculating",
114+
"last_error_at",
115+
# NOTE: `groups` is the legacy cohort-condition field (deprecated in favour of
116+
# `filters`). calculate_people_ch() always saves it in update_fields even when
117+
# unchanged (see cohort.py:347). Real definition changes go through a full save
118+
# (update_fields=None), so they still trigger invalidation.
119+
"groups",
120+
"cohort_type",
121+
]
122+
)
123+
124+
125+
def _extract_cohort_ids_from_flag_filters(flags_data: list[dict[str, Any]]) -> set[int]:
126+
"""Extract cohort IDs directly referenced in active flag filters.
127+
128+
Only scans ``groups`` — the other filter sections cannot contain cohort
129+
properties:
130+
- ``super_groups`` are early-access enrollment gates that only use person
131+
properties (``$feature_enrollment/*``).
132+
- ``holdout`` uses a different schema for configuring experiment holdouts
133+
with no property filters at all.
134+
"""
135+
cohort_ids: set[int] = set()
136+
for flag in flags_data:
137+
if not flag.get("active", True) or flag.get("deleted", False):
138+
continue
139+
for group in flag.get("filters", {}).get("groups") or []:
140+
for prop in group.get("properties") or []:
141+
if prop.get("type") == "cohort":
142+
try:
143+
cohort_ids.add(int(prop["value"]))
144+
except (ValueError, KeyError, TypeError):
145+
continue
146+
return cohort_ids
147+
148+
149+
_MAX_COHORT_DEPENDENCY_DEPTH = 20
150+
151+
152+
def _load_cohorts_with_deps(seed_ids: set[int], **team_filter: Any) -> dict[int, Cohort]:
153+
"""BFS-load cohorts by seed IDs, resolving transitive cohort-on-cohort deps.
154+
155+
Args:
156+
seed_ids: Initial cohort IDs to load.
157+
**team_filter: Passed to Cohort.objects.filter() for team scoping,
158+
e.g. team_id=5 or team_id__in={5, 6}.
159+
160+
Returns:
161+
Dict mapping cohort PK to loaded Cohort instance.
162+
"""
163+
if not seed_ids:
164+
return {}
165+
166+
all_ids = set(seed_ids)
167+
ids_to_load = set(seed_ids)
168+
loaded: dict[int, Cohort] = {}
169+
depth = 0
170+
171+
while ids_to_load:
172+
if depth >= _MAX_COHORT_DEPENDENCY_DEPTH:
173+
logger.warning(
174+
"Cohort dependency depth limit reached",
175+
depth=depth,
176+
remaining_ids=ids_to_load,
177+
)
178+
break
179+
depth += 1
180+
newly_loaded: list[Cohort] = []
181+
for cohort in Cohort.objects.filter(pk__in=ids_to_load, deleted=False, **team_filter):
182+
loaded[cohort.pk] = cohort
183+
newly_loaded.append(cohort)
184+
185+
ids_to_load_next: set[int] = set()
186+
for cohort in newly_loaded:
187+
for dep_id in extract_cohort_dependencies(cohort):
188+
if dep_id not in all_ids:
189+
all_ids.add(dep_id)
190+
ids_to_load_next.add(dep_id)
191+
ids_to_load = ids_to_load_next
192+
193+
return loaded
194+
195+
196+
def _get_referenced_cohorts(team_id: int, flags_data: list[dict[str, Any]]) -> list[dict[str, Any]]:
197+
"""Fetch cohort definitions referenced by flags, including transitive cohort-on-cohort deps."""
198+
direct_ids = _extract_cohort_ids_from_flag_filters(flags_data)
199+
loaded = _load_cohorts_with_deps(direct_ids, team_id=team_id)
200+
return [_serialize_cohort(c) for c in loaded.values()]
201+
202+
203+
def _serialize_cohort(cohort: Cohort) -> dict[str, Any]:
204+
"""Serialize a Cohort to a dict matching the Rust Cohort struct field names.
205+
206+
Keep in sync with rust/feature-flags/src/cohorts/cohort_models.rs::Cohort.
207+
The Rust struct has no serde(default) on required fields (deleted, is_calculating,
208+
is_static, errors_calculating, groups), so omitting any of these causes a
209+
deserialization failure.
210+
"""
211+
return {
212+
"id": cohort.id,
213+
"name": cohort.name,
214+
"description": cohort.description,
215+
"team_id": cohort.team_id,
216+
"deleted": cohort.deleted,
217+
"filters": cohort.filters,
218+
"query": cohort.query,
219+
"version": cohort.version,
220+
"pending_version": cohort.pending_version,
221+
"count": cohort.count,
222+
"is_calculating": cohort.is_calculating,
223+
"is_static": cohort.is_static,
224+
"errors_calculating": cohort.errors_calculating,
225+
"groups": cohort.groups,
226+
"created_by_id": cohort.created_by_id,
227+
"cohort_type": cohort.cohort_type,
228+
}
229+
230+
95231
def _compute_flag_dependencies(flags_data: list[dict[str, Any]]) -> dict[str, Any]:
96232
"""
97233
Compute flag dependency metadata and return evaluation metadata.
@@ -185,24 +321,28 @@ def _get_feature_flags_for_service(team: Team) -> dict[str, Any]:
185321
in /flags would return unusable encrypted ciphertext.
186322
187323
Returns:
188-
dict: {"flags": [...], "evaluation_metadata": {...}} where flags is a list
189-
of flag dictionaries and evaluation_metadata contains pre-computed dependency
190-
metadata (stages, missing deps, transitive deps).
324+
dict: {"flags": [...], "evaluation_metadata": {...}, "cohorts": [...]} where
325+
flags is a list of flag dictionaries, evaluation_metadata contains pre-computed
326+
dependency metadata (stages, missing deps, transitive deps), and cohorts contains
327+
serialized cohort definitions referenced by the flags (including transitive deps).
191328
"""
192329
# Exclude encrypted remote config flags at DB level for efficiency
193330
flags = get_feature_flags(team=team, exclude_encrypted_remote_config=True)
194331
flags_data = serialize_feature_flags(flags)
195332
evaluation_metadata = _compute_flag_dependencies(flags_data)
196333

334+
cohorts = _get_referenced_cohorts(team.id, flags_data)
335+
197336
logger.info(
198337
"Loaded feature flags for service cache",
199338
team_id=team.id,
200339
project_id=team.project_id,
201340
flag_count=len(flags_data),
341+
cohort_count=len(cohorts),
202342
)
203343

204344
# Wrap in dict for HyperCache compatibility
205-
return {"flags": flags_data, "evaluation_metadata": evaluation_metadata}
345+
return {"flags": flags_data, "evaluation_metadata": evaluation_metadata, "cohorts": cohorts}
206346

207347

208348
def _get_feature_flags_for_teams_batch(teams: list[Team]) -> dict[int, dict[str, Any]]:
@@ -220,7 +360,7 @@ def _get_feature_flags_for_teams_batch(teams: list[Team]) -> dict[int, dict[str,
220360
teams: List of Team objects to load flags for
221361
222362
Returns:
223-
Dict mapping team_id to {"flags": [...], "evaluation_metadata": {...}} for each team
363+
Dict mapping team_id to {"flags": [...], "evaluation_metadata": {...}, "cohorts": [...]} for each team
224364
"""
225365
if not teams:
226366
return {}
@@ -255,21 +395,44 @@ def _get_feature_flags_for_teams_batch(teams: list[Team]) -> dict[int, dict[str,
255395
for flag in all_flags:
256396
flags_by_team_id[flag.team_id].append(flag)
257397

258-
# Serialize flags for each team
259-
result: dict[int, dict[str, Any]] = {}
398+
# Serialize flags for each team and collect all referenced cohort IDs
399+
flags_data_by_team: dict[int, list[dict[str, Any]]] = {}
400+
all_cohort_ids: set[int] = set()
260401
for team in teams:
261402
team_flags = flags_by_team_id.get(team.id, [])
262403
flags_data = serialize_feature_flags(team_flags)
404+
flags_data_by_team[team.id] = flags_data
405+
all_cohort_ids.update(_extract_cohort_ids_from_flag_filters(flags_data))
406+
407+
# Batch-load all referenced cohorts across all teams (including transitive deps)
408+
team_ids = {t.id for t in teams}
409+
loaded_cohorts = _load_cohorts_with_deps(all_cohort_ids, team_id__in=team_ids)
410+
411+
# Group loaded cohorts by team_id
412+
cohorts_by_team: dict[int, list[dict[str, Any]]] = defaultdict(list)
413+
for cohort in loaded_cohorts.values():
414+
cohorts_by_team[cohort.team_id].append(_serialize_cohort(cohort))
415+
416+
# Build result for each team
417+
result: dict[int, dict[str, Any]] = {}
418+
for team in teams:
419+
flags_data = flags_data_by_team[team.id]
263420
evaluation_metadata = _compute_flag_dependencies(flags_data)
264421

422+
team_cohorts = cohorts_by_team.get(team.id, [])
265423
logger.info(
266424
"Loaded feature flags for service cache (batch)",
267425
team_id=team.id,
268426
project_id=team.project_id,
269427
flag_count=len(flags_data),
428+
cohort_count=len(team_cohorts),
270429
)
271430

272-
result[team.id] = {"flags": flags_data, "evaluation_metadata": evaluation_metadata}
431+
result[team.id] = {
432+
"flags": flags_data,
433+
"evaluation_metadata": evaluation_metadata,
434+
"cohorts": team_cohorts,
435+
}
273436

274437
return result
275438

@@ -746,3 +909,25 @@ def tag_changed_flags_cache(sender, instance: "Tag", created: bool, **kwargs):
746909
for team_id in FeatureFlagEvaluationTag.get_team_ids_using_tag(instance):
747910
# Capture team_id in closure to avoid late binding issues
748911
transaction.on_commit(lambda tid=team_id: update_team_service_flags_cache.delay(tid)) # type: ignore[misc]
912+
913+
914+
@receiver(post_save, sender=Cohort)
915+
@receiver(post_delete, sender=Cohort)
916+
def cohort_changed_flags_cache(sender, instance: "Cohort", **kwargs):
917+
"""
918+
Invalidate flags cache when a cohort definition changes.
919+
920+
Skips recalculation-only saves (count, version, is_calculating, etc.) to avoid
921+
rebuilding the flags cache on every static cohort recalculation.
922+
Only operates when FLAGS_REDIS_URL is configured.
923+
"""
924+
if not settings.FLAGS_REDIS_URL:
925+
return
926+
927+
update_fields = kwargs.get("update_fields")
928+
if update_fields is not None and frozenset(update_fields) <= _COHORT_RECALCULATION_FIELDS:
929+
return
930+
931+
from posthog.tasks.feature_flags import update_team_service_flags_cache
932+
933+
transaction.on_commit(lambda: update_team_service_flags_cache.delay(instance.team_id))

0 commit comments

Comments
 (0)