Skip to content

Commit b740827

Browse files
Only defer EmrCreateJobFlowOperator when wait_policy is set (apache#56077)
* fixing emr operator deferral logic * fixed failing test ruff lint
1 parent e1bea44 commit b740827

File tree

2 files changed

+38
-23
lines changed

2 files changed

+38
-23
lines changed

providers/amazon/src/airflow/providers/amazon/aws/operators/emr.py

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -748,30 +748,32 @@ def execute(self, context: Context) -> str | None:
748748
job_flow_id=self._job_flow_id,
749749
log_uri=get_log_uri(emr_client=self.hook.conn, job_flow_id=self._job_flow_id),
750750
)
751-
if self.deferrable:
752-
self.defer(
753-
trigger=EmrCreateJobFlowTrigger(
754-
job_flow_id=self._job_flow_id,
755-
aws_conn_id=self.aws_conn_id,
756-
waiter_delay=self.waiter_delay,
757-
waiter_max_attempts=self.waiter_max_attempts,
758-
),
759-
method_name="execute_complete",
760-
# timeout is set to ensure that if a trigger dies, the timeout does not restart
761-
# 60 seconds is added to allow the trigger to exit gracefully (i.e. yield TriggerEvent)
762-
timeout=timedelta(seconds=self.waiter_max_attempts * self.waiter_delay + 60),
763-
)
764751
if self.wait_policy:
765752
waiter_name = WAITER_POLICY_NAME_MAPPING[self.wait_policy]
766-
self.hook.get_waiter(waiter_name).wait(
767-
ClusterId=self._job_flow_id,
768-
WaiterConfig=prune_dict(
769-
{
770-
"Delay": self.waiter_delay,
771-
"MaxAttempts": self.waiter_max_attempts,
772-
}
773-
),
774-
)
753+
754+
if self.deferrable:
755+
self.defer(
756+
trigger=EmrCreateJobFlowTrigger(
757+
job_flow_id=self._job_flow_id,
758+
aws_conn_id=self.aws_conn_id,
759+
waiter_delay=self.waiter_delay,
760+
waiter_max_attempts=self.waiter_max_attempts,
761+
),
762+
method_name="execute_complete",
763+
# timeout is set to ensure that if a trigger dies, the timeout does not restart
764+
# 60 seconds is added to allow the trigger to exit gracefully (i.e. yield TriggerEvent)
765+
timeout=timedelta(seconds=self.waiter_max_attempts * self.waiter_delay + 60),
766+
)
767+
else:
768+
self.hook.get_waiter(waiter_name).wait(
769+
ClusterId=self._job_flow_id,
770+
WaiterConfig=prune_dict(
771+
{
772+
"Delay": self.waiter_delay,
773+
"MaxAttempts": self.waiter_max_attempts,
774+
}
775+
),
776+
)
775777
return self._job_flow_id
776778

777779
def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str:

providers/amazon/tests/unit/amazon/aws/operators/test_emr_create_job_flow.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,17 +238,30 @@ def test_execute_with_wait_policy(self, mock_waiter, _, mocked_hook_client, wait
238238
def test_create_job_flow_deferrable(self, mocked_hook_client):
239239
"""
240240
Test to make sure that the operator raises a TaskDeferred exception
241-
if run in deferrable mode.
241+
if run in deferrable mode and wait_policy is set.
242242
"""
243243
mocked_hook_client.run_job_flow.return_value = RUN_JOB_FLOW_SUCCESS_RETURN
244244

245245
self.operator.deferrable = True
246+
self.operator.wait_policy = WaitPolicy.WAIT_FOR_COMPLETION
246247
with pytest.raises(TaskDeferred) as exc:
247248
self.operator.execute(self.mock_context)
248249

249250
assert isinstance(exc.value.trigger, EmrCreateJobFlowTrigger), (
250251
"Trigger is not a EmrCreateJobFlowTrigger"
251252
)
252253

254+
def test_create_job_flow_deferrable_no_wait(self, mocked_hook_client):
255+
"""
256+
Test to make sure that the operator does NOT raise a TaskDeferred exception
257+
if run in deferrable mode but wait_policy is not set.
258+
"""
259+
mocked_hook_client.run_job_flow.return_value = RUN_JOB_FLOW_SUCCESS_RETURN
260+
261+
self.operator.deferrable = True
262+
# wait_policy is None by default
263+
result = self.operator.execute(self.mock_context)
264+
assert result == JOB_FLOW_ID
265+
253266
def test_template_fields(self):
254267
validate_template_fields(self.operator)

0 commit comments

Comments
 (0)