Skip to content

Commit 65c981f

Browse files
committed
Add restore subscription/connections
1 parent 4d60dde commit 65c981f

File tree

7 files changed

+109
-26
lines changed

7 files changed

+109
-26
lines changed

dev/dev.sql

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ ALTER TABLE connections RENAME COLUMN user_id TO subscription_id;
158158

159159
CREATE TABLE subscriptions (
160160
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
161-
expires_at TIMESTAMP WITH TIME ZONE NOT NULL,
161+
expires_at TIMESTAMP WITH TIME ZONE ,
162162
referred_by CHAR(13),
163163

164164
created_at TIMESTAMP WITH TIME ZONE DEFAULT now(),

src/bin/api/core/http/handlers/sub.rs

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -545,11 +545,6 @@ where
545545
"Subscription {} are expired",
546546
sub_param.id
547547
))));
548-
} else {
549-
return Ok(Box::new(http::not_found(&format!(
550-
"Subscription {} are not found",
551-
sub_param.id
552-
))));
553548
}
554549
}
555550

src/bin/api/core/tasks.rs

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use chrono::NaiveTime;
33
use chrono::TimeZone;
44
use chrono::Utc;
55
use futures::future::join_all;
6+
use pony::http::requests::ConnUpdateRequest;
67
use rand::Rng;
78
use std::collections::HashMap;
89
use std::time::Duration;
@@ -37,6 +38,7 @@ pub trait Tasks {
3738
async fn collect_conn_stat(&self) -> Result<()>;
3839
async fn cleanup_expired_connections(&self, interval_sec: u64, publisher: ZmqPublisher);
3940
async fn cleanup_expired_subscriptions(&self, interval_sec: u64, publisher: ZmqPublisher);
41+
async fn restore_subscriptions(&self, interval_sec: u64, publisher: ZmqPublisher);
4042
}
4143

4244
#[async_trait]
@@ -168,6 +170,77 @@ impl Tasks for Api<HashMap<String, Vec<Node>>, Connection, Subscription> {
168170
}
169171
}
170172

