diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt
index aa0bd36de11ad..e858231534dad 100644
--- a/docs/spelling_wordlist.txt
+++ b/docs/spelling_wordlist.txt
@@ -264,6 +264,7 @@ comparator
compat
Compute
compute
+ComputeManagementClient
Computenodes
ComputeNodeState
concat
@@ -406,6 +407,9 @@ dbt
dbutils
ddl
de
+deallocate
+deallocated
+Debounce
debuggability
declaratively
decommissioning
diff --git a/providers/microsoft/azure/docs/index.rst b/providers/microsoft/azure/docs/index.rst
index 8ba0c07f63738..39bf993fa5fd8 100644
--- a/providers/microsoft/azure/docs/index.rst
+++ b/providers/microsoft/azure/docs/index.rst
@@ -128,6 +128,7 @@ PIP package Version required
``azure-kusto-data`` ``>=4.1.0,!=4.6.0,!=5.0.0``
``azure-mgmt-datafactory`` ``>=2.0.0``
``azure-mgmt-containerregistry`` ``>=8.0.0``
+``azure-mgmt-compute`` ``>=33.0.0``
``azure-mgmt-containerinstance`` ``>=10.1.0``
``msgraph-core`` ``>=1.3.3``
``msgraphfs`` ``>=0.3.0``
diff --git a/providers/microsoft/azure/docs/operators/compute.rst b/providers/microsoft/azure/docs/operators/compute.rst
new file mode 100644
index 0000000000000..6a9728d09dbcf
--- /dev/null
+++ b/providers/microsoft/azure/docs/operators/compute.rst
@@ -0,0 +1,108 @@
+ .. Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ .. http://www.apache.org/licenses/LICENSE-2.0
+
+ .. Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+
+
+Azure Virtual Machine Operators
+================================
+
+Waiting Strategy
+----------------
+The VM action operators support two patterns:
+
+* ``wait_for_completion=True`` (default): operator blocks until Azure operation finishes.
+* ``wait_for_completion=False``: operator submits the operation and returns quickly.
+
+When you want to reduce worker slot usage for long VM state transitions, use
+``wait_for_completion=False`` together with
+:class:`~airflow.providers.microsoft.azure.sensors.compute.AzureVirtualMachineStateSensor`
+in ``deferrable=True`` mode to move the waiting to the triggerer.
+
+.. _howto/operator:AzureVirtualMachineStartOperator:
+
+AzureVirtualMachineStartOperator
+---------------------------------
+Use the
+:class:`~airflow.providers.microsoft.azure.operators.compute.AzureVirtualMachineStartOperator`
+to start an Azure Virtual Machine.
+
+Below is an example of using this operator to start a VM:
+
+.. exampleinclude:: /../tests/system/microsoft/azure/example_azure_compute.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_operator_azure_vm_start]
+ :end-before: [END howto_operator_azure_vm_start]
+
+
+.. _howto/operator:AzureVirtualMachineStopOperator:
+
+AzureVirtualMachineStopOperator
+--------------------------------
+Use the
+:class:`~airflow.providers.microsoft.azure.operators.compute.AzureVirtualMachineStopOperator`
+to stop (deallocate) an Azure Virtual Machine. This releases compute resources and stops billing.
+
+Below is an example of using this operator to stop a VM:
+
+.. exampleinclude:: /../tests/system/microsoft/azure/example_azure_compute.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_operator_azure_vm_stop]
+ :end-before: [END howto_operator_azure_vm_stop]
+
+
+.. _howto/operator:AzureVirtualMachineRestartOperator:
+
+AzureVirtualMachineRestartOperator
+-----------------------------------
+Use the
+:class:`~airflow.providers.microsoft.azure.operators.compute.AzureVirtualMachineRestartOperator`
+to restart an Azure Virtual Machine.
+
+Below is an example of using this operator to restart a VM:
+
+.. exampleinclude:: /../tests/system/microsoft/azure/example_azure_compute.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_operator_azure_vm_restart]
+ :end-before: [END howto_operator_azure_vm_restart]
+
+
+.. _howto/sensor:AzureVirtualMachineStateSensor:
+
+AzureVirtualMachineStateSensor
+-------------------------------
+Use the
+:class:`~airflow.providers.microsoft.azure.sensors.compute.AzureVirtualMachineStateSensor`
+to poll a VM until it reaches a target power state (e.g., ``running``, ``deallocated``).
+This sensor supports deferrable mode.
+
+Below is an example of using this sensor:
+
+.. exampleinclude:: /../tests/system/microsoft/azure/example_azure_compute.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_sensor_azure_vm_state]
+ :end-before: [END howto_sensor_azure_vm_state]
+
+
+Reference
+---------
+
+For further information, look at:
+
+* `Azure Virtual Machines Documentation `__
diff --git a/providers/microsoft/azure/provider.yaml b/providers/microsoft/azure/provider.yaml
index 8b89a7cbd0a0e..fdf39ed0d2c96 100644
--- a/providers/microsoft/azure/provider.yaml
+++ b/providers/microsoft/azure/provider.yaml
@@ -118,6 +118,12 @@ integrations:
- /docs/apache-airflow-providers-microsoft-azure/operators/batch.rst
logo: /docs/integration-logos/Microsoft-Azure-Batch.png
tags: [azure]
+ - integration-name: Microsoft Azure Compute
+ external-doc-url: https://azure.microsoft.com/en-us/products/virtual-machines/
+ how-to-guide:
+ - /docs/apache-airflow-providers-microsoft-azure/operators/compute.rst
+ logo: /docs/integration-logos/Microsoft-Azure.png
+ tags: [azure]
- integration-name: Microsoft Azure Blob Storage
external-doc-url: https://azure.microsoft.com/en-us/services/storage/blobs/
how-to-guide:
@@ -191,6 +197,9 @@ integrations:
tags: [azure]
operators:
+ - integration-name: Microsoft Azure Compute
+ python-modules:
+ - airflow.providers.microsoft.azure.operators.compute
- integration-name: Microsoft Azure Data Lake Storage
python-modules:
- airflow.providers.microsoft.azure.operators.adls
@@ -226,6 +235,9 @@ operators:
- airflow.providers.microsoft.azure.operators.powerbi
sensors:
+ - integration-name: Microsoft Azure Compute
+ python-modules:
+ - airflow.providers.microsoft.azure.sensors.compute
- integration-name: Microsoft Azure Cosmos DB
python-modules:
- airflow.providers.microsoft.azure.sensors.cosmos
@@ -244,6 +256,9 @@ filesystems:
- airflow.providers.microsoft.azure.fs.msgraph
hooks:
+ - integration-name: Microsoft Azure Compute
+ python-modules:
+ - airflow.providers.microsoft.azure.hooks.compute
- integration-name: Microsoft Azure Container Instances
python-modules:
- airflow.providers.microsoft.azure.hooks.container_volume
@@ -290,6 +305,9 @@ hooks:
- airflow.providers.microsoft.azure.hooks.powerbi
triggers:
+ - integration-name: Microsoft Azure Compute
+ python-modules:
+ - airflow.providers.microsoft.azure.triggers.compute
- integration-name: Microsoft Azure Data Factory
python-modules:
- airflow.providers.microsoft.azure.triggers.data_factory
@@ -362,6 +380,36 @@ connection-types:
label: Workload Identity Tenant ID
schema:
type: ["string", "null"]
+ - hook-class-name: airflow.providers.microsoft.azure.hooks.compute.AzureComputeHook
+ connection-type: azure_compute
+ ui-field-behaviour:
+ hidden-fields: ["schema", "port", "host"]
+ relabeling:
+ login: Client ID
+ password: Client Secret
+ placeholders:
+ extra: '{"key_path": "path to json file for auth", "key_json": "specifies json dict for auth"}'
+ login: client_id (token credentials auth)
+ password: secret (token credentials auth)
+ tenantId: tenantId (token credentials auth)
+ subscriptionId: subscriptionId (token credentials auth)
+ conn-fields:
+ tenantId:
+ label: Azure Tenant ID
+ schema:
+ type: ["string", "null"]
+ subscriptionId:
+ label: Azure Subscription ID
+ schema:
+ type: ["string", "null"]
+ managed_identity_client_id:
+ label: Managed Identity Client ID
+ schema:
+ type: ["string", "null"]
+ workload_identity_tenant_id:
+ label: Workload Identity Tenant ID
+ schema:
+ type: ["string", "null"]
- hook-class-name: airflow.providers.microsoft.azure.hooks.adx.AzureDataExplorerHook
connection-type: azure_data_explorer
ui-field-behaviour:
diff --git a/providers/microsoft/azure/pyproject.toml b/providers/microsoft/azure/pyproject.toml
index 3cc4edf2c5376..37f5021c84cc2 100644
--- a/providers/microsoft/azure/pyproject.toml
+++ b/providers/microsoft/azure/pyproject.toml
@@ -81,6 +81,7 @@ dependencies = [
"azure-kusto-data>=4.1.0,!=4.6.0,!=5.0.0",
"azure-mgmt-datafactory>=2.0.0",
"azure-mgmt-containerregistry>=8.0.0",
+ "azure-mgmt-compute>=33.0.0",
"azure-mgmt-containerinstance>=10.1.0",
"msgraph-core>=1.3.3",
"msgraphfs>=0.3.0",
diff --git a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/get_provider_info.py b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/get_provider_info.py
index 80cc07ce077de..a13e4e6229b1c 100644
--- a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/get_provider_info.py
+++ b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/get_provider_info.py
@@ -34,6 +34,13 @@ def get_provider_info():
"logo": "/docs/integration-logos/Microsoft-Azure-Batch.png",
"tags": ["azure"],
},
+ {
+ "integration-name": "Microsoft Azure Compute",
+ "external-doc-url": "https://azure.microsoft.com/en-us/products/virtual-machines/",
+ "how-to-guide": ["/docs/apache-airflow-providers-microsoft-azure/operators/compute.rst"],
+ "logo": "/docs/integration-logos/Microsoft-Azure.png",
+ "tags": ["azure"],
+ },
{
"integration-name": "Microsoft Azure Blob Storage",
"external-doc-url": "https://azure.microsoft.com/en-us/services/storage/blobs/",
@@ -135,6 +142,10 @@ def get_provider_info():
},
],
"operators": [
+ {
+ "integration-name": "Microsoft Azure Compute",
+ "python-modules": ["airflow.providers.microsoft.azure.operators.compute"],
+ },
{
"integration-name": "Microsoft Azure Data Lake Storage",
"python-modules": ["airflow.providers.microsoft.azure.operators.adls"],
@@ -181,6 +192,10 @@ def get_provider_info():
},
],
"sensors": [
+ {
+ "integration-name": "Microsoft Azure Compute",
+ "python-modules": ["airflow.providers.microsoft.azure.sensors.compute"],
+ },
{
"integration-name": "Microsoft Azure Cosmos DB",
"python-modules": ["airflow.providers.microsoft.azure.sensors.cosmos"],
@@ -203,6 +218,10 @@ def get_provider_info():
"airflow.providers.microsoft.azure.fs.msgraph",
],
"hooks": [
+ {
+ "integration-name": "Microsoft Azure Compute",
+ "python-modules": ["airflow.providers.microsoft.azure.hooks.compute"],
+ },
{
"integration-name": "Microsoft Azure Container Instances",
"python-modules": [
@@ -265,6 +284,10 @@ def get_provider_info():
},
],
"triggers": [
+ {
+ "integration-name": "Microsoft Azure Compute",
+ "python-modules": ["airflow.providers.microsoft.azure.triggers.compute"],
+ },
{
"integration-name": "Microsoft Azure Data Factory",
"python-modules": ["airflow.providers.microsoft.azure.triggers.data_factory"],
@@ -349,6 +372,36 @@ def get_provider_info():
},
},
},
+ {
+ "hook-class-name": "airflow.providers.microsoft.azure.hooks.compute.AzureComputeHook",
+ "connection-type": "azure_compute",
+ "ui-field-behaviour": {
+ "hidden-fields": ["schema", "port", "host"],
+ "relabeling": {"login": "Client ID", "password": "Client Secret"},
+ "placeholders": {
+ "extra": '{"key_path": "path to json file for auth", "key_json": "specifies json dict for auth"}',
+ "login": "client_id (token credentials auth)",
+ "password": "secret (token credentials auth)",
+ "tenantId": "tenantId (token credentials auth)",
+ "subscriptionId": "subscriptionId (token credentials auth)",
+ },
+ },
+ "conn-fields": {
+ "tenantId": {"label": "Azure Tenant ID", "schema": {"type": ["string", "null"]}},
+ "subscriptionId": {
+ "label": "Azure Subscription ID",
+ "schema": {"type": ["string", "null"]},
+ },
+ "managed_identity_client_id": {
+ "label": "Managed Identity Client ID",
+ "schema": {"type": ["string", "null"]},
+ },
+ "workload_identity_tenant_id": {
+ "label": "Workload Identity Tenant ID",
+ "schema": {"type": ["string", "null"]},
+ },
+ },
+ },
{
"hook-class-name": "airflow.providers.microsoft.azure.hooks.adx.AzureDataExplorerHook",
"connection-type": "azure_data_explorer",
diff --git a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/compute.py b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/compute.py
new file mode 100644
index 0000000000000..6d5564aeafefd
--- /dev/null
+++ b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/compute.py
@@ -0,0 +1,237 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from functools import cached_property
+from typing import Any, cast
+
+from azure.common.client_factory import get_client_from_auth_file, get_client_from_json_dict
+from azure.identity import ClientSecretCredential, DefaultAzureCredential
+from azure.identity.aio import (
+ ClientSecretCredential as AsyncClientSecretCredential,
+ DefaultAzureCredential as AsyncDefaultAzureCredential,
+)
+from azure.mgmt.compute import ComputeManagementClient
+from azure.mgmt.compute.aio import ComputeManagementClient as AsyncComputeManagementClient
+
+from airflow.providers.common.compat.connection import get_async_connection
+from airflow.providers.microsoft.azure.hooks.base_azure import AzureBaseHook
+from airflow.providers.microsoft.azure.utils import (
+ get_async_default_azure_credential,
+ get_sync_default_azure_credential,
+)
+
+
+class AzureComputeHook(AzureBaseHook):
+ """
+ A hook to interact with Azure Compute to manage Virtual Machines.
+
+ :param azure_conn_id: :ref:`Azure connection id` of
+ a service principal which will be used to manage virtual machines.
+ """
+
+ conn_name_attr = "azure_conn_id"
+ default_conn_name = "azure_default"
+ conn_type = "azure_compute"
+ hook_name = "Azure Compute"
+
+ def __init__(self, azure_conn_id: str = default_conn_name) -> None:
+ super().__init__(sdk_client=ComputeManagementClient, conn_id=azure_conn_id)
+ self._async_conn: AsyncComputeManagementClient | None = None
+
+ @cached_property
+ def connection(self) -> ComputeManagementClient:
+ return self.get_conn()
+
+ def get_conn(self) -> ComputeManagementClient:
+ """
+ Authenticate the resource using the connection id passed during init.
+
+ :return: the authenticated ComputeManagementClient.
+ """
+ conn = self.get_connection(self.conn_id)
+ tenant = conn.extra_dejson.get("tenantId")
+
+ key_path = conn.extra_dejson.get("key_path")
+ if key_path:
+ if not key_path.endswith(".json"):
+ raise ValueError("Unrecognised extension for key file.")
+ self.log.info("Getting connection using a JSON key file.")
+ return get_client_from_auth_file(client_class=self.sdk_client, auth_path=key_path)
+
+ key_json = conn.extra_dejson.get("key_json")
+ if key_json:
+ self.log.info("Getting connection using a JSON config.")
+ return get_client_from_json_dict(client_class=self.sdk_client, config_dict=key_json)
+
+ credential: ClientSecretCredential | DefaultAzureCredential
+ if all([conn.login, conn.password, tenant]):
+ self.log.info("Getting connection using specific credentials and subscription_id.")
+ credential = ClientSecretCredential(
+ client_id=cast("str", conn.login),
+ client_secret=cast("str", conn.password),
+ tenant_id=cast("str", tenant),
+ )
+ else:
+ self.log.info("Using DefaultAzureCredential as credential")
+ managed_identity_client_id = conn.extra_dejson.get("managed_identity_client_id")
+ workload_identity_tenant_id = conn.extra_dejson.get("workload_identity_tenant_id")
+ credential = get_sync_default_azure_credential(
+ managed_identity_client_id=managed_identity_client_id,
+ workload_identity_tenant_id=workload_identity_tenant_id,
+ )
+
+ subscription_id = cast("str", conn.extra_dejson.get("subscriptionId"))
+ return ComputeManagementClient(
+ credential=credential,
+ subscription_id=subscription_id,
+ )
+
+ def start_instance(
+ self, resource_group_name: str, vm_name: str, wait_for_completion: bool = True
+ ) -> None:
+ """
+ Start a virtual machine instance.
+
+ :param resource_group_name: Name of the resource group.
+ :param vm_name: Name of the virtual machine.
+ :param wait_for_completion: Wait for the operation to complete.
+ """
+ self.log.info("Starting VM %s in resource group %s", vm_name, resource_group_name)
+ poller = self.connection.virtual_machines.begin_start(resource_group_name, vm_name)
+ if wait_for_completion:
+ poller.result()
+
+ def stop_instance(self, resource_group_name: str, vm_name: str, wait_for_completion: bool = True) -> None:
+ """
+ Stop (deallocate) a virtual machine instance.
+
+ Uses ``begin_deallocate`` to release compute resources and stop billing.
+
+ :param resource_group_name: Name of the resource group.
+ :param vm_name: Name of the virtual machine.
+ :param wait_for_completion: Wait for the operation to complete.
+ """
+ self.log.info("Stopping (deallocating) VM %s in resource group %s", vm_name, resource_group_name)
+ poller = self.connection.virtual_machines.begin_deallocate(resource_group_name, vm_name)
+ if wait_for_completion:
+ poller.result()
+
+ def restart_instance(
+ self, resource_group_name: str, vm_name: str, wait_for_completion: bool = True
+ ) -> None:
+ """
+ Restart a virtual machine instance.
+
+ :param resource_group_name: Name of the resource group.
+ :param vm_name: Name of the virtual machine.
+ :param wait_for_completion: Wait for the operation to complete.
+ """
+ self.log.info("Restarting VM %s in resource group %s", vm_name, resource_group_name)
+ poller = self.connection.virtual_machines.begin_restart(resource_group_name, vm_name)
+ if wait_for_completion:
+ poller.result()
+
+ def get_power_state(self, resource_group_name: str, vm_name: str) -> str:
+ """
+ Get the power state of a virtual machine.
+
+ :param resource_group_name: Name of the resource group.
+ :param vm_name: Name of the virtual machine.
+ :return: Power state string, e.g. ``running``, ``deallocated``, ``stopped``.
+ """
+ instance_view = self.connection.virtual_machines.instance_view(resource_group_name, vm_name)
+ for status in instance_view.statuses or []:
+ if status.code and status.code.startswith("PowerState/"):
+ return status.code.split("/", 1)[1]
+ return "unknown"
+
+ def test_connection(self) -> tuple[bool, str]:
+ """Test the Azure Compute connection."""
+ try:
+ next(self.connection.virtual_machines.list_all(), None)
+ except Exception as e:
+ return False, str(e)
+ return True, "Successfully connected to Azure Compute."
+
+ # ------------------------------------------------------------------
+ # Async interface (used by AzureVirtualMachineStateTrigger)
+ # ------------------------------------------------------------------
+
+ async def get_async_conn(self) -> AsyncComputeManagementClient:
+ """
+ Return an authenticated async :class:`~azure.mgmt.compute.aio.ComputeManagementClient`.
+
+ Supports service-principal (login/password + tenantId) and
+ DefaultAzureCredential auth. Legacy ``key_path`` / ``key_json``
+ auth files are not supported in the async path.
+ """
+ if self._async_conn is not None:
+ return self._async_conn
+
+ conn = await get_async_connection(self.conn_id)
+ tenant = conn.extra_dejson.get("tenantId")
+ subscription_id = cast("str", conn.extra_dejson.get("subscriptionId"))
+
+ credential: AsyncClientSecretCredential | AsyncDefaultAzureCredential
+ if conn.login and conn.password and tenant:
+ credential = AsyncClientSecretCredential(
+ client_id=conn.login,
+ client_secret=conn.password,
+ tenant_id=tenant,
+ )
+ else:
+ managed_identity_client_id = conn.extra_dejson.get("managed_identity_client_id")
+ workload_identity_tenant_id = conn.extra_dejson.get("workload_identity_tenant_id")
+ credential = get_async_default_azure_credential(
+ managed_identity_client_id=managed_identity_client_id,
+ workload_identity_tenant_id=workload_identity_tenant_id,
+ )
+
+ self._async_conn = AsyncComputeManagementClient(
+ credential=credential,
+ subscription_id=subscription_id,
+ )
+ return self._async_conn
+
+ async def async_get_power_state(self, resource_group_name: str, vm_name: str) -> str:
+ """
+ Get the power state of a virtual machine using the native async client.
+
+ :param resource_group_name: Name of the resource group.
+ :param vm_name: Name of the virtual machine.
+ :return: Power state string, e.g. ``running``, ``deallocated``, ``stopped``.
+ """
+ client = await self.get_async_conn()
+ instance_view = await client.virtual_machines.instance_view(resource_group_name, vm_name)
+ for status in instance_view.statuses or []:
+ if status.code and status.code.startswith("PowerState/"):
+ return status.code.split("/", 1)[1]
+ return "unknown"
+
+ async def close(self) -> None:
+ """Close the async connection."""
+ if self._async_conn is not None:
+ await self._async_conn.close()
+ self._async_conn = None
+
+ async def __aenter__(self) -> AzureComputeHook:
+ return self
+
+ async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
+ await self.close()
diff --git a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/compute.py b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/compute.py
new file mode 100644
index 0000000000000..b4bf06f17e03f
--- /dev/null
+++ b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/compute.py
@@ -0,0 +1,137 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from abc import abstractmethod
+from collections.abc import Sequence
+from functools import cached_property
+from typing import TYPE_CHECKING
+
+from airflow.providers.common.compat.sdk import BaseOperator
+from airflow.providers.microsoft.azure.hooks.compute import AzureComputeHook
+
+if TYPE_CHECKING:
+ from airflow.sdk import Context
+
+
+class BaseAzureVirtualMachineOperator(BaseOperator):
+ """
+ Base operator for Azure Virtual Machine operations.
+
+ :param resource_group_name: Name of the Azure resource group.
+ :param vm_name: Name of the virtual machine.
+ :param wait_for_completion: Wait for the operation to complete. Default True.
+ :param azure_conn_id: Azure connection id.
+ """
+
+ template_fields: Sequence[str] = ("resource_group_name", "vm_name")
+ ui_color = "#0078d4"
+ ui_fgcolor = "#ffffff"
+
+ def __init__(
+ self,
+ *,
+ resource_group_name: str,
+ vm_name: str,
+ wait_for_completion: bool = True,
+ azure_conn_id: str = "azure_default",
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ self.resource_group_name = resource_group_name
+ self.vm_name = vm_name
+ self.wait_for_completion = wait_for_completion
+ self.azure_conn_id = azure_conn_id
+
+ @cached_property
+ def hook(self) -> AzureComputeHook:
+ return AzureComputeHook(azure_conn_id=self.azure_conn_id)
+
+ @abstractmethod
+ def execute(self, context: Context) -> None: ...
+
+
+class AzureVirtualMachineStartOperator(BaseAzureVirtualMachineOperator):
+ """
+ Start an Azure Virtual Machine.
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the guide:
+ :ref:`howto/operator:AzureVirtualMachineStartOperator`
+
+ :param resource_group_name: Name of the Azure resource group.
+ :param vm_name: Name of the virtual machine.
+ :param wait_for_completion: Wait for the VM to reach 'running' state. Default True.
+ :param azure_conn_id: Azure connection id.
+ """
+
+ def execute(self, context: Context) -> None:
+ self.hook.start_instance(
+ resource_group_name=self.resource_group_name,
+ vm_name=self.vm_name,
+ wait_for_completion=self.wait_for_completion,
+ )
+ self.log.info("VM %s started successfully.", self.vm_name)
+
+
+class AzureVirtualMachineStopOperator(BaseAzureVirtualMachineOperator):
+ """
+ Stop (deallocate) an Azure Virtual Machine.
+
+ Uses ``begin_deallocate`` to release compute resources and stop billing.
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the guide:
+ :ref:`howto/operator:AzureVirtualMachineStopOperator`
+
+ :param resource_group_name: Name of the Azure resource group.
+ :param vm_name: Name of the virtual machine.
+ :param wait_for_completion: Wait for the VM to reach 'deallocated' state. Default True.
+ :param azure_conn_id: Azure connection id.
+ """
+
+ def execute(self, context: Context) -> None:
+ self.hook.stop_instance(
+ resource_group_name=self.resource_group_name,
+ vm_name=self.vm_name,
+ wait_for_completion=self.wait_for_completion,
+ )
+ self.log.info("VM %s stopped (deallocated) successfully.", self.vm_name)
+
+
+class AzureVirtualMachineRestartOperator(BaseAzureVirtualMachineOperator):
+ """
+ Restart an Azure Virtual Machine.
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the guide:
+ :ref:`howto/operator:AzureVirtualMachineRestartOperator`
+
+ :param resource_group_name: Name of the Azure resource group.
+ :param vm_name: Name of the virtual machine.
+ :param wait_for_completion: Wait for the VM to reach 'running' state. Default True.
+ :param azure_conn_id: Azure connection id.
+ """
+
+ def execute(self, context: Context) -> None:
+ self.hook.restart_instance(
+ resource_group_name=self.resource_group_name,
+ vm_name=self.vm_name,
+ wait_for_completion=self.wait_for_completion,
+ )
+ self.log.info("VM %s restarted successfully.", self.vm_name)
diff --git a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/sensors/compute.py b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/sensors/compute.py
new file mode 100644
index 0000000000000..1a9c05185dcb6
--- /dev/null
+++ b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/sensors/compute.py
@@ -0,0 +1,114 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from collections.abc import Sequence
+from datetime import timedelta
+from typing import TYPE_CHECKING
+
+from airflow.providers.common.compat.sdk import BaseSensorOperator, conf
+from airflow.providers.microsoft.azure.hooks.compute import AzureComputeHook
+from airflow.providers.microsoft.azure.triggers.compute import AzureVirtualMachineStateTrigger
+
+if TYPE_CHECKING:
+ from airflow.sdk import Context
+
+
+class AzureVirtualMachineStateSensor(BaseSensorOperator):
+ """
+ Poll an Azure Virtual Machine until it reaches a target power state.
+
+ .. seealso::
+ For more information on how to use this sensor, take a look at the guide:
+ :ref:`howto/sensor:AzureVirtualMachineStateSensor`
+
+ :param resource_group_name: Name of the Azure resource group.
+ :param vm_name: Name of the virtual machine.
+ :param target_state: Desired power state, e.g. ``running``, ``deallocated``.
+ :param azure_conn_id: Azure connection id.
+ :param deferrable: If True, run in deferrable mode.
+ """
+
+ template_fields: Sequence[str] = ("resource_group_name", "vm_name", "target_state")
+ ui_color = "#0078d4"
+ ui_fgcolor = "#ffffff"
+
+ VALID_STATES = frozenset({"running", "deallocated", "stopped", "starting", "deallocating"})
+
+ def __init__(
+ self,
+ *,
+ resource_group_name: str,
+ vm_name: str,
+ target_state: str,
+ azure_conn_id: str = "azure_default",
+ deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
+ **kwargs,
+ ) -> None:
+ if target_state not in self.VALID_STATES:
+ raise ValueError(
+ f"Invalid target_state: {target_state}. Must be one of {sorted(self.VALID_STATES)}"
+ )
+ super().__init__(**kwargs)
+ self.resource_group_name = resource_group_name
+ self.vm_name = vm_name
+ self.target_state = target_state
+ self.azure_conn_id = azure_conn_id
+ self.deferrable = deferrable
+
+ def poke(self, context: Context) -> bool:
+ hook = AzureComputeHook(azure_conn_id=self.azure_conn_id)
+ current_state = hook.get_power_state(self.resource_group_name, self.vm_name)
+ self.log.info("VM %s power state: %s", self.vm_name, current_state)
+ return current_state == self.target_state
+
+ def execute(self, context: Context) -> None:
+ """
+ Poll for the VM power state.
+
+ In deferrable mode, the polling is deferred to the triggerer. Otherwise
+ the sensor waits synchronously.
+ """
+ if not self.deferrable:
+ super().execute(context=context)
+ else:
+ if not self.poke(context=context):
+ self.defer(
+ timeout=timedelta(seconds=self.timeout),
+ trigger=AzureVirtualMachineStateTrigger(
+ resource_group_name=self.resource_group_name,
+ vm_name=self.vm_name,
+ target_state=self.target_state,
+ azure_conn_id=self.azure_conn_id,
+ poke_interval=self.poke_interval,
+ ),
+ method_name="execute_complete",
+ )
+
+ def execute_complete(self, context: Context, event: dict[str, str]) -> None:
+ """
+ Handle callback from the trigger.
+
+ Relies on trigger to throw an exception, otherwise it assumes execution was successful.
+ """
+ if event:
+ if event["status"] == "error":
+ raise RuntimeError(event["message"])
+ self.log.info(event["message"])
+ else:
+ raise RuntimeError("Did not receive valid event from the triggerer")
diff --git a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/triggers/compute.py b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/triggers/compute.py
new file mode 100644
index 0000000000000..e1682af6b4bd7
--- /dev/null
+++ b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/triggers/compute.py
@@ -0,0 +1,91 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import asyncio
+from collections.abc import AsyncIterator
+from typing import Any
+
+from airflow.triggers.base import BaseTrigger, TriggerEvent
+
+
+class AzureVirtualMachineStateTrigger(BaseTrigger):
+ """
+ Poll the Azure VM power state and yield a TriggerEvent once it matches the target.
+
+ Uses the native async Azure SDK client (``azure.mgmt.compute.aio``) so that
+ the triggerer event loop is never blocked.
+
+ :param resource_group_name: Name of the Azure resource group.
+ :param vm_name: Name of the virtual machine.
+ :param target_state: Desired power state, e.g. ``running``, ``deallocated``.
+ :param azure_conn_id: Azure connection id.
+ :param poke_interval: Polling interval in seconds.
+ """
+
+ def __init__(
+ self,
+ resource_group_name: str,
+ vm_name: str,
+ target_state: str,
+ azure_conn_id: str = "azure_default",
+ poke_interval: float = 30.0,
+ ) -> None:
+ super().__init__()
+ self.resource_group_name = resource_group_name
+ self.vm_name = vm_name
+ self.target_state = target_state
+ self.azure_conn_id = azure_conn_id
+ self.poke_interval = poke_interval
+
+ def serialize(self) -> tuple[str, dict[str, Any]]:
+ """Serialize AzureVirtualMachineStateTrigger arguments and classpath."""
+ return (
+ f"{self.__class__.__module__}.{self.__class__.__name__}",
+ {
+ "resource_group_name": self.resource_group_name,
+ "vm_name": self.vm_name,
+ "target_state": self.target_state,
+ "azure_conn_id": self.azure_conn_id,
+ "poke_interval": self.poke_interval,
+ },
+ )
+
+ async def run(self) -> AsyncIterator[TriggerEvent]:
+ """Poll VM power state asynchronously until it matches the target state."""
+ from airflow.providers.microsoft.azure.hooks.compute import AzureComputeHook
+
+ try:
+ async with AzureComputeHook(azure_conn_id=self.azure_conn_id) as hook:
+ while True:
+ power_state = await hook.async_get_power_state(self.resource_group_name, self.vm_name)
+ if power_state == self.target_state:
+ message = f"VM {self.vm_name} reached state '{self.target_state}'."
+ yield TriggerEvent({"status": "success", "message": message})
+ return
+ self.log.info(
+ "VM %s power state: %s. Waiting for %s. Sleeping for %s seconds.",
+ self.vm_name,
+ power_state,
+ self.target_state,
+ self.poke_interval,
+ )
+ await asyncio.sleep(self.poke_interval)
+ except Exception as e:
+ yield TriggerEvent({"status": "error", "message": str(e)})
+ return
diff --git a/providers/microsoft/azure/tests/system/microsoft/azure/example_azure_compute.py b/providers/microsoft/azure/tests/system/microsoft/azure/example_azure_compute.py
new file mode 100644
index 0000000000000..116056d4ea12a
--- /dev/null
+++ b/providers/microsoft/azure/tests/system/microsoft/azure/example_azure_compute.py
@@ -0,0 +1,98 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import os
+from datetime import datetime
+
+from airflow import DAG
+from airflow.providers.microsoft.azure.operators.compute import (
+ AzureVirtualMachineRestartOperator,
+ AzureVirtualMachineStartOperator,
+ AzureVirtualMachineStopOperator,
+)
+from airflow.providers.microsoft.azure.sensors.compute import (
+ AzureVirtualMachineStateSensor,
+)
+
+DAG_ID = "example_azure_compute"
+RESOURCE_GROUP = os.environ.get("AZURE_RESOURCE_GROUP", "airflow-test-rg")
+VM_NAME = os.environ.get("AZURE_VM_NAME", "airflow-test-vm")
+
+with DAG(
+ dag_id=DAG_ID,
+ schedule=None,
+ start_date=datetime(2026, 1, 1),
+ catchup=False,
+ tags=["example", "azure", "compute"],
+) as dag:
+ # [START howto_operator_azure_vm_start]
+ start_vm = AzureVirtualMachineStartOperator(
+ task_id="start_vm",
+ resource_group_name=RESOURCE_GROUP,
+ vm_name=VM_NAME,
+ wait_for_completion=False,
+ )
+ # [END howto_operator_azure_vm_start]
+
+ # [START howto_sensor_azure_vm_state]
+ sense_running = AzureVirtualMachineStateSensor(
+ task_id="sense_running",
+ resource_group_name=RESOURCE_GROUP,
+ vm_name=VM_NAME,
+ target_state="running",
+ deferrable=True,
+ poke_interval=10,
+ timeout=300,
+ )
+ # [END howto_sensor_azure_vm_state]
+
+ # [START howto_operator_azure_vm_restart]
+ restart_vm = AzureVirtualMachineRestartOperator(
+ task_id="restart_vm",
+ resource_group_name=RESOURCE_GROUP,
+ vm_name=VM_NAME,
+ wait_for_completion=True,
+ )
+ # [END howto_operator_azure_vm_restart]
+
+ # [START howto_operator_azure_vm_stop]
+ stop_vm = AzureVirtualMachineStopOperator(
+ task_id="stop_vm",
+ resource_group_name=RESOURCE_GROUP,
+ vm_name=VM_NAME,
+ wait_for_completion=False,
+ )
+ # [END howto_operator_azure_vm_stop]
+
+ sense_deallocated = AzureVirtualMachineStateSensor(
+ task_id="sense_deallocated",
+ resource_group_name=RESOURCE_GROUP,
+ vm_name=VM_NAME,
+ target_state="deallocated",
+ deferrable=True,
+ poke_interval=10,
+ timeout=300,
+ )
+
+ start_vm >> sense_running >> restart_vm >> stop_vm >> sense_deallocated
+
+from tests_common.test_utils.system_tests import get_test_run # noqa: E402
+
+# Needed to run the example DAG with pytest (see: contributing-docs/testing/system_tests.rst)
+test_run = get_test_run(dag)
diff --git a/providers/microsoft/azure/tests/unit/microsoft/azure/hooks/test_compute.py b/providers/microsoft/azure/tests/unit/microsoft/azure/hooks/test_compute.py
new file mode 100644
index 0000000000000..9be756422b8bb
--- /dev/null
+++ b/providers/microsoft/azure/tests/unit/microsoft/azure/hooks/test_compute.py
@@ -0,0 +1,273 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest.mock import AsyncMock, MagicMock, patch
+
+import pytest
+
+from airflow.models import Connection
+from airflow.providers.microsoft.azure.hooks.compute import AzureComputeHook
+
+CONN_ID = "azure_compute_test"
+
+
+class TestAzureComputeHook:
+ @pytest.fixture(autouse=True)
+ def setup_test_cases(self, create_mock_connection):
+ create_mock_connection(
+ Connection(
+ conn_id=CONN_ID,
+ conn_type="azure_compute",
+ login="client_id",
+ password="client_secret",
+ extra={
+ "tenantId": "tenant_id",
+ "subscriptionId": "subscription_id",
+ },
+ )
+ )
+
+ @patch("airflow.providers.microsoft.azure.hooks.compute.ComputeManagementClient")
+ @patch("airflow.providers.microsoft.azure.hooks.compute.ClientSecretCredential")
+ def test_get_conn_returns_compute_management_client(self, mock_credential, mock_client):
+ hook = AzureComputeHook(azure_conn_id=CONN_ID)
+ result = hook.get_conn()
+
+ mock_credential.assert_called_once_with(
+ client_id="client_id",
+ client_secret="client_secret",
+ tenant_id="tenant_id",
+ )
+ mock_client.assert_called_once_with(
+ credential=mock_credential.return_value,
+ subscription_id="subscription_id",
+ )
+ assert result == mock_client.return_value
+
+ @patch("airflow.providers.microsoft.azure.hooks.compute.ComputeManagementClient")
+ @patch("airflow.providers.microsoft.azure.hooks.compute.get_sync_default_azure_credential")
+ def test_get_conn_with_default_azure_credential(
+ self, mock_default_cred, mock_client, create_mock_connection
+ ):
+ create_mock_connection(
+ Connection(
+ conn_id="azure_no_login",
+ conn_type="azure_compute",
+ extra={"subscriptionId": "sub_id"},
+ )
+ )
+ hook = AzureComputeHook(azure_conn_id="azure_no_login")
+ hook.get_conn()
+
+ mock_default_cred.assert_called_once_with(
+ managed_identity_client_id=None,
+ workload_identity_tenant_id=None,
+ )
+ mock_client.assert_called_once_with(
+ credential=mock_default_cred.return_value,
+ subscription_id="sub_id",
+ )
+
+ @patch("airflow.providers.microsoft.azure.hooks.compute.get_client_from_auth_file")
+ def test_get_conn_with_key_path(self, mock_get_client_from_auth_file, create_mock_connection):
+ create_mock_connection(
+ Connection(
+ conn_id="azure_with_key_path",
+ conn_type="azure_compute",
+ extra={"key_path": "/tmp/azure-key.json"},
+ )
+ )
+ hook = AzureComputeHook(azure_conn_id="azure_with_key_path")
+ conn = hook.get_conn()
+
+ mock_get_client_from_auth_file.assert_called_once_with(
+ client_class=hook.sdk_client, auth_path="/tmp/azure-key.json"
+ )
+ assert conn == mock_get_client_from_auth_file.return_value
+
+ @patch("airflow.providers.microsoft.azure.hooks.compute.get_client_from_json_dict")
+ def test_get_conn_with_key_json(self, mock_get_client_from_json_dict, create_mock_connection):
+ create_mock_connection(
+ Connection(
+ conn_id="azure_with_key_json",
+ conn_type="azure_compute",
+ extra={"key_json": {"tenantId": "tenant", "subscriptionId": "sub"}},
+ )
+ )
+ hook = AzureComputeHook(azure_conn_id="azure_with_key_json")
+ hook.get_conn()
+
+ mock_get_client_from_json_dict.assert_called_once_with(
+ client_class=hook.sdk_client, config_dict={"tenantId": "tenant", "subscriptionId": "sub"}
+ )
+
+ @patch("airflow.providers.microsoft.azure.hooks.compute.ComputeManagementClient")
+ @patch("airflow.providers.microsoft.azure.hooks.compute.ClientSecretCredential")
+ def test_start_instance(self, mock_credential, mock_client):
+ mock_poller = MagicMock()
+ mock_client.return_value.virtual_machines.begin_start.return_value = mock_poller
+
+ hook = AzureComputeHook(azure_conn_id=CONN_ID)
+ hook.start_instance(resource_group_name="rg", vm_name="vm1")
+
+ mock_client.return_value.virtual_machines.begin_start.assert_called_once_with("rg", "vm1")
+ mock_poller.result.assert_called_once()
+ mock_poller.result.reset_mock()
+
+ hook.start_instance(resource_group_name="rg", vm_name="vm1", wait_for_completion=False)
+ mock_poller.result.assert_not_called()
+
+ @patch("airflow.providers.microsoft.azure.hooks.compute.ComputeManagementClient")
+ @patch("airflow.providers.microsoft.azure.hooks.compute.ClientSecretCredential")
+ def test_stop_instance(self, mock_credential, mock_client):
+ mock_poller = MagicMock()
+ mock_client.return_value.virtual_machines.begin_deallocate.return_value = mock_poller
+
+ hook = AzureComputeHook(azure_conn_id=CONN_ID)
+ hook.stop_instance(resource_group_name="rg", vm_name="vm1")
+
+ mock_client.return_value.virtual_machines.begin_deallocate.assert_called_once_with("rg", "vm1")
+ mock_poller.result.assert_called_once()
+
+ @patch("airflow.providers.microsoft.azure.hooks.compute.ComputeManagementClient")
+ @patch("airflow.providers.microsoft.azure.hooks.compute.ClientSecretCredential")
+ def test_restart_instance(self, mock_credential, mock_client):
+ mock_poller = MagicMock()
+ mock_client.return_value.virtual_machines.begin_restart.return_value = mock_poller
+
+ hook = AzureComputeHook(azure_conn_id=CONN_ID)
+ hook.restart_instance(resource_group_name="rg", vm_name="vm1")
+
+ mock_client.return_value.virtual_machines.begin_restart.assert_called_once_with("rg", "vm1")
+ mock_poller.result.assert_called_once()
+
+ @pytest.mark.parametrize(
+ ("status_code", "expected"),
+ [
+ ("PowerState/running", "running"),
+ ("PowerState/deallocated", "deallocated"),
+ ("foo/bar", "unknown"),
+ (None, "unknown"),
+ ],
+ )
+ @patch("airflow.providers.microsoft.azure.hooks.compute.ComputeManagementClient")
+ @patch("airflow.providers.microsoft.azure.hooks.compute.ClientSecretCredential")
+ def test_get_power_state(self, mock_credential, mock_client, status_code, expected):
+ mock_instance_view = MagicMock()
+ mock_instance_view.statuses = [
+ MagicMock(code="ProvisioningState/succeeded"),
+ MagicMock(code=status_code),
+ ]
+ mock_client.return_value.virtual_machines.instance_view.return_value = mock_instance_view
+
+ hook = AzureComputeHook(azure_conn_id=CONN_ID)
+ state = hook.get_power_state(resource_group_name="rg", vm_name="vm1")
+
+ assert state == expected
+
+ @patch("airflow.providers.microsoft.azure.hooks.compute.ComputeManagementClient")
+ @patch("airflow.providers.microsoft.azure.hooks.compute.ClientSecretCredential")
+ def test_test_connection_success(self, mock_credential, mock_client):
+ mock_client.return_value.virtual_machines.list_all.return_value = iter([])
+
+ hook = AzureComputeHook(azure_conn_id=CONN_ID)
+ result, message = hook.test_connection()
+
+ assert result is True
+ assert "Successfully" in message
+
+ @patch("airflow.providers.microsoft.azure.hooks.compute.ComputeManagementClient")
+ @patch("airflow.providers.microsoft.azure.hooks.compute.ClientSecretCredential")
+ def test_test_connection_failure(self, mock_credential, mock_client):
+ mock_client.return_value.virtual_machines.list_all.side_effect = Exception("Auth failed")
+
+ hook = AzureComputeHook(azure_conn_id=CONN_ID)
+ result, message = hook.test_connection()
+
+ assert result is False
+ assert "Auth failed" in message
+
+ # ------------------------------------------------------------------
+ # Async interface tests
+ # ------------------------------------------------------------------
+
+ @pytest.mark.asyncio
+ @patch("airflow.providers.microsoft.azure.hooks.compute.AsyncComputeManagementClient")
+ @patch("airflow.providers.microsoft.azure.hooks.compute.AsyncClientSecretCredential")
+ @patch(
+ "airflow.providers.microsoft.azure.hooks.compute.get_async_connection",
+ new_callable=AsyncMock,
+ )
+ async def test_get_async_conn_with_service_principal(
+ self, mock_get_async_connection, mock_async_cred, mock_async_client
+ ):
+ mock_conn = MagicMock()
+ mock_conn.login = "client_id"
+ mock_conn.password = "client_secret"
+ mock_conn.extra_dejson = {"tenantId": "tenant_id", "subscriptionId": "sub_id"}
+ mock_get_async_connection.return_value = mock_conn
+
+ hook = AzureComputeHook(azure_conn_id=CONN_ID)
+ conn = await hook.get_async_conn()
+
+ mock_async_cred.assert_called_once_with(
+ client_id="client_id",
+ client_secret="client_secret",
+ tenant_id="tenant_id",
+ )
+ mock_async_client.assert_called_once_with(
+ credential=mock_async_cred.return_value,
+ subscription_id="sub_id",
+ )
+ assert conn == mock_async_client.return_value
+
+ @pytest.mark.asyncio
+ @patch(
+ "airflow.providers.microsoft.azure.hooks.compute.AzureComputeHook.get_async_conn",
+ new_callable=AsyncMock,
+ )
+ async def test_async_get_power_state(self, mock_get_async_conn):
+ mock_status_power = MagicMock(code="PowerState/running")
+ mock_status_prov = MagicMock(code="ProvisioningState/succeeded")
+ mock_instance_view = MagicMock(statuses=[mock_status_prov, mock_status_power])
+ mock_client = AsyncMock()
+ mock_client.virtual_machines.instance_view.return_value = mock_instance_view
+ mock_get_async_conn.return_value = mock_client
+
+ hook = AzureComputeHook(azure_conn_id=CONN_ID)
+ state = await hook.async_get_power_state("rg", "vm1")
+
+ assert state == "running"
+ mock_client.virtual_machines.instance_view.assert_called_once_with("rg", "vm1")
+
+ @pytest.mark.asyncio
+ @patch(
+ "airflow.providers.microsoft.azure.hooks.compute.AzureComputeHook.get_async_conn",
+ new_callable=AsyncMock,
+ )
+ async def test_async_context_manager_closes_conn(self, mock_get_async_conn):
+ mock_client = AsyncMock()
+ mock_get_async_conn.return_value = mock_client
+
+ hook = AzureComputeHook(azure_conn_id=CONN_ID)
+ async with hook as h:
+ assert h is hook
+ h._async_conn = mock_client
+
+ mock_client.close.assert_called_once()
diff --git a/providers/microsoft/azure/tests/unit/microsoft/azure/operators/test_compute.py b/providers/microsoft/azure/tests/unit/microsoft/azure/operators/test_compute.py
new file mode 100644
index 0000000000000..879dd219da7bc
--- /dev/null
+++ b/providers/microsoft/azure/tests/unit/microsoft/azure/operators/test_compute.py
@@ -0,0 +1,138 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest.mock import patch
+
+from airflow.providers.microsoft.azure.operators.compute import (
+ AzureVirtualMachineRestartOperator,
+ AzureVirtualMachineStartOperator,
+ AzureVirtualMachineStopOperator,
+)
+
+RESOURCE_GROUP = "test-rg"
+VM_NAME = "test-vm"
+CONN_ID = "azure_default"
+
+
+class TestAzureVirtualMachineStartOperator:
+ def test_init(self):
+ op = AzureVirtualMachineStartOperator(
+ task_id="start_vm",
+ resource_group_name=RESOURCE_GROUP,
+ vm_name=VM_NAME,
+ azure_conn_id=CONN_ID,
+ )
+ assert op.resource_group_name == RESOURCE_GROUP
+ assert op.vm_name == VM_NAME
+ assert op.wait_for_completion is True
+ assert op.azure_conn_id == CONN_ID
+
+ def test_template_fields(self):
+ op = AzureVirtualMachineStartOperator(
+ task_id="start_vm",
+ resource_group_name=RESOURCE_GROUP,
+ vm_name=VM_NAME,
+ )
+ assert "resource_group_name" in op.template_fields
+ assert "vm_name" in op.template_fields
+
+ @patch("airflow.providers.microsoft.azure.operators.compute.AzureComputeHook")
+ def test_execute_start_instance(self, mock_hook_cls):
+ op = AzureVirtualMachineStartOperator(
+ task_id="start_vm",
+ resource_group_name=RESOURCE_GROUP,
+ vm_name=VM_NAME,
+ )
+ op.execute(context=None)
+
+ mock_hook_cls.return_value.start_instance.assert_called_once_with(
+ resource_group_name=RESOURCE_GROUP,
+ vm_name=VM_NAME,
+ wait_for_completion=True,
+ )
+
+ @patch("airflow.providers.microsoft.azure.operators.compute.AzureComputeHook")
+ def test_execute_start_instance_no_wait(self, mock_hook_cls):
+ op = AzureVirtualMachineStartOperator(
+ task_id="start_vm",
+ resource_group_name=RESOURCE_GROUP,
+ vm_name=VM_NAME,
+ wait_for_completion=False,
+ )
+ op.execute(context=None)
+
+ mock_hook_cls.return_value.start_instance.assert_called_once_with(
+ resource_group_name=RESOURCE_GROUP,
+ vm_name=VM_NAME,
+ wait_for_completion=False,
+ )
+
+
+class TestAzureVirtualMachineStopOperator:
+ def test_init(self):
+ op = AzureVirtualMachineStopOperator(
+ task_id="stop_vm",
+ resource_group_name=RESOURCE_GROUP,
+ vm_name=VM_NAME,
+ )
+ assert op.resource_group_name == RESOURCE_GROUP
+ assert op.vm_name == VM_NAME
+ assert op.wait_for_completion is True
+
+ @patch("airflow.providers.microsoft.azure.operators.compute.AzureComputeHook")
+ def test_execute_stop_instance(self, mock_hook_cls):
+ op = AzureVirtualMachineStopOperator(
+ task_id="stop_vm",
+ resource_group_name=RESOURCE_GROUP,
+ vm_name=VM_NAME,
+ )
+ op.execute(context=None)
+
+ mock_hook_cls.return_value.stop_instance.assert_called_once_with(
+ resource_group_name=RESOURCE_GROUP,
+ vm_name=VM_NAME,
+ wait_for_completion=True,
+ )
+
+
+class TestAzureVirtualMachineRestartOperator:
+ def test_init(self):
+ op = AzureVirtualMachineRestartOperator(
+ task_id="restart_vm",
+ resource_group_name=RESOURCE_GROUP,
+ vm_name=VM_NAME,
+ )
+ assert op.resource_group_name == RESOURCE_GROUP
+ assert op.vm_name == VM_NAME
+ assert op.wait_for_completion is True
+
+ @patch("airflow.providers.microsoft.azure.operators.compute.AzureComputeHook")
+ def test_execute_restart_instance(self, mock_hook_cls):
+ op = AzureVirtualMachineRestartOperator(
+ task_id="restart_vm",
+ resource_group_name=RESOURCE_GROUP,
+ vm_name=VM_NAME,
+ )
+ op.execute(context=None)
+
+ mock_hook_cls.return_value.restart_instance.assert_called_once_with(
+ resource_group_name=RESOURCE_GROUP,
+ vm_name=VM_NAME,
+ wait_for_completion=True,
+ )
diff --git a/providers/microsoft/azure/tests/unit/microsoft/azure/sensors/test_compute.py b/providers/microsoft/azure/tests/unit/microsoft/azure/sensors/test_compute.py
new file mode 100644
index 0000000000000..47e4da37a8d46
--- /dev/null
+++ b/providers/microsoft/azure/tests/unit/microsoft/azure/sensors/test_compute.py
@@ -0,0 +1,123 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest.mock import patch
+
+import pytest
+
+from airflow.providers.common.compat.sdk import TaskDeferred
+from airflow.providers.microsoft.azure.sensors.compute import AzureVirtualMachineStateSensor
+
+RESOURCE_GROUP = "test-rg"
+VM_NAME = "test-vm"
+CONN_ID = "azure_default"
+
+
+class TestAzureVirtualMachineStateSensor:
+ def test_init(self):
+ sensor = AzureVirtualMachineStateSensor(
+ task_id="sense_vm",
+ resource_group_name=RESOURCE_GROUP,
+ vm_name=VM_NAME,
+ target_state="running",
+ azure_conn_id=CONN_ID,
+ )
+ assert sensor.resource_group_name == RESOURCE_GROUP
+ assert sensor.vm_name == VM_NAME
+ assert sensor.target_state == "running"
+ assert sensor.azure_conn_id == CONN_ID
+
+ def test_init_invalid_target_state(self):
+ with pytest.raises(ValueError, match="Invalid target_state"):
+ AzureVirtualMachineStateSensor(
+ task_id="sense_vm",
+ resource_group_name=RESOURCE_GROUP,
+ vm_name=VM_NAME,
+ target_state="invalid_state",
+ )
+
+ def test_template_fields(self):
+ sensor = AzureVirtualMachineStateSensor(
+ task_id="sense_vm",
+ resource_group_name=RESOURCE_GROUP,
+ vm_name=VM_NAME,
+ target_state="running",
+ )
+ assert "resource_group_name" in sensor.template_fields
+ assert "vm_name" in sensor.template_fields
+ assert "target_state" in sensor.template_fields
+
+ @pytest.mark.parametrize(
+ ("return_value", "expected"),
+ [
+ ("running", True),
+ ("deallocated", False),
+ ],
+ )
+ @patch("airflow.providers.microsoft.azure.sensors.compute.AzureComputeHook")
+ def test_poke(self, mock_hook_cls, return_value, expected):
+ mock_hook_cls.return_value.get_power_state.return_value = return_value
+
+ sensor = AzureVirtualMachineStateSensor(
+ task_id="sense_vm",
+ resource_group_name=RESOURCE_GROUP,
+ vm_name=VM_NAME,
+ target_state="running",
+ )
+ assert sensor.poke(context=None) is expected
+
+ @patch("airflow.providers.microsoft.azure.sensors.compute.AzureComputeHook")
+ def test_deferrable_mode(self, mock_hook_cls):
+ mock_hook_cls.return_value.get_power_state.return_value = "deallocated"
+
+ sensor = AzureVirtualMachineStateSensor(
+ task_id="sense_vm",
+ resource_group_name=RESOURCE_GROUP,
+ vm_name=VM_NAME,
+ target_state="running",
+ deferrable=True,
+ )
+ with pytest.raises(TaskDeferred):
+ sensor.execute(context=None)
+
+ def test_execute_complete_success(self):
+ sensor = AzureVirtualMachineStateSensor(
+ task_id="sense_vm",
+ resource_group_name=RESOURCE_GROUP,
+ vm_name=VM_NAME,
+ target_state="running",
+ )
+ # Should not raise
+ sensor.execute_complete(
+ context=None,
+ event={"status": "success", "message": "VM reached running state"},
+ )
+
+ def test_execute_complete_error(self):
+ sensor = AzureVirtualMachineStateSensor(
+ task_id="sense_vm",
+ resource_group_name=RESOURCE_GROUP,
+ vm_name=VM_NAME,
+ target_state="running",
+ )
+ with pytest.raises(RuntimeError, match="Something went wrong"):
+ sensor.execute_complete(
+ context=None,
+ event={"status": "error", "message": "Something went wrong"},
+ )
diff --git a/providers/microsoft/azure/tests/unit/microsoft/azure/triggers/test_compute.py b/providers/microsoft/azure/tests/unit/microsoft/azure/triggers/test_compute.py
new file mode 100644
index 0000000000000..80deba518fb6a
--- /dev/null
+++ b/providers/microsoft/azure/tests/unit/microsoft/azure/triggers/test_compute.py
@@ -0,0 +1,126 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest import mock
+
+import pytest
+
+from airflow.providers.microsoft.azure.triggers.compute import AzureVirtualMachineStateTrigger
+from airflow.triggers.base import TriggerEvent
+
+RESOURCE_GROUP = "test-rg"
+VM_NAME = "test-vm"
+TARGET_STATE = "running"
+CONN_ID = "azure_default"
+POKE_INTERVAL = 10.0
+
+
+class TestAzureVirtualMachineStateTrigger:
+ def test_serialize(self):
+ trigger = AzureVirtualMachineStateTrigger(
+ resource_group_name=RESOURCE_GROUP,
+ vm_name=VM_NAME,
+ target_state=TARGET_STATE,
+ azure_conn_id=CONN_ID,
+ poke_interval=POKE_INTERVAL,
+ )
+
+ actual = trigger.serialize()
+
+ assert isinstance(actual, tuple)
+ assert (
+ actual[0]
+ == f"{AzureVirtualMachineStateTrigger.__module__}.{AzureVirtualMachineStateTrigger.__name__}"
+ )
+ assert actual[1] == {
+ "resource_group_name": RESOURCE_GROUP,
+ "vm_name": VM_NAME,
+ "target_state": TARGET_STATE,
+ "azure_conn_id": CONN_ID,
+ "poke_interval": POKE_INTERVAL,
+ }
+
+ @pytest.mark.asyncio
+ @mock.patch(
+ "airflow.providers.microsoft.azure.hooks.compute.AzureComputeHook.async_get_power_state",
+ new_callable=mock.AsyncMock,
+ )
+ async def test_run_immediate_success(self, mock_get_power_state):
+ mock_get_power_state.return_value = TARGET_STATE
+
+ trigger = AzureVirtualMachineStateTrigger(
+ resource_group_name=RESOURCE_GROUP,
+ vm_name=VM_NAME,
+ target_state=TARGET_STATE,
+ azure_conn_id=CONN_ID,
+ poke_interval=POKE_INTERVAL,
+ )
+
+ generator = trigger.run()
+ response = await generator.asend(None)
+
+ assert response == TriggerEvent(
+ {"status": "success", "message": f"VM {VM_NAME} reached state '{TARGET_STATE}'."}
+ )
+
+ @pytest.mark.asyncio
+ @mock.patch("asyncio.sleep", return_value=None)
+ @mock.patch(
+ "airflow.providers.microsoft.azure.hooks.compute.AzureComputeHook.async_get_power_state",
+ new_callable=mock.AsyncMock,
+ )
+ async def test_run_polls_until_success(self, mock_get_power_state, mock_sleep):
+ mock_get_power_state.side_effect = ["deallocated", TARGET_STATE]
+
+ trigger = AzureVirtualMachineStateTrigger(
+ resource_group_name=RESOURCE_GROUP,
+ vm_name=VM_NAME,
+ target_state=TARGET_STATE,
+ azure_conn_id=CONN_ID,
+ poke_interval=POKE_INTERVAL,
+ )
+
+ generator = trigger.run()
+ response = await generator.asend(None)
+
+ assert mock_get_power_state.call_count == 2
+ assert response == TriggerEvent(
+ {"status": "success", "message": f"VM {VM_NAME} reached state '{TARGET_STATE}'."}
+ )
+
+ @pytest.mark.asyncio
+ @mock.patch(
+ "airflow.providers.microsoft.azure.hooks.compute.AzureComputeHook.async_get_power_state",
+ new_callable=mock.AsyncMock,
+ )
+ async def test_run_error(self, mock_get_power_state):
+ mock_get_power_state.side_effect = Exception("API error")
+
+ trigger = AzureVirtualMachineStateTrigger(
+ resource_group_name=RESOURCE_GROUP,
+ vm_name=VM_NAME,
+ target_state=TARGET_STATE,
+ azure_conn_id=CONN_ID,
+ poke_interval=POKE_INTERVAL,
+ )
+
+ generator = trigger.run()
+ response = await generator.asend(None)
+
+ assert response == TriggerEvent({"status": "error", "message": "API error"})