diff --git a/composer/tools/composer_dags.py b/composer/tools/composer_dags.py index a5306fa52d5..ed08b0e429b 100644 --- a/composer/tools/composer_dags.py +++ b/composer/tools/composer_dags.py @@ -111,6 +111,22 @@ def pause_dag( logger.info("Unable to pause DAG %s", dag_id) logger.info(command_output[1]) + @staticmethod + def pause_all_dags( + project_name: str, + environment: str, + location: str, + sdk_endpoint: str, + ) -> None: + """Pause all the DAGs in the given environment.""" + command = ( + f"CLOUDSDK_API_ENDPOINT_OVERRIDES_COMPOSER={sdk_endpoint} gcloud composer environments" + f" run {environment} --project={project_name} --location={location}" + f" dags pause -- \"^(?!airflow_monitoring$).*\" --treat-dag-id-as-regex -y" + ) + command_output = DAG._run_shell_command_locally_once(command=command) + logger.info(command_output[1]) + @staticmethod def unpause_dag( project_name: str, @@ -136,6 +152,22 @@ def unpause_dag( logger.info("Unable to Unpause DAG %s", dag_id) logger.info(command_output[1]) + @staticmethod + def unpause_all_dags( + project_name: str, + environment: str, + location: str, + sdk_endpoint: str, + ) -> None: + """UnPause all the DAGs in the given environment.""" + command = ( + f"CLOUDSDK_API_ENDPOINT_OVERRIDES_COMPOSER={sdk_endpoint} gcloud composer environments" + f" run {environment} --project={project_name} --location={location}" + f" dags unpause -- \".*\" --treat-dag-id-as-regex -y" + ) + command_output = DAG._run_shell_command_locally_once(command=command) + logger.info(command_output[1]) + @staticmethod def describe_environment( project_name: str, environment: str, location: str, sdk_endpoint: str @@ -151,25 +183,7 @@ def describe_environment( logger.info("Environment Info:\n %s", environment_json["name"]) return environment_json - -def main( - project_name: str, environment: str, location: str, operation: str, sdk_endpoint=str -) -> int: - logger.info("DAG Pause/UnPause Script for Cloud Composer") - environment_info = DAG.describe_environment( - project_name=project_name, - environment=environment, - location=location, - sdk_endpoint=sdk_endpoint, - ) - versions = DAG.COMPOSER_AF_VERSION_RE.match( - environment_info["config"]["softwareConfig"]["imageVersion"] - ).groups() - logger.info( - "Image version: %s", - environment_info["config"]["softwareConfig"]["imageVersion"], - ) - airflow_version = (int(versions[3]), int(versions[4]), int(versions[5])) +def legacy_operations(project_name: str, environment: str, location: str, sdk_endpoint: str, airflow_version: tuple[int, int, int], operation: str) -> None: list_of_dags = DAG.get_list_of_dags( project_name=project_name, environment=environment, @@ -201,6 +215,45 @@ def main( dag_id=dag, airflow_version=airflow_version, ) + +def modern_operations(project_name: str, environment: str, location: str, sdk_endpoint: str, operation: str) -> None: + if operation == "pause": + DAG.pause_all_dags( + project_name=project_name, + environment=environment, + location=location, + sdk_endpoint=sdk_endpoint, + ) + else: + DAG.unpause_all_dags( + project_name=project_name, + environment=environment, + location=location, + sdk_endpoint=sdk_endpoint, + ) + +def main( + project_name: str, environment: str, location: str, operation: str, sdk_endpoint: str +) -> int: + logger.info("DAG Pause/UnPause Script for Cloud Composer") + environment_info = DAG.describe_environment( + project_name=project_name, + environment=environment, + location=location, + sdk_endpoint=sdk_endpoint, + ) + versions = DAG.COMPOSER_AF_VERSION_RE.match( + environment_info["config"]["softwareConfig"]["imageVersion"] + ).groups() + logger.info( + "Image version: %s", + environment_info["config"]["softwareConfig"]["imageVersion"], + ) + airflow_version = (int(versions[3]), int(versions[4]), int(versions[5])) + if airflow_version < (2, 9, 0): + legacy_operations(project_name, environment, location, sdk_endpoint, airflow_version, operation) + else: + modern_operations(project_name, environment, location, sdk_endpoint, operation) return 0