diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index aa0bd36de11ad..e858231534dad 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -264,6 +264,7 @@ comparator compat Compute compute +ComputeManagementClient Computenodes ComputeNodeState concat @@ -406,6 +407,9 @@ dbt dbutils ddl de +deallocate +deallocated +Debounce debuggability declaratively decommissioning diff --git a/providers/microsoft/azure/docs/index.rst b/providers/microsoft/azure/docs/index.rst index 8ba0c07f63738..39bf993fa5fd8 100644 --- a/providers/microsoft/azure/docs/index.rst +++ b/providers/microsoft/azure/docs/index.rst @@ -128,6 +128,7 @@ PIP package Version required ``azure-kusto-data`` ``>=4.1.0,!=4.6.0,!=5.0.0`` ``azure-mgmt-datafactory`` ``>=2.0.0`` ``azure-mgmt-containerregistry`` ``>=8.0.0`` +``azure-mgmt-compute`` ``>=33.0.0`` ``azure-mgmt-containerinstance`` ``>=10.1.0`` ``msgraph-core`` ``>=1.3.3`` ``msgraphfs`` ``>=0.3.0`` diff --git a/providers/microsoft/azure/docs/operators/compute.rst b/providers/microsoft/azure/docs/operators/compute.rst new file mode 100644 index 0000000000000..6a9728d09dbcf --- /dev/null +++ b/providers/microsoft/azure/docs/operators/compute.rst @@ -0,0 +1,108 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + +Azure Virtual Machine Operators +================================ + +Waiting Strategy +---------------- +The VM action operators support two patterns: + +* ``wait_for_completion=True`` (default): operator blocks until Azure operation finishes. +* ``wait_for_completion=False``: operator submits the operation and returns quickly. + +When you want to reduce worker slot usage for long VM state transitions, use +``wait_for_completion=False`` together with +:class:`~airflow.providers.microsoft.azure.sensors.compute.AzureVirtualMachineStateSensor` +in ``deferrable=True`` mode to move the waiting to the triggerer. + +.. _howto/operator:AzureVirtualMachineStartOperator: + +AzureVirtualMachineStartOperator +--------------------------------- +Use the +:class:`~airflow.providers.microsoft.azure.operators.compute.AzureVirtualMachineStartOperator` +to start an Azure Virtual Machine. + +Below is an example of using this operator to start a VM: + +.. exampleinclude:: /../tests/system/microsoft/azure/example_azure_compute.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_azure_vm_start] + :end-before: [END howto_operator_azure_vm_start] + + +.. _howto/operator:AzureVirtualMachineStopOperator: + +AzureVirtualMachineStopOperator +-------------------------------- +Use the +:class:`~airflow.providers.microsoft.azure.operators.compute.AzureVirtualMachineStopOperator` +to stop (deallocate) an Azure Virtual Machine. This releases compute resources and stops billing. + +Below is an example of using this operator to stop a VM: + +.. exampleinclude:: /../tests/system/microsoft/azure/example_azure_compute.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_azure_vm_stop] + :end-before: [END howto_operator_azure_vm_stop] + + +.. _howto/operator:AzureVirtualMachineRestartOperator: + +AzureVirtualMachineRestartOperator +----------------------------------- +Use the +:class:`~airflow.providers.microsoft.azure.operators.compute.AzureVirtualMachineRestartOperator` +to restart an Azure Virtual Machine. + +Below is an example of using this operator to restart a VM: + +.. exampleinclude:: /../tests/system/microsoft/azure/example_azure_compute.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_azure_vm_restart] + :end-before: [END howto_operator_azure_vm_restart] + + +.. _howto/sensor:AzureVirtualMachineStateSensor: + +AzureVirtualMachineStateSensor +------------------------------- +Use the +:class:`~airflow.providers.microsoft.azure.sensors.compute.AzureVirtualMachineStateSensor` +to poll a VM until it reaches a target power state (e.g., ``running``, ``deallocated``). +This sensor supports deferrable mode. + +Below is an example of using this sensor: + +.. exampleinclude:: /../tests/system/microsoft/azure/example_azure_compute.py + :language: python + :dedent: 4 + :start-after: [START howto_sensor_azure_vm_state] + :end-before: [END howto_sensor_azure_vm_state] + + +Reference +--------- + +For further information, look at: + +* `Azure Virtual Machines Documentation `__ diff --git a/providers/microsoft/azure/provider.yaml b/providers/microsoft/azure/provider.yaml index 8b89a7cbd0a0e..fdf39ed0d2c96 100644 --- a/providers/microsoft/azure/provider.yaml +++ b/providers/microsoft/azure/provider.yaml @@ -118,6 +118,12 @@ integrations: - /docs/apache-airflow-providers-microsoft-azure/operators/batch.rst logo: /docs/integration-logos/Microsoft-Azure-Batch.png tags: [azure] + - integration-name: Microsoft Azure Compute + external-doc-url: https://azure.microsoft.com/en-us/products/virtual-machines/ + how-to-guide: + - /docs/apache-airflow-providers-microsoft-azure/operators/compute.rst + logo: /docs/integration-logos/Microsoft-Azure.png + tags: [azure] - integration-name: Microsoft Azure Blob Storage external-doc-url: https://azure.microsoft.com/en-us/services/storage/blobs/ how-to-guide: @@ -191,6 +197,9 @@ integrations: tags: [azure] operators: + - integration-name: Microsoft Azure Compute + python-modules: + - airflow.providers.microsoft.azure.operators.compute - integration-name: Microsoft Azure Data Lake Storage python-modules: - airflow.providers.microsoft.azure.operators.adls @@ -226,6 +235,9 @@ operators: - airflow.providers.microsoft.azure.operators.powerbi sensors: + - integration-name: Microsoft Azure Compute + python-modules: + - airflow.providers.microsoft.azure.sensors.compute - integration-name: Microsoft Azure Cosmos DB python-modules: - airflow.providers.microsoft.azure.sensors.cosmos @@ -244,6 +256,9 @@ filesystems: - airflow.providers.microsoft.azure.fs.msgraph hooks: + - integration-name: Microsoft Azure Compute + python-modules: + - airflow.providers.microsoft.azure.hooks.compute - integration-name: Microsoft Azure Container Instances python-modules: - airflow.providers.microsoft.azure.hooks.container_volume @@ -290,6 +305,9 @@ hooks: - airflow.providers.microsoft.azure.hooks.powerbi triggers: + - integration-name: Microsoft Azure Compute + python-modules: + - airflow.providers.microsoft.azure.triggers.compute - integration-name: Microsoft Azure Data Factory python-modules: - airflow.providers.microsoft.azure.triggers.data_factory @@ -362,6 +380,36 @@ connection-types: label: Workload Identity Tenant ID schema: type: ["string", "null"] + - hook-class-name: airflow.providers.microsoft.azure.hooks.compute.AzureComputeHook + connection-type: azure_compute + ui-field-behaviour: + hidden-fields: ["schema", "port", "host"] + relabeling: + login: Client ID + password: Client Secret + placeholders: + extra: '{"key_path": "path to json file for auth", "key_json": "specifies json dict for auth"}' + login: client_id (token credentials auth) + password: secret (token credentials auth) + tenantId: tenantId (token credentials auth) + subscriptionId: subscriptionId (token credentials auth) + conn-fields: + tenantId: + label: Azure Tenant ID + schema: + type: ["string", "null"] + subscriptionId: + label: Azure Subscription ID + schema: + type: ["string", "null"] + managed_identity_client_id: + label: Managed Identity Client ID + schema: + type: ["string", "null"] + workload_identity_tenant_id: + label: Workload Identity Tenant ID + schema: + type: ["string", "null"] - hook-class-name: airflow.providers.microsoft.azure.hooks.adx.AzureDataExplorerHook connection-type: azure_data_explorer ui-field-behaviour: diff --git a/providers/microsoft/azure/pyproject.toml b/providers/microsoft/azure/pyproject.toml index 3cc4edf2c5376..37f5021c84cc2 100644 --- a/providers/microsoft/azure/pyproject.toml +++ b/providers/microsoft/azure/pyproject.toml @@ -81,6 +81,7 @@ dependencies = [ "azure-kusto-data>=4.1.0,!=4.6.0,!=5.0.0", "azure-mgmt-datafactory>=2.0.0", "azure-mgmt-containerregistry>=8.0.0", + "azure-mgmt-compute>=33.0.0", "azure-mgmt-containerinstance>=10.1.0", "msgraph-core>=1.3.3", "msgraphfs>=0.3.0", diff --git a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/get_provider_info.py b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/get_provider_info.py index 80cc07ce077de..a13e4e6229b1c 100644 --- a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/get_provider_info.py +++ b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/get_provider_info.py @@ -34,6 +34,13 @@ def get_provider_info(): "logo": "/docs/integration-logos/Microsoft-Azure-Batch.png", "tags": ["azure"], }, + { + "integration-name": "Microsoft Azure Compute", + "external-doc-url": "https://azure.microsoft.com/en-us/products/virtual-machines/", + "how-to-guide": ["/docs/apache-airflow-providers-microsoft-azure/operators/compute.rst"], + "logo": "/docs/integration-logos/Microsoft-Azure.png", + "tags": ["azure"], + }, { "integration-name": "Microsoft Azure Blob Storage", "external-doc-url": "https://azure.microsoft.com/en-us/services/storage/blobs/", @@ -135,6 +142,10 @@ def get_provider_info(): }, ], "operators": [ + { + "integration-name": "Microsoft Azure Compute", + "python-modules": ["airflow.providers.microsoft.azure.operators.compute"], + }, { "integration-name": "Microsoft Azure Data Lake Storage", "python-modules": ["airflow.providers.microsoft.azure.operators.adls"], @@ -181,6 +192,10 @@ def get_provider_info(): }, ], "sensors": [ + { + "integration-name": "Microsoft Azure Compute", + "python-modules": ["airflow.providers.microsoft.azure.sensors.compute"], + }, { "integration-name": "Microsoft Azure Cosmos DB", "python-modules": ["airflow.providers.microsoft.azure.sensors.cosmos"], @@ -203,6 +218,10 @@ def get_provider_info(): "airflow.providers.microsoft.azure.fs.msgraph", ], "hooks": [ + { + "integration-name": "Microsoft Azure Compute", + "python-modules": ["airflow.providers.microsoft.azure.hooks.compute"], + }, { "integration-name": "Microsoft Azure Container Instances", "python-modules": [ @@ -265,6 +284,10 @@ def get_provider_info(): }, ], "triggers": [ + { + "integration-name": "Microsoft Azure Compute", + "python-modules": ["airflow.providers.microsoft.azure.triggers.compute"], + }, { "integration-name": "Microsoft Azure Data Factory", "python-modules": ["airflow.providers.microsoft.azure.triggers.data_factory"], @@ -349,6 +372,36 @@ def get_provider_info(): }, }, }, + { + "hook-class-name": "airflow.providers.microsoft.azure.hooks.compute.AzureComputeHook", + "connection-type": "azure_compute", + "ui-field-behaviour": { + "hidden-fields": ["schema", "port", "host"], + "relabeling": {"login": "Client ID", "password": "Client Secret"}, + "placeholders": { + "extra": '{"key_path": "path to json file for auth", "key_json": "specifies json dict for auth"}', + "login": "client_id (token credentials auth)", + "password": "secret (token credentials auth)", + "tenantId": "tenantId (token credentials auth)", + "subscriptionId": "subscriptionId (token credentials auth)", + }, + }, + "conn-fields": { + "tenantId": {"label": "Azure Tenant ID", "schema": {"type": ["string", "null"]}}, + "subscriptionId": { + "label": "Azure Subscription ID", + "schema": {"type": ["string", "null"]}, + }, + "managed_identity_client_id": { + "label": "Managed Identity Client ID", + "schema": {"type": ["string", "null"]}, + }, + "workload_identity_tenant_id": { + "label": "Workload Identity Tenant ID", + "schema": {"type": ["string", "null"]}, + }, + }, + }, { "hook-class-name": "airflow.providers.microsoft.azure.hooks.adx.AzureDataExplorerHook", "connection-type": "azure_data_explorer", diff --git a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/compute.py b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/compute.py new file mode 100644 index 0000000000000..6d5564aeafefd --- /dev/null +++ b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/compute.py @@ -0,0 +1,237 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from functools import cached_property +from typing import Any, cast + +from azure.common.client_factory import get_client_from_auth_file, get_client_from_json_dict +from azure.identity import ClientSecretCredential, DefaultAzureCredential +from azure.identity.aio import ( + ClientSecretCredential as AsyncClientSecretCredential, + DefaultAzureCredential as AsyncDefaultAzureCredential, +) +from azure.mgmt.compute import ComputeManagementClient +from azure.mgmt.compute.aio import ComputeManagementClient as AsyncComputeManagementClient + +from airflow.providers.common.compat.connection import get_async_connection +from airflow.providers.microsoft.azure.hooks.base_azure import AzureBaseHook +from airflow.providers.microsoft.azure.utils import ( + get_async_default_azure_credential, + get_sync_default_azure_credential, +) + + +class AzureComputeHook(AzureBaseHook): + """ + A hook to interact with Azure Compute to manage Virtual Machines. + + :param azure_conn_id: :ref:`Azure connection id` of + a service principal which will be used to manage virtual machines. + """ + + conn_name_attr = "azure_conn_id" + default_conn_name = "azure_default" + conn_type = "azure_compute" + hook_name = "Azure Compute" + + def __init__(self, azure_conn_id: str = default_conn_name) -> None: + super().__init__(sdk_client=ComputeManagementClient, conn_id=azure_conn_id) + self._async_conn: AsyncComputeManagementClient | None = None + + @cached_property + def connection(self) -> ComputeManagementClient: + return self.get_conn() + + def get_conn(self) -> ComputeManagementClient: + """ + Authenticate the resource using the connection id passed during init. + + :return: the authenticated ComputeManagementClient. + """ + conn = self.get_connection(self.conn_id) + tenant = conn.extra_dejson.get("tenantId") + + key_path = conn.extra_dejson.get("key_path") + if key_path: + if not key_path.endswith(".json"): + raise ValueError("Unrecognised extension for key file.") + self.log.info("Getting connection using a JSON key file.") + return get_client_from_auth_file(client_class=self.sdk_client, auth_path=key_path) + + key_json = conn.extra_dejson.get("key_json") + if key_json: + self.log.info("Getting connection using a JSON config.") + return get_client_from_json_dict(client_class=self.sdk_client, config_dict=key_json) + + credential: ClientSecretCredential | DefaultAzureCredential + if all([conn.login, conn.password, tenant]): + self.log.info("Getting connection using specific credentials and subscription_id.") + credential = ClientSecretCredential( + client_id=cast("str", conn.login), + client_secret=cast("str", conn.password), + tenant_id=cast("str", tenant), + ) + else: + self.log.info("Using DefaultAzureCredential as credential") + managed_identity_client_id = conn.extra_dejson.get("managed_identity_client_id") + workload_identity_tenant_id = conn.extra_dejson.get("workload_identity_tenant_id") + credential = get_sync_default_azure_credential( + managed_identity_client_id=managed_identity_client_id, + workload_identity_tenant_id=workload_identity_tenant_id, + ) + + subscription_id = cast("str", conn.extra_dejson.get("subscriptionId")) + return ComputeManagementClient( + credential=credential, + subscription_id=subscription_id, + ) + + def start_instance( + self, resource_group_name: str, vm_name: str, wait_for_completion: bool = True + ) -> None: + """ + Start a virtual machine instance. + + :param resource_group_name: Name of the resource group. + :param vm_name: Name of the virtual machine. + :param wait_for_completion: Wait for the operation to complete. + """ + self.log.info("Starting VM %s in resource group %s", vm_name, resource_group_name) + poller = self.connection.virtual_machines.begin_start(resource_group_name, vm_name) + if wait_for_completion: + poller.result() + + def stop_instance(self, resource_group_name: str, vm_name: str, wait_for_completion: bool = True) -> None: + """ + Stop (deallocate) a virtual machine instance. + + Uses ``begin_deallocate`` to release compute resources and stop billing. + + :param resource_group_name: Name of the resource group. + :param vm_name: Name of the virtual machine. + :param wait_for_completion: Wait for the operation to complete. + """ + self.log.info("Stopping (deallocating) VM %s in resource group %s", vm_name, resource_group_name) + poller = self.connection.virtual_machines.begin_deallocate(resource_group_name, vm_name) + if wait_for_completion: + poller.result() + + def restart_instance( + self, resource_group_name: str, vm_name: str, wait_for_completion: bool = True + ) -> None: + """ + Restart a virtual machine instance. + + :param resource_group_name: Name of the resource group. + :param vm_name: Name of the virtual machine. + :param wait_for_completion: Wait for the operation to complete. + """ + self.log.info("Restarting VM %s in resource group %s", vm_name, resource_group_name) + poller = self.connection.virtual_machines.begin_restart(resource_group_name, vm_name) + if wait_for_completion: + poller.result() + + def get_power_state(self, resource_group_name: str, vm_name: str) -> str: + """ + Get the power state of a virtual machine. + + :param resource_group_name: Name of the resource group. + :param vm_name: Name of the virtual machine. + :return: Power state string, e.g. ``running``, ``deallocated``, ``stopped``. + """ + instance_view = self.connection.virtual_machines.instance_view(resource_group_name, vm_name) + for status in instance_view.statuses or []: + if status.code and status.code.startswith("PowerState/"): + return status.code.split("/", 1)[1] + return "unknown" + + def test_connection(self) -> tuple[bool, str]: + """Test the Azure Compute connection.""" + try: + next(self.connection.virtual_machines.list_all(), None) + except Exception as e: + return False, str(e) + return True, "Successfully connected to Azure Compute." + + # ------------------------------------------------------------------ + # Async interface (used by AzureVirtualMachineStateTrigger) + # ------------------------------------------------------------------ + + async def get_async_conn(self) -> AsyncComputeManagementClient: + """ + Return an authenticated async :class:`~azure.mgmt.compute.aio.ComputeManagementClient`. + + Supports service-principal (login/password + tenantId) and + DefaultAzureCredential auth. Legacy ``key_path`` / ``key_json`` + auth files are not supported in the async path. + """ + if self._async_conn is not None: + return self._async_conn + + conn = await get_async_connection(self.conn_id) + tenant = conn.extra_dejson.get("tenantId") + subscription_id = cast("str", conn.extra_dejson.get("subscriptionId")) + + credential: AsyncClientSecretCredential | AsyncDefaultAzureCredential + if conn.login and conn.password and tenant: + credential = AsyncClientSecretCredential( + client_id=conn.login, + client_secret=conn.password, + tenant_id=tenant, + ) + else: + managed_identity_client_id = conn.extra_dejson.get("managed_identity_client_id") + workload_identity_tenant_id = conn.extra_dejson.get("workload_identity_tenant_id") + credential = get_async_default_azure_credential( + managed_identity_client_id=managed_identity_client_id, + workload_identity_tenant_id=workload_identity_tenant_id, + ) + + self._async_conn = AsyncComputeManagementClient( + credential=credential, + subscription_id=subscription_id, + ) + return self._async_conn + + async def async_get_power_state(self, resource_group_name: str, vm_name: str) -> str: + """ + Get the power state of a virtual machine using the native async client. + + :param resource_group_name: Name of the resource group. + :param vm_name: Name of the virtual machine. + :return: Power state string, e.g. ``running``, ``deallocated``, ``stopped``. + """ + client = await self.get_async_conn() + instance_view = await client.virtual_machines.instance_view(resource_group_name, vm_name) + for status in instance_view.statuses or []: + if status.code and status.code.startswith("PowerState/"): + return status.code.split("/", 1)[1] + return "unknown" + + async def close(self) -> None: + """Close the async connection.""" + if self._async_conn is not None: + await self._async_conn.close() + self._async_conn = None + + async def __aenter__(self) -> AzureComputeHook: + return self + + async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + await self.close() diff --git a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/compute.py b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/compute.py new file mode 100644 index 0000000000000..b4bf06f17e03f --- /dev/null +++ b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/compute.py @@ -0,0 +1,137 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from abc import abstractmethod +from collections.abc import Sequence +from functools import cached_property +from typing import TYPE_CHECKING + +from airflow.providers.common.compat.sdk import BaseOperator +from airflow.providers.microsoft.azure.hooks.compute import AzureComputeHook + +if TYPE_CHECKING: + from airflow.sdk import Context + + +class BaseAzureVirtualMachineOperator(BaseOperator): + """ + Base operator for Azure Virtual Machine operations. + + :param resource_group_name: Name of the Azure resource group. + :param vm_name: Name of the virtual machine. + :param wait_for_completion: Wait for the operation to complete. Default True. + :param azure_conn_id: Azure connection id. + """ + + template_fields: Sequence[str] = ("resource_group_name", "vm_name") + ui_color = "#0078d4" + ui_fgcolor = "#ffffff" + + def __init__( + self, + *, + resource_group_name: str, + vm_name: str, + wait_for_completion: bool = True, + azure_conn_id: str = "azure_default", + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.resource_group_name = resource_group_name + self.vm_name = vm_name + self.wait_for_completion = wait_for_completion + self.azure_conn_id = azure_conn_id + + @cached_property + def hook(self) -> AzureComputeHook: + return AzureComputeHook(azure_conn_id=self.azure_conn_id) + + @abstractmethod + def execute(self, context: Context) -> None: ... + + +class AzureVirtualMachineStartOperator(BaseAzureVirtualMachineOperator): + """ + Start an Azure Virtual Machine. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:AzureVirtualMachineStartOperator` + + :param resource_group_name: Name of the Azure resource group. + :param vm_name: Name of the virtual machine. + :param wait_for_completion: Wait for the VM to reach 'running' state. Default True. + :param azure_conn_id: Azure connection id. + """ + + def execute(self, context: Context) -> None: + self.hook.start_instance( + resource_group_name=self.resource_group_name, + vm_name=self.vm_name, + wait_for_completion=self.wait_for_completion, + ) + self.log.info("VM %s started successfully.", self.vm_name) + + +class AzureVirtualMachineStopOperator(BaseAzureVirtualMachineOperator): + """ + Stop (deallocate) an Azure Virtual Machine. + + Uses ``begin_deallocate`` to release compute resources and stop billing. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:AzureVirtualMachineStopOperator` + + :param resource_group_name: Name of the Azure resource group. + :param vm_name: Name of the virtual machine. + :param wait_for_completion: Wait for the VM to reach 'deallocated' state. Default True. + :param azure_conn_id: Azure connection id. + """ + + def execute(self, context: Context) -> None: + self.hook.stop_instance( + resource_group_name=self.resource_group_name, + vm_name=self.vm_name, + wait_for_completion=self.wait_for_completion, + ) + self.log.info("VM %s stopped (deallocated) successfully.", self.vm_name) + + +class AzureVirtualMachineRestartOperator(BaseAzureVirtualMachineOperator): + """ + Restart an Azure Virtual Machine. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:AzureVirtualMachineRestartOperator` + + :param resource_group_name: Name of the Azure resource group. + :param vm_name: Name of the virtual machine. + :param wait_for_completion: Wait for the VM to reach 'running' state. Default True. + :param azure_conn_id: Azure connection id. + """ + + def execute(self, context: Context) -> None: + self.hook.restart_instance( + resource_group_name=self.resource_group_name, + vm_name=self.vm_name, + wait_for_completion=self.wait_for_completion, + ) + self.log.info("VM %s restarted successfully.", self.vm_name) diff --git a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/sensors/compute.py b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/sensors/compute.py new file mode 100644 index 0000000000000..1a9c05185dcb6 --- /dev/null +++ b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/sensors/compute.py @@ -0,0 +1,114 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from collections.abc import Sequence +from datetime import timedelta +from typing import TYPE_CHECKING + +from airflow.providers.common.compat.sdk import BaseSensorOperator, conf +from airflow.providers.microsoft.azure.hooks.compute import AzureComputeHook +from airflow.providers.microsoft.azure.triggers.compute import AzureVirtualMachineStateTrigger + +if TYPE_CHECKING: + from airflow.sdk import Context + + +class AzureVirtualMachineStateSensor(BaseSensorOperator): + """ + Poll an Azure Virtual Machine until it reaches a target power state. + + .. seealso:: + For more information on how to use this sensor, take a look at the guide: + :ref:`howto/sensor:AzureVirtualMachineStateSensor` + + :param resource_group_name: Name of the Azure resource group. + :param vm_name: Name of the virtual machine. + :param target_state: Desired power state, e.g. ``running``, ``deallocated``. + :param azure_conn_id: Azure connection id. + :param deferrable: If True, run in deferrable mode. + """ + + template_fields: Sequence[str] = ("resource_group_name", "vm_name", "target_state") + ui_color = "#0078d4" + ui_fgcolor = "#ffffff" + + VALID_STATES = frozenset({"running", "deallocated", "stopped", "starting", "deallocating"}) + + def __init__( + self, + *, + resource_group_name: str, + vm_name: str, + target_state: str, + azure_conn_id: str = "azure_default", + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), + **kwargs, + ) -> None: + if target_state not in self.VALID_STATES: + raise ValueError( + f"Invalid target_state: {target_state}. Must be one of {sorted(self.VALID_STATES)}" + ) + super().__init__(**kwargs) + self.resource_group_name = resource_group_name + self.vm_name = vm_name + self.target_state = target_state + self.azure_conn_id = azure_conn_id + self.deferrable = deferrable + + def poke(self, context: Context) -> bool: + hook = AzureComputeHook(azure_conn_id=self.azure_conn_id) + current_state = hook.get_power_state(self.resource_group_name, self.vm_name) + self.log.info("VM %s power state: %s", self.vm_name, current_state) + return current_state == self.target_state + + def execute(self, context: Context) -> None: + """ + Poll for the VM power state. + + In deferrable mode, the polling is deferred to the triggerer. Otherwise + the sensor waits synchronously. + """ + if not self.deferrable: + super().execute(context=context) + else: + if not self.poke(context=context): + self.defer( + timeout=timedelta(seconds=self.timeout), + trigger=AzureVirtualMachineStateTrigger( + resource_group_name=self.resource_group_name, + vm_name=self.vm_name, + target_state=self.target_state, + azure_conn_id=self.azure_conn_id, + poke_interval=self.poke_interval, + ), + method_name="execute_complete", + ) + + def execute_complete(self, context: Context, event: dict[str, str]) -> None: + """ + Handle callback from the trigger. + + Relies on trigger to throw an exception, otherwise it assumes execution was successful. + """ + if event: + if event["status"] == "error": + raise RuntimeError(event["message"]) + self.log.info(event["message"]) + else: + raise RuntimeError("Did not receive valid event from the triggerer") diff --git a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/triggers/compute.py b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/triggers/compute.py new file mode 100644 index 0000000000000..e1682af6b4bd7 --- /dev/null +++ b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/triggers/compute.py @@ -0,0 +1,91 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import asyncio +from collections.abc import AsyncIterator +from typing import Any + +from airflow.triggers.base import BaseTrigger, TriggerEvent + + +class AzureVirtualMachineStateTrigger(BaseTrigger): + """ + Poll the Azure VM power state and yield a TriggerEvent once it matches the target. + + Uses the native async Azure SDK client (``azure.mgmt.compute.aio``) so that + the triggerer event loop is never blocked. + + :param resource_group_name: Name of the Azure resource group. + :param vm_name: Name of the virtual machine. + :param target_state: Desired power state, e.g. ``running``, ``deallocated``. + :param azure_conn_id: Azure connection id. + :param poke_interval: Polling interval in seconds. + """ + + def __init__( + self, + resource_group_name: str, + vm_name: str, + target_state: str, + azure_conn_id: str = "azure_default", + poke_interval: float = 30.0, + ) -> None: + super().__init__() + self.resource_group_name = resource_group_name + self.vm_name = vm_name + self.target_state = target_state + self.azure_conn_id = azure_conn_id + self.poke_interval = poke_interval + + def serialize(self) -> tuple[str, dict[str, Any]]: + """Serialize AzureVirtualMachineStateTrigger arguments and classpath.""" + return ( + f"{self.__class__.__module__}.{self.__class__.__name__}", + { + "resource_group_name": self.resource_group_name, + "vm_name": self.vm_name, + "target_state": self.target_state, + "azure_conn_id": self.azure_conn_id, + "poke_interval": self.poke_interval, + }, + ) + + async def run(self) -> AsyncIterator[TriggerEvent]: + """Poll VM power state asynchronously until it matches the target state.""" + from airflow.providers.microsoft.azure.hooks.compute import AzureComputeHook + + try: + async with AzureComputeHook(azure_conn_id=self.azure_conn_id) as hook: + while True: + power_state = await hook.async_get_power_state(self.resource_group_name, self.vm_name) + if power_state == self.target_state: + message = f"VM {self.vm_name} reached state '{self.target_state}'." + yield TriggerEvent({"status": "success", "message": message}) + return + self.log.info( + "VM %s power state: %s. Waiting for %s. Sleeping for %s seconds.", + self.vm_name, + power_state, + self.target_state, + self.poke_interval, + ) + await asyncio.sleep(self.poke_interval) + except Exception as e: + yield TriggerEvent({"status": "error", "message": str(e)}) + return diff --git a/providers/microsoft/azure/tests/system/microsoft/azure/example_azure_compute.py b/providers/microsoft/azure/tests/system/microsoft/azure/example_azure_compute.py new file mode 100644 index 0000000000000..116056d4ea12a --- /dev/null +++ b/providers/microsoft/azure/tests/system/microsoft/azure/example_azure_compute.py @@ -0,0 +1,98 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import os +from datetime import datetime + +from airflow import DAG +from airflow.providers.microsoft.azure.operators.compute import ( + AzureVirtualMachineRestartOperator, + AzureVirtualMachineStartOperator, + AzureVirtualMachineStopOperator, +) +from airflow.providers.microsoft.azure.sensors.compute import ( + AzureVirtualMachineStateSensor, +) + +DAG_ID = "example_azure_compute" +RESOURCE_GROUP = os.environ.get("AZURE_RESOURCE_GROUP", "airflow-test-rg") +VM_NAME = os.environ.get("AZURE_VM_NAME", "airflow-test-vm") + +with DAG( + dag_id=DAG_ID, + schedule=None, + start_date=datetime(2026, 1, 1), + catchup=False, + tags=["example", "azure", "compute"], +) as dag: + # [START howto_operator_azure_vm_start] + start_vm = AzureVirtualMachineStartOperator( + task_id="start_vm", + resource_group_name=RESOURCE_GROUP, + vm_name=VM_NAME, + wait_for_completion=False, + ) + # [END howto_operator_azure_vm_start] + + # [START howto_sensor_azure_vm_state] + sense_running = AzureVirtualMachineStateSensor( + task_id="sense_running", + resource_group_name=RESOURCE_GROUP, + vm_name=VM_NAME, + target_state="running", + deferrable=True, + poke_interval=10, + timeout=300, + ) + # [END howto_sensor_azure_vm_state] + + # [START howto_operator_azure_vm_restart] + restart_vm = AzureVirtualMachineRestartOperator( + task_id="restart_vm", + resource_group_name=RESOURCE_GROUP, + vm_name=VM_NAME, + wait_for_completion=True, + ) + # [END howto_operator_azure_vm_restart] + + # [START howto_operator_azure_vm_stop] + stop_vm = AzureVirtualMachineStopOperator( + task_id="stop_vm", + resource_group_name=RESOURCE_GROUP, + vm_name=VM_NAME, + wait_for_completion=False, + ) + # [END howto_operator_azure_vm_stop] + + sense_deallocated = AzureVirtualMachineStateSensor( + task_id="sense_deallocated", + resource_group_name=RESOURCE_GROUP, + vm_name=VM_NAME, + target_state="deallocated", + deferrable=True, + poke_interval=10, + timeout=300, + ) + + start_vm >> sense_running >> restart_vm >> stop_vm >> sense_deallocated + +from tests_common.test_utils.system_tests import get_test_run # noqa: E402 + +# Needed to run the example DAG with pytest (see: contributing-docs/testing/system_tests.rst) +test_run = get_test_run(dag) diff --git a/providers/microsoft/azure/tests/unit/microsoft/azure/hooks/test_compute.py b/providers/microsoft/azure/tests/unit/microsoft/azure/hooks/test_compute.py new file mode 100644 index 0000000000000..9be756422b8bb --- /dev/null +++ b/providers/microsoft/azure/tests/unit/microsoft/azure/hooks/test_compute.py @@ -0,0 +1,273 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from airflow.models import Connection +from airflow.providers.microsoft.azure.hooks.compute import AzureComputeHook + +CONN_ID = "azure_compute_test" + + +class TestAzureComputeHook: + @pytest.fixture(autouse=True) + def setup_test_cases(self, create_mock_connection): + create_mock_connection( + Connection( + conn_id=CONN_ID, + conn_type="azure_compute", + login="client_id", + password="client_secret", + extra={ + "tenantId": "tenant_id", + "subscriptionId": "subscription_id", + }, + ) + ) + + @patch("airflow.providers.microsoft.azure.hooks.compute.ComputeManagementClient") + @patch("airflow.providers.microsoft.azure.hooks.compute.ClientSecretCredential") + def test_get_conn_returns_compute_management_client(self, mock_credential, mock_client): + hook = AzureComputeHook(azure_conn_id=CONN_ID) + result = hook.get_conn() + + mock_credential.assert_called_once_with( + client_id="client_id", + client_secret="client_secret", + tenant_id="tenant_id", + ) + mock_client.assert_called_once_with( + credential=mock_credential.return_value, + subscription_id="subscription_id", + ) + assert result == mock_client.return_value + + @patch("airflow.providers.microsoft.azure.hooks.compute.ComputeManagementClient") + @patch("airflow.providers.microsoft.azure.hooks.compute.get_sync_default_azure_credential") + def test_get_conn_with_default_azure_credential( + self, mock_default_cred, mock_client, create_mock_connection + ): + create_mock_connection( + Connection( + conn_id="azure_no_login", + conn_type="azure_compute", + extra={"subscriptionId": "sub_id"}, + ) + ) + hook = AzureComputeHook(azure_conn_id="azure_no_login") + hook.get_conn() + + mock_default_cred.assert_called_once_with( + managed_identity_client_id=None, + workload_identity_tenant_id=None, + ) + mock_client.assert_called_once_with( + credential=mock_default_cred.return_value, + subscription_id="sub_id", + ) + + @patch("airflow.providers.microsoft.azure.hooks.compute.get_client_from_auth_file") + def test_get_conn_with_key_path(self, mock_get_client_from_auth_file, create_mock_connection): + create_mock_connection( + Connection( + conn_id="azure_with_key_path", + conn_type="azure_compute", + extra={"key_path": "/tmp/azure-key.json"}, + ) + ) + hook = AzureComputeHook(azure_conn_id="azure_with_key_path") + conn = hook.get_conn() + + mock_get_client_from_auth_file.assert_called_once_with( + client_class=hook.sdk_client, auth_path="/tmp/azure-key.json" + ) + assert conn == mock_get_client_from_auth_file.return_value + + @patch("airflow.providers.microsoft.azure.hooks.compute.get_client_from_json_dict") + def test_get_conn_with_key_json(self, mock_get_client_from_json_dict, create_mock_connection): + create_mock_connection( + Connection( + conn_id="azure_with_key_json", + conn_type="azure_compute", + extra={"key_json": {"tenantId": "tenant", "subscriptionId": "sub"}}, + ) + ) + hook = AzureComputeHook(azure_conn_id="azure_with_key_json") + hook.get_conn() + + mock_get_client_from_json_dict.assert_called_once_with( + client_class=hook.sdk_client, config_dict={"tenantId": "tenant", "subscriptionId": "sub"} + ) + + @patch("airflow.providers.microsoft.azure.hooks.compute.ComputeManagementClient") + @patch("airflow.providers.microsoft.azure.hooks.compute.ClientSecretCredential") + def test_start_instance(self, mock_credential, mock_client): + mock_poller = MagicMock() + mock_client.return_value.virtual_machines.begin_start.return_value = mock_poller + + hook = AzureComputeHook(azure_conn_id=CONN_ID) + hook.start_instance(resource_group_name="rg", vm_name="vm1") + + mock_client.return_value.virtual_machines.begin_start.assert_called_once_with("rg", "vm1") + mock_poller.result.assert_called_once() + mock_poller.result.reset_mock() + + hook.start_instance(resource_group_name="rg", vm_name="vm1", wait_for_completion=False) + mock_poller.result.assert_not_called() + + @patch("airflow.providers.microsoft.azure.hooks.compute.ComputeManagementClient") + @patch("airflow.providers.microsoft.azure.hooks.compute.ClientSecretCredential") + def test_stop_instance(self, mock_credential, mock_client): + mock_poller = MagicMock() + mock_client.return_value.virtual_machines.begin_deallocate.return_value = mock_poller + + hook = AzureComputeHook(azure_conn_id=CONN_ID) + hook.stop_instance(resource_group_name="rg", vm_name="vm1") + + mock_client.return_value.virtual_machines.begin_deallocate.assert_called_once_with("rg", "vm1") + mock_poller.result.assert_called_once() + + @patch("airflow.providers.microsoft.azure.hooks.compute.ComputeManagementClient") + @patch("airflow.providers.microsoft.azure.hooks.compute.ClientSecretCredential") + def test_restart_instance(self, mock_credential, mock_client): + mock_poller = MagicMock() + mock_client.return_value.virtual_machines.begin_restart.return_value = mock_poller + + hook = AzureComputeHook(azure_conn_id=CONN_ID) + hook.restart_instance(resource_group_name="rg", vm_name="vm1") + + mock_client.return_value.virtual_machines.begin_restart.assert_called_once_with("rg", "vm1") + mock_poller.result.assert_called_once() + + @pytest.mark.parametrize( + ("status_code", "expected"), + [ + ("PowerState/running", "running"), + ("PowerState/deallocated", "deallocated"), + ("foo/bar", "unknown"), + (None, "unknown"), + ], + ) + @patch("airflow.providers.microsoft.azure.hooks.compute.ComputeManagementClient") + @patch("airflow.providers.microsoft.azure.hooks.compute.ClientSecretCredential") + def test_get_power_state(self, mock_credential, mock_client, status_code, expected): + mock_instance_view = MagicMock() + mock_instance_view.statuses = [ + MagicMock(code="ProvisioningState/succeeded"), + MagicMock(code=status_code), + ] + mock_client.return_value.virtual_machines.instance_view.return_value = mock_instance_view + + hook = AzureComputeHook(azure_conn_id=CONN_ID) + state = hook.get_power_state(resource_group_name="rg", vm_name="vm1") + + assert state == expected + + @patch("airflow.providers.microsoft.azure.hooks.compute.ComputeManagementClient") + @patch("airflow.providers.microsoft.azure.hooks.compute.ClientSecretCredential") + def test_test_connection_success(self, mock_credential, mock_client): + mock_client.return_value.virtual_machines.list_all.return_value = iter([]) + + hook = AzureComputeHook(azure_conn_id=CONN_ID) + result, message = hook.test_connection() + + assert result is True + assert "Successfully" in message + + @patch("airflow.providers.microsoft.azure.hooks.compute.ComputeManagementClient") + @patch("airflow.providers.microsoft.azure.hooks.compute.ClientSecretCredential") + def test_test_connection_failure(self, mock_credential, mock_client): + mock_client.return_value.virtual_machines.list_all.side_effect = Exception("Auth failed") + + hook = AzureComputeHook(azure_conn_id=CONN_ID) + result, message = hook.test_connection() + + assert result is False + assert "Auth failed" in message + + # ------------------------------------------------------------------ + # Async interface tests + # ------------------------------------------------------------------ + + @pytest.mark.asyncio + @patch("airflow.providers.microsoft.azure.hooks.compute.AsyncComputeManagementClient") + @patch("airflow.providers.microsoft.azure.hooks.compute.AsyncClientSecretCredential") + @patch( + "airflow.providers.microsoft.azure.hooks.compute.get_async_connection", + new_callable=AsyncMock, + ) + async def test_get_async_conn_with_service_principal( + self, mock_get_async_connection, mock_async_cred, mock_async_client + ): + mock_conn = MagicMock() + mock_conn.login = "client_id" + mock_conn.password = "client_secret" + mock_conn.extra_dejson = {"tenantId": "tenant_id", "subscriptionId": "sub_id"} + mock_get_async_connection.return_value = mock_conn + + hook = AzureComputeHook(azure_conn_id=CONN_ID) + conn = await hook.get_async_conn() + + mock_async_cred.assert_called_once_with( + client_id="client_id", + client_secret="client_secret", + tenant_id="tenant_id", + ) + mock_async_client.assert_called_once_with( + credential=mock_async_cred.return_value, + subscription_id="sub_id", + ) + assert conn == mock_async_client.return_value + + @pytest.mark.asyncio + @patch( + "airflow.providers.microsoft.azure.hooks.compute.AzureComputeHook.get_async_conn", + new_callable=AsyncMock, + ) + async def test_async_get_power_state(self, mock_get_async_conn): + mock_status_power = MagicMock(code="PowerState/running") + mock_status_prov = MagicMock(code="ProvisioningState/succeeded") + mock_instance_view = MagicMock(statuses=[mock_status_prov, mock_status_power]) + mock_client = AsyncMock() + mock_client.virtual_machines.instance_view.return_value = mock_instance_view + mock_get_async_conn.return_value = mock_client + + hook = AzureComputeHook(azure_conn_id=CONN_ID) + state = await hook.async_get_power_state("rg", "vm1") + + assert state == "running" + mock_client.virtual_machines.instance_view.assert_called_once_with("rg", "vm1") + + @pytest.mark.asyncio + @patch( + "airflow.providers.microsoft.azure.hooks.compute.AzureComputeHook.get_async_conn", + new_callable=AsyncMock, + ) + async def test_async_context_manager_closes_conn(self, mock_get_async_conn): + mock_client = AsyncMock() + mock_get_async_conn.return_value = mock_client + + hook = AzureComputeHook(azure_conn_id=CONN_ID) + async with hook as h: + assert h is hook + h._async_conn = mock_client + + mock_client.close.assert_called_once() diff --git a/providers/microsoft/azure/tests/unit/microsoft/azure/operators/test_compute.py b/providers/microsoft/azure/tests/unit/microsoft/azure/operators/test_compute.py new file mode 100644 index 0000000000000..879dd219da7bc --- /dev/null +++ b/providers/microsoft/azure/tests/unit/microsoft/azure/operators/test_compute.py @@ -0,0 +1,138 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest.mock import patch + +from airflow.providers.microsoft.azure.operators.compute import ( + AzureVirtualMachineRestartOperator, + AzureVirtualMachineStartOperator, + AzureVirtualMachineStopOperator, +) + +RESOURCE_GROUP = "test-rg" +VM_NAME = "test-vm" +CONN_ID = "azure_default" + + +class TestAzureVirtualMachineStartOperator: + def test_init(self): + op = AzureVirtualMachineStartOperator( + task_id="start_vm", + resource_group_name=RESOURCE_GROUP, + vm_name=VM_NAME, + azure_conn_id=CONN_ID, + ) + assert op.resource_group_name == RESOURCE_GROUP + assert op.vm_name == VM_NAME + assert op.wait_for_completion is True + assert op.azure_conn_id == CONN_ID + + def test_template_fields(self): + op = AzureVirtualMachineStartOperator( + task_id="start_vm", + resource_group_name=RESOURCE_GROUP, + vm_name=VM_NAME, + ) + assert "resource_group_name" in op.template_fields + assert "vm_name" in op.template_fields + + @patch("airflow.providers.microsoft.azure.operators.compute.AzureComputeHook") + def test_execute_start_instance(self, mock_hook_cls): + op = AzureVirtualMachineStartOperator( + task_id="start_vm", + resource_group_name=RESOURCE_GROUP, + vm_name=VM_NAME, + ) + op.execute(context=None) + + mock_hook_cls.return_value.start_instance.assert_called_once_with( + resource_group_name=RESOURCE_GROUP, + vm_name=VM_NAME, + wait_for_completion=True, + ) + + @patch("airflow.providers.microsoft.azure.operators.compute.AzureComputeHook") + def test_execute_start_instance_no_wait(self, mock_hook_cls): + op = AzureVirtualMachineStartOperator( + task_id="start_vm", + resource_group_name=RESOURCE_GROUP, + vm_name=VM_NAME, + wait_for_completion=False, + ) + op.execute(context=None) + + mock_hook_cls.return_value.start_instance.assert_called_once_with( + resource_group_name=RESOURCE_GROUP, + vm_name=VM_NAME, + wait_for_completion=False, + ) + + +class TestAzureVirtualMachineStopOperator: + def test_init(self): + op = AzureVirtualMachineStopOperator( + task_id="stop_vm", + resource_group_name=RESOURCE_GROUP, + vm_name=VM_NAME, + ) + assert op.resource_group_name == RESOURCE_GROUP + assert op.vm_name == VM_NAME + assert op.wait_for_completion is True + + @patch("airflow.providers.microsoft.azure.operators.compute.AzureComputeHook") + def test_execute_stop_instance(self, mock_hook_cls): + op = AzureVirtualMachineStopOperator( + task_id="stop_vm", + resource_group_name=RESOURCE_GROUP, + vm_name=VM_NAME, + ) + op.execute(context=None) + + mock_hook_cls.return_value.stop_instance.assert_called_once_with( + resource_group_name=RESOURCE_GROUP, + vm_name=VM_NAME, + wait_for_completion=True, + ) + + +class TestAzureVirtualMachineRestartOperator: + def test_init(self): + op = AzureVirtualMachineRestartOperator( + task_id="restart_vm", + resource_group_name=RESOURCE_GROUP, + vm_name=VM_NAME, + ) + assert op.resource_group_name == RESOURCE_GROUP + assert op.vm_name == VM_NAME + assert op.wait_for_completion is True + + @patch("airflow.providers.microsoft.azure.operators.compute.AzureComputeHook") + def test_execute_restart_instance(self, mock_hook_cls): + op = AzureVirtualMachineRestartOperator( + task_id="restart_vm", + resource_group_name=RESOURCE_GROUP, + vm_name=VM_NAME, + ) + op.execute(context=None) + + mock_hook_cls.return_value.restart_instance.assert_called_once_with( + resource_group_name=RESOURCE_GROUP, + vm_name=VM_NAME, + wait_for_completion=True, + ) diff --git a/providers/microsoft/azure/tests/unit/microsoft/azure/sensors/test_compute.py b/providers/microsoft/azure/tests/unit/microsoft/azure/sensors/test_compute.py new file mode 100644 index 0000000000000..47e4da37a8d46 --- /dev/null +++ b/providers/microsoft/azure/tests/unit/microsoft/azure/sensors/test_compute.py @@ -0,0 +1,123 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest.mock import patch + +import pytest + +from airflow.providers.common.compat.sdk import TaskDeferred +from airflow.providers.microsoft.azure.sensors.compute import AzureVirtualMachineStateSensor + +RESOURCE_GROUP = "test-rg" +VM_NAME = "test-vm" +CONN_ID = "azure_default" + + +class TestAzureVirtualMachineStateSensor: + def test_init(self): + sensor = AzureVirtualMachineStateSensor( + task_id="sense_vm", + resource_group_name=RESOURCE_GROUP, + vm_name=VM_NAME, + target_state="running", + azure_conn_id=CONN_ID, + ) + assert sensor.resource_group_name == RESOURCE_GROUP + assert sensor.vm_name == VM_NAME + assert sensor.target_state == "running" + assert sensor.azure_conn_id == CONN_ID + + def test_init_invalid_target_state(self): + with pytest.raises(ValueError, match="Invalid target_state"): + AzureVirtualMachineStateSensor( + task_id="sense_vm", + resource_group_name=RESOURCE_GROUP, + vm_name=VM_NAME, + target_state="invalid_state", + ) + + def test_template_fields(self): + sensor = AzureVirtualMachineStateSensor( + task_id="sense_vm", + resource_group_name=RESOURCE_GROUP, + vm_name=VM_NAME, + target_state="running", + ) + assert "resource_group_name" in sensor.template_fields + assert "vm_name" in sensor.template_fields + assert "target_state" in sensor.template_fields + + @pytest.mark.parametrize( + ("return_value", "expected"), + [ + ("running", True), + ("deallocated", False), + ], + ) + @patch("airflow.providers.microsoft.azure.sensors.compute.AzureComputeHook") + def test_poke(self, mock_hook_cls, return_value, expected): + mock_hook_cls.return_value.get_power_state.return_value = return_value + + sensor = AzureVirtualMachineStateSensor( + task_id="sense_vm", + resource_group_name=RESOURCE_GROUP, + vm_name=VM_NAME, + target_state="running", + ) + assert sensor.poke(context=None) is expected + + @patch("airflow.providers.microsoft.azure.sensors.compute.AzureComputeHook") + def test_deferrable_mode(self, mock_hook_cls): + mock_hook_cls.return_value.get_power_state.return_value = "deallocated" + + sensor = AzureVirtualMachineStateSensor( + task_id="sense_vm", + resource_group_name=RESOURCE_GROUP, + vm_name=VM_NAME, + target_state="running", + deferrable=True, + ) + with pytest.raises(TaskDeferred): + sensor.execute(context=None) + + def test_execute_complete_success(self): + sensor = AzureVirtualMachineStateSensor( + task_id="sense_vm", + resource_group_name=RESOURCE_GROUP, + vm_name=VM_NAME, + target_state="running", + ) + # Should not raise + sensor.execute_complete( + context=None, + event={"status": "success", "message": "VM reached running state"}, + ) + + def test_execute_complete_error(self): + sensor = AzureVirtualMachineStateSensor( + task_id="sense_vm", + resource_group_name=RESOURCE_GROUP, + vm_name=VM_NAME, + target_state="running", + ) + with pytest.raises(RuntimeError, match="Something went wrong"): + sensor.execute_complete( + context=None, + event={"status": "error", "message": "Something went wrong"}, + ) diff --git a/providers/microsoft/azure/tests/unit/microsoft/azure/triggers/test_compute.py b/providers/microsoft/azure/tests/unit/microsoft/azure/triggers/test_compute.py new file mode 100644 index 0000000000000..80deba518fb6a --- /dev/null +++ b/providers/microsoft/azure/tests/unit/microsoft/azure/triggers/test_compute.py @@ -0,0 +1,126 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest import mock + +import pytest + +from airflow.providers.microsoft.azure.triggers.compute import AzureVirtualMachineStateTrigger +from airflow.triggers.base import TriggerEvent + +RESOURCE_GROUP = "test-rg" +VM_NAME = "test-vm" +TARGET_STATE = "running" +CONN_ID = "azure_default" +POKE_INTERVAL = 10.0 + + +class TestAzureVirtualMachineStateTrigger: + def test_serialize(self): + trigger = AzureVirtualMachineStateTrigger( + resource_group_name=RESOURCE_GROUP, + vm_name=VM_NAME, + target_state=TARGET_STATE, + azure_conn_id=CONN_ID, + poke_interval=POKE_INTERVAL, + ) + + actual = trigger.serialize() + + assert isinstance(actual, tuple) + assert ( + actual[0] + == f"{AzureVirtualMachineStateTrigger.__module__}.{AzureVirtualMachineStateTrigger.__name__}" + ) + assert actual[1] == { + "resource_group_name": RESOURCE_GROUP, + "vm_name": VM_NAME, + "target_state": TARGET_STATE, + "azure_conn_id": CONN_ID, + "poke_interval": POKE_INTERVAL, + } + + @pytest.mark.asyncio + @mock.patch( + "airflow.providers.microsoft.azure.hooks.compute.AzureComputeHook.async_get_power_state", + new_callable=mock.AsyncMock, + ) + async def test_run_immediate_success(self, mock_get_power_state): + mock_get_power_state.return_value = TARGET_STATE + + trigger = AzureVirtualMachineStateTrigger( + resource_group_name=RESOURCE_GROUP, + vm_name=VM_NAME, + target_state=TARGET_STATE, + azure_conn_id=CONN_ID, + poke_interval=POKE_INTERVAL, + ) + + generator = trigger.run() + response = await generator.asend(None) + + assert response == TriggerEvent( + {"status": "success", "message": f"VM {VM_NAME} reached state '{TARGET_STATE}'."} + ) + + @pytest.mark.asyncio + @mock.patch("asyncio.sleep", return_value=None) + @mock.patch( + "airflow.providers.microsoft.azure.hooks.compute.AzureComputeHook.async_get_power_state", + new_callable=mock.AsyncMock, + ) + async def test_run_polls_until_success(self, mock_get_power_state, mock_sleep): + mock_get_power_state.side_effect = ["deallocated", TARGET_STATE] + + trigger = AzureVirtualMachineStateTrigger( + resource_group_name=RESOURCE_GROUP, + vm_name=VM_NAME, + target_state=TARGET_STATE, + azure_conn_id=CONN_ID, + poke_interval=POKE_INTERVAL, + ) + + generator = trigger.run() + response = await generator.asend(None) + + assert mock_get_power_state.call_count == 2 + assert response == TriggerEvent( + {"status": "success", "message": f"VM {VM_NAME} reached state '{TARGET_STATE}'."} + ) + + @pytest.mark.asyncio + @mock.patch( + "airflow.providers.microsoft.azure.hooks.compute.AzureComputeHook.async_get_power_state", + new_callable=mock.AsyncMock, + ) + async def test_run_error(self, mock_get_power_state): + mock_get_power_state.side_effect = Exception("API error") + + trigger = AzureVirtualMachineStateTrigger( + resource_group_name=RESOURCE_GROUP, + vm_name=VM_NAME, + target_state=TARGET_STATE, + azure_conn_id=CONN_ID, + poke_interval=POKE_INTERVAL, + ) + + generator = trigger.run() + response = await generator.asend(None) + + assert response == TriggerEvent({"status": "error", "message": "API error"})