Skip to content

Commit 2e10d68

Browse files
alubbadAya Lubbad
andauthored
feat: change default behavior for endpoints (#505)
* feat: change default behavior for endpoints * made default use cache_endpoint * fixed default access * fixing lint errors * removed override function * minor fixes * fix: renaming variables * fix: changed error messages --------- Co-authored-by: Aya Lubbad <aya@momentohq.com>
1 parent 419df26 commit 2e10d68

File tree

2 files changed

+84
-27
lines changed

2 files changed

+84
-27
lines changed

src/credential_provider.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ pub struct CredentialProvider {
3131
pub(crate) token_endpoint: String,
3232
pub(crate) endpoint_security: EndpointSecurity,
3333
pub(crate) use_private_endpoints: bool,
34+
pub(crate) use_endpoints_http_api: bool,
3435
}
3536

3637
impl Display for CredentialProvider {
@@ -53,6 +54,7 @@ impl Debug for CredentialProvider {
5354
.field("token_endpoint", &self.token_endpoint)
5455
.field("endpoint_security", &self.endpoint_security)
5556
.field("use_private_endpoints", &self.use_private_endpoints)
57+
.field("use_endpoints_http_api", &self.use_endpoints_http_api)
5658
.finish()
5759
}
5860
}
@@ -196,6 +198,14 @@ impl CredentialProvider {
196198
/// addresses to connect to.
197199
pub fn with_private_endpoints(mut self) -> CredentialProvider {
198200
self.use_private_endpoints = true;
201+
self.use_endpoints_http_api = true;
202+
self
203+
}
204+
/// Directs the ProtosocketCacheClient to look up public endpoints when discovering
205+
/// addresses to connect to.
206+
pub fn with_endpoints(mut self) -> CredentialProvider {
207+
self.use_private_endpoints = false;
208+
self.use_endpoints_http_api = true;
199209
self
200210
}
201211

@@ -236,6 +246,7 @@ fn process_v1_token(auth_token_bytes: Vec<u8>) -> MomentoResult<CredentialProvid
236246
token_endpoint: https_endpoint(get_token_endpoint(&json.endpoint)),
237247
endpoint_security: EndpointSecurity::Tls,
238248
use_private_endpoints: false,
249+
use_endpoints_http_api: false,
239250
})
240251
}
241252

src/protosocket/cache/connection_manager.rs

Lines changed: 73 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ use std::{
2626
};
2727

2828
use crate::{MomentoError, MomentoResult};
29+
use std::net::ToSocketAddrs;
2930

3031
#[derive(Debug)]
3132
struct BackgroundAddressLoader {
@@ -122,39 +123,75 @@ impl ClientConnector for ProtosocketConnectionManager {
122123
let address = match self.credential_provider.endpoint_security {
123124
EndpointSecurity::Tls => {
124125
log::debug!("selecting address from address provider for TLS endpoint");
125-
if self
126-
.address_provider
127-
.get_addresses()
128-
.for_az(self.az_id.as_deref())
129-
.is_empty()
130-
{
131-
if let Err(e) = self.address_provider.try_refresh_addresses().await {
132-
log::warn!("error refreshing address list: {e:?}");
126+
if self.credential_provider.use_endpoints_http_api {
127+
if self
128+
.address_provider
129+
.get_addresses()
130+
.for_az(self.az_id.as_deref())
131+
.is_empty()
132+
{
133+
if let Err(e) = self.address_provider.try_refresh_addresses().await {
134+
log::warn!("error refreshing address list: {e:?}");
135+
}
133136
}
137+
let addresses = self
138+
.address_provider
139+
.get_addresses()
140+
.for_az(self.az_id.as_deref());
141+
if addresses.is_empty() {
142+
return Err(protosocket_rpc::Error::IoFailure(
143+
std::io::Error::new(
144+
std::io::ErrorKind::AddrNotAvailable,
145+
"No addresses available from address provider",
146+
)
147+
.into(),
148+
));
149+
}
150+
addresses[self
151+
.connection_sequence
152+
.fetch_add(1, std::sync::atomic::Ordering::Relaxed)
153+
% addresses.len()]
154+
} else {
155+
// Use the modified cache_endpoint with :9004 appended and https:// prefix removed
156+
let mut cache_endpoint = self
157+
.credential_provider
158+
.cache_endpoint
159+
.strip_prefix("https://")
160+
.unwrap_or(&self.credential_provider.cache_endpoint)
161+
.to_string();
162+
cache_endpoint.push_str(":9004");
163+
164+
cache_endpoint
165+
.to_socket_addrs()
166+
.map_err(|e| {
167+
protosocket_rpc::Error::IoFailure(
168+
std::io::Error::other(format!(
169+
"could not parse address from endpoint: {}: {:?}",
170+
&self.credential_provider.cache_endpoint, e
171+
))
172+
.into(),
173+
)
174+
})?
175+
.next()
176+
.ok_or_else(|| {
177+
protosocket_rpc::Error::IoFailure(
178+
std::io::Error::new(
179+
std::io::ErrorKind::AddrNotAvailable,
180+
format!(
181+
"Unable to connect: endpoint '{}' did not resolve.",
182+
cache_endpoint
183+
),
184+
)
185+
.into(),
186+
)
187+
})?
134188
}
135-
let addresses = self
136-
.address_provider
137-
.get_addresses()
138-
.for_az(self.az_id.as_deref());
139-
if addresses.is_empty() {
140-
return Err(protosocket_rpc::Error::IoFailure(
141-
std::io::Error::new(
142-
std::io::ErrorKind::AddrNotAvailable,
143-
"No addresses available from address provider",
144-
)
145-
.into(),
146-
));
147-
}
148-
addresses[self
149-
.connection_sequence
150-
.fetch_add(1, std::sync::atomic::Ordering::Relaxed)
151-
% addresses.len()]
152189
}
153190
_ => {
154191
log::debug!("using endpoint address directly for endpoint override");
155192
self.credential_provider
156193
.cache_endpoint
157-
.parse()
194+
.to_socket_addrs()
158195
.map_err(|e| {
159196
protosocket_rpc::Error::IoFailure(
160197
std::io::Error::other(format!(
@@ -164,10 +201,19 @@ impl ClientConnector for ProtosocketConnectionManager {
164201
.into(),
165202
)
166203
})?
204+
.next()
205+
.ok_or_else(|| {
206+
protosocket_rpc::Error::IoFailure(
207+
std::io::Error::new(
208+
std::io::ErrorKind::AddrNotAvailable,
209+
"Failed to resolve endpoint hostname into a valid address",
210+
)
211+
.into(),
212+
)
213+
})?
167214
}
168215
};
169216
log::debug!("connecting over protosocket to {address}");
170-
171217
let unauthenticated_client = create_protosocket_connection(
172218
self.credential_provider.clone(),
173219
self.runtime.clone(),

0 commit comments

Comments
 (0)