173+
async fn restore_subscriptions(&self, interval_sec: u64, publisher: ZmqPublisher) {
174+
let mut interval = tokio::time::interval(Duration::from_secs(interval_sec));
175+
176+
loop {
177+
interval.tick().await;
178+
log::debug!("Run restore subscriptions task");
179+
180+
let expired_subs: Vec<uuid::Uuid> = {
181+
let mem = self.sync.memory.read().await;
182+
mem.subscriptions
183+
.iter()
184+
.filter_map(|(id, sub)| if sub.is_active() { Some(*id) } else { None })
185+
.collect()
186+
};
187+
188+
for sub_id in expired_subs {
189+
let conns_to_restore: Vec<(uuid::Uuid, Connection)> = {
190+
let mem = self.sync.memory.read().await;
191+
mem.connections
192+
.get_by_subscription_id(&sub_id)
193+
.map(|conns| {
194+
conns
195+
.iter()
196+
.filter(|(_id, c)| c.get_deleted())
197+
.filter_map(|(id, c)| Some((*id, c.clone().into())))
198+
.collect()
199+
})
200+
.unwrap_or_default()
201+
};
202+
203+
for (conn_id, conn) in conns_to_restore {
204+
let msg = conn.as_update_message(&conn_id);
205+
if let Ok(bytes) = rkyv::to_bytes::<_, 1024>(&msg) {
206+
let key = conn
207+
.node_id
208+
.map(|id| id.to_string())
209+
.unwrap_or_else(|| conn.get_env());
210+
let _ = publisher.send_binary(&key, bytes.as_ref()).await;
211+
}
212+
213+
let conn_upd = ConnUpdateRequest {
214+
env: Some(conn.get_env()),
215+
is_deleted: Some(false),
216+
password: conn.get_password(),
217+
days: None,
218+
};
219+
220+
match SyncOp::update_conn(&self.sync, &conn_id, conn_upd).await {
221+
Ok(StorageOperationStatus::Updated(_)) => {
222+
log::info!("Expired connection {} restored", conn_id);
223+
}
224+
Ok(status) => {
225+
log::warn!(
226+
"Connection {} could not be restored: {:?}",
227+
conn_id,
228+
status
229+
);
230+
}
231+
Err(e) => {
232+
log::error!(
233+
"Failed to restore expired connection {}: {:?}",
234+
conn_id,
235+
e
236+
);
237+
}
238+
}
239+
}
240+
}
241+
}
242+
}
243+
171244
async fn periodic_db_sync(&self, interval_sec: u64) {
172245
let base = Duration::from_secs(interval_sec);
173246

src/bin/api/main.rs

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ async fn main() -> Result<()> {
175175
let _ = tokio::spawn({
176176
let api = api.clone();
177177
let publisher = publisher.clone();
178-
let job_interval = Duration::from_secs(60);
178+
let job_interval = Duration::from_secs(settings.api.subscription_expire_interval);
179179
log::info!("cleanup_expired_subscriptions task started");
180180

181181
async move {
@@ -184,6 +184,18 @@ async fn main() -> Result<()> {
184184
}
185185
});
186186

187+
let _ = tokio::spawn({
188+
let api = api.clone();
189+
let publisher = publisher.clone();
190+
let job_interval = Duration::from_secs(settings.api.subscription_restore_interval);
191+
log::info!("restore_subscriptions task started");
192+
193+
async move {
194+
api.restore_subscriptions(job_interval.as_secs(), publisher)
195+
.await;
196+
}
197+
});
198+
187199
let api = api.clone();
188200
let api_handle = tokio::spawn(async move {
189201
if let Err(e) = api.run(settings.api.hostname).await {

src/config/settings.rs

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -80,12 +80,6 @@ fn default_api_token() -> String {
8080
fn default_label() -> String {
8181
"🏴‍☠️🏴‍☠️🏴‍☠️ dev".to_string()
8282
}
83-
fn default_node_healthcheck_timeout() -> i16 {
84-
60
85-
}
86-
fn default_conn_limit_check_interval() -> u64 {
87-
60
88-
}
8983

9084
fn default_stat_job_interval() -> u64 {
9185
60
@@ -117,12 +111,24 @@ fn default_db_sync_interval_sec() -> u64 {
117111
300
118112
}
119113

114+
fn default_subscription_restore_interval_sec() -> u64 {
115+
60
116+
}
117+
118+
fn default_subscription_expire_interval_sec() -> u64 {
119+
60
120+
}
121+
122+
fn default_node_healthcheck_timeout() -> i16 {
123+
60
124+
}
125+
120126
fn default_max_bandwidth_bps() -> i64 {
121127
100_000_000
122128
}
123129

124130
fn default_hostname() -> String {
125-
"http://localhost".to_string()
131+
"http://localhost:5005".to_string()
126132
}
127133

128134
#[derive(Clone, Debug, Deserialize, Default)]
@@ -131,10 +137,6 @@ pub struct ApiServiceConfig {
131137
pub address: Option<Ipv4Addr>,
132138
#[serde(default = "default_api_web_port")]
133139
pub port: u16,
134-
#[serde(default = "default_node_healthcheck_timeout")]
135-
pub node_health_check_timeout: i16,
136-
#[serde(default = "default_conn_limit_check_interval")]
137-
pub conn_limit_check_interval: u64,
138140
#[serde(default = "default_collect_conn_stat_interval")]
139141
pub collect_conn_stat_interval: u64,
140142
#[serde(default = "default_healthcheck_interval")]
@@ -151,6 +153,12 @@ pub struct ApiServiceConfig {
151153
pub db_sync_interval_sec: u64,
152154
#[serde(default = "default_hostname")]
153155
pub hostname: String,
156+
#[serde(default = "default_subscription_restore_interval_sec")]
157+
pub subscription_restore_interval: u64,
158+
#[serde(default = "default_subscription_expire_interval_sec")]
159+
pub subscription_expire_interval: u64,
160+
#[serde(default = "default_node_healthcheck_timeout")]
161+
pub node_health_check_timeout: i16,
154162
}
155163

156164
#[derive(Clone, Debug, Deserialize, Default)]

src/memory/storage/subscription.rs

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,6 @@ where
2222
{
2323
fn count_invited_by(&self, referral_code: &str) -> usize {
2424
self.values()
25-
.inspect(|s| {
26-
log::debug!(
27-
"cmp: referred_by={:?} target={:?}",
28-
s.referred_by(),
29-
referral_code
30-
);
31-
})
3225
.filter(|s| s.referred_by().as_deref() == Some(referral_code))
3326
.count()
3427
}

src/xray_op/client.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,9 @@ impl HandlerActions for Arc<Mutex<HandlerClient>> {
145145
"Remove SS user error, password not provided".to_string(),
146146
))
147147
}
148-
Tag::Wireguard => todo!(),
148+
Tag::Wireguard => Err(crate::PonyError::Custom(
149+
"Removing Wireguard is not implemented".to_string(),
150+
)),
149151
}
150152
}
151153
}

0 commit comments

Comments
 (0)