Skip to content

Commit 68de11e

Browse files
authored
refactor(feeder-gateway): clean up types + 0.14.0 support (#272)
This is primarily to update the feeder gateway objects to ensure it is compatible with the recent feeder API [upgrade] of Starknet 0.14.0 and to do some clean up to prevent redundant definitions whenever possible as most of them have similar network format as the RPC types. The 0.14.0 upgrade also introduces a new endpoint to the feeder gateway, `get_preconfirmed_block`, but I didn't manage to have it working (keep getting not found and bad request errors on both mainnet and sepolia). But, we're not using method anywhere yet so I don't include it in this commit. [upgrade]: https://community.starknet.io/t/sn-0-14-0-pre-release-notes/115618#p-2359352-feeder-api-15
1 parent 4f11144 commit 68de11e

File tree

11 files changed

+6041
-588
lines changed

11 files changed

+6041
-588
lines changed
Lines changed: 160 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1+
use std::str::FromStr;
2+
13
use katana_primitives::class::CasmContractClass;
24
use katana_primitives::Felt;
3-
use reqwest::header::{HeaderMap, HeaderValue};
4-
use reqwest::{Client, StatusCode};
5+
use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
6+
use reqwest::{Client, Method, Request, StatusCode};
57
use serde::de::DeserializeOwned;
68
use serde::{Deserialize, Serialize};
79
use tracing::error;
@@ -12,47 +14,11 @@ use crate::types::{Block, BlockId, ContractClass, StateUpdate, StateUpdateWithBl
1214
/// HTTP request header for the feeder gateway API key. This allow bypassing the rate limiting.
1315
const X_THROTTLING_BYPASS: &str = "X-Throttling-Bypass";
1416

15-
#[derive(Debug, thiserror::Error)]
16-
pub enum Error {
17-
#[error(transparent)]
18-
Network(reqwest::Error),
19-
20-
#[error(transparent)]
21-
Sequencer(SequencerError),
22-
23-
#[error("failed to parse header value '{value}'")]
24-
InvalidHeaderValue { value: String },
25-
26-
#[error("request rate limited")]
27-
RateLimited,
28-
}
29-
30-
impl From<reqwest::Error> for Error {
31-
fn from(err: reqwest::Error) -> Self {
32-
if let Some(status) = err.status() {
33-
if status == StatusCode::TOO_MANY_REQUESTS {
34-
return Self::RateLimited;
35-
}
36-
}
37-
38-
Self::Network(err)
39-
}
40-
}
41-
42-
impl Error {
43-
/// Returns `true` if the error is due to rate limiting.
44-
pub fn is_rate_limited(&self) -> bool {
45-
matches!(self, Self::RateLimited)
46-
}
47-
}
48-
4917
/// Client for interacting with the Starknet's feeder gateway.
5018
#[derive(Debug, Clone)]
5119
pub struct SequencerGateway {
5220
/// The feeder gateway base URL.
5321
base_url: Url,
54-
/// The HTTP client used to send the requests.
55-
http_client: Client,
5622
/// The API key used to bypass the rate limiting of the feeder gateway.
5723
api_key: Option<String>,
5824
}
@@ -74,9 +40,7 @@ impl SequencerGateway {
7440

7541
/// Creates a new gateway client at the given base URL.
7642
pub fn new(base_url: Url) -> Self {
77-
let api_key = None;
78-
let client = Client::new();
79-
Self { http_client: client, base_url, api_key }
43+
Self { base_url, api_key: None }
8044
}
8145

8246
/// Sets the API key.
@@ -86,28 +50,28 @@ impl SequencerGateway {
8650
}
8751

8852
pub async fn get_block(&self, block_id: BlockId) -> Result<Block, Error> {
89-
self.feeder_gateway("get_block").with_block_id(block_id).send().await
53+
self.feeder_gateway("get_block").block_id(block_id).send().await
9054
}
9155

9256
pub async fn get_state_update(&self, block_id: BlockId) -> Result<StateUpdate, Error> {
93-
self.feeder_gateway("get_state_update").with_block_id(block_id).send().await
57+
self.feeder_gateway("get_state_update").block_id(block_id).send().await
9458
}
9559

9660
pub async fn get_state_update_with_block(
9761
&self,
9862
block_id: BlockId,
9963
) -> Result<StateUpdateWithBlock, Error> {
10064
self.feeder_gateway("get_state_update")
101-
.add_query_param("includeBlock", "true")
102-
.with_block_id(block_id)
65+
.query_param("includeBlock", "true")
66+
.block_id(block_id)
10367
.send()
10468
.await
10569
}
10670

10771
pub async fn get_class(&self, hash: Felt, block_id: BlockId) -> Result<ContractClass, Error> {
10872
self.feeder_gateway("get_class_by_hash")
109-
.add_query_param("classHash", &format!("{hash:#x}"))
110-
.with_block_id(block_id)
73+
.query_param("classHash", &format!("{hash:#x}"))
74+
.block_id(block_id)
11175
.send()
11276
.await
11377
}
@@ -118,66 +82,80 @@ impl SequencerGateway {
11882
block_id: BlockId,
11983
) -> Result<CasmContractClass, Error> {
12084
self.feeder_gateway("get_compiled_class_by_class_hash")
121-
.add_query_param("classHash", &format!("{hash:#x}"))
122-
.with_block_id(block_id)
85+
.query_param("classHash", &format!("{hash:#x}"))
86+
.block_id(block_id)
12387
.send()
12488
.await
12589
}
12690

127-
fn feeder_gateway(&self, method: &str) -> RequestBuilder<'_> {
91+
/// Creates a [`RequestBuilder`] for a feeder gateway endpoint.
92+
///
93+
/// This method constructs a URL by appending "feeder_gateway" and the specified endpoint
94+
/// to the base URL, then returns a [`RequestBuilder`] that can be used to build and send the
95+
/// request.
96+
///
97+
/// ## Arguments
98+
///
99+
/// * `endpoint` - The specific feeder gateway endpoint to call (e.g., "get_block",
100+
/// "get_state_update")
101+
///
102+
/// # Example
103+
///
104+
/// ```rust,ignore
105+
/// let gateway = SequencerGateway::sn_mainnet();
106+
/// let request = gateway.feeder_gateway("get_block")
107+
/// .block_id(BlockId::Latest)
108+
/// .send()
109+
/// .await?;
110+
/// ```
111+
fn feeder_gateway(&self, endpoint: &str) -> RequestBuilder<'_> {
128112
let mut url = self.base_url.clone();
129-
url.path_segments_mut().expect("invalid base url").extend(["feeder_gateway", method]);
130-
RequestBuilder { gateway_client: self, url }
113+
url.path_segments_mut().expect("invalid base url").extend(["feeder_gateway", endpoint]);
114+
RequestBuilder::new(self, url)
131115
}
132116
}
133117

134-
#[derive(Debug, Deserialize)]
135-
#[serde(untagged)]
136-
enum Response<T> {
137-
Data(T),
138-
Error(SequencerError),
139-
}
140-
141-
#[derive(Debug, Clone)]
142-
struct RequestBuilder<'a> {
143-
gateway_client: &'a SequencerGateway,
144-
url: Url,
145-
}
118+
#[derive(Debug, thiserror::Error)]
119+
pub enum Error {
120+
#[error(transparent)]
121+
Network(reqwest::Error),
146122

147-
impl RequestBuilder<'_> {
148-
fn with_block_id(self, block_id: BlockId) -> Self {
149-
match block_id {
150-
// latest block is implied, if no block id specified
151-
BlockId::Latest => self,
152-
BlockId::Hash(hash) => self.add_query_param("blockHash", &format!("{hash:#x}")),
153-
BlockId::Number(num) => self.add_query_param("blockNumber", &num.to_string()),
154-
}
155-
}
123+
#[error(transparent)]
124+
Sequencer(SequencerError),
156125

157-
fn add_query_param(mut self, key: &str, value: &str) -> Self {
158-
self.url.query_pairs_mut().append_pair(key, value);
159-
self
160-
}
126+
#[error("failed to parse header value '{value}'")]
127+
InvalidHeaderValue { value: String },
161128

162-
async fn send<T: DeserializeOwned>(self) -> Result<T, Error> {
163-
let mut headers = HeaderMap::new();
129+
#[error("request rate limited")]
130+
RateLimited,
131+
}
164132

165-
if let Some(key) = self.gateway_client.api_key.as_ref() {
166-
let value = HeaderValue::from_str(key)
167-
.map_err(|_| Error::InvalidHeaderValue { value: key.to_string() })?;
168-
headers.insert(X_THROTTLING_BYPASS, value);
133+
impl From<reqwest::Error> for Error {
134+
fn from(err: reqwest::Error) -> Self {
135+
if let Some(status) = err.status() {
136+
if status == StatusCode::TOO_MANY_REQUESTS {
137+
return Self::RateLimited;
138+
}
169139
}
170140

171-
let request = self.gateway_client.http_client.get(self.url).headers(headers);
172-
let response = request.send().await?.error_for_status()?;
141+
Self::Network(err)
142+
}
143+
}
173144

174-
match response.json::<Response<T>>().await? {
175-
Response::Data(data) => Ok(data),
176-
Response::Error(error) => Err(Error::Sequencer(error)),
177-
}
145+
impl Error {
146+
/// Returns `true` if the error is due to rate limiting.
147+
pub fn is_rate_limited(&self) -> bool {
148+
matches!(self, Self::RateLimited)
178149
}
179150
}
180151

152+
#[derive(Debug, Deserialize)]
153+
#[serde(untagged)]
154+
pub enum Response<T> {
155+
Data(T),
156+
Error(SequencerError),
157+
}
158+
181159
#[derive(Debug, thiserror::Error, Deserialize)]
182160
#[error("{message} ({code:?})")]
183161
pub struct SequencerError {
@@ -221,6 +199,74 @@ pub enum ErrorCode {
221199
DeprecatedEndpoint,
222200
}
223201

202+
#[derive(Debug, Clone)]
203+
struct RequestBuilder<'a> {
204+
gateway_client: &'a SequencerGateway,
205+
block_id: Option<BlockId>,
206+
url: Url,
207+
}
208+
209+
impl<'a> RequestBuilder<'a> {
210+
fn new(gateway_client: &'a SequencerGateway, url: Url) -> Self {
211+
Self { gateway_client, block_id: None, url }
212+
}
213+
214+
fn block_id(mut self, block_id: BlockId) -> Self {
215+
self.block_id = Some(block_id);
216+
self
217+
}
218+
219+
/// Adds a query parameter to the request URL.
220+
fn query_param(mut self, key: &str, value: &str) -> Self {
221+
self.url.query_pairs_mut().append_pair(key, value);
222+
self
223+
}
224+
}
225+
226+
impl RequestBuilder<'_> {
227+
/// Send the request.
228+
async fn send<T: DeserializeOwned>(self) -> Result<T, Error> {
229+
let request = self.build()?;
230+
let response = Client::new().execute(request).await?.error_for_status()?;
231+
232+
match response.json::<Response<T>>().await? {
233+
Response::Data(data) => Ok(data),
234+
Response::Error(error) => Err(Error::Sequencer(error)),
235+
}
236+
}
237+
238+
/// Build the request.
239+
fn build(self) -> Result<Request, Error> {
240+
let mut url = self.url;
241+
242+
if let Some(id) = self.block_id {
243+
match id {
244+
BlockId::Hash(hash) => {
245+
url.query_pairs_mut().append_pair("blockHash", &format!("{hash:#x}"));
246+
}
247+
BlockId::Number(num) => {
248+
url.query_pairs_mut().append_pair("blockNumber", &num.to_string());
249+
}
250+
BlockId::Latest => {
251+
// latest block is implied, if no block id is specified
252+
}
253+
}
254+
}
255+
256+
let mut request = Request::new(Method::GET, url);
257+
258+
if let Some(value) = self.gateway_client.api_key.as_ref() {
259+
let key = HeaderName::from_str(X_THROTTLING_BYPASS).expect("valid header name");
260+
let value = HeaderValue::from_str(value)
261+
.map_err(|_| Error::InvalidHeaderValue { value: value.to_string() })?;
262+
263+
*request.headers_mut() = HeaderMap::from_iter([(key, value)]);
264+
}
265+
266+
Ok(request)
267+
}
268+
}
269+
224270
#[cfg(test)]
225271
mod tests {
226272

@@ -230,19 +276,22 @@ mod tests {
230276
fn request_block_id() {
231277
let base_url = Url::parse("https://example.com/").unwrap();
232278
let client = SequencerGateway::new(base_url);
233-
let req = client.feeder_gateway("test");
279+
let builder = client.feeder_gateway("test");
234280

235281
// Test block hash
236282
let hash = Felt::from(123);
237-
let hash_url = req.clone().with_block_id(BlockId::Hash(hash)).url;
283+
let req = builder.clone().block_id(BlockId::Hash(hash)).build().unwrap();
284+
let hash_url = req.url();
238285
assert_eq!(hash_url.query(), Some("blockHash=0x7b"));
239286

240287
// Test block number
241-
let num_url = req.clone().with_block_id(BlockId::Number(42)).url;
288+
let req = builder.clone().block_id(BlockId::Number(42)).build().unwrap();
289+
let num_url = req.url();
242290
assert_eq!(num_url.query(), Some("blockNumber=42"));
243291

244292
// Test latest block (should have no query params)
245-
let latest_url = req.with_block_id(BlockId::Latest).url;
293+
let req = builder.clone().block_id(BlockId::Latest).build().unwrap();
294+
let latest_url = req.url();
246295
assert_eq!(latest_url.query(), None);
247296
}
248297

@@ -253,9 +302,9 @@ mod tests {
253302
let req = client.feeder_gateway("test");
254303

255304
let url = req
256-
.add_query_param("param1", "value1")
257-
.add_query_param("param2", "value2")
258-
.add_query_param("param3", "value3")
305+
.query_param("param1", "value1")
306+
.query_param("param2", "value2")
307+
.query_param("param3", "value3")
259308
.url;
260309

261310
let query = url.query().unwrap();
@@ -265,19 +314,24 @@ mod tests {
265314
}
266315

267316
#[test]
268-
#[ignore]
269-
fn request_block_id_overwrite() {
270-
let base_url = Url::parse("https://example.com/").unwrap();
271-
let client = SequencerGateway::new(base_url);
272-
let req = client.feeder_gateway("test");
317+
fn api_key_header() {
318+
let url = Url::parse("https://example.com/").unwrap();
273319

274-
let url = req.clone().with_block_id(BlockId::Latest).with_block_id(BlockId::Number(42)).url;
320+
// Test with API key set
321+
let api_key = "test-api-key-12345";
322+
let client_with_key = SequencerGateway::new(url.clone()).with_api_key(api_key.to_string());
323+
let req = client_with_key.feeder_gateway("test").build().unwrap();
275324

276-
assert_eq!(url.query(), Some("blockNumber=42"));
325+
// Check that the X-Throttling-Bypass header is set with the correct API key
326+
let headers = req.headers();
327+
assert_eq!(headers.get(X_THROTTLING_BYPASS).unwrap().to_str().unwrap(), api_key);
277328

278-
let hash = Felt::from(123);
279-
let url = req.clone().with_block_id(BlockId::Hash(hash)).url;
329+
// Test without API key
330+
let client_without_key = SequencerGateway::new(url);
331+
let req = client_without_key.feeder_gateway("test").build().unwrap();
280332

281-
assert_eq!(url.query(), Some("blockNumber=pending"));
333+
// Check that the X-Throttling-Bypass header is not present
334+
let headers = req.headers();
335+
assert!(headers.get(X_THROTTLING_BYPASS).is_none());
282336
}
283337
}

0 commit comments

Comments
 (0)