Skip to content

Commit 5013aad

Browse files
authored
Fix dag-processor crash when renaming DAG tag case on MySQL (apache#57113)
When a user changed only the case of a DAG tag (e.g., 'dangerous' to 'DANGEROUS'), the dag-processor would crash with a duplicate key error on MySQL due to case-insensitive collation in the PRIMARY KEY. This occurred because SQLAlchemy executed INSERT operations before DELETE operations during the flush. The fix ensures DELETE operations complete before attempting INSERT operations by explicitly flushing and refreshing the tag relationship from the database. Fixes apache#56940
1 parent 35bbcd0 commit 5013aad

File tree

2 files changed

+55
-1
lines changed

2 files changed

+55
-1
lines changed

airflow-core/src/airflow/dag_processing/collection.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,10 +159,32 @@ def calculate(cls, dags: dict[str, LazyDeserializedDAG], *, session: Session) ->
159159

160160
def _update_dag_tags(tag_names: set[str], dm: DagModel, *, session: Session) -> None:
161161
orm_tags = {t.name: t for t in dm.tags}
162+
tags_to_delete = []
162163
for name, orm_tag in orm_tags.items():
163164
if name not in tag_names:
164165
session.delete(orm_tag)
165-
dm.tags.extend(DagTag(name=name, dag_id=dm.dag_id) for name in tag_names.difference(orm_tags))
166+
tags_to_delete.append(orm_tag)
167+
168+
tags_to_add = tag_names.difference(orm_tags)
169+
if tags_to_delete:
170+
# Remove deleted tags from the collection to keep it in sync
171+
for tag in tags_to_delete:
172+
dm.tags.remove(tag)
173+
174+
# Check if there's a potential case-only rename on MySQL (e.g., 'tag' -> 'TAG').
175+
# MySQL uses case-insensitive collation for the (name, dag_id) primary key by default,
176+
# which can cause duplicate key errors when renaming tags with only case changes.
177+
if get_dialect_name(session) == "mysql":
178+
orm_tags_lower = {name.lower(): name for name in orm_tags}
179+
has_case_only_change = any(tag.lower() in orm_tags_lower for tag in tags_to_add)
180+
181+
if has_case_only_change:
182+
# Force DELETE operations to execute before INSERT operations.
183+
session.flush()
184+
# Refresh the tags relationship from the database to reflect the deletions.
185+
session.expire(dm, ["tags"])
186+
187+
dm.tags.extend(DagTag(name=name, dag_id=dm.dag_id) for name in tags_to_add)
166188

167189

168190
def _update_dag_owner_links(dag_owner_links: dict[str, str], dm: DagModel, *, session: Session) -> None:

airflow-core/tests/unit/dag_processing/test_collection.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
AssetModelOperation,
3838
DagModelOperation,
3939
_get_latest_runs_stmt,
40+
_update_dag_tags,
4041
update_dag_parsing_results_in_db,
4142
)
4243
from airflow.exceptions import SerializationError
@@ -48,6 +49,7 @@
4849
DagScheduleAssetNameReference,
4950
DagScheduleAssetUriReference,
5051
)
52+
from airflow.models.dag import DagTag
5153
from airflow.models.errors import ParseImportError
5254
from airflow.models.serialized_dag import SerializedDagModel
5355
from airflow.providers.standard.operators.empty import EmptyOperator
@@ -941,3 +943,33 @@ def test_max_consecutive_failed_dag_runs_defaults_from_conf_when_none(
941943
update_dag_parsing_results_in_db("testing", None, [dag], {}, 0.1, set(), session)
942944
orm_dag = session.get(DagModel, "dag_max_failed_runs_default")
943945
assert orm_dag.max_consecutive_failed_dag_runs == 6
946+
947+
948+
@pytest.mark.db_test
949+
class TestUpdateDagTags:
950+
@pytest.fixture(autouse=True)
951+
def setup_teardown(self, session):
952+
yield
953+
session.query(DagModel).filter(DagModel.dag_id == "test_dag").delete()
954+
session.commit()
955+
956+
@pytest.mark.parametrize(
957+
["initial_tags", "new_tags", "expected_tags"],
958+
[
959+
(["dangerous"], {"DANGEROUS"}, {"DANGEROUS"}),
960+
(["existing"], {"existing", "new"}, {"existing", "new"}),
961+
(["tag1", "tag2"], {"tag1"}, {"tag1"}),
962+
(["keep", "remove", "lowercase"], {"keep", "LOWERCASE", "new"}, {"keep", "LOWERCASE", "new"}),
963+
(["tag1", "tag2"], set(), set()),
964+
],
965+
)
966+
def test_update_dag_tags(self, testing_dag_bundle, session, initial_tags, new_tags, expected_tags):
967+
dag_model = DagModel(dag_id="test_dag", bundle_name="testing")
968+
dag_model.tags = [DagTag(name=tag, dag_id="test_dag") for tag in initial_tags]
969+
session.add(dag_model)
970+
session.commit()
971+
972+
_update_dag_tags(new_tags, dag_model, session=session)
973+
session.commit()
974+
975+
assert {t.name for t in dag_model.tags} == expected_tags

0 commit comments

Comments
 (0)