Skip to content

Commit b6f80f8

Browse files
authored
fix(aci): don't fire action if there is a conflict when creating WAGS (#104002)
1 parent a7abad1 commit b6f80f8

File tree

2 files changed

+109
-28
lines changed

2 files changed

+109
-28
lines changed

src/sentry/workflow_engine/processors/action.py

Lines changed: 58 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from collections import defaultdict
33
from datetime import datetime, timedelta
44

5-
from django.db import models
5+
from django.db import connection, models
66
from django.db.models import Case, Value, When
77
from django.utils import timezone
88

@@ -36,6 +36,9 @@
3636
logger = logging.getLogger(__name__)
3737

3838
EnqueuedAction = tuple[DataConditionGroup, list[DataCondition]]
39+
UpdatedStatuses = int
40+
CreatedStatuses = int
41+
ConflictedStatuses = list[tuple[int, int]] # (workflow_id, action_id)
3942

4043

4144
def get_workflow_action_group_statuses(
@@ -71,13 +74,13 @@ def process_workflow_action_group_statuses(
7174
workflows: BaseQuerySet[Workflow],
7275
group: Group,
7376
now: datetime,
74-
) -> tuple[dict[int, int], set[int], list[WorkflowActionGroupStatus]]:
77+
) -> tuple[dict[int, set[int]], set[int], list[WorkflowActionGroupStatus]]:
7578
"""
7679
Determine which workflow actions should be fired based on their statuses.
7780
Prepare the statuses to update and create.
7881
"""
7982

80-
action_to_workflow_ids: dict[int, int] = {} # will dedupe because there can be only 1
83+
updated_action_to_workflows_ids: dict[int, set[int]] = defaultdict(set)
8184
workflow_frequencies: dict[int, timedelta] = {
8285
workflow.id: workflow.config.get("frequency", 0) * timedelta(minutes=1)
8386
for workflow in workflows
@@ -91,7 +94,7 @@ def process_workflow_action_group_statuses(
9194
status.workflow_id, zero_timedelta
9295
):
9396
# we should fire the workflow for this action
94-
action_to_workflow_ids[action_id] = status.workflow_id
97+
updated_action_to_workflows_ids[action_id].add(status.workflow_id)
9598
statuses_to_update.add(status.id)
9699

97100
missing_statuses: list[WorkflowActionGroupStatus] = []
@@ -107,31 +110,51 @@ def process_workflow_action_group_statuses(
107110
workflow_id=workflow_id, action_id=action_id, group=group, date_updated=now
108111
)
109112
)
110-
action_to_workflow_ids[action_id] = workflow_id
113+
updated_action_to_workflows_ids[action_id].add(workflow_id)
111114

112-
return action_to_workflow_ids, statuses_to_update, missing_statuses
115+
return updated_action_to_workflows_ids, statuses_to_update, missing_statuses
113116

114117

115118
def update_workflow_action_group_statuses(
116119
now: datetime, statuses_to_update: set[int], missing_statuses: list[WorkflowActionGroupStatus]
117-
) -> None:
118-
WorkflowActionGroupStatus.objects.filter(
120+
) -> tuple[UpdatedStatuses, CreatedStatuses, ConflictedStatuses]:
121+
updated_count = WorkflowActionGroupStatus.objects.filter(
119122
id__in=statuses_to_update, date_updated__lt=now
120123
).update(date_updated=now)
121124

