@@ -628,3 +628,57 @@ def test_dynamic_dag_update_success(self, dag_maker, session):
628628 updated_sdag = SDM .get ("test_dynamic_success" , session = session )
629629 assert updated_sdag .dag_hash != initial_hash # Hash should change
630630 assert len (updated_sdag .dag .task_dict ) == 2 # Should have 2 tasks now
631+
632+ def test_write_dag_atomicity_on_dagcode_failure (self , dag_maker , session ):
633+ """
634+ Test that SerializedDagModel.write_dag maintains atomicity.
635+
636+ If DagCode.write_code fails, the entire transaction should rollback,
637+ including the DagVersion. This test verifies that DagVersion is not
638+ committed separately, which would leave orphaned records.
639+
640+ This test would fail if DagVersion.write_dag() was used (which commits
641+ immediately), because the DagVersion would be persisted even though
642+ the rest of the transaction failed.
643+ """
644+ from airflow .models .dagcode import DagCode
645+
646+ with dag_maker ("test_atomicity_dag" ):
647+ EmptyOperator (task_id = "task1" )
648+
649+ dag = dag_maker .dag
650+ initial_version_count = session .query (DagVersion ).filter (DagVersion .dag_id == dag .dag_id ).count ()
651+ assert initial_version_count == 1 , "Should have one DagVersion after initial write"
652+ dag_maker .create_dagrun () # ensure the second dag version is created
653+
654+ EmptyOperator (task_id = "task2" , dag = dag )
655+ modified_lazy_dag = LazyDeserializedDAG .from_dag (dag )
656+
657+ # Mock DagCode.write_code to raise an exception
658+ with mock .patch .object (
659+ DagCode , "write_code" , side_effect = RuntimeError ("Simulated DagCode.write_code failure" )
660+ ):
661+ with pytest .raises (RuntimeError , match = "Simulated DagCode.write_code failure" ):
662+ SDM .write_dag (
663+ dag = modified_lazy_dag ,
664+ bundle_name = "testing" ,
665+ bundle_version = None ,
666+ session = session ,
667+ )
668+ session .rollback ()
669+
670+ # Verify that no new DagVersion was committed
671+ # Use a fresh session to ensure we're reading from committed data
672+ with create_session () as fresh_session :
673+ final_version_count = (
674+ fresh_session .query (DagVersion ).filter (DagVersion .dag_id == dag .dag_id ).count ()
675+ )
676+ assert final_version_count == initial_version_count , (
677+ "DagVersion should not be committed when DagCode.write_code fails"
678+ )
679+
680+ sdag = SDM .get (dag .dag_id , session = fresh_session )
681+ assert sdag is not None , "Original SerializedDagModel should still exist"
682+ assert len (sdag .dag .task_dict ) == 1 , (
683+ "SerializedDagModel should not be updated when write fails"
684+ )
0 commit comments