Skip to content

Commit e24a554

Browse files
committed
rpc: refactor controller deployment middleware flow
1 parent 6a38439 commit e24a554

File tree

1 file changed

+151
-120
lines changed

1 file changed

+151
-120
lines changed

crates/rpc/rpc-server/src/middleware/cartridge.rs

Lines changed: 151 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use std::borrow::Cow;
2+
use std::collections::HashSet;
23
use std::future::Future;
34

45
use cartridge::CartridgeApiClient;
@@ -19,6 +20,7 @@ use katana_rpc_api::error::cartridge::CartridgeApiError;
1920
use katana_rpc_api::error::starknet::StarknetApiError;
2021
use katana_rpc_types::broadcasted::{BroadcastedTx, BroadcastedTxWithChainId};
2122
use katana_rpc_types::{BroadcastedInvokeTx, FeeEstimate, FeeSource, OutsideExecution};
23+
use serde::de::DeserializeOwned;
2224
use serde::Deserialize;
2325
use starknet::core::types::SimulationFlagForEstimateFee;
2426
use starknet::macros::selector;
@@ -116,11 +118,41 @@ where
116118
PF: ProviderFactory,
117119
<PF as ProviderFactory>::Provider: ProviderRO,
118120
{
119-
// if `handle_estimate_fees` has added some new transactions at the
120-
// beginning of updated_txs, we have to remove
121-
// extras results from estimate_fees to be
122-
// sure to return the same number of result than the number
123-
// of transactions in the request.
121+
fn controller_deployment_error(reason: impl Into<String>) -> CartridgeApiError {
122+
CartridgeApiError::ControllerDeployment { reason: reason.into() }
123+
}
124+
125+
fn estimate_fee_candidate_addresses(transactions: &[BroadcastedTx]) -> Vec<ContractAddress> {
126+
transactions
127+
.iter()
128+
.filter_map(|tx| match tx {
129+
BroadcastedTx::Invoke(tx) => Some(tx.sender_address),
130+
BroadcastedTx::Declare(tx) => Some(tx.sender_address),
131+
_ => None,
132+
})
133+
.collect()
134+
}
135+
136+
fn build_estimate_fee_request<'a>(
137+
request: &Request<'a>,
138+
transactions: Vec<BroadcastedTx>,
139+
simulation_flags: Vec<SimulationFlagForEstimateFee>,
140+
block_id: BlockIdOrTag,
141+
) -> Result<Request<'a>, CartridgeApiError> {
142+
let params = rpc_params!(transactions, simulation_flags, block_id);
143+
let params = params.to_rpc_params().map_err(|err| {
144+
Self::controller_deployment_error(format!(
145+
"failed to serialize augmented estimateFee params: {err}"
146+
))
147+
})?;
148+
149+
let mut new_request = request.clone();
150+
new_request.params = params.map(Cow::Owned);
151+
152+
Ok(new_request)
153+
}
154+
155+
// If deployment txs are added, return the no-fee estimates for the original requests only.
124156
async fn starknet_estimate_fee<'a>(
125157
&self,
126158
params: EstimateFeeParams,
@@ -151,56 +183,51 @@ where
151183
request: Request<'a>,
152184
) -> Result<S::MethodResponse, CartridgeApiError> {
153185
let EstimateFeeParams { block_id, simulation_flags, transactions } = params;
154-
155-
let mut undeployed_addresses: Vec<ContractAddress> = Vec::new();
156-
157-
// iterate thru all txs and deploy any undeployed contract (if they are a Controller)
158-
for tx in &transactions {
159-
let address = match tx {
160-
BroadcastedTx::Invoke(tx) => tx.sender_address,
161-
BroadcastedTx::Declare(tx) => tx.sender_address,
162-
_ => continue,
163-
};
164-
165-
undeployed_addresses.push(address);
166-
}
167-
168-
let deployer_nonce =
169-
self.context.starknet.nonce_at(block_id, self.context.deployer_address).await.unwrap();
170-
let deploy_controller_txs =
171-
self.get_controller_deployment_txs(undeployed_addresses, deployer_nonce).await.unwrap();
186+
let candidate_addresses = Self::estimate_fee_candidate_addresses(&transactions);
187+
188+
let deployer_nonce = self
189+
.context
190+
.starknet
191+
.nonce_at(block_id, self.context.deployer_address)
192+
.await
193+
.map_err(|err| {
194+
Self::controller_deployment_error(format!("failed to get deployer nonce: {err}"))
195+
})?;
196+
let deploy_controller_txs = self
197+
.get_controller_deployment_txs(candidate_addresses, deployer_nonce)
198+
.await
199+
.map_err(|err| Self::controller_deployment_error(err.to_string()))?;
172200

173201
// no Controller to deploy, simply forward the request
174202
if deploy_controller_txs.is_empty() {
175203
return Ok(self.service.call(request).await);
176204
}
177205

178206
let original_txs_count = transactions.len();
179-
let deploy_controller_txs_count = deploy_controller_txs.len();
180-
181207
let new_txs = [deploy_controller_txs, transactions].concat();
182208
let new_txs_count = new_txs.len();
183-
184-
// craft a new estimate fee request with the deploy Controller txs included
185-
let new_request = {
186-
let params = rpc_params!(new_txs, simulation_flags, block_id);
187-
let params = params.to_rpc_params().unwrap();
188-
189-
let mut new_request = request.clone();
190-
new_request.params = params.map(Cow::Owned);
191-
192-
new_request
193-
};
209+
let new_request =
210+
Self::build_estimate_fee_request(&request, new_txs, simulation_flags, block_id)?;
194211

195212
let response = self.service.call(new_request).await;
196-
197-
let res = response.as_json().get();
198-
let res = serde_json::from_str::<Response<'_, Vec<FeeEstimate>>>(res).unwrap();
213+
let response_body = response.as_json().get();
214+
let res = serde_json::from_str::<Response<'_, Vec<FeeEstimate>>>(response_body).map_err(
215+
|err| {
216+
Self::controller_deployment_error(format!(
217+
"failed to parse estimateFee response: {err}"
218+
))
219+
},
220+
)?;
199221

