Skip to content

Commit c9cba16

Browse files
Subscribe to Multiple Subaccounts at once (#133)
* feat: multiple sub_accounts to sub to * fix: bad comment and method for default subaccount id on AppState * revert: option wrap of def subacct id * combine sub_account_id configs combine default sub_account_id and active_sub_accounts Co-authored-by: jordy25519 <beauchjord@gmail.com> * feat: sub_account_ids deduping * feat: default func for swift node --------- Co-authored-by: jordy25519 <beauchjord@gmail.com>
1 parent 4c96703 commit c9cba16

File tree

4 files changed

+130
-49
lines changed

4 files changed

+130
-49
lines changed

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,8 @@ Options:
157157
--commitment solana commitment level to use for state updates (default:
158158
confirmed)
159159
--default-sub-account-id
160-
default sub_account_id to use (default: 0)
160+
default sub_account_id to use as default. Use the new active-sub-accounts param to subscribe to multiple sub accounts. This param will override active-sub-accounts.
161+
--active-sub-accounts sub accounts to subscribe to. (default: 0)
161162
--skip-tx-preflight
162163
skip tx preflight checks
163164
--extra-rpcs extra solana RPC urls for improved Tx broadcast

src/controller.rs

Lines changed: 65 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -90,16 +90,16 @@ pub struct AppState {
9090
pub client: Arc<DriftClient>,
9191
/// Solana tx commitment level for preflight confirmation
9292
tx_commitment: CommitmentConfig,
93-
/// default sub_account_id to use if not provided
94-
default_subaccount_id: u16,
93+
/// sub_account_ids to subscribe to
94+
sub_account_ids: Vec<u16>,
9595
/// skip tx preflight on send or not (default: false)
9696
skip_tx_preflight: bool,
9797
priority_fee_subscriber: Arc<PriorityFeeSubscriber>,
9898
slot_subscriber: Arc<SlotSubscriber>,
9999
/// list of additional RPC endpoints for tx broadcast
100100
extra_rpcs: Vec<Arc<RpcClient>>,
101101
/// swift node url
102-
swift_node: Option<String>,
102+
swift_node: String,
103103
}
104104

105105
impl AppState {
@@ -111,12 +111,12 @@ impl AppState {
111111
pub fn signer(&self) -> Pubkey {
112112
self.wallet.signer()
113113
}
114-
pub fn default_sub_account(&self) -> Pubkey {
115-
self.wallet.sub_account(self.default_subaccount_id)
114+
pub fn sub_account(&self, sub_account_id: u16) -> Pubkey {
115+
self.wallet.sub_account(sub_account_id)
116116
}
117117
pub fn resolve_sub_account(&self, sub_account_id: Option<u16>) -> Pubkey {
118118
self.wallet
119-
.sub_account(sub_account_id.unwrap_or(self.default_subaccount_id))
119+
.sub_account(sub_account_id.unwrap_or(self.default_sub_account_id()))
120120
}
121121

122122
/// Initialize Gateway Drift client
@@ -125,18 +125,18 @@ impl AppState {
125125
/// * `devnet` - whether to run against devnet or not
126126
/// * `wallet` - wallet to use for tx signing
127127
/// * `commitment` - Slot finalisation/commitement levels
128-
/// * `default_subaccount_id` - by default all queries will use this sub-account
128+
/// * `sub_account_ids` - the sub_accounts to subscribe too. In your query specify a specific subaccount, otherwise subaccount 0 will be used as default
129129
/// * `skip_tx_preflight` - submit txs without checking preflight results
130130
/// * `extra_rpcs` - list of additional RPC endpoints for tx submission
131131
pub async fn new(
132132
endpoint: &str,
133133
devnet: bool,
134134
wallet: Wallet,
135135
commitment: Option<(CommitmentConfig, CommitmentConfig)>,
136-
default_subaccount_id: Option<u16>,
136+
sub_account_ids: Vec<u16>,
137137
skip_tx_preflight: bool,
138138
extra_rpcs: Vec<&str>,
139-
swift_node: Option<String>,
139+
swift_node: String,
140140
) -> Self {
141141
let (state_commitment, tx_commitment) =
142142
commitment.unwrap_or((CommitmentConfig::confirmed(), CommitmentConfig::confirmed()));
@@ -151,11 +151,13 @@ impl AppState {
151151
.await
152152
.expect("ok");
153153

154-
let default_subaccount = wallet.sub_account(default_subaccount_id.unwrap_or(0));
155-
if let Err(err) = client.subscribe_account(&default_subaccount).await {
156-
log::error!(target: LOG_TARGET, "couldn't subscribe to user updates: {err:?}");
157-
} else {
158-
log::info!(target: LOG_TARGET, "subscribed to subaccount: {default_subaccount}");
154+
for sub_account_id in &sub_account_ids {
155+
let sub_account = wallet.sub_account(*sub_account_id);
156+
if let Err(err) = client.subscribe_account(&sub_account).await {
157+
log::error!(target: LOG_TARGET, "couldn't subscribe to user updates: {err:?}. subaccount: {sub_account_id}");
158+
} else {
159+
log::info!(target: LOG_TARGET, "subscribed to subaccount: {sub_account}");
160+
}
159161
}
160162

161163
let priority_fee_subscriber = PriorityFeeSubscriber::with_config(
@@ -190,7 +192,7 @@ impl AppState {
190192
Self {
191193
client: Arc::new(client),
192194
tx_commitment,
193-
default_subaccount_id: default_subaccount_id.unwrap_or(0),
195+
sub_account_ids,
194196
skip_tx_preflight,
195197
priority_fee_subscriber,
196198
slot_subscriber: Arc::new(slot_subscriber),
@@ -207,11 +209,26 @@ impl AppState {
207209
&self,
208210
configured_markets: &[MarketId],
209211
) -> Result<(), SdkError> {
210-
let default_sub_account = self.default_sub_account();
212+
let sub_account_ids = self.sub_account_ids.clone();
213+
for id in sub_account_ids {
214+
self.sync_market_subscriptions_on_user_subaccount_changes(configured_markets, id)
215+
.await?;
216+
}
217+
218+
Ok(())
219+
}
220+
221+
async fn sync_market_subscriptions_on_user_subaccount_changes(
222+
&self,
223+
configured_markets: &[MarketId],
224+
sub_account_id: u16,
225+
) -> Result<(), SdkError> {
226+
let sub_account = self.sub_account(sub_account_id);
211227
let state_commitment = self.tx_commitment;
212228
let configured_markets_vec = configured_markets.to_vec();
213229
let self_clone = self.clone();
214-
let mut current_user_markets_to_subscribe = self.get_marketids_to_subscribe().await?;
230+
let mut current_user_markets_to_subscribe =
231+
self.get_marketids_to_subscribe(sub_account).await?;
215232

216233
tokio::spawn(async move {
217234
let pubsub_config = RpcAccountInfoConfig {
@@ -224,7 +241,7 @@ impl AppState {
224241
let pubsub_client = self_clone.client.ws();
225242

226243
let (mut account_subscription, unsubscribe_fn) = match pubsub_client
227-
.account_subscribe(&default_sub_account, Some(pubsub_config))
244+
.account_subscribe(&sub_account, Some(pubsub_config))
228245
.await
229246
{
230247
Ok(res) => res,
@@ -239,7 +256,7 @@ impl AppState {
239256
// Process incoming account updates
240257
while let Some(_) = account_subscription.next().await {
241258
let current_market_ids_count = current_user_markets_to_subscribe.len();
242-
match self_clone.get_marketids_to_subscribe().await {
259+
match self_clone.get_marketids_to_subscribe(sub_account).await {
243260
Ok(new_market_ids) => {
244261
if new_market_ids.len() != current_market_ids_count {
245262
if let Err(err) = self_clone
@@ -267,13 +284,13 @@ impl AppState {
267284
Ok(())
268285
}
269286

270-
async fn get_marketids_to_subscribe(&self) -> Result<Vec<MarketId>, SdkError> {
271-
let (all_spot, all_perp) = self
272-
.client
273-
.all_positions(&self.default_sub_account())
274-
.await?;
287+
async fn get_marketids_to_subscribe(
288+
&self,
289+
sub_account: Pubkey,
290+
) -> Result<Vec<MarketId>, SdkError> {
291+
let (all_spot, all_perp) = self.client.all_positions(&sub_account).await?;
275292

276-
let open_orders = self.client.all_orders(&self.default_sub_account()).await?;
293+
let open_orders = self.client.all_orders(&sub_account).await?;
277294

278295
let user_markets: Vec<MarketId> = all_spot
279296
.iter()
@@ -296,11 +313,25 @@ impl AppState {
296313
/// * configured_markets - list of static markets provided by user
297314
///
298315
/// additional subscriptions will be included based on user's current positions (on default sub-account)
316+
299317
pub(crate) async fn subscribe_market_data(
300318
&self,
301319
configured_markets: &[MarketId],
302320
) -> Result<(), SdkError> {
303-
let mut user_markets = self.get_marketids_to_subscribe().await?;
321+
for id in self.sub_account_ids.clone() {
322+
self.subscribe_market_data_for_subaccount(configured_markets, id)
323+
.await?;
324+
}
325+
Ok(())
326+
}
327+
328+
async fn subscribe_market_data_for_subaccount(
329+
&self,
330+
configured_markets: &[MarketId],
331+
sub_account_id: u16,
332+
) -> Result<(), SdkError> {
333+
let sub_account = self.sub_account(sub_account_id);
334+
let mut user_markets = self.get_marketids_to_subscribe(sub_account).await?;
304335
user_markets.extend_from_slice(configured_markets);
305336

306337
let init_rpc_throttle: u64 = std::env::var("INIT_RPC_THROTTLE")
@@ -636,7 +667,7 @@ impl AppState {
636667
let orders_len = orders_iter.len();
637668
let mut signed_messages = Vec::with_capacity(orders_len);
638669
let mut hashes: Vec<String> = Vec::with_capacity(orders_len);
639-
let sub_account_id = ctx.sub_account_id.unwrap_or(self.default_subaccount_id);
670+
let sub_account_id = ctx.sub_account_id.unwrap_or(self.default_sub_account_id());
640671
let current_slot = self.slot_subscriber.current_slot();
641672
let orders_with_hex: Vec<(OrderParams, Vec<u8>)> = orders_iter
642673
.map(|order| {
@@ -663,7 +694,7 @@ impl AppState {
663694
};
664695
let incoming_msg = IncomingSignedMessage {
665696
taker_authority: self.authority().to_string(),
666-
signature: general_purpose::STANDARD.encode(signature),
697+
signature: general_purpose::STANDARD.encode(signature), // TODO: test just using .to_string() for base64 encoding
667698
message: String::from_utf8(message).unwrap(),
668699
signing_authority: self.signer().to_string(),
669700
market_type,
@@ -677,11 +708,7 @@ impl AppState {
677708

678709
let client = reqwest::Client::new();
679710

680-
let swift_orders_url = self
681-
.swift_node
682-
.clone()
683-
.unwrap_or("https://master.swift.drift.trade".to_string())
684-
+ "/orders";
711+
let swift_orders_url = self.swift_node.clone() + "/orders";
685712

686713
let mut futures = FuturesOrdered::new();
687714
for msg in signed_messages {
@@ -900,7 +927,7 @@ impl AppState {
900927
ctx: Context,
901928
new_margin_ratio: Decimal,
902929
) -> GatewayResult<TxResponse> {
903-
let sub_account_id = ctx.sub_account_id.unwrap_or(self.default_subaccount_id);
930+
let sub_account_id = ctx.sub_account_id.unwrap_or(self.default_sub_account_id());
904931
let sub_account_address = self.wallet.sub_account(sub_account_id);
905932
let account_data = self.client.get_user_account(&sub_account_address).await?;
906933

@@ -921,6 +948,10 @@ impl AppState {
921948
self.send_tx(tx, "set_margin_ratio", ctx.ttl).await
922949
}
923950

951+
pub fn default_sub_account_id(&self) -> u16 {
952+
self.sub_account_ids[0]
953+
}
954+
924955
fn get_priority_fee(&self) -> u64 {
925956
self.priority_fee_subscriber.priority_fee_nth(0.9)
926957
}

src/main.rs

Lines changed: 62 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -277,12 +277,22 @@ async fn main() -> std::io::Result<()> {
277277
let tx_commitment = CommitmentConfig::from_str(&config.tx_commitment)
278278
.expect("one of: processed | confirmed | finalized");
279279
let extra_rpcs = config.extra_rpcs.as_ref();
280+
281+
let mut sub_account_ids = vec![config.default_sub_account_id];
282+
sub_account_ids.extend(
283+
config
284+
.active_sub_accounts
285+
.split(",")
286+
.map(|s| s.parse::<u16>().unwrap()),
287+
);
288+
sub_account_ids.dedup();
289+
280290
let state = AppState::new(
281291
&config.rpc_host,
282292
config.dev,
283293
wallet,
284294
Some((state_commitment, tx_commitment)),
285-
Some(config.default_sub_account_id),
295+
sub_account_ids.clone(),
286296
config.skip_tx_preflight,
287297
extra_rpcs
288298
.map(|s| s.split(",").collect())
@@ -330,17 +340,17 @@ async fn main() -> std::io::Result<()> {
330340
if delegate.is_some() {
331341
info!(
332342
target: LOG_TARGET,
333-
"🪪 authority: {:?}, default sub-account: {:?}, 🔑 delegate: {:?}",
343+
"🪪 authority: {:?}, sub-accounts: {:?}, 🔑 delegate: {:?}",
334344
state.authority(),
335-
state.default_sub_account(),
345+
sub_account_ids,
336346
state.signer(),
337347
);
338348
} else {
339349
info!(
340350
target: LOG_TARGET,
341-
"🪪 authority: {:?}, default sub-account: {:?}",
351+
"🪪 authority: {:?}, sub-accounts: {:?}",
342352
state.authority(),
343-
state.default_sub_account()
353+
sub_account_ids
344354
);
345355
if emulate.is_some() {
346356
warn!("using emulation mode, tx signing unavailable");
@@ -455,6 +465,22 @@ fn handle_deser_error<T>(err: serde_json::Error) -> Either<HttpResponse, Json<T>
455465
)))
456466
}
457467

468+
fn default_swift_node() -> String {
469+
let strings: Vec<String> = std::env::args_os()
470+
.map(|s| s.into_string())
471+
.collect::<Result<Vec<_>, _>>()
472+
.unwrap_or_else(|arg| {
473+
eprintln!("Invalid utf8: {}", arg.to_string_lossy());
474+
std::process::exit(1)
475+
});
476+
let is_dev = strings.iter().any(|s| s.to_string() == "--dev".to_string());
477+
if is_dev {
478+
"https://master.swift.drift.trade".to_string()
479+
} else {
480+
"https://swift.drift.trade".to_string()
481+
}
482+
}
483+
458484
#[derive(FromArgs)]
459485
/// Drift gateway server
460486
struct GatewayConfig {
@@ -467,8 +493,8 @@ struct GatewayConfig {
467493
#[argh(option)]
468494
markets: Option<String>,
469495
/// swift node url
470-
#[argh(option)]
471-
swift_node: Option<String>,
496+
#[argh(option, default = "default_swift_node()")]
497+
swift_node: String,
472498
/// run in devnet mode
473499
#[argh(switch)]
474500
dev: bool,
@@ -501,6 +527,9 @@ struct GatewayConfig {
501527
/// default sub_account_id to use (default: 0)
502528
#[argh(option, default = "0")]
503529
default_sub_account_id: u16,
530+
/// list of active sub_account_ids to use (default: 0)
531+
#[argh(option, default = "String::from(\"0\")")]
532+
active_sub_accounts: String,
504533
/// skip tx preflight checks
505534
#[argh(switch)]
506535
skip_tx_preflight: bool,
@@ -552,7 +581,17 @@ mod tests {
552581
};
553582
let rpc_endpoint = std::env::var("TEST_RPC_ENDPOINT")
554583
.unwrap_or_else(|_| "https://api.devnet.solana.com".to_string());
555-
AppState::new(&rpc_endpoint, true, wallet, None, None, false, vec![], None).await
584+
AppState::new(
585+
&rpc_endpoint,
586+
true,
587+
wallet,
588+
None,
589+
vec![0],
590+
false,
591+
vec![],
592+
"https://master.swift.drift.trade".to_string(),
593+
)
594+
.await
556595
}
557596

558597
// likely safe to ignore during development, mainly regression test for CI
@@ -573,8 +612,17 @@ mod tests {
573612

574613
let rpc_endpoint = std::env::var("TEST_MAINNET_RPC_ENDPOINT")
575614
.unwrap_or_else(|_| "https://api.mainnet-beta.solana.com".to_string());
576-
let state =
577-
AppState::new(&rpc_endpoint, true, wallet, None, None, false, vec![], None).await;
615+
let state = AppState::new(
616+
&rpc_endpoint,
617+
true,
618+
wallet,
619+
None,
620+
vec![],
621+
false,
622+
vec![],
623+
"https://master.swift.drift.trade".to_string(),
624+
)
625+
.await;
578626

579627
let app = test::init_service(
580628
App::new()
@@ -617,10 +665,10 @@ mod tests {
617665
false,
618666
wallet,
619667
None,
620-
None,
668+
vec![],
621669
false,
622670
vec![],
623-
None,
671+
"https://master.swift.drift.trade".to_string(),
624672
)
625673
.await;
626674

@@ -660,10 +708,10 @@ mod tests {
660708
false,
661709
wallet,
662710
None,
663-
None,
711+
vec![],
664712
false,
665713
vec![],
666-
None,
714+
"https://master.swift.drift.trade".to_string(),
667715
)
668716
.await;
669717

0 commit comments

Comments
 (0)