Skip to content

Commit 4bccdbd

Browse files
heathsCopilot
andauthored
Use PagerState, PollerState instead of Option for callbacks (#2837)
* Use PagerState, PollerState instead of Option for callbacks For `ItemIterator`, `PageIterator`, and `Poller`, define a bespoke enum that works like an `Option` but that we can add to or redefine as needed, as will most likely be the case for `Poller`. * Apply suggestion from @Copilot Co-authored-by: Copilot <[email protected]> * Retain Option<String> option for reconstitution support --------- Co-authored-by: Copilot <[email protected]>
1 parent f035043 commit 4bccdbd

File tree

10 files changed

+188
-93
lines changed

10 files changed

+188
-93
lines changed

sdk/core/azure_core/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66

77
### Breaking Changes
88

9+
- `Pager::from_callback` and `PageIterator::from_callback` define a parameter of type `PagerState<C>` instead of `Option<C>`, where `None` => `Initial` and `Some(C)` => `More(C)`.
10+
- `Poller::from_callback` defines a parameter of type `PollerState<N>` instead of `Option<N>`, where `None` => `Initial` and `Some(N)` => `More(N)`.
11+
912
### Bugs Fixed
1013

1114
### Other Changes

sdk/core/azure_core/src/http/pager.rs

Lines changed: 89 additions & 31 deletions
Large diffs are not rendered by default.

sdk/core/azure_core/src/http/poller.rs

Lines changed: 44 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,40 @@ use std::{
2727
const DEFAULT_RETRY_TIME: Duration = Duration::seconds(30);
2828
const MIN_RETRY_TIME: Duration = Duration::seconds(1);
2929

30+
/// Represents the state of a [`Poller`].
31+
#[derive(Debug, Default, PartialEq, Eq)]
32+
pub enum PollerState<N> {
33+
/// The poller should fetch the initial status.
34+
#[default]
35+
Initial,
36+
/// The poller should fetch subsequent status.
37+
More(N),
38+
}
39+
40+
impl<N> PollerState<N> {
41+
/// Maps a [`PollerState<N>`] to a [`PollerState<U>`] by applying a function to a next link `N` (if `PollerState::More`) or returns `PollerState::Initial` (if `PollerState::Initial`).
42+
#[inline]
43+
pub fn map<U, F>(self, f: F) -> PollerState<U>
44+
where
45+
F: FnOnce(N) -> U,
46+
{
47+
match self {
48+
PollerState::Initial => PollerState::Initial,
49+
PollerState::More(c) => PollerState::More(f(c)),
50+
}
51+
}
52+
}
53+
54+
impl<N: Clone> Clone for PollerState<N> {
55+
#[inline]
56+
fn clone(&self) -> Self {
57+
match self {
58+
PollerState::Initial => PollerState::Initial,
59+
PollerState::More(c) => PollerState::More(c.clone()),
60+
}
61+
}
62+
}
63+
3064
/// Long-running operation (LRO) status.
3165
#[derive(Debug, Default, Clone, PartialEq, Eq)]
3266
pub enum PollerStatus {
@@ -193,7 +227,7 @@ where
193227
{
194228
/// Creates a [`Poller<M>`] from a callback that will be called repeatedly to monitor a long-running operation (LRO).
195229
///
196-
/// This method expects a callback that accepts a single `Option<N>` parameter, and returns a [`PollerResult<M, N>`] value asynchronously.
230+
/// This method expects a callback that accepts a single [`PollerState<N>`] parameter, and returns a [`PollerResult<M, N>`] value asynchronously.
197231
/// The `N` type parameter is the type of the next link/continuation token. It may be any [`Send`]able type.
198232
/// The `M` type parameter must implement [`StatusMonitor`].
199233
///
@@ -210,7 +244,7 @@ where
210244
/// To poll a long-running operation:
211245
///
212246
/// ```rust,no_run
213-
/// # use azure_core::{Result, http::{Context, Pipeline, RawResponse, Request, Response, Method, Url, poller::{Poller, PollerResult, PollerStatus, StatusMonitor}}, json};
247+
/// # use azure_core::{Result, http::{Context, Pipeline, RawResponse, Request, Response, Method, Url, poller::{Poller, PollerResult, PollerState, PollerStatus, StatusMonitor}}, json};
214248
/// # use serde::Deserialize;
215249
/// # let api_version = "2025-06-04".to_string();
216250
/// # let pipeline: Pipeline = panic!("Not a runnable example");
@@ -232,13 +266,13 @@ where
232266
/// let url = "https://example.com/my_operation".parse().unwrap();
233267
/// let mut req = Request::new(url, Method::Post);
234268
///
235-
/// let poller = Poller::from_callback(move |operation_url: Option<Url>| {
269+
/// let poller = Poller::from_callback(move |operation_url: PollerState<Url>| {
236270
/// // The callback must be 'static, so you have to clone and move any values you want to use.
237271
/// let pipeline = pipeline.clone();
238272
/// let api_version = api_version.clone();
239273
/// let mut req = req.clone();
240274
/// async move {
241-
/// if let Some(operation_url) = operation_url {
275+
/// if let PollerState::More(operation_url) = operation_url {
242276
/// // Use the operation URL for polling
243277
/// *req.url_mut() = operation_url;
244278
/// req.set_method(Method::Get);
@@ -275,10 +309,10 @@ where
275309
/// ```
276310
pub fn from_callback<
277311
#[cfg(not(target_arch = "wasm32"))] N: Send + 'static,
278-
#[cfg(not(target_arch = "wasm32"))] F: Fn(Option<N>) -> Fut + Send + 'static,
312+
#[cfg(not(target_arch = "wasm32"))] F: Fn(PollerState<N>) -> Fut + Send + 'static,
279313
#[cfg(not(target_arch = "wasm32"))] Fut: Future<Output = crate::Result<PollerResult<M, N>>> + Send + 'static,
280314
#[cfg(target_arch = "wasm32")] N: 'static,
281-
#[cfg(target_arch = "wasm32")] F: Fn(Option<N>) -> Fut + 'static,
315+
#[cfg(target_arch = "wasm32")] F: Fn(PollerState<N>) -> Fut + 'static,
282316
#[cfg(target_arch = "wasm32")] Fut: Future<Output = crate::Result<PollerResult<M, N>>> + 'static,
283317
>(
284318
make_request: F,
@@ -396,10 +430,10 @@ enum State<N> {
396430
fn create_poller_stream<
397431
M,
398432
#[cfg(not(target_arch = "wasm32"))] N: Send + 'static,
399-
#[cfg(not(target_arch = "wasm32"))] F: Fn(Option<N>) -> Fut + Send + 'static,
433+
#[cfg(not(target_arch = "wasm32"))] F: Fn(PollerState<N>) -> Fut + Send + 'static,
400434
#[cfg(not(target_arch = "wasm32"))] Fut: Future<Output = crate::Result<PollerResult<M, N>>> + Send + 'static,
401435
#[cfg(target_arch = "wasm32")] N: 'static,
402-
#[cfg(target_arch = "wasm32")] F: Fn(Option<N>) -> Fut + 'static,
436+
#[cfg(target_arch = "wasm32")] F: Fn(PollerState<N>) -> Fut + 'static,
403437
#[cfg(target_arch = "wasm32")] Fut: Future<Output = crate::Result<PollerResult<M, N>>> + 'static,
404438
>(
405439
make_request: F,
@@ -412,8 +446,8 @@ where
412446
(State::Init, make_request),
413447
|(state, make_request)| async move {
414448
let result = match state {
415-
State::Init => make_request(None).await,
416-
State::InProgress(n) => make_request(Some(n)).await,
449+
State::Init => make_request(PollerState::Initial).await,
450+
State::InProgress(n) => make_request(PollerState::More(n)).await,
417451
State::Done => return None,
418452
};
419453
let (item, next_state) = match result {

sdk/cosmos/azure_data_cosmos/src/pipeline/mod.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ pub use authorization_policy::AuthorizationPolicy;
1010
use azure_core::http::{
1111
request::{options::ContentType, Request},
1212
response::Response,
13-
ClientOptions, Context, Method, RawResponse,
13+
ClientOptions, Context, Method, PagerState, RawResponse,
1414
};
1515
use futures::TryStreamExt;
1616
use serde::de::DeserializeOwned;
@@ -101,7 +101,7 @@ impl CosmosPipeline {
101101
let mut req = base_request.clone();
102102
let ctx = ctx.clone();
103103
async move {
104-
if let Some(continuation) = continuation {
104+
if let PagerState::More(continuation) = continuation {
105105
req.insert_header(constants::CONTINUATION, continuation);
106106
}
107107

sdk/keyvault/azure_security_keyvault_certificates/src/clients.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ use crate::models::{
99
};
1010
use azure_core::{
1111
http::{
12-
poller::{get_retry_after, PollerResult, StatusMonitor as _},
12+
poller::{get_retry_after, PollerResult, PollerState, StatusMonitor as _},
1313
Body, Method, Poller, PollerStatus, RawResponse, Request, RequestContent, Url,
1414
},
1515
json, Result,
@@ -177,9 +177,9 @@ impl CertificateClientExt for CertificateClient {
177177
let parameters: Body = parameters.into();
178178

179179
Ok(Poller::from_callback(
180-
move |next_link: Option<Url>| {
180+
move |next_link: PollerState<Url>| {
181181
let (mut request, next_link) = match next_link {
182-
Some(next_link) => {
182+
PollerState::More(next_link) => {
183183
// Make sure the `api-version` is set appropriately.
184184
let qp = next_link
185185
.query_pairs()
@@ -196,7 +196,7 @@ impl CertificateClientExt for CertificateClient {
196196

197197
(request, next_link)
198198
}
199-
None => {
199+
PollerState::Initial => {
200200
let mut request = Request::new(url.clone(), Method::Post);
201201
request.insert_header("accept", "application/json");
202202
request.insert_header("content-type", "application/json");

sdk/keyvault/azure_security_keyvault_certificates/src/generated/clients/certificate_client.rs

Lines changed: 14 additions & 14 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

sdk/keyvault/azure_security_keyvault_keys/src/generated/clients/key_client.rs

Lines changed: 11 additions & 11 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)