Skip to content

Commit ddc248e

Browse files
authored
Merge pull request #145 from at-microcosm/relay-get-host-status
Add `com.atproto.sync.getHostStatus` xrpc endpoint
2 parents f84a597 + 2b3f62f commit ddc248e

File tree

2 files changed

+74
-27
lines changed

2 files changed

+74
-27
lines changed

rsky-relay/src/server/server.rs

Lines changed: 67 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ use color_eyre::eyre::eyre;
1313
use httparse::{EMPTY_HEADER, Status};
1414
#[cfg(not(feature = "labeler"))]
1515
use rusqlite::named_params;
16-
use rusqlite::{Connection, OpenFlags};
16+
use rusqlite::{Connection, OpenFlags, OptionalExtension};
1717
use rustls::{ServerConfig, ServerConnection, StreamOwned};
1818
use thiserror::Error;
1919
use url::Url;
@@ -25,13 +25,16 @@ use crate::config::{HOSTS_MIN_ACCOUNTS, HOSTS_RELAY};
2525
use crate::crawler::{RequestCrawl, RequestCrawlSender};
2626
use crate::publisher::{MaybeTlsStream, SubscribeRepos, SubscribeReposSender};
2727
#[cfg(not(feature = "labeler"))]
28-
use crate::server::types::{Host, HostStatus, ListHosts};
28+
use crate::server::types::{GetHostStatus, Host, HostStatus, ListHosts};
2929

3030
const SLEEP: Duration = Duration::from_millis(10);
3131

3232
#[cfg(not(feature = "labeler"))]
3333
const PATH_LIST_HOSTS: &str = "/xrpc/com.atproto.sync.listHosts";
3434

35+
#[cfg(not(feature = "labeler"))]
36+
const PATH_HOST_STATUS: &str = "/xrpc/com.atproto.sync.getHostStatus";
37+
3538
const PATH_SUBSCRIBE: &str = if cfg!(feature = "labeler") {
3639
"/xrpc/com.atproto.label.subscribeLabels"
3740
} else {
@@ -215,23 +218,25 @@ impl Server {
215218
let method = parser.method.ok_or_else(|| eyre!("method missing"))?;
216219
let path = parser.path.ok_or_else(|| eyre!("path missing"))?;
217220
let url = Url::options().base_url(Some(&self.base_url)).parse(path)?;
221+
222+
let body_response = |status, body: &str| -> Vec<u8> {
223+
format!(
224+
"HTTP/1.1 {status}\r\n\
225+
Content-Type: text/plain; charset=utf-8\r\n\
226+
Content-Length: {}\r\n\
227+
Connection: close\r\n\
228+
\r\n\
229+
{body}",
230+
body.len()
231+
)
232+
.into()
233+
};
234+
218235
match (method, url.path()) {
219236
("GET", "/") => {
220-
let body = INDEX_ASCII;
221-
let response = format!(
222-
"HTTP/1.1 200 OK\r\n\
223-
Content-Type: text/plain; charset=utf-8\r\n\
224-
Content-Length: {}\r\n\
225-
Connection: close\r\n\
226-
\r\n\
227-
{}",
228-
body.len(),
229-
body
230-
);
231-
232237
#[expect(clippy::unwrap_used)]
233238
let mut stream = stream.0.take().unwrap();
234-
stream.write_all(response.as_bytes())?;
239+
stream.write_all(&body_response("200 OK", INDEX_ASCII))?;
235240
stream.flush()?;
236241
stream.shutdown()?;
237242
Ok(())
@@ -249,21 +254,29 @@ impl Server {
249254
}
250255
};
251256

252-
let response = format!(
253-
"HTTP/1.1 {}\r\n\
254-
Content-Type: application/json; charset=utf-8\r\n\
255-
Content-Length: {}\r\n\
256-
Connection: close\r\n\
257-
\r\n\
258-
{}",
259-
status,
260-
body.len(),
261-
body
262-
);
257+
#[expect(clippy::unwrap_used)]
258+
let mut stream = stream.0.take().unwrap();
259+
stream.write_all(&body_response(status, &body))?;
260+
stream.flush()?;
261+
stream.shutdown()?;
262+
Ok(())
263+
}
264+
#[cfg(not(feature = "labeler"))]
265+
("GET", PATH_HOST_STATUS) => {
266+
let (status, body) = match self.host_status(&url) {
267+
Ok(hosts) => ("200 OK", serde_json::to_string(&hosts)?),
268+
Err(e) => {
269+
let error = serde_json::json!({
270+
"error": "BadRequest",
271+
"message": e.to_string(),
272+
});
273+
("400 Bad Request", serde_json::to_string(&error)?)
274+
}
275+
};
263276

264277
#[expect(clippy::unwrap_used)]
265278
let mut stream = stream.0.take().unwrap();
266-
stream.write_all(response.as_bytes())?;
279+
stream.write_all(&body_response(status, &body))?;
267280
stream.flush()?;
268281
stream.shutdown()?;
269282
Ok(())
@@ -362,6 +375,33 @@ impl Server {
362375
Ok(ListHosts { cursor, hosts })
363376
}
364377

378+
#[cfg(not(feature = "labeler"))]
379+
fn host_status(&mut self, url: &Url) -> Result<GetHostStatus> {
380+
let mut hostname = None;
381+
for (key, value) in url.query_pairs() {
382+
match key.as_ref() {
383+
"hostname" => hostname = Some(value.to_string()),
384+
// Ignore unknown query parameters.
385+
_ => (),
386+
}
387+
}
388+
let hostname = hostname.ok_or(eyre!("hostname param is required"))?;
389+
390+
Ok(self
391+
.relay_conn
392+
.prepare_cached("SELECT cursor FROM hosts WHERE host = :host")?
393+
.query_one(named_params! { ":host": hostname.clone() }, |row| {
394+
Ok(GetHostStatus {
395+
hostname: hostname.clone(),
396+
seq: row.get("cursor")?,
397+
// TODO: Track status of hosts.
398+
status: HostStatus::Active,
399+
})
400+
})
401+
.optional()?
402+
.ok_or(eyre!("hostname {hostname:?} not found"))?)
403+
}
404+
365405
#[cfg(not(feature = "labeler"))]
366406
fn query_hosts(&mut self) -> Result<()> {
367407
let client = reqwest::blocking::Client::builder()

rsky-relay/src/server/types.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,10 @@ pub enum HostStatus {
2424
Throttled,
2525
Banned,
2626
}
27+
28+
#[derive(Debug, Serialize, Deserialize)]
29+
pub struct GetHostStatus {
30+
pub hostname: String,
31+
pub seq: u64,
32+
pub status: HostStatus,
33+
}

0 commit comments

Comments
 (0)