|
1 | | -# # Copyright 2025 Google LLC |
2 | | -# # |
3 | | -# # Licensed under the Apache License, Version 2.0 (the "License"); |
4 | | -# # you may not use this file except in compliance with the License. |
5 | | -# # You may obtain a copy of the License at |
6 | | -# # |
7 | | -# # http://www.apache.org/licenses/LICENSE-2.0 |
8 | | -# # |
9 | | -# # Unless required by applicable law or agreed to in writing, software |
10 | | -# # distributed under the License is distributed on an "AS IS" BASIS, |
11 | | -# # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | | -# # See the License for the specific language governing permissions and |
13 | | -# # limitations under the License. |
14 | | -"""The test file of scheduling helper using absltest.""" |
| 1 | +# Copyright 2025 Google LLC |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
15 | 14 |
|
16 | | -import datetime as dt |
17 | | -from absl.testing import absltest |
18 | | -from absl.testing import parameterized |
19 | | -from airflow.models import DagBag |
| 15 | +"""Unit tests for scheduling_helper.py.""" |
20 | 16 |
|
| 17 | +import datetime as dt |
| 18 | +from unittest.mock import patch |
| 19 | +from absl.testing import absltest, parameterized |
21 | 20 | from dags.common.scheduling_helper import scheduling_helper |
22 | 21 |
|
23 | 22 |
|
24 | | -class TestSchedulingHelper(parameterized.TestCase): |
25 | | - """Test cases for the SchedulingHelper class logic.""" |
| 23 | +class TestSchedulingHelperBase(parameterized.TestCase): |
| 24 | + """Base class for SchedulingHelper tests with shared mock data.""" |
26 | 25 |
|
27 | 26 | def setUp(self): |
28 | 27 | super().setUp() |
29 | | - self.dag_folder = "dags/tpu_observability" |
30 | | - # Mock data to simulate the stacking logic |
31 | | - # Offset calculation: Start = 08:00 + Sum(Previous Timeouts + 15m Margin) |
32 | | - self.fake_registered_dags = { |
33 | | - "fake_cluster": { |
34 | | - "dag_1": dt.timedelta(minutes=30), # Start: 08:00 |
35 | | - "dag_2": dt.timedelta(minutes=30), # Start: 08:00 + 30 + 15 = 08:45 |
36 | | - "dag_3": dt.timedelta(minutes=60), # Start: 08:45 + 30 + 15 = 09:30 |
| 28 | + # Mock data with non-round numbers to ensure precise calculation |
| 29 | + self.mock_registry = { |
| 30 | + "cluster_a": { |
| 31 | + "dag_1": dt.timedelta(minutes=12), # Start: 08:00 |
| 32 | + "dag_2": dt.timedelta( |
| 33 | + minutes=33 |
| 34 | + ), # Start: 08:00 + 12m + 15m = 08:27 |
| 35 | + "dag_3": dt.timedelta( |
| 36 | + seconds=45 |
| 37 | + ), # Start: 08:27 + 33m + 15m = 09:15 |
| 38 | + "dag_4": dt.timedelta( |
| 39 | + minutes=20 |
| 40 | + ), # Start: 09:15 + 45s + 15m = 09:35 |
| 41 | + "dag_5": dt.timedelta( |
| 42 | + minutes=10 |
| 43 | + ), # Start: 09:35 + 20m + 15m = 09:70 (10:10) |
| 44 | + "dag_6": dt.timedelta( |
| 45 | + minutes=5 |
| 46 | + ), # Start: 10:10 + 10m + 15m = 10:35 |
37 | 47 | }, |
38 | | - "overtime_cluster": { |
39 | | - "extreme_dag": dt.timedelta(hours=25), |
| 48 | + "cluster_b": { |
| 49 | + "dag_x": dt.timedelta(minutes=5), |
| 50 | + "dag_y": dt.timedelta(minutes=10), |
40 | 51 | }, |
41 | 52 | } |
42 | | - # Patch the global REGISTERED_DAGS in the module |
43 | | - self.patcher = absltest.mock.patch( |
44 | | - "dags.common.scheduling_helper.scheduling_helper.REGISTERED_DAGS", |
45 | | - self.fake_registered_dags, |
46 | | - ) |
47 | | - self.patcher.start() |
48 | 53 |
|
49 | | - def tearDown(self): |
50 | | - self.patcher.stop() |
51 | | - super().tearDown() |
52 | 54 |
|
53 | | - # --- Unit Tests --- |
| 55 | +class TestSchedulingLogic(TestSchedulingHelperBase): |
| 56 | + """Validates the cron string generation and stacking logic.""" |
54 | 57 |
|
55 | | - def test_get_dag_timeout_is_correct(self): |
56 | | - """Verifies that get_dag_timeout retrieves the correct timedelta.""" |
57 | | - timeout = scheduling_helper.get_dag_timeout("dag_2") |
58 | | - self.assertEqual(timeout, dt.timedelta(minutes=30)) |
| 58 | + @patch("dags.common.scheduling_helper.scheduling_helper.REGISTERED_DAGS") |
| 59 | + def test_alignment_with_anchor(self, mock_registered): |
| 60 | + mock_registered.items.return_value = self.mock_registry.items() |
| 61 | + # The first DAG should always align with DEFAULT_ANCHOR (08:00 UTC) |
| 62 | + schedule = scheduling_helper.SchedulingHelper.arrange_schedule_time("dag_1") |
| 63 | + self.assertEqual(schedule, "0 8 * * *") |
59 | 64 |
|
60 | | - def test_arrange_schedule_time_logic(self): |
61 | | - """Tests the stacking logic (Anchor + Offset + Margin).""" |
62 | | - # 1st DAG should be at the anchor (08:00) |
63 | | - self.assertEqual( |
64 | | - scheduling_helper.SchedulingHelper.arrange_schedule_time("dag_1"), |
65 | | - "0 8 * * *", |
66 | | - ) |
67 | | - # 2nd DAG = 08:00 + 30m (timeout) + 15m (margin) = 08:45 |
68 | | - self.assertEqual( |
69 | | - scheduling_helper.SchedulingHelper.arrange_schedule_time("dag_2"), |
70 | | - "45 8 * * *", |
71 | | - ) |
72 | | - # 3rd DAG = 08:45 + 30m (timeout) + 15m (margin) = 09:30 |
73 | | - self.assertEqual( |
74 | | - scheduling_helper.SchedulingHelper.arrange_schedule_time("dag_3"), |
75 | | - "30 9 * * *", |
76 | | - ) |
| 65 | + @patch("dags.common.scheduling_helper.scheduling_helper.REGISTERED_DAGS") |
| 66 | + def test_complex_calculation(self, mock_registered): |
| 67 | + mock_registered.items.return_value = self.mock_registry.items() |
| 68 | + # Testing the 'stacking' effect with non-standard durations |
| 69 | + schedule = scheduling_helper.SchedulingHelper.arrange_schedule_time("dag_2") |
| 70 | + self.assertEqual(schedule, "27 8 * * *") |
77 | 71 |
|
78 | | - def test_day_of_week_options(self): |
79 | | - """Verifies that DayOfWeek enum correctly applies to the Cron string.""" |
80 | | - dag_id = "dag_1" |
81 | | - # Weekend mode |
| 72 | + @parameterized.named_parameters( |
| 73 | + ("all", scheduling_helper.DayOfWeek.ALL, "*"), |
| 74 | + ("weekday", scheduling_helper.DayOfWeek.WEEK_DAY, "1-5"), |
| 75 | + ("weekend", scheduling_helper.DayOfWeek.WEEKEND, "0,6"), |
| 76 | + ) |
| 77 | + @patch("dags.common.scheduling_helper.scheduling_helper.REGISTERED_DAGS") |
| 78 | + def test_day_of_week_options( |
| 79 | + self, day_enum, expected_suffix, mock_registered |
| 80 | + ): |
| 81 | + mock_registered.items.return_value = self.mock_registry.items() |
82 | 82 | schedule = scheduling_helper.SchedulingHelper.arrange_schedule_time( |
83 | | - dag_id, scheduling_helper.DayOfWeek.WEEKEND |
| 83 | + "dag_1", day_of_week=day_enum |
84 | 84 | ) |
85 | | - self.assertEqual(schedule, "0 8 * * 0,6") |
| 85 | + self.assertTrue(schedule.endswith(expected_suffix)) |
86 | 86 |
|
87 | | - # --- Error Handling Tests --- |
88 | 87 |
|
89 | | - def test_nonexist_dag(self): |
90 | | - """Tests that a ValueError is raised for unregistered DAGs.""" |
| 88 | +class TestErrorHandling(TestSchedulingHelperBase): |
| 89 | + """Validates boundary conditions and registration checks.""" |
| 90 | + |
| 91 | + @patch("dags.common.scheduling_helper.scheduling_helper.REGISTERED_DAGS") |
| 92 | + def test_unregistered_dag(self, mock_registered): |
| 93 | + mock_registered.items.return_value = self.mock_registry.items() |
91 | 94 | with self.assertRaisesRegex(ValueError, "is not registered"): |
92 | 95 | scheduling_helper.SchedulingHelper.arrange_schedule_time("ghost_dag") |
93 | 96 |
|
94 | | - def test_overtime_error(self): |
95 | | - """Tests that schedules exceeding 24 hours trigger a ValueError.""" |
| 97 | + @patch("dags.common.scheduling_helper.scheduling_helper.REGISTERED_DAGS") |
| 98 | + def test_24hours_window_single_dag(self, mock_registered): |
| 99 | + mock_registered.items.return_value = { |
| 100 | + "c1": {"huge_dag": dt.timedelta(hours=25)} |
| 101 | + }.items() |
| 102 | + with self.assertRaisesRegex(ValueError, "Schedule exceeds 24h window"): |
| 103 | + scheduling_helper.SchedulingHelper.arrange_schedule_time("huge_dag") |
| 104 | + |
| 105 | + @patch("dags.common.scheduling_helper.scheduling_helper.REGISTERED_DAGS") |
| 106 | + def test_24hours_window_cumulative(self, mock_registered): |
| 107 | + # 5 DAGs @ 5 hours each = 25 hours. The 6th DAG should trigger the error. |
| 108 | + long_dags = {f"d{i}": dt.timedelta(hours=5) for i in range(6)} |
| 109 | + mock_registered.items.return_value = {"c1": long_dags}.items() |
96 | 110 | with self.assertRaisesRegex(ValueError, "Schedule exceeds 24h window"): |
97 | | - scheduling_helper.SchedulingHelper.arrange_schedule_time("extreme_dag") |
| 111 | + scheduling_helper.SchedulingHelper.arrange_schedule_time("d5") |
| 112 | + |
| 113 | + |
| 114 | +class TestFormatConsistency(TestSchedulingHelperBase): |
| 115 | + """Ensures output is valid and deterministic.""" |
| 116 | + |
| 117 | + @patch("dags.common.scheduling_helper.scheduling_helper.REGISTERED_DAGS") |
| 118 | + def test_output_is_valid_cron(self, mock_registered): |
| 119 | + mock_registered.items.return_value = self.mock_registry.items() |
| 120 | + cron_pattern = r"^\d{1,2} \d{1,2} \* \* (\*|1-5|0,6)$" |
| 121 | + res = scheduling_helper.SchedulingHelper.arrange_schedule_time("dag_1") |
| 122 | + self.assertRegex(res, cron_pattern) |
98 | 123 |
|
99 | 124 |
|
100 | 125 | if __name__ == "__main__": |
|
0 commit comments