11use std:: {
22 path:: PathBuf ,
3- pin:: pin,
43 sync:: { Arc , LazyLock } ,
5- time:: { Duration , Instant } ,
64} ;
75
86use miniserve:: { http:: StatusCode , Content , Request , Response } ;
97use serde:: { Deserialize , Serialize } ;
10- use tokio:: {
11- fs, join,
12- sync:: { mpsc, oneshot} ,
13- task:: JoinSet ,
14- } ;
8+ use stateful:: StatefulThread ;
9+ use tokio:: { fs, join, task:: JoinSet } ;
10+
11+ mod stateful;
1512
1613async fn index ( _req : Request ) -> Response {
1714 let content = include_str ! ( "../index.html" ) . to_string ( ) ;
@@ -42,55 +39,49 @@ async fn load_docs(paths: Vec<PathBuf>) -> Vec<String> {
4239 docs
4340}
4441
45- type Payload = ( Arc < Vec < String > > , oneshot:: Sender < Option < Vec < String > > > ) ;
46-
47- fn chatbot_thread ( ) -> ( mpsc:: Sender < Payload > , mpsc:: Sender < ( ) > ) {
48- let ( req_tx, mut req_rx) = mpsc:: channel :: < Payload > ( 1024 ) ;
49- let ( cancel_tx, mut cancel_rx) = mpsc:: channel :: < ( ) > ( 1 ) ;
50- tokio:: spawn ( async move {
51- let mut chatbot = chatbot:: Chatbot :: new ( vec ! [ ":-)" . into( ) , "^^" . into( ) ] ) ;
52- while let Some ( ( messages, responder) ) = req_rx. recv ( ) . await {
53- let doc_paths = chatbot. retrieval_documents ( & messages) ;
54- let docs = load_docs ( doc_paths) . await ;
55- let mut chat_fut = pin ! ( chatbot. query_chat( & messages, & docs) ) ;
56- let mut cancel_fut = pin ! ( cancel_rx. recv( ) ) ;
57- let start = Instant :: now ( ) ;
58- loop {
59- let log_fut = tokio:: time:: sleep ( Duration :: from_secs ( 1 ) ) ;
60- tokio:: select! {
61- response = & mut chat_fut => {
62- responder. send( Some ( response) ) . unwrap( ) ;
63- break ;
64- }
65- _ = & mut cancel_fut => {
66- responder. send( None ) . unwrap( ) ;
67- break ;
68- }
69- _ = log_fut => {
70- println!( "Waiting for {} seconds" , start. elapsed( ) . as_secs( ) ) ;
71- }
72- }
73- }
74- }
75- } ) ;
76- ( req_tx, cancel_tx)
42+ struct LogFunction {
43+ logger : chatbot:: Logger ,
7744}
7845
79- static CHATBOT_THREAD : LazyLock < ( mpsc:: Sender < Payload > , mpsc:: Sender < ( ) > ) > =
80- LazyLock :: new ( chatbot_thread) ;
46+ impl stateful:: StatefulFunction for LogFunction {
47+ type Input = Arc < Vec < String > > ;
48+ type Output = ( ) ;
8149
82- async fn query_chat ( messages : & Arc < Vec < String > > ) -> Option < Vec < String > > {
83- let ( tx, rx) = oneshot:: channel ( ) ;
84- CHATBOT_THREAD
85- . 0
86- . send ( ( Arc :: clone ( messages) , tx) )
87- . await
88- . unwrap ( ) ;
89- rx. await . unwrap ( )
50+ async fn call ( & mut self , messages : Self :: Input ) -> Self :: Output {
51+ self . logger . append ( messages. last ( ) . unwrap ( ) ) ;
52+ self . logger . save ( ) . await . unwrap ( ) ;
53+ }
54+ }
55+
56+ static LOG_THREAD : LazyLock < StatefulThread < LogFunction > > = LazyLock :: new ( || {
57+ StatefulThread :: new ( LogFunction {
58+ logger : chatbot:: Logger :: default ( ) ,
59+ } )
60+ } ) ;
61+
62+ struct ChatbotFunction {
63+ chatbot : chatbot:: Chatbot ,
9064}
9165
66+ impl stateful:: StatefulFunction for ChatbotFunction {
67+ type Input = Arc < Vec < String > > ;
68+ type Output = Vec < String > ;
69+
70+ async fn call ( & mut self , messages : Self :: Input ) -> Self :: Output {
71+ let doc_paths = self . chatbot . retrieval_documents ( & messages) ;
72+ let docs = load_docs ( doc_paths) . await ;
73+ self . chatbot . query_chat ( & messages, & docs) . await
74+ }
75+ }
76+
77+ static CHATBOT_THREAD : LazyLock < StatefulThread < ChatbotFunction > > = LazyLock :: new ( || {
78+ StatefulThread :: new ( ChatbotFunction {
79+ chatbot : chatbot:: Chatbot :: new ( vec ! [ ":-)" . into( ) , "^^" . into( ) ] ) ,
80+ } )
81+ } ) ;
82+
9283async fn cancel ( _req : Request ) -> Response {
93- CHATBOT_THREAD . 1 . send ( ( ) ) . await . unwrap ( ) ;
84+ CHATBOT_THREAD . cancel ( ) . await ;
9485 Ok ( Content :: Html ( "success" . into ( ) ) )
9586}
9687
@@ -103,7 +94,11 @@ async fn chat(req: Request) -> Response {
10394 } ;
10495
10596 let messages = Arc :: new ( data. messages ) ;
106- let ( i, responses_opt) = join ! ( chatbot:: gen_random_number( ) , query_chat( & messages) ) ;
97+ let ( i, responses_opt, _) = join ! (
98+ chatbot:: gen_random_number( ) ,
99+ CHATBOT_THREAD . call( Arc :: clone( & messages) ) ,
100+ LOG_THREAD . call( Arc :: clone( & messages) )
101+ ) ;
107102
108103 let response = match responses_opt {
109104 Some ( mut responses) => {
0 commit comments