122-
all_statuses = WorkflowActionGroupStatus.objects.bulk_create(
123-
missing_statuses,
124-
batch_size=1000,
125-
ignore_conflicts=True,
126-
)
127-
missing_status_pairs = [
128-
(status.workflow_id, status.action_id) for status in all_statuses if status.id is None
125+
if not missing_statuses:
126+
return updated_count, 0, []
127+
128+
# Use raw SQL: only returns successfully created rows
129+
# XXX: the query does not currently include batch size limit like bulk_create does
130+
with connection.cursor() as cursor:
131+
# Build values for batch insert
132+
values_placeholders = []
133+
values_data = []
134+
for s in missing_statuses:
135+
values_placeholders.append("(%s, %s, %s, %s, %s)")
136+
values_data.extend([s.workflow_id, s.action_id, s.group_id, now, now])
137+
138+
sql = f"""
139+
INSERT INTO workflow_engine_workflowactiongroupstatus
140+
(workflow_id, action_id, group_id, date_added, date_updated)
141+
VALUES {', '.join(values_placeholders)}
142+
ON CONFLICT (workflow_id, action_id, group_id) DO NOTHING
143+
RETURNING workflow_id, action_id
144+
"""
145+
146+
cursor.execute(sql, values_data)
147+
created_rows = set(cursor.fetchall()) # Only returns newly inserted rows
148+
149+
# Figure out which ones conflicted (weren't returned)
150+
conflicted_statuses = [
151+
(s.workflow_id, s.action_id)
152+
for s in missing_statuses
153+
if (s.workflow_id, s.action_id) not in created_rows
129154
]
130-
if missing_status_pairs:
131-
logger.warning(
132-
"Failed to create WorkflowActionGroupStatus objects",
133-
extra={"missing_status_pairs": missing_status_pairs},
134-
)
155+
156+
created_count = len(created_rows)
157+
return updated_count, created_count, conflicted_statuses
135158

136159

137160
def get_unique_active_actions(
@@ -190,7 +213,7 @@ def filter_recently_fired_workflow_actions(
190213
workflow_ids=workflow_ids,
191214
)
192215
now = timezone.now()
193-
action_to_workflow_ids, statuses_to_update, missing_statuses = (
216+
action_to_workflows_ids, statuses_to_update, missing_statuses = (
194217
process_workflow_action_group_statuses(
195218
action_to_workflows_ids=action_to_workflows_ids,
196219
action_to_statuses=action_to_statuses,
@@ -199,14 +222,24 @@ def filter_recently_fired_workflow_actions(
199222
now=now,
200223
)
201224
)
202-
update_workflow_action_group_statuses(now, statuses_to_update, missing_statuses)
225+
_, _, conflicted_statuses = update_workflow_action_group_statuses(
226+
now, statuses_to_update, missing_statuses
227+
)
228+
229+
# if statuses were not created for some reason, we should not fire for them
230+
for workflow_id, action_id in conflicted_statuses:
231+
action_to_workflows_ids[action_id].remove(workflow_id)
232+
if not action_to_workflows_ids[action_id]:
233+
action_to_workflows_ids.pop(action_id)
203234

204-
actions_queryset = Action.objects.filter(id__in=list(action_to_workflow_ids.keys()))
235+
actions_queryset = Action.objects.filter(id__in=list(action_to_workflows_ids.keys()))
205236

206237
# annotate actions with workflow_id they are firing for (deduped)
207238
workflow_id_cases = [
208-
When(id=action_id, then=Value(workflow_id))
209-
for action_id, workflow_id in action_to_workflow_ids.items()
239+
When(
240+
id=action_id, then=Value(min(list(workflow_ids)))
241+
) # select 1 workflow to fire for, this is arbitrary but deterministic
242+
for action_id, workflow_ids in action_to_workflows_ids.items()
210243
]
211244

212245
return actions_queryset.annotate(

tests/sentry/workflow_engine/processors/test_action.py

Lines changed: 51 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def test_multiple_workflows_single_action__first_fire(self) -> None:
107107
# dedupes action if both workflows will fire it
108108
assert set(triggered_actions) == {self.action}
109109
# Dedupes action so we have a single workflow_id -> environment to fire with
110-
assert {getattr(action, "workflow_id") for action in triggered_actions} == {workflow.id}
110+
assert getattr(triggered_actions[0], "workflow_id") == self.workflow.id
111111

112112
assert WorkflowActionGroupStatus.objects.filter(action=self.action).count() == 2
113113

@@ -191,8 +191,8 @@ def test_process_workflow_action_group_statuses(self) -> None:
191191
)
192192

193193
assert action_to_workflow_ids == {
194-
self.action.id: self.workflow.id,
195-
action.id: workflow.id,
194+
self.action.id: {self.workflow.id},
195+
action.id: {workflow.id},
196196
}
197197
assert statuses_to_update == {status_2.id}
198198

@@ -222,6 +222,54 @@ def test_update_workflow_action_group_statuses(self) -> None:
222222
for status in all_statuses:
223223
assert status.date_updated == timezone.now()
224224

225+
def test_returns_uncreated_statuses(self) -> None:
226+
WorkflowActionGroupStatus.objects.create(
227+
workflow=self.workflow, action=self.action, group=self.group
228+
)
229+
230+
statuses_to_create = [
231+
WorkflowActionGroupStatus(
232+
workflow=self.workflow,
233+
action=self.action,
234+
group=self.group,
235+
date_updated=timezone.now(),
236+
)
237+
]
238+
_, _, uncreated_statuses = update_workflow_action_group_statuses(
239+
timezone.now(), set(), statuses_to_create
240+
)
241+
242+
assert uncreated_statuses == [(self.workflow.id, self.action.id)]
243+
244+
@patch("sentry.workflow_engine.processors.action.update_workflow_action_group_statuses")
245+
def test_does_not_fire_for_uncreated_statuses(self, mock_update: MagicMock) -> None:
246+
mock_update.return_value = (0, 0, [(self.workflow.id, self.action.id)])
247+
248+
triggered_actions = filter_recently_fired_workflow_actions(
249+
set(DataConditionGroup.objects.all()), self.event_data
250+
)
251+
252+
assert set(triggered_actions) == set()
253+
254+
@patch("sentry.workflow_engine.processors.action.update_workflow_action_group_statuses")
255+
def test_fires_for_non_conflicting_workflow(self, mock_update: MagicMock) -> None:
256+
workflow = self.create_workflow(organization=self.organization, config={"frequency": 1440})
257+
action_group = self.create_data_condition_group(logic_type="any-short")
258+
self.create_data_condition_group_action(
259+
condition_group=action_group,
260+
action=self.action,
261+
) # shared action
262+
self.create_workflow_data_condition_group(workflow, action_group)
263+
264+
mock_update.return_value = (0, 0, [(self.workflow.id, self.action.id)])
265+
266+
triggered_actions = filter_recently_fired_workflow_actions(
267+
set(DataConditionGroup.objects.all()), self.event_data
268+
)
269+
270+
assert set(triggered_actions) == {self.action}
271+
assert getattr(triggered_actions[0], "workflow_id") == workflow.id
272+
225273

226274
class TestIsActionPermitted(BaseWorkflowTest):
227275
@patch("sentry.workflow_engine.processors.action._get_integration_features")

0 commit comments

Comments
 (0)