Skip to content

Commit 6b15e0c

Browse files
committed
Add support for Logger, refactor threads using new StatefulThread type
1 parent 0e89b4e commit 6b15e0c

File tree

2 files changed

+119
-51
lines changed

2 files changed

+119
-51
lines changed

crates/server/src/main.rs

Lines changed: 46 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,14 @@
11
use std::{
22
path::PathBuf,
3-
pin::pin,
43
sync::{Arc, LazyLock},
5-
time::{Duration, Instant},
64
};
75

86
use miniserve::{http::StatusCode, Content, Request, Response};
97
use serde::{Deserialize, Serialize};
10-
use tokio::{
11-
fs, join,
12-
sync::{mpsc, oneshot},
13-
task::JoinSet,
14-
};
8+
use stateful::StatefulThread;
9+
use tokio::{fs, join, task::JoinSet};
10+
11+
mod stateful;
1512

1613
async fn index(_req: Request) -> Response {
1714
let content = include_str!("../index.html").to_string();
@@ -42,55 +39,49 @@ async fn load_docs(paths: Vec<PathBuf>) -> Vec<String> {
4239
docs
4340
}
4441

45-
type Payload = (Arc<Vec<String>>, oneshot::Sender<Option<Vec<String>>>);
46-
47-
fn chatbot_thread() -> (mpsc::Sender<Payload>, mpsc::Sender<()>) {
48-
let (req_tx, mut req_rx) = mpsc::channel::<Payload>(1024);
49-
let (cancel_tx, mut cancel_rx) = mpsc::channel::<()>(1);
50-
tokio::spawn(async move {
51-
let mut chatbot = chatbot::Chatbot::new(vec![":-)".into(), "^^".into()]);
52-
while let Some((messages, responder)) = req_rx.recv().await {
53-
let doc_paths = chatbot.retrieval_documents(&messages);
54-
let docs = load_docs(doc_paths).await;
55-
let mut chat_fut = pin!(chatbot.query_chat(&messages, &docs));
56-
let mut cancel_fut = pin!(cancel_rx.recv());
57-
let start = Instant::now();
58-
loop {
59-
let log_fut = tokio::time::sleep(Duration::from_secs(1));
60-
tokio::select! {
61-
response = &mut chat_fut => {
62-
responder.send(Some(response)).unwrap();
63-
break;
64-
}
65-
_ = &mut cancel_fut => {
66-
responder.send(None).unwrap();
67-
break;
68-
}
69-
_ = log_fut => {
70-
println!("Waiting for {} seconds", start.elapsed().as_secs());
71-
}
72-
}
73-
}
74-
}
75-
});
76-
(req_tx, cancel_tx)
42+
struct LogFunction {
43+
logger: chatbot::Logger,
7744
}
7845

79-
static CHATBOT_THREAD: LazyLock<(mpsc::Sender<Payload>, mpsc::Sender<()>)> =
80-
LazyLock::new(chatbot_thread);
46+
impl stateful::StatefulFunction for LogFunction {
47+
type Input = Arc<Vec<String>>;
48+
type Output = ();
8149

82-
async fn query_chat(messages: &Arc<Vec<String>>) -> Option<Vec<String>> {
83-
let (tx, rx) = oneshot::channel();
84-
CHATBOT_THREAD
85-
.0
86-
.send((Arc::clone(messages), tx))
87-
.await
88-
.unwrap();
89-
rx.await.unwrap()
50+
async fn call(&mut self, messages: Self::Input) -> Self::Output {
51+
self.logger.append(messages.last().unwrap());
52+
self.logger.save().await.unwrap();
53+
}
54+
}
55+
56+
static LOG_THREAD: LazyLock<StatefulThread<LogFunction>> = LazyLock::new(|| {
57+
StatefulThread::new(LogFunction {
58+
logger: chatbot::Logger::default(),
59+
})
60+
});
61+
62+
struct ChatbotFunction {
63+
chatbot: chatbot::Chatbot,
9064
}
9165

66+
impl stateful::StatefulFunction for ChatbotFunction {
67+
type Input = Arc<Vec<String>>;
68+
type Output = Vec<String>;
69+
70+
async fn call(&mut self, messages: Self::Input) -> Self::Output {
71+
let doc_paths = self.chatbot.retrieval_documents(&messages);
72+
let docs = load_docs(doc_paths).await;
73+
self.chatbot.query_chat(&messages, &docs).await
74+
}
75+
}
76+
77+
static CHATBOT_THREAD: LazyLock<StatefulThread<ChatbotFunction>> = LazyLock::new(|| {
78+
StatefulThread::new(ChatbotFunction {
79+
chatbot: chatbot::Chatbot::new(vec![":-)".into(), "^^".into()]),
80+
})
81+
});
82+
9283
async fn cancel(_req: Request) -> Response {
93-
CHATBOT_THREAD.1.send(()).await.unwrap();
84+
CHATBOT_THREAD.cancel().await;
9485
Ok(Content::Html("success".into()))
9586
}
9687

@@ -103,7 +94,11 @@ async fn chat(req: Request) -> Response {
10394
};
10495

10596
let messages = Arc::new(data.messages);
106-
let (i, responses_opt) = join!(chatbot::gen_random_number(), query_chat(&messages));
97+
let (i, responses_opt, _) = join!(
98+
chatbot::gen_random_number(),
99+
CHATBOT_THREAD.call(Arc::clone(&messages)),
100+
LOG_THREAD.call(Arc::clone(&messages))
101+
);
107102

108103
let response = match responses_opt {
109104
Some(mut responses) => {

crates/server/src/stateful.rs

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
use std::{
2+
fmt::Debug,
3+
future::Future,
4+
pin::pin,
5+
time::{Duration, Instant},
6+
};
7+
8+
use tokio::{
9+
sync::{mpsc, oneshot},
10+
task::JoinHandle,
11+
};
12+
13+
pub trait StatefulFunction: Send + 'static {
14+
type Input: Send;
15+
type Output: Send + Debug;
16+
fn call(&mut self, input: Self::Input) -> impl Future<Output = Self::Output> + Send;
17+
}
18+
19+
type Payload<F> = (
20+
<F as StatefulFunction>::Input,
21+
oneshot::Sender<Option<<F as StatefulFunction>::Output>>,
22+
);
23+
24+
pub struct StatefulThread<F: StatefulFunction> {
25+
_handle: JoinHandle<()>,
26+
input_tx: mpsc::Sender<Payload<F>>,
27+
cancel_tx: mpsc::Sender<()>,
28+
}
29+
30+
impl<F: StatefulFunction> StatefulThread<F> {
31+
pub fn new(mut func: F) -> Self {
32+
let (input_tx, mut input_rx) = mpsc::channel::<Payload<F>>(1024);
33+
let (cancel_tx, mut cancel_rx) = mpsc::channel::<()>(1);
34+
let _handle = tokio::spawn(async move {
35+
while let Some((input, responder)) = input_rx.recv().await {
36+
let mut output_fut = pin!(func.call(input));
37+
let mut cancel_fut = pin!(cancel_rx.recv());
38+
let start = Instant::now();
39+
loop {
40+
let log_fut = tokio::time::sleep(Duration::from_secs(1));
41+
tokio::select! {
42+
response = &mut output_fut => {
43+
responder.send(Some(response)).unwrap();
44+
break;
45+
}
46+
_ = &mut cancel_fut => {
47+
responder.send(None).unwrap();
48+
break;
49+
}
50+
_ = log_fut => {
51+
println!("Waiting for {} seconds", start.elapsed().as_secs());
52+
}
53+
}
54+
}
55+
}
56+
});
57+
StatefulThread {
58+
_handle,
59+
input_tx,
60+
cancel_tx,
61+
}
62+
}
63+
64+
pub async fn call(&self, input: F::Input) -> Option<F::Output> {
65+
let (tx, rx) = oneshot::channel();
66+
self.input_tx.send((input, tx)).await.unwrap();
67+
rx.await.unwrap()
68+
}
69+
70+
pub async fn cancel(&self) {
71+
self.cancel_tx.send(()).await.unwrap();
72+
}
73+
}

0 commit comments

Comments
 (0)