Skip to content

Commit 96e510a

Browse files
committed
feat: Enhance unit tests for scheduling_helper with improved mock data and validation
1 parent 2c265a1 commit 96e510a

File tree

2 files changed

+99
-73
lines changed

2 files changed

+99
-73
lines changed

dags/common/scheduling_helper/scheduling_helper.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
"""Helper module for scheduling DAGs across clusters."""
16+
1617
import datetime as dt
1718
import enum
1819
from typing import TypeAlias

dags/common/scheduling_helper/scheduling_helper_test.py

Lines changed: 98 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -1,100 +1,125 @@
1-
# # Copyright 2025 Google LLC
2-
# #
3-
# # Licensed under the Apache License, Version 2.0 (the "License");
4-
# # you may not use this file except in compliance with the License.
5-
# # You may obtain a copy of the License at
6-
# #
7-
# # http://www.apache.org/licenses/LICENSE-2.0
8-
# #
9-
# # Unless required by applicable law or agreed to in writing, software
10-
# # distributed under the License is distributed on an "AS IS" BASIS,
11-
# # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12-
# # See the License for the specific language governing permissions and
13-
# # limitations under the License.
14-
"""The test file of scheduling helper using absltest."""
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
1514

16-
import datetime as dt
17-
from absl.testing import absltest
18-
from absl.testing import parameterized
19-
from airflow.models import DagBag
15+
"""Unit tests for scheduling_helper.py."""
2016

17+
import datetime as dt
18+
from unittest.mock import patch
19+
from absl.testing import absltest, parameterized
2120
from dags.common.scheduling_helper import scheduling_helper
2221

2322

24-
class TestSchedulingHelper(parameterized.TestCase):
25-
"""Test cases for the SchedulingHelper class logic."""
23+
class TestSchedulingHelperBase(parameterized.TestCase):
24+
"""Base class for SchedulingHelper tests with shared mock data."""
2625

