diff --git a/sdk/core/azure-core/CHANGELOG.md b/sdk/core/azure-core/CHANGELOG.md index d286b5a95434..9bc6e22c6cd5 100644 --- a/sdk/core/azure-core/CHANGELOG.md +++ b/sdk/core/azure-core/CHANGELOG.md @@ -1,12 +1,10 @@ # Release History -## 1.37.1 (Unreleased) - -### Features Added +## 1.38.0 (2026-01-08) ### Breaking Changes -### Bugs Fixed +- Changed the continuation token format. Continuation tokens generated by previous versions of azure-core are not compatible with this version. ### Other Changes diff --git a/sdk/core/azure-core/TROUBLESHOOTING.md b/sdk/core/azure-core/TROUBLESHOOTING.md new file mode 100644 index 000000000000..e2638fe241fe --- /dev/null +++ b/sdk/core/azure-core/TROUBLESHOOTING.md @@ -0,0 +1,40 @@ +# Troubleshooting Azure Core + +This document provides solutions to common issues you may encounter when using the Azure Core library. + +## Continuation Token Compatibility Issues + +### Error: "Continuation token from a previous version is not compatible" + +**Symptoms:** + +You may encounter an error message like: + +``` +ValueError: This continuation token is not compatible with this version of azure-core. It may have been generated by a previous version. +``` + +**Cause:** + +Starting from azure-core version 1.38.0, the continuation token format was changed. This change was made to improve security and portability. Continuation tokens are opaque strings and their internal format is not guaranteed to be stable across versions. + +Continuation tokens generated by previous versions of azure-core are not compatible with version 1.38.0 and later. + +**Solution:** + +Unfortunately, old continuation tokens cannot be migrated to the new version. You will need to: + +1. **Start a new long-running operation**: Instead of using the old continuation token, initiate a new request for your long-running operation. + +2. **Check operation status via Azure Portal or CLI**: If you need to check the status of an operation that was started with an old token, you can use the Azure Portal or Azure CLI to check the operation status directly. + +3. **Update or pin your dependencies**: Ensure that any new continuation tokens are generated and consumed using the same version of azure-core (1.38.0 or later). + +**Prevention:** + +To avoid this issue in the future: + +- When upgrading azure-core, ensure that any stored continuation tokens are either consumed before the upgrade or discarded. +- Design your application to handle the case where a continuation token may become invalid. + +For more information, see the [CHANGELOG](https://github.com/Azure/azure-sdk-for-python/blob/main/sdk/core/azure-core/CHANGELOG.md) for version 1.38.0. diff --git a/sdk/core/azure-core/azure/core/_version.py b/sdk/core/azure-core/azure/core/_version.py index 2b16c46cdc5d..5d327ea975bf 100644 --- a/sdk/core/azure-core/azure/core/_version.py +++ b/sdk/core/azure-core/azure/core/_version.py @@ -9,4 +9,4 @@ # regenerated. # -------------------------------------------------------------------------- -VERSION = "1.37.1" +VERSION = "1.38.0" diff --git a/sdk/core/azure-core/azure/core/polling/_poller.py b/sdk/core/azure-core/azure/core/polling/_poller.py index f984a5bbd995..13699b92d50a 100644 --- a/sdk/core/azure-core/azure/core/polling/_poller.py +++ b/sdk/core/azure-core/azure/core/polling/_poller.py @@ -23,7 +23,6 @@ # IN THE SOFTWARE. # # -------------------------------------------------------------------------- -import base64 import logging import threading import uuid @@ -31,6 +30,7 @@ from azure.core.exceptions import AzureError from azure.core.tracing.decorator import distributed_trace from azure.core.tracing.common import with_current_context +from ._utils import _encode_continuation_token, _decode_continuation_token PollingReturnType_co = TypeVar("PollingReturnType_co", covariant=True) @@ -162,9 +162,7 @@ def get_continuation_token(self) -> str: :rtype: str :return: An opaque continuation token """ - import pickle - - return base64.b64encode(pickle.dumps(self._initial_response)).decode("ascii") + return _encode_continuation_token(self._initial_response) @classmethod def from_continuation_token( @@ -182,9 +180,8 @@ def from_continuation_token( deserialization_callback = kwargs["deserialization_callback"] except KeyError: raise ValueError("Need kwarg 'deserialization_callback' to be recreated from continuation_token") from None - import pickle - initial_response = pickle.loads(base64.b64decode(continuation_token)) # nosec + initial_response = _decode_continuation_token(continuation_token) return None, initial_response, deserialization_callback diff --git a/sdk/core/azure-core/azure/core/polling/_utils.py b/sdk/core/azure-core/azure/core/polling/_utils.py new file mode 100644 index 000000000000..86d6907c0559 --- /dev/null +++ b/sdk/core/azure-core/azure/core/polling/_utils.py @@ -0,0 +1,140 @@ +# -------------------------------------------------------------------------- +# +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the ""Software""), to +# deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +# sell copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +# +# -------------------------------------------------------------------------- +"""Shared utilities for polling continuation token serialization.""" + +import base64 +import binascii +import json +from typing import Any, Dict, Mapping + + +# Current continuation token version +_CONTINUATION_TOKEN_VERSION = 1 + +# Error message for incompatible continuation tokens from older versions +_INCOMPATIBLE_TOKEN_ERROR_MESSAGE = ( + "This continuation token is not compatible with this version of azure-core. " + "It may have been generated by a previous version. " + "See https://aka.ms/azsdk/python/core/troubleshoot for more information." +) + +# Headers that are needed for LRO rehydration. +# We use an allowlist approach for security - only include headers we actually need. +_LRO_HEADERS = frozenset( + [ + "operation-location", + # azure-asyncoperation is included only for back compat with mgmt-core<=1.6.0 + "azure-asyncoperation", + "location", + "content-type", + "retry-after", + ] +) + + +def _filter_sensitive_headers(headers: Mapping[str, str]) -> Dict[str, str]: + """Filter headers to only include those needed for LRO rehydration. + + Uses an allowlist approach - only headers required for polling are included. + + :param headers: The headers to filter. + :type headers: Mapping[str, str] + :return: A new dictionary with only allowed headers. + :rtype: dict[str, str] + """ + return {k: v for k, v in headers.items() if k.lower() in _LRO_HEADERS} + + +def _is_pickle_format(data: bytes) -> bool: + """Check if the data appears to be in pickle format. + + Pickle protocol markers start with \\x80 followed by a protocol version byte (1-5). + + :param data: The bytes to check. + :type data: bytes + :return: True if the data appears to be pickled, False otherwise. + :rtype: bool + """ + if not data or len(data) < 2: + return False + # Check for pickle protocol marker (0x80) followed by protocol version 1-5 + return data[0:1] == b"\x80" and 1 <= data[1] <= 5 + + +def _decode_continuation_token(continuation_token: str) -> Dict[str, Any]: + """Decode a base64-encoded JSON continuation token. + + :param continuation_token: The base64-encoded continuation token. + :type continuation_token: str + :return: The decoded JSON data as a dictionary (the "data" field from the token). + :rtype: dict + :raises ValueError: If the token is invalid or in an unsupported format. + """ + try: + decoded_bytes = base64.b64decode(continuation_token) + token = json.loads(decoded_bytes.decode("utf-8")) + except binascii.Error: + # Invalid base64 input + raise ValueError("This doesn't look like a continuation token the sdk created.") from None + except (json.JSONDecodeError, UnicodeDecodeError): + # Check if the data appears to be from an older version + if _is_pickle_format(decoded_bytes): + raise ValueError(_INCOMPATIBLE_TOKEN_ERROR_MESSAGE) from None + raise ValueError("Invalid continuation token format.") from None + + # Validate token schema - must be a dict with a version field + if not isinstance(token, dict) or "version" not in token: + raise ValueError("Invalid continuation token format.") from None + + # For now, we only support version 1 + # Future versions can add handling for older versions here if needed + if token["version"] != _CONTINUATION_TOKEN_VERSION: + raise ValueError(_INCOMPATIBLE_TOKEN_ERROR_MESSAGE) from None + + return token["data"] + + +def _encode_continuation_token(data: Any) -> str: + """Encode data as a base64-encoded JSON continuation token. + + The token includes a version field for future compatibility checking. + + :param data: The data to encode. Must be JSON-serializable. + :type data: any + :return: The base64-encoded JSON string. + :rtype: str + :raises TypeError: If the data is not JSON-serializable. + """ + token = { + "version": _CONTINUATION_TOKEN_VERSION, + "data": data, + } + try: + return base64.b64encode(json.dumps(token, separators=(",", ":")).encode("utf-8")).decode("ascii") + except (TypeError, ValueError) as err: + raise TypeError( + "Unable to generate a continuation token for this operation. Payload is not JSON-serializable." + ) from err diff --git a/sdk/core/azure-core/azure/core/polling/base_polling.py b/sdk/core/azure-core/azure/core/polling/base_polling.py index 7bc5669d358b..b88070a59d8c 100644 --- a/sdk/core/azure-core/azure/core/polling/base_polling.py +++ b/sdk/core/azure-core/azure/core/polling/base_polling.py @@ -33,6 +33,7 @@ Tuple, Callable, Dict, + Mapping, Sequence, Generic, TypeVar, @@ -46,7 +47,9 @@ from ..pipeline._tools import is_rest from .._enum_meta import CaseInsensitiveEnumMeta from .. import PipelineClient -from ..pipeline import PipelineResponse +from ..pipeline import PipelineResponse, PipelineContext +from ..rest._helpers import decode_to_text, get_charset_encoding +from ..utils._utils import case_insensitive_dict from ..pipeline.transport import ( HttpTransport, HttpRequest as LegacyHttpRequest, @@ -54,6 +57,11 @@ AsyncHttpResponse as LegacyAsyncHttpResponse, ) from ..rest import HttpRequest, HttpResponse, AsyncHttpResponse +from ._utils import ( + _encode_continuation_token, + _decode_continuation_token, + _filter_sensitive_headers, +) HttpRequestType = Union[LegacyHttpRequest, HttpRequest] @@ -80,6 +88,56 @@ _SUCCEEDED = frozenset(["succeeded"]) +class _ContinuationTokenHttpResponse: + """A minimal HTTP response class for reconstructing responses from continuation tokens. + + This class provides just enough interface to be used with LRO polling operations + when restoring from a continuation token. + + :param request: The HTTP request (optional, may be None if not available in the continuation token) + :type request: ~azure.core.rest.HttpRequest or None + :param status_code: The HTTP status code + :type status_code: int + :param headers: The response headers + :type headers: dict + :param content: The response content + :type content: bytes + """ + + def __init__( + self, + request: Optional[HttpRequest], + status_code: int, + headers: Dict[str, str], + content: bytes, + ): + self.request = request + self.status_code = status_code + self.headers = case_insensitive_dict(headers) + self._content = content + + @property + def content(self) -> bytes: + """Return the response content. + + :return: The response content + :rtype: bytes + """ + return self._content + + def text(self) -> str: + """Return the response content as text. + + Uses the charset from Content-Type header if available, otherwise falls back + to UTF-8 with replacement for invalid characters. + + :return: The response content as text + :rtype: str + """ + encoding = get_charset_encoding(self) + return decode_to_text(encoding, self._content) + + def _get_content(response: AllHttpResponseType) -> bytes: """Get the content of this response. This is designed specifically to avoid a warning of mypy for body() access, as this method is deprecated. @@ -645,18 +703,82 @@ def initialize( except OperationFailed as err: raise HttpResponseError(response=initial_response.http_response, error=err) from err + def _filter_headers_for_continuation_token(self, headers: Mapping[str, str]) -> Dict[str, str]: + """Filter headers to include in the continuation token. + + Subclasses can override this method to include additional headers needed + for their specific LRO implementation. + + :param headers: The response headers to filter. + :type headers: Mapping[str, str] + :return: A filtered dictionary of headers to include in the continuation token. + :rtype: dict[str, str] + """ + return _filter_sensitive_headers(headers) + def get_continuation_token(self) -> str: """Get a continuation token that can be used to recreate this poller. - The continuation token is a base64 encoded string that contains the initial response - serialized with pickle. :rtype: str - :return: The continuation token. + :return: An opaque continuation token. :raises ValueError: If the initial response is not set. """ - import pickle - - return base64.b64encode(pickle.dumps(self._initial_response)).decode("ascii") + response = self._initial_response.http_response + request = response.request + # Serialize the essential parts of the PipelineResponse to JSON. + if request: + request_headers = {} + # Preserve x-ms-client-request-id for request correlation + if "x-ms-client-request-id" in request.headers: + request_headers["x-ms-client-request-id"] = request.headers["x-ms-client-request-id"] + request_state = { + "method": request.method, + "url": request.url, + "headers": request_headers, + } + else: + request_state = None + # Get response content, handling the case where it might not be read yet + try: + content = _get_content(response) or b"" + except Exception: # pylint: disable=broad-except + content = b"" + # Get deserialized data from context if available (optimization). + # If context doesn't have it, fall back to parsing the response body directly. + # Note: deserialized_data is only included if it's JSON-serializable. + # Non-JSON-serializable types (e.g., XML ElementTree) are skipped and set to None. + # In such cases, the data can still be re-parsed from the raw content bytes. + deserialized_data = None + raw_deserialized = None + if self._initial_response.context is not None: + raw_deserialized = self._initial_response.context.get("deserialized_data") + # Fallback: try to get deserialized data from the response body if context didn't have it + if raw_deserialized is None and content: + try: + raw_deserialized = json.loads(content) + except (json.JSONDecodeError, ValueError, TypeError): + # Response body is not valid JSON, leave as None + pass + if raw_deserialized is not None: + try: + # Test if the data is JSON-serializable + json.dumps(raw_deserialized) + deserialized_data = raw_deserialized + except (TypeError, ValueError): + # Skip non-JSON-serializable data (e.g., XML ElementTree objects) + deserialized_data = None + state = { + "request": request_state, + "response": { + "status_code": response.status_code, + "headers": self._filter_headers_for_continuation_token(response.headers), + "content": base64.b64encode(content).decode("ascii"), + }, + "context": { + "deserialized_data": deserialized_data, + }, + } + return _encode_continuation_token(state) @classmethod def from_continuation_token( @@ -681,11 +803,34 @@ def from_continuation_token( except KeyError: raise ValueError("Need kwarg 'deserialization_callback' to be recreated from continuation_token") from None - import pickle - - initial_response = pickle.loads(base64.b64decode(continuation_token)) # nosec - # Restore the transport in the context - initial_response.context.transport = client._pipeline._transport # pylint: disable=protected-access + state = _decode_continuation_token(continuation_token) + # Reconstruct HttpRequest if present + request_state = state.get("request") + http_request = None + if request_state is not None: + http_request = HttpRequest( + method=request_state["method"], + url=request_state["url"], + headers=request_state.get("headers", {}), + ) + # Reconstruct HttpResponse using the minimal response class + response_state = state["response"] + http_response = _ContinuationTokenHttpResponse( + request=http_request, + status_code=response_state["status_code"], + headers=response_state["headers"], + content=base64.b64decode(response_state["content"]), + ) + # Reconstruct PipelineResponse + context = PipelineContext(client._pipeline._transport) # pylint: disable=protected-access + context_state = state.get("context", {}) + if context_state.get("deserialized_data") is not None: + context["deserialized_data"] = context_state["deserialized_data"] + initial_response = PipelineResponse( + http_request=http_request, + http_response=http_response, + context=context, + ) return client, initial_response, deserialization_callback def status(self) -> str: diff --git a/sdk/core/azure-core/tests/async_tests/test_base_polling_async.py b/sdk/core/azure-core/tests/async_tests/test_base_polling_async.py index 16cbfd48f19f..56e5e44a4e24 100644 --- a/sdk/core/azure-core/tests/async_tests/test_base_polling_async.py +++ b/sdk/core/azure-core/tests/async_tests/test_base_polling_async.py @@ -25,7 +25,6 @@ # -------------------------------------------------------------------------- import base64 import json -import pickle import re from utils import HTTP_REQUESTS from azure.core.pipeline._tools import is_rest @@ -690,7 +689,13 @@ async def test_long_running_negative(http_request, http_response): poll = async_poller(CLIENT, response, TestBasePolling.mock_outputs, AsyncLROBasePolling(0)) with pytest.raises(HttpResponseError) as error: # TODO: Node.js raises on deserialization await poll - assert error.value.continuation_token == base64.b64encode(pickle.dumps(response)).decode("ascii") + # Verify continuation token is set and is a valid JSON-encoded token + assert error.value.continuation_token is not None + assert isinstance(error.value.continuation_token, str) + # Verify the token can be decoded + decoded = json.loads(base64.b64decode(error.value.continuation_token).decode("utf-8")) + assert "request" in decoded["data"] + assert "response" in decoded["data"] LOCATION_BODY = json.dumps({"name": TEST_NAME}) POLLING_STATUS = 200 diff --git a/sdk/core/azure-core/tests/test_base_polling.py b/sdk/core/azure-core/tests/test_base_polling.py index 3a4b6d75f135..a23f4824acf1 100644 --- a/sdk/core/azure-core/tests/test_base_polling.py +++ b/sdk/core/azure-core/tests/test_base_polling.py @@ -28,7 +28,6 @@ import json import re import types -import pickle import platform try: @@ -177,6 +176,67 @@ def test_base_polling_continuation_token(client, polling_response, http_response new_polling.initialize(*polling_args) +def test_base_polling_continuation_token_pickle_incompatibility(client): + """Test that from_continuation_token raises ValueError with helpful message for old pickle tokens.""" + import pickle + + # Simulate an old pickle-based continuation token (would have been a pickled PipelineResponse) + old_pickle_data = pickle.dumps({"some": "data"}) + old_continuation_token = base64.b64encode(old_pickle_data).decode("ascii") + + with pytest.raises(ValueError) as excinfo: + LROBasePolling.from_continuation_token( + old_continuation_token, + deserialization_callback=lambda x: x, + client=client, + ) + + error_message = str(excinfo.value) + assert "aka.ms" in error_message + + +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_base_polling_continuation_token_with_stream_response(port, http_request, deserialization_cb): + """Test that get_continuation_token works correctly with real server responses.""" + client = MockRestClient(port) + request = http_request( + "POST", + "http://localhost:{}/polling/continuation-token-stream".format(port), + ) + initial_response = client._client._pipeline.run(request) + + # Create polling operation + polling = LROBasePolling(timeout=0) + polling.initialize( + client._client, + initial_response, + deserialization_cb, + ) + + # get_continuation_token should work with real server response + continuation_token = polling.get_continuation_token() + assert isinstance(continuation_token, str) + + # Verify the token can be decoded and contains expected structure + decoded = json.loads(base64.b64decode(continuation_token).decode("utf-8")) + assert decoded["version"] == 1 + assert "request" in decoded["data"] + assert "response" in decoded["data"] + assert decoded["data"]["response"]["status_code"] == 202 + # Content should be preserved + content_bytes = base64.b64decode(decoded["data"]["response"]["content"]) + assert b"InProgress" in content_bytes + + # Verify we can restore from the continuation token + polling_args = LROBasePolling.from_continuation_token( + continuation_token, + deserialization_callback=deserialization_cb, + client=client._client, + ) + new_polling = LROBasePolling() + new_polling.initialize(*polling_args) + + @pytest.mark.parametrize("http_response", REQUESTS_TRANSPORT_RESPONSES) def test_delay_extraction_int(polling_response, http_response): polling = polling_response(http_response, {"Retry-After": "10"}) @@ -714,7 +774,13 @@ def test_long_running_negative(self, http_request, http_response): poll = LROPoller(CLIENT, response, TestBasePolling.mock_outputs, LROBasePolling(0)) with pytest.raises(HttpResponseError) as error: # TODO: Node.js raises on deserialization poll.result() - assert error.value.continuation_token == base64.b64encode(pickle.dumps(response)).decode("ascii") + # Verify continuation token is set and is a valid JSON-encoded token + assert error.value.continuation_token is not None + assert isinstance(error.value.continuation_token, str) + # Verify the token can be decoded + decoded = json.loads(base64.b64decode(error.value.continuation_token).decode("utf-8")) + assert "request" in decoded["data"] + assert "response" in decoded["data"] LOCATION_BODY = json.dumps({"name": TEST_NAME}) POLLING_STATUS = 200 @@ -833,3 +899,102 @@ def test_post_check_patch(http_request): with pytest.raises(AttributeError) as ex: algorithm.get_final_get_url(None) assert "'NoneType' object has no attribute 'http_response'" in str(ex.value) + + +def test_continuation_token_with_non_json_serializable_data(port, deserialization_cb): + """Test that continuation token gracefully handles non-JSON-serializable data like XML.""" + import base64 + import json + import xml.etree.ElementTree as ET + + from azure.core.polling.base_polling import LROBasePolling + from azure.core.rest import HttpRequest + + client = MockRestClient(port) + request = HttpRequest( + "POST", + "http://localhost:{}/polling/continuation-token-xml".format(port), + ) + initial_response = client._client._pipeline.run(request) + + # Simulate XML deserialized data (non-JSON-serializable) + xml_element = ET.fromstring(b"InProgress") + initial_response.context["deserialized_data"] = xml_element + + # Create polling operation + polling = LROBasePolling(timeout=0) + polling.initialize( + client._client, + initial_response, + deserialization_cb, + ) + + # Get continuation token - this should NOT raise an error + token = polling.get_continuation_token() + + # Decode and verify the token structure + decoded = json.loads(base64.b64decode(token).decode("utf-8")) + + # deserialized_data should be None because XML is not JSON-serializable + assert decoded["data"]["context"]["deserialized_data"] is None + + # The raw content should still be preserved + content_bytes = base64.b64decode(decoded["data"]["response"]["content"]) + assert b"" in content_bytes or b"InProgress" in content_bytes + + +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_continuation_token_excludes_request_headers(port, http_request, deserialization_cb): + """Test that continuation token does not include sensitive request headers for security.""" + import base64 + import json + + from azure.core.polling.base_polling import LROBasePolling + + client = MockRestClient(port) + request = http_request( + "POST", + "http://localhost:{}/polling/continuation-token".format(port), + ) + # Add headers that should NOT be included in the continuation token + request.headers["Authorization"] = "Bearer super-secret-token" + request.headers["x-ms-authorization-auxiliary"] = "auxiliary-secret" + request.headers["x-custom-header"] = "custom-value" + # Add header that SHOULD be included for request correlation + request.headers["x-ms-client-request-id"] = "test-request-id-12345" + + initial_response = client._client._pipeline.run(request) + + # Create polling operation + polling = LROBasePolling(timeout=0) + polling.initialize( + client._client, + initial_response, + deserialization_cb, + ) + + token = polling.get_continuation_token() + + # Decode and verify sensitive request headers are not included + decoded = json.loads(base64.b64decode(token).decode("utf-8")) + + # Request should contain method, url, and only safe headers (x-ms-client-request-id) + request_state = decoded["data"]["request"] + assert request_state["method"] == "POST" + assert "continuation-token" in request_state["url"] + # Only x-ms-client-request-id should be in headers + assert "headers" in request_state + assert request_state["headers"].get("x-ms-client-request-id") == "test-request-id-12345" + # Sensitive headers should NOT be included + assert "Authorization" not in request_state["headers"] + assert "x-ms-authorization-auxiliary" not in request_state["headers"] + assert "x-custom-header" not in request_state["headers"] + + # Verify we can restore from the continuation token + polling_args = LROBasePolling.from_continuation_token( + token, + deserialization_callback=deserialization_cb, + client=client._client, + ) + new_polling = LROBasePolling() + new_polling.initialize(*polling_args) diff --git a/sdk/core/azure-core/tests/test_polling.py b/sdk/core/azure-core/tests/test_polling.py index e3fb92e124b6..cdabd4575b90 100644 --- a/sdk/core/azure-core/tests/test_polling.py +++ b/sdk/core/azure-core/tests/test_polling.py @@ -23,6 +23,7 @@ # THE SOFTWARE. # # -------------------------------------------------------------------------- +import base64 import time try: @@ -99,6 +100,53 @@ def deserialization_cb(response): assert no_polling_revived.resource() == "Treated: " + initial_response +def test_no_polling_continuation_token_missing_callback(): + """Test that from_continuation_token raises ValueError when deserialization_callback is missing.""" + no_polling = NoPolling() + no_polling.initialize(None, "test", lambda x: x) + + continuation_token = no_polling.get_continuation_token() + + with pytest.raises(ValueError) as excinfo: + NoPolling.from_continuation_token(continuation_token) + assert "deserialization_callback" in str(excinfo.value) + + +def test_no_polling_continuation_token_pickle_incompatibility(): + """Test that from_continuation_token raises ValueError with helpful message for old pickle tokens.""" + import pickle + + # Simulate an old pickle-based continuation token + old_pickle_data = pickle.dumps({"some": "data"}) + old_continuation_token = base64.b64encode(old_pickle_data).decode("ascii") + + with pytest.raises(ValueError) as excinfo: + NoPolling.from_continuation_token(old_continuation_token, deserialization_callback=lambda x: x) + + error_message = str(excinfo.value) + assert "aka.ms" in error_message + + +def test_no_polling_continuation_token_non_serializable(): + """Test that get_continuation_token raises TypeError for non-JSON-serializable initial responses.""" + no_polling = NoPolling() + + # Create a non-JSON-serializable object + class CustomObject: + def __init__(self, value): + self.value = value + + initial_response = CustomObject("test") + + no_polling.initialize(None, initial_response, lambda x: x) + + with pytest.raises(TypeError) as excinfo: + no_polling.get_continuation_token() + + error_message = str(excinfo.value) + assert "not JSON-serializable" in error_message + + def test_polling_with_path_format_arguments(client): method = LROBasePolling(timeout=0, path_format_arguments={"host": "host:3000", "accountName": "local"}) client._base_url = "http://{accountName}{host}" diff --git a/sdk/core/azure-core/tests/testserver_tests/coretestserver/coretestserver/test_routes/polling.py b/sdk/core/azure-core/tests/testserver_tests/coretestserver/coretestserver/test_routes/polling.py index 005867a38e35..527fe77e37ff 100644 --- a/sdk/core/azure-core/tests/testserver_tests/coretestserver/coretestserver/test_routes/polling.py +++ b/sdk/core/azure-core/tests/testserver_tests/coretestserver/coretestserver/test_routes/polling.py @@ -169,3 +169,57 @@ def polling_with_options_first(): @polling_api.route("/final-get-with-location", methods=["GET"]) def polling_with_options_final_get_with_location(): return Response('{"returnedFrom": "locationHeaderUrl"}', status=200) + + +@polling_api.route("/continuation-token", methods=["POST"]) +def continuation_token_initial(): + """Initial LRO response for continuation token tests.""" + base_url = get_base_url(request) + return Response( + '{"properties":{"provisioningState": "InProgress"}}', + headers={ + "operation-location": "{}/polling/continuation-token-status".format(base_url), + "x-ms-request-id": "test-request-id-12345", + }, + status=202, + ) + + +@polling_api.route("/continuation-token-status", methods=["GET"]) +def continuation_token_status(): + """Status endpoint for continuation token tests.""" + return Response('{"status": "Succeeded"}', status=200) + + +@polling_api.route("/continuation-token-xml", methods=["POST"]) +def continuation_token_xml_initial(): + """Initial LRO response with XML body for continuation token tests.""" + base_url = get_base_url(request) + return Response( + "InProgress", + headers={ + "operation-location": "{}/polling/continuation-token-xml-status".format(base_url), + "content-type": "application/xml", + }, + status=202, + ) + + +@polling_api.route("/continuation-token-xml-status", methods=["GET"]) +def continuation_token_xml_status(): + """Status endpoint for XML continuation token tests.""" + return Response('{"status": "Succeeded"}', status=200) + + +@polling_api.route("/continuation-token-stream", methods=["POST"]) +def continuation_token_stream_initial(): + """Initial LRO response for stream continuation token tests.""" + base_url = get_base_url(request) + return Response( + '{"status": "InProgress"}', + headers={ + "operation-location": "{}/polling/continuation-token-status".format(base_url), + "content-type": "application/json", + }, + status=202, + ) diff --git a/sdk/core/azure-mgmt-core/tests/asynctests/test_async_arm_polling.py b/sdk/core/azure-mgmt-core/tests/asynctests/test_async_arm_polling.py index 278b03797349..d19c4f5ba7dc 100644 --- a/sdk/core/azure-mgmt-core/tests/asynctests/test_async_arm_polling.py +++ b/sdk/core/azure-mgmt-core/tests/asynctests/test_async_arm_polling.py @@ -523,10 +523,12 @@ async def test_long_running_negative(): LOCATION_BODY = "{" POLLING_STATUS = 203 response = TestArmPolling.mock_send("POST", 202, {"location": LOCATION_URL}) - poll = async_poller(CLIENT, response, TestArmPolling.mock_outputs, AsyncARMPolling(0)) + polling_method = AsyncARMPolling(0) + poll = async_poller(CLIENT, response, TestArmPolling.mock_outputs, polling_method) with pytest.raises(HttpResponseError) as error: # TODO: Node.js raises on deserialization await poll - assert error.value.continuation_token == base64.b64encode(pickle.dumps(response)).decode("ascii") + # Verify continuation token is set + assert error.value.continuation_token == polling_method.get_continuation_token() LOCATION_BODY = json.dumps({"name": TEST_NAME}) POLLING_STATUS = 200 diff --git a/sdk/core/azure-mgmt-core/tests/test_arm_polling.py b/sdk/core/azure-mgmt-core/tests/test_arm_polling.py index 5432b8955382..9c93b0e1ef76 100644 --- a/sdk/core/azure-mgmt-core/tests/test_arm_polling.py +++ b/sdk/core/azure-mgmt-core/tests/test_arm_polling.py @@ -528,7 +528,8 @@ def test_long_running_negative(self): poll = LROPoller(CLIENT, response, TestArmPolling.mock_outputs, ARMPolling(0)) with pytest.raises(HttpResponseError) as error: # TODO: Node.js raises on deserialization poll.result() - assert error.value.continuation_token == base64.b64encode(pickle.dumps(response)).decode("ascii") + # Verify continuation token is set + assert error.value.continuation_token == poll.continuation_token() LOCATION_BODY = json.dumps({"name": TEST_NAME}) POLLING_STATUS = 200