@@ -6,15 +6,26 @@ use async_std::io::Read as AsyncRead;
66use async_std:: prelude:: * ;
77use async_std:: task:: { ready, Context , Poll } ;
88use std:: pin:: Pin ;
9+ use std:: sync:: atomic:: { AtomicBool , Ordering } ;
10+ use std:: sync:: Arc ;
911
10- pin_project_lite:: pin_project! {
11- /// An SSE protocol encoder.
12- #[ derive( Debug ) ]
13- pub struct Encoder {
14- buf: Option <Vec <u8 >>,
15- #[ pin]
16- receiver: sync:: Receiver <Vec <u8 >>,
17- cursor: usize ,
12+ use pin_project:: { pin_project, pinned_drop} ;
13+
14+ #[ pin_project( PinnedDrop ) ]
15+ /// An SSE protocol encoder.
16+ #[ derive( Debug ) ]
17+ pub struct Encoder {
18+ buf : Option < Vec < u8 > > ,
19+ #[ pin]
20+ receiver : sync:: Receiver < Vec < u8 > > ,
21+ cursor : usize ,
22+ disconnected : Arc < AtomicBool > ,
23+ }
24+
25+ #[ pinned_drop]
26+ impl PinnedDrop for Encoder {
27+ fn drop ( self : Pin < & mut Self > ) {
28+ self . disconnected . store ( true , Ordering :: Relaxed ) ;
1829 }
1930}
2031
@@ -79,53 +90,80 @@ impl AsyncRead for Encoder {
7990// }
8091
8192/// The sending side of the encoder.
82- #[ derive( Debug ) ]
83- pub struct Sender ( sync:: Sender < Vec < u8 > > ) ;
93+ #[ derive( Debug , Clone ) ]
94+ pub struct Sender {
95+ sender : sync:: Sender < Vec < u8 > > ,
96+ disconnected : Arc < std:: sync:: atomic:: AtomicBool > ,
97+ }
8498
8599/// Create a new SSE encoder.
86100pub fn encode ( ) -> ( Sender , Encoder ) {
87101 let ( sender, receiver) = sync:: channel ( 1 ) ;
102+ let disconnected = Arc :: new ( AtomicBool :: new ( false ) ) ;
103+
88104 let encoder = Encoder {
89105 receiver,
90106 buf : None ,
91107 cursor : 0 ,
108+ disconnected : disconnected. clone ( ) ,
109+ } ;
110+
111+ let sender = Sender {
112+ sender,
113+ disconnected,
92114 } ;
93- ( Sender ( sender) , encoder)
115+
116+ ( sender, encoder)
94117}
95118
119+ /// An error that represents that the [Encoder] has been dropped.
120+ #[ derive( Debug , Eq , PartialEq ) ]
121+ pub struct DisconnectedError ;
122+ impl std:: error:: Error for DisconnectedError { }
123+ impl std:: fmt:: Display for DisconnectedError {
124+ fn fmt ( & self , f : & mut std:: fmt:: Formatter < ' _ > ) -> std:: fmt:: Result {
125+ write ! ( f, "Disconnected" )
126+ }
127+ }
128+
129+ #[ must_use]
96130impl Sender {
97131 /// Send a new message over SSE.
98- pub async fn send ( & self , name : & str , data : & str , id : Option < & str > ) {
132+ pub async fn send (
133+ & self ,
134+ name : & str ,
135+ data : & str ,
136+ id : Option < & str > ,
137+ ) -> Result < ( ) , DisconnectedError > {
138+ if self . disconnected . load ( Ordering :: Relaxed ) {
139+ return Err ( DisconnectedError ) ;
140+ }
141+
99142 // Write the event name
100143 let msg = format ! ( "event:{}\n " , name) ;
101- self . 0 . send ( msg. into_bytes ( ) ) . await ;
144+ self . sender . send ( msg. into_bytes ( ) ) . await ;
102145
103146 // Write the id
104147 if let Some ( id) = id {
105- self . 0 . send ( format ! ( "id:{}\n " , id) . into_bytes ( ) ) . await ;
148+ self . sender . send ( format ! ( "id:{}\n " , id) . into_bytes ( ) ) . await ;
106149 }
107150
108151 // Write the data section, and end.
109152 let msg = format ! ( "data:{}\n \n " , data) ;
110- self . 0 . send ( msg. into_bytes ( ) ) . await ;
153+ self . sender . send ( msg. into_bytes ( ) ) . await ;
154+ Ok ( ( ) )
111155 }
112156
113157 /// Send a new "retry" message over SSE.
114158 pub async fn send_retry ( & self , dur : Duration , id : Option < & str > ) {
115159 // Write the id
116160 if let Some ( id) = id {
117- self . 0 . send ( format ! ( "id:{}\n " , id) . into_bytes ( ) ) . await ;
161+ self . sender . send ( format ! ( "id:{}\n " , id) . into_bytes ( ) ) . await ;
118162 }
119163
120164 // Write the retry section, and end.
121165 let dur = dur. as_secs_f64 ( ) as u64 ;
122166 let msg = format ! ( "retry:{}\n \n " , dur) ;
123- self . 0 . send ( msg. into_bytes ( ) ) . await ;
124- }
125- }
126-
127- impl Clone for Sender {
128- fn clone ( & self ) -> Self {
129- Self ( self . 0 . clone ( ) )
167+ self . sender . send ( msg. into_bytes ( ) ) . await ;
130168 }
131169}
0 commit comments