200222
match res.payload {
201-
ResponsePayload::Success(mut estimates) => {
202-
assert_eq!(estimates.len(), new_txs_count);
203-
estimates.to_mut().drain(0..deploy_controller_txs_count);
223+
ResponsePayload::Success(estimates) => {
224+
if estimates.len() != new_txs_count {
225+
return Err(Self::controller_deployment_error(format!(
226+
"unexpected estimateFee response length: expected {new_txs_count}, got {}",
227+
estimates.len()
228+
)));
229+
}
230+
204231
Ok(build_no_fee_response(&request, original_txs_count))
205232
}
206233

@@ -216,7 +243,8 @@ where
216243
let block_id = BlockIdOrTag::PreConfirmed;
217244

218245
// check if the address has already been deployed.
219-
let is_deployed = match self.context.starknet.class_hash_at_address(block_id, address).await {
246+
let is_deployed = match self.context.starknet.class_hash_at_address(block_id, address).await
247+
{
220248
Ok(..) => true,
221249
Err(StarknetApiError::ContractNotFound) => false,
222250

@@ -231,23 +259,18 @@ where
231259
return Ok(());
232260
}
233261

234-
let result = self.context.starknet.nonce_at(block_id, self.context.deployer_address).await;
235-
let nonce = match result {
236-
Ok(nonce) => nonce,
237-
Err(e) => {
238-
return Err(CartridgeApiError::ControllerDeployment {
239-
reason: format!("failed to get deployer nonce: {e}"),
240-
});
241-
}
242-
};
243-
244-
let result = self.get_controller_deployment_tx(address, nonce).await;
245-
let deploy_tx = match result {
246-
Ok(tx) => tx,
247-
Err(e) => {
248-
return Err(CartridgeApiError::ControllerDeployment { reason: e.to_string() });
249-
}
250-
};
262+
let nonce = self
263+
.context
264+
.starknet
265+
.nonce_at(block_id, self.context.deployer_address)
266+
.await
267+
.map_err(|err| {
268+
Self::controller_deployment_error(format!("failed to get deployer nonce: {err}"))
269+
})?;
270+
let deploy_tx = self
271+
.get_controller_deployment_tx(address, nonce)
272+
.await
273+
.map_err(|err| Self::controller_deployment_error(err.to_string()))?;
251274

252275
// None means the address is not of a Controller
253276
if let Some(tx) = deploy_tx {
@@ -263,15 +286,15 @@ where
263286

264287
async fn get_controller_deployment_txs(
265288
&self,
266-
controller_addreses: Vec<ContractAddress>,
289+
controller_addresses: Vec<ContractAddress>,
267290
initial_nonce: Nonce,
268291
) -> Result<Vec<BroadcastedTx>, Error> {
269292
let mut deploy_transactions: Vec<BroadcastedTx> = Vec::new();
270-
let mut processed_addresses: Vec<ContractAddress> = Vec::new();
293+
let mut processed_addresses: HashSet<ContractAddress> = HashSet::new();
271294

272295
let mut deployer_nonce = initial_nonce;
273296

274-
for address in controller_addreses {
297+
for address in controller_addresses {
275298
// If the address has already been processed in this txs batch, just skip.
276299
if processed_addresses.contains(&address) {
277300
continue;
@@ -282,7 +305,7 @@ where
282305
// None means the address is not a Controller
283306
if let Some(tx) = deploy_tx {
284307
deployer_nonce += Nonce::ONE;
285-
processed_addresses.push(address);
308+
processed_addresses.insert(address);
286309
deploy_transactions.push(BroadcastedTx::Invoke(tx));
287310
}
288311
}
@@ -426,73 +449,81 @@ struct EstimateFeeParams {
426449
block_id: BlockIdOrTag,
427450
}
428451

429-
fn parse_execute_outside_params(request: &Request<'_>) -> Option<AddExecuteOutsideParams> {
430-
let params = request.params();
431-
432-
if params.is_object() {
433-
match params.parse() {
434-
Ok(p) => Some(p),
435-
Err(..) => {
436-
debug!(target: "cartridge", "Failed to parse execute outside params.");
437-
None
438-
}
439-
}
440-
} else {
441-
let mut seq = params.sequence();
442-
443-
let address: Result<ContractAddress, _> = seq.next();
444-
let outside_execution: Result<OutsideExecution, _> = seq.next();
445-
let signature: Result<Vec<Felt>, _> = seq.next();
446-
let fee_source: Result<Option<FeeSource>, _> = seq.next();
447-
448-
match (address, outside_execution, signature) {
449-
(Ok(address), Ok(outside_execution), Ok(signature)) => Some(AddExecuteOutsideParams {
450-
address,
451-
outside_execution,
452-
signature,
453-
fee_source: fee_source.ok().flatten(),
454-
}),
455-
_ => {
456-
debug!(target: "cartridge", "Failed to parse execute outside params.");
457-
None
458-
}
452+
#[derive(Deserialize)]
453+
struct AddExecuteOutsidePositionalParams(
454+
ContractAddress,
455+
OutsideExecution,
456+
Vec<Felt>,
457+
#[serde(default)] Option<FeeSource>,
458+
);
459+
460+
#[derive(Deserialize)]
461+
#[serde(untagged)]
462+
enum AddExecuteOutsideRequestParams {
463+
Named(AddExecuteOutsideParams),
464+
Positional(AddExecuteOutsidePositionalParams),
465+
}
466+
467+
impl From<AddExecuteOutsideRequestParams> for AddExecuteOutsideParams {
468+
fn from(value: AddExecuteOutsideRequestParams) -> Self {
469+
match value {
470+
AddExecuteOutsideRequestParams::Named(params) => params,
471+
AddExecuteOutsideRequestParams::Positional(params) => Self {
472+
address: params.0,
473+
outside_execution: params.1,
474+
signature: params.2,
475+
fee_source: params.3,
476+
},
459477
}
460478
}
461479
}
462480

463-
/// Extract estimate_fee parameters from the request.
464-
fn parse_estimate_fee_params(request: &Request<'_>) -> Option<EstimateFeeParams> {
465-
let params = request.params();
466-
467-
if params.is_object() {
468-
match params.parse() {
469-
Ok(p) => Some(p),
470-
Err(..) => {
471-
debug!(target: "cartridge", "Failed to parse estimate fee params.");
472-
None
473-
}
474-
}
475-
} else {
476-
let mut seq = params.sequence();
481+
#[derive(Deserialize)]
482+
struct EstimateFeePositionalParams(
483+
Vec<BroadcastedTx>,
484+
Vec<SimulationFlagForEstimateFee>,
485+
BlockIdOrTag,
486+
);
477487

478-
let txs_result: Result<Vec<BroadcastedTx>, _> = seq.next();
479-
let simulation_flags_result: Result<Vec<SimulationFlagForEstimateFee>, _> = seq.next();
480-
let block_id_result: Result<BlockIdOrTag, _> = seq.next();
488+
#[derive(Deserialize)]
489+
#[serde(untagged)]
490+
enum EstimateFeeRequestParams {
491+
Named(EstimateFeeParams),
492+
Positional(EstimateFeePositionalParams),
493+
}
481494

482-
match (txs_result, simulation_flags_result, block_id_result) {
483-
(Ok(txs), Ok(simulation_flags), Ok(block_id)) => {
484-
Some(EstimateFeeParams { transactions: txs, simulation_flags, block_id })
485-
}
486-
_ => {
487-
debug!(target: "cartridge", "Failed to parse estimate fee params.");
488-
None
495+
impl From<EstimateFeeRequestParams> for EstimateFeeParams {
496+
fn from(value: EstimateFeeRequestParams) -> Self {
497+
match value {
498+
EstimateFeeRequestParams::Named(params) => params,
499+
EstimateFeeRequestParams::Positional(params) => {
500+
Self { transactions: params.0, simulation_flags: params.1, block_id: params.2 }
489501
}
490502
}
491503
}
492504
}
493505

494-
// <--- TODO: this function should be removed once estimateFee will return 0 fees
495-
// when --dev.no-fee is used.
506+
fn parse_params<T: DeserializeOwned>(request: &Request<'_>, method: &str) -> Option<T> {
507+
match request.params().parse() {
508+
Ok(params) => Some(params),
509+
Err(..) => {
510+
debug!(target: "cartridge", "Failed to parse {method} params.");
511+
None
512+
}
513+
}
514+
}
515+
516+
fn parse_execute_outside_params(request: &Request<'_>) -> Option<AddExecuteOutsideParams> {
517+
parse_params::<AddExecuteOutsideRequestParams>(request, "execute outside").map(Into::into)
518+
}
519+
520+
/// Extract estimate_fee parameters from the request.
521+
fn parse_estimate_fee_params(request: &Request<'_>) -> Option<EstimateFeeParams> {
522+
parse_params::<EstimateFeeRequestParams>(request, "estimate fee").map(Into::into)
523+
}
524+
525+
// Temporary shim for --dev.no-fee when deployment txs are prepended for controllers.
526+
// Remove once starknet_estimateFee natively returns zeroed fees in this scenario.
496527
fn build_no_fee_response(request: &Request<'_>, count: usize) -> MethodResponse {
497528
let estimate_fees = vec![
498529
FeeEstimate {

0 commit comments

Comments
 (0)