@@ -227,8 +227,24 @@ impl WebSocketStrategy {
227227
228228impl LibsqlInterface for WebSocketStrategy {
229229 async fn get_transaction_baton ( & mut self , sql : & str ) -> Result < String , SqliteError > {
230- // Implementation for WebSocket transport
231- unimplemented ! ( )
230+ let ( stream_id, _) = self . open_stream ( ) . await ?;
231+ let mut request = serde_json:: json!( {
232+ "type" : "execute" ,
233+ "stream_id" : stream_id,
234+ "stmt" : {
235+ "sql" : sql
236+ }
237+ } ) ;
238+
239+ let result = self . send ( & mut request) . await ;
240+ if let Err ( e) = result {
241+ return Err ( SqliteError :: new (
242+ format ! ( "Failed to get transaction baton: {}" , e) ,
243+ Some ( SQLITE_ERROR ) ,
244+ ) ) ;
245+ }
246+
247+ Ok ( stream_id. to_string ( ) )
232248 }
233249
234250 async fn send (
@@ -242,10 +258,17 @@ impl LibsqlInterface for WebSocketStrategy {
242258 ) ) ;
243259 }
244260
245- let ( stream_id, bus) = self . open_stream ( ) . await ?;
246- let request_id = WebSocketStrategy :: next_request_id ( ) ;
247- request[ "stream_id" ] = serde_json:: Value :: from ( stream_id) ;
261+ let bus: ResponseBus ;
248262
263+ if request. get ( "stream_id" ) . is_none ( ) {
264+ let ( stream_id, actual_bus) = self . open_stream ( ) . await ?;
265+ request[ "stream_id" ] = serde_json:: Value :: from ( stream_id) ;
266+ bus = actual_bus;
267+ } else {
268+ bus = self . bus . clone ( ) ;
269+ }
270+
271+ let request_id = WebSocketStrategy :: next_request_id ( ) ;
249272 let request = serde_json:: json!( {
250273 "type" : "request" ,
251274 "request_id" : request_id,
@@ -310,16 +333,23 @@ impl LibsqlInterface for WebSocketStrategy {
310333 & self ,
311334 sql : & str ,
312335 params : & Vec < serde_json:: Value > ,
313- baton : Option < & String > ,
336+ stream_id : Option < & String > ,
314337 is_transacting : bool ,
315338 ) -> serde_json:: Value {
316- serde_json:: json!( {
339+ let mut request = serde_json:: json!( {
317340 "type" : "execute" ,
318341 "stmt" : {
319342 "sql" : sql,
320343 "args" : params
321344 }
322- } )
345+ } ) ;
346+
347+ if is_transacting {
348+ let stream_id: i32 = stream_id. and_then ( |s| s. parse :: < i32 > ( ) . ok ( ) ) . unwrap ( ) ;
349+ request[ "stream_id" ] = serde_json:: json!( stream_id) ;
350+ }
351+
352+ request
323353 }
324354}
325355
0 commit comments