Skip to content

Commit ab87e97

Browse files
committed
fix interactive transactions over websocket
1 parent ad25f94 commit ab87e97

File tree

2 files changed

+40
-10
lines changed

2 files changed

+40
-10
lines changed

src/sqlite.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -359,10 +359,10 @@ async fn execute_sql_and_params(
359359
let mut request = db.connection.get_json_request(db, sql, &params);
360360
match db.connection.send(&mut request).await {
361361
Ok(response) => return Ok(response),
362-
Err(_) => {
362+
Err(err) => {
363363
db.connection.strategy = transport::ActiveStrategy::Http;
364364
if cfg!(debug_assertions) {
365-
println!("WebSocket failed, retrying with HTTP...");
365+
eprintln!("WebSocket failed, retrying with HTTP... {}", err);
366366
}
367367
}
368368
}

src/transport/wss.rs

Lines changed: 38 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -227,8 +227,24 @@ impl WebSocketStrategy {
227227

228228
impl LibsqlInterface for WebSocketStrategy {
229229
async fn get_transaction_baton(&mut self, sql: &str) -> Result<String, SqliteError> {
230-
// Implementation for WebSocket transport
231-
unimplemented!()
230+
let (stream_id, _) = self.open_stream().await?;
231+
let mut request = serde_json::json!({
232+
"type": "execute",
233+
"stream_id": stream_id,
234+
"stmt": {
235+
"sql": sql
236+
}
237+
});
238+
239+
let result = self.send(&mut request).await;
240+
if let Err(e) = result {
241+
return Err(SqliteError::new(
242+
format!("Failed to get transaction baton: {}", e),
243+
Some(SQLITE_ERROR),
244+
));
245+
}
246+
247+
Ok(stream_id.to_string())
232248
}
233249

234250
async fn send(
@@ -242,10 +258,17 @@ impl LibsqlInterface for WebSocketStrategy {
242258
));
243259
}
244260

245-
let (stream_id, bus) = self.open_stream().await?;
246-
let request_id = WebSocketStrategy::next_request_id();
247-
request["stream_id"] = serde_json::Value::from(stream_id);
261+
let bus: ResponseBus;
248262

263+
if request.get("stream_id").is_none() {
264+
let (stream_id, actual_bus) = self.open_stream().await?;
265+
request["stream_id"] = serde_json::Value::from(stream_id);
266+
bus = actual_bus;
267+
} else {
268+
bus = self.bus.clone();
269+
}
270+
271+
let request_id = WebSocketStrategy::next_request_id();
249272
let request = serde_json::json!({
250273
"type": "request",
251274
"request_id": request_id,
@@ -310,16 +333,23 @@ impl LibsqlInterface for WebSocketStrategy {
310333
&self,
311334
sql: &str,
312335
params: &Vec<serde_json::Value>,
313-
baton: Option<&String>,
336+
stream_id: Option<&String>,
314337
is_transacting: bool,
315338
) -> serde_json::Value {
316-
serde_json::json!({
339+
let mut request = serde_json::json!({
317340
"type": "execute",
318341
"stmt": {
319342
"sql": sql,
320343
"args": params
321344
}
322-
})
345+
});
346+
347+
if is_transacting {
348+
let stream_id: i32 = stream_id.and_then(|s| s.parse::<i32>().ok()).unwrap();
349+
request["stream_id"] = serde_json::json!(stream_id);
350+
}
351+
352+
request
323353
}
324354
}
325355

0 commit comments

Comments
 (0)