2726
def setUp(self):
2827
super().setUp()
29-
self.dag_folder = "dags/tpu_observability"
30-
# Mock data to simulate the stacking logic
31-
# Offset calculation: Start = 08:00 + Sum(Previous Timeouts + 15m Margin)
32-
self.fake_registered_dags = {
33-
"fake_cluster": {
34-
"dag_1": dt.timedelta(minutes=30), # Start: 08:00
35-
"dag_2": dt.timedelta(minutes=30), # Start: 08:00 + 30 + 15 = 08:45
36-
"dag_3": dt.timedelta(minutes=60), # Start: 08:45 + 30 + 15 = 09:30
28+
# Mock data with non-round numbers to ensure precise calculation
29+
self.mock_registry = {
30+
"cluster_a": {
31+
"dag_1": dt.timedelta(minutes=12), # Start: 08:00
32+
"dag_2": dt.timedelta(
33+
minutes=33
34+
), # Start: 08:00 + 12m + 15m = 08:27
35+
"dag_3": dt.timedelta(
36+
seconds=45
37+
), # Start: 08:27 + 33m + 15m = 09:15
38+
"dag_4": dt.timedelta(
39+
minutes=20
40+
), # Start: 09:15 + 45s + 15m = 09:35
41+
"dag_5": dt.timedelta(
42+
minutes=10
43+
), # Start: 09:35 + 20m + 15m = 09:70 (10:10)
44+
"dag_6": dt.timedelta(
45+
minutes=5
46+
), # Start: 10:10 + 10m + 15m = 10:35
3747
},
38-
"overtime_cluster": {
39-
"extreme_dag": dt.timedelta(hours=25),
48+
"cluster_b": {
49+
"dag_x": dt.timedelta(minutes=5),
50+
"dag_y": dt.timedelta(minutes=10),
4051
},
4152
}
42-
# Patch the global REGISTERED_DAGS in the module
43-
self.patcher = absltest.mock.patch(
44-
"dags.common.scheduling_helper.scheduling_helper.REGISTERED_DAGS",
45-
self.fake_registered_dags,
46-
)
47-
self.patcher.start()
4853

49-
def tearDown(self):
50-
self.patcher.stop()
51-
super().tearDown()
5254

53-
# --- Unit Tests ---
55+
class TestSchedulingLogic(TestSchedulingHelperBase):
56+
"""Validates the cron string generation and stacking logic."""
5457

55-
def test_get_dag_timeout_is_correct(self):
56-
"""Verifies that get_dag_timeout retrieves the correct timedelta."""
57-
timeout = scheduling_helper.get_dag_timeout("dag_2")
58-
self.assertEqual(timeout, dt.timedelta(minutes=30))
58+
@patch("dags.common.scheduling_helper.scheduling_helper.REGISTERED_DAGS")
59+
def test_alignment_with_anchor(self, mock_registered):
60+
mock_registered.items.return_value = self.mock_registry.items()
61+
# The first DAG should always align with DEFAULT_ANCHOR (08:00 UTC)
62+
schedule = scheduling_helper.SchedulingHelper.arrange_schedule_time("dag_1")
63+
self.assertEqual(schedule, "0 8 * * *")
5964

60-
def test_arrange_schedule_time_logic(self):
61-
"""Tests the stacking logic (Anchor + Offset + Margin)."""
62-
# 1st DAG should be at the anchor (08:00)
63-
self.assertEqual(
64-
scheduling_helper.SchedulingHelper.arrange_schedule_time("dag_1"),
65-
"0 8 * * *",
66-
)
67-
# 2nd DAG = 08:00 + 30m (timeout) + 15m (margin) = 08:45
68-
self.assertEqual(
69-
scheduling_helper.SchedulingHelper.arrange_schedule_time("dag_2"),
70-
"45 8 * * *",
71-
)
72-
# 3rd DAG = 08:45 + 30m (timeout) + 15m (margin) = 09:30
73-
self.assertEqual(
74-
scheduling_helper.SchedulingHelper.arrange_schedule_time("dag_3"),
75-
"30 9 * * *",
76-
)
65+
@patch("dags.common.scheduling_helper.scheduling_helper.REGISTERED_DAGS")
66+
def test_complex_calculation(self, mock_registered):
67+
mock_registered.items.return_value = self.mock_registry.items()
68+
# Testing the 'stacking' effect with non-standard durations
69+
schedule = scheduling_helper.SchedulingHelper.arrange_schedule_time("dag_2")
70+
self.assertEqual(schedule, "27 8 * * *")
7771

78-
def test_day_of_week_options(self):
79-
"""Verifies that DayOfWeek enum correctly applies to the Cron string."""
80-
dag_id = "dag_1"
81-
# Weekend mode
72+
@parameterized.named_parameters(
73+
("all", scheduling_helper.DayOfWeek.ALL, "*"),
74+
("weekday", scheduling_helper.DayOfWeek.WEEK_DAY, "1-5"),
75+
("weekend", scheduling_helper.DayOfWeek.WEEKEND, "0,6"),
76+
)
77+
@patch("dags.common.scheduling_helper.scheduling_helper.REGISTERED_DAGS")
78+
def test_day_of_week_options(
79+
self, day_enum, expected_suffix, mock_registered
80+
):
81+
mock_registered.items.return_value = self.mock_registry.items()
8282
schedule = scheduling_helper.SchedulingHelper.arrange_schedule_time(
83-
dag_id, scheduling_helper.DayOfWeek.WEEKEND
83+
"dag_1", day_of_week=day_enum
8484
)
85-
self.assertEqual(schedule, "0 8 * * 0,6")
85+
self.assertTrue(schedule.endswith(expected_suffix))
8686

87-
# --- Error Handling Tests ---
8887

89-
def test_nonexist_dag(self):
90-
"""Tests that a ValueError is raised for unregistered DAGs."""
88+
class TestErrorHandling(TestSchedulingHelperBase):
89+
"""Validates boundary conditions and registration checks."""
90+
91+
@patch("dags.common.scheduling_helper.scheduling_helper.REGISTERED_DAGS")
92+
def test_unregistered_dag(self, mock_registered):
93+
mock_registered.items.return_value = self.mock_registry.items()
9194
with self.assertRaisesRegex(ValueError, "is not registered"):
9295
scheduling_helper.SchedulingHelper.arrange_schedule_time("ghost_dag")
9396

94-
def test_overtime_error(self):
95-
"""Tests that schedules exceeding 24 hours trigger a ValueError."""
97+
@patch("dags.common.scheduling_helper.scheduling_helper.REGISTERED_DAGS")
98+
def test_24hours_window_single_dag(self, mock_registered):
99+
mock_registered.items.return_value = {
100+
"c1": {"huge_dag": dt.timedelta(hours=25)}
101+
}.items()
102+
with self.assertRaisesRegex(ValueError, "Schedule exceeds 24h window"):
103+
scheduling_helper.SchedulingHelper.arrange_schedule_time("huge_dag")
104+
105+
@patch("dags.common.scheduling_helper.scheduling_helper.REGISTERED_DAGS")
106+
def test_24hours_window_cumulative(self, mock_registered):
107+
# 5 DAGs @ 5 hours each = 25 hours. The 6th DAG should trigger the error.
108+
long_dags = {f"d{i}": dt.timedelta(hours=5) for i in range(6)}
109+
mock_registered.items.return_value = {"c1": long_dags}.items()
96110
with self.assertRaisesRegex(ValueError, "Schedule exceeds 24h window"):
97-
scheduling_helper.SchedulingHelper.arrange_schedule_time("extreme_dag")
111+
scheduling_helper.SchedulingHelper.arrange_schedule_time("d5")
112+
113+
114+
class TestFormatConsistency(TestSchedulingHelperBase):
115+
"""Ensures output is valid and deterministic."""
116+
117+
@patch("dags.common.scheduling_helper.scheduling_helper.REGISTERED_DAGS")
118+
def test_output_is_valid_cron(self, mock_registered):
119+
mock_registered.items.return_value = self.mock_registry.items()
120+
cron_pattern = r"^\d{1,2} \d{1,2} \* \* (\*|1-5|0,6)$"
121+
res = scheduling_helper.SchedulingHelper.arrange_schedule_time("dag_1")
122+
self.assertRegex(res, cron_pattern)
98123

99124

100125
if __name__ == "__main__":

0 commit comments

Comments
 (0)