|
16 | 16 | # under the License. |
17 | 17 | from __future__ import annotations |
18 | 18 |
|
| 19 | +import json |
19 | 20 | import os |
20 | 21 | import re |
21 | 22 | import subprocess |
@@ -58,6 +59,9 @@ class BaseK8STest: |
58 | 59 |
|
59 | 60 | @pytest.fixture(autouse=True) |
60 | 61 | def base_tests_setup(self, request): |
| 62 | + self.set_api_server_base_url_config() |
| 63 | + self.rollout_restart_deployment("airflow-api-server") |
| 64 | + self.ensure_deployment_health("airflow-api-server") |
61 | 65 | # Replacement for unittests.TestCase.id() |
62 | 66 | self.test_id = f"{request.node.cls.__name__}_{request.node.name}" |
63 | 67 | self.session = self._get_session_with_retries() |
@@ -204,6 +208,63 @@ def ensure_deployment_health(deployment_name: str, namespace: str = "airflow"): |
204 | 208 | ).decode() |
205 | 209 | assert "successfully rolled out" in deployment_rollout_status |
206 | 210 |
|
| 211 | + @staticmethod |
| 212 | + def rollout_restart_deployment(deployment_name: str, namespace: str = "airflow"): |
| 213 | + """Rollout restart the deployment.""" |
| 214 | + check_call(["kubectl", "rollout", "restart", "deployment", deployment_name, "-n", namespace]) |
| 215 | + |
| 216 | + def _parse_airflow_cfg_as_dict(self, airflow_cfg: str) -> dict[str, dict[str, str]]: |
| 217 | + """Parse the airflow.cfg file as a dictionary.""" |
| 218 | + parsed_airflow_cfg: dict[str, dict[str, str]] = {} |
| 219 | + for line in airflow_cfg.splitlines(): |
| 220 | + if line.startswith("["): |
| 221 | + section = line[1:-1] |
| 222 | + parsed_airflow_cfg[section] = {} |
| 223 | + elif "=" in line: |
| 224 | + key, value = line.split("=", 1) |
| 225 | + parsed_airflow_cfg[section][key.strip()] = value.strip() |
| 226 | + return parsed_airflow_cfg |
| 227 | + |
| 228 | + def _parse_airflow_cfg_dict_as_escaped_toml(self, airflow_cfg_dict: dict) -> str: |
| 229 | + """Parse the airflow.cfg dictionary as a toml string.""" |
| 230 | + airflow_cfg_str = "" |
| 231 | + for section, section_dict in airflow_cfg_dict.items(): |
| 232 | + airflow_cfg_str += f"[{section}]\n" |
| 233 | + for key, value in section_dict.items(): |
| 234 | + airflow_cfg_str += f"{key} = {value}\n" |
| 235 | + airflow_cfg_str += "\n" |
| 236 | + # escape newlines and double quotes |
| 237 | + return airflow_cfg_str.replace("\n", "\\n").replace('"', '\\"') |
| 238 | + |
| 239 | + def set_api_server_base_url_config(self): |
| 240 | + """Set [api/base_url] with `f"http://{KUBERNETES_HOST_PORT}"` as env in k8s configmap.""" |
| 241 | + configmap_name = "airflow-config" |
| 242 | + configmap_key = "airflow.cfg" |
| 243 | + original_configmap_json_str = check_output( |
| 244 | + ["kubectl", "get", "configmap", configmap_name, "-n", "airflow", "-o", "json"] |
| 245 | + ).decode() |
| 246 | + original_config_map = json.loads(original_configmap_json_str) |
| 247 | + original_airflow_cfg = original_config_map["data"][configmap_key] |
| 248 | + # set [api/base_url] with `f"http://{KUBERNETES_HOST_PORT}"` in airflow.cfg |
| 249 | + # The airflow.cfg is toml format, so we need to convert it to json |
| 250 | + airflow_cfg_dict = self._parse_airflow_cfg_as_dict(original_airflow_cfg) |
| 251 | + airflow_cfg_dict["api"]["base_url"] = f"http://{KUBERNETES_HOST_PORT}" |
| 252 | + # update the configmap with the new airflow.cfg |
| 253 | + check_call( |
| 254 | + [ |
| 255 | + "kubectl", |
| 256 | + "patch", |
| 257 | + "configmap", |
| 258 | + configmap_name, |
| 259 | + "-n", |
| 260 | + "airflow", |
| 261 | + "--type", |
| 262 | + "merge", |
| 263 | + "-p", |
| 264 | + f'{{"data": {{"{configmap_key}": "{self._parse_airflow_cfg_dict_as_escaped_toml(airflow_cfg_dict)}"}}}}', |
| 265 | + ] |
| 266 | + ) |
| 267 | + |
207 | 268 | def ensure_dag_expected_state(self, host, logical_date, dag_id, expected_final_state, timeout): |
208 | 269 | tries = 0 |
209 | 270 | state = "" |
|
0 commit comments