Skip to content

Commit cf31ef3

Browse files
committed
src/rptocol/tds: add TDS 8.0 protocol interface
TDS 8.0 Allows a TLS only protocol and can switch between MySQL and TDS wire protocols. TDS supports inline cancelation and transaction rpc start, rollback, and commit.
1 parent b370f89 commit cf31ef3

File tree

13 files changed

+3314
-11
lines changed

13 files changed

+3314
-11
lines changed

src/executor/scan.rs

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ use std::sync::Arc;
66

77
use async_trait::async_trait;
88

9+
use tracing::debug;
10+
911
use crate::planner::logical::ResolvedExpr;
1012
use crate::txn::MvccStorage;
1113

@@ -58,14 +60,23 @@ impl Executor for TableScan {
5860
let prefix = table_key_prefix(&self.table);
5961
let end = table_key_end(&self.table);
6062

63+
debug!(table = %self.table, "TableScan: starting scan");
64+
6165
// Use MVCC scan with visibility filtering if we have a transaction context,
6266
// otherwise fall back to raw storage scan (for DDL or legacy tests)
6367
let kv_pairs = if let Some(ref ctx) = self.txn_context {
64-
self.mvcc
68+
debug!(table = %self.table, "TableScan: calling mvcc.scan()");
69+
let result = self
70+
.mvcc
6571
.scan(Some(&prefix), Some(&end), &ctx.read_view)
66-
.await?
72+
.await?;
73+
debug!(table = %self.table, rows = result.len(), "TableScan: mvcc.scan() returned");
74+
result
6775
} else {
68-
self.mvcc.inner().scan(Some(&prefix), Some(&end)).await?
76+
debug!(table = %self.table, "TableScan: calling inner.scan()");
77+
let result = self.mvcc.inner().scan(Some(&prefix), Some(&end)).await?;
78+
debug!(table = %self.table, rows = result.len(), "TableScan: inner.scan() returned");
79+
result
6980
};
7081

7182
// Collect keys that are buffered (for read-your-writes merge)

src/main.rs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@ struct Cli {
3030
key_path: PathBuf,
3131
#[arg(long, default_value = "./certs/ca.crt", env = "ROODB_CA_CERT_PATH")]
3232
raft_ca_cert_path: PathBuf,
33+
/// Enable TDS 8.0 protocol support (in addition to MySQL protocol)
34+
#[arg(long, default_value_t = false, env = "ROODB_TDS")]
35+
tds: bool,
3336
}
3437

3538
#[tokio::main]
@@ -51,6 +54,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
5154
cert_path,
5255
key_path,
5356
raft_ca_cert_path,
57+
tds,
5458
} = Cli::parse();
5559

5660
tracing::info!(port, ?data_dir, "Starting RooDB");
@@ -107,7 +111,11 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
107111

108112
// Start RooDB server
109113
let addr: SocketAddr = format!("0.0.0.0:{}", port).parse()?;
110-
let server = RooDbServer::new(addr, tls_config, storage, catalog, raft_node);
114+
let mut server = RooDbServer::new(addr, tls_config, storage, catalog, raft_node);
115+
if tds {
116+
tracing::info!("TDS 8.0 protocol enabled");
117+
server.enable_tds();
118+
}
111119
server.run().await?;
112120

113121
Ok(())

src/protocol/mod.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
//! Wire protocol implementations
22
//!
3-
//! Currently supports RooDB client protocol over TLS.
3+
//! Supports RooDB client protocol (MySQL-compatible) and TDS protocol over TLS.
44
55
pub mod roodb;
6+
pub mod tds;
67

78
pub use roodb::RooDbConnection;
9+
pub use tds::TdsConnection;

src/protocol/tds/codec.rs

Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,225 @@
1+
//! TDS packet framing: 8-byte header, multi-packet message assembly.
2+
3+
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
4+
5+
/// TDS packet header size
6+
const HEADER_SIZE: usize = 8;
7+
/// Maximum TDS packet size (matches client default)
8+
const MAX_PACKET_SIZE: usize = 4096;
9+
/// Maximum body per packet
10+
const MAX_BODY_SIZE: usize = MAX_PACKET_SIZE - HEADER_SIZE;
11+
12+
/// TDS packet types
13+
#[repr(u8)]
14+
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15+
pub enum PacketType {
16+
SqlBatch = 1,
17+
Rpc = 3,
18+
TabularResult = 4,
19+
Attention = 6,
20+
BulkLoad = 7,
21+
TransactionManager = 0x0E,
22+
Tds7Login = 0x10,
23+
PreLogin = 0x12,
24+
}
25+
26+
impl TryFrom<u8> for PacketType {
27+
type Error = PacketError;
28+
fn try_from(v: u8) -> Result<Self, Self::Error> {
29+
match v {
30+
1 => Ok(Self::SqlBatch),
31+
3 => Ok(Self::Rpc),
32+
4 => Ok(Self::TabularResult),
33+
6 => Ok(Self::Attention),
34+
7 => Ok(Self::BulkLoad),
35+
0x0E => Ok(Self::TransactionManager),
36+
0x10 => Ok(Self::Tds7Login),
37+
0x12 => Ok(Self::PreLogin),
38+
_ => Err(PacketError::UnknownType(v)),
39+
}
40+
}
41+
}
42+
43+
/// Packet status flags
44+
const STATUS_EOM: u8 = 0x01;
45+
46+
#[derive(Debug, thiserror::Error)]
47+
pub enum PacketError {
48+
#[error("IO error: {0}")]
49+
Io(#[from] std::io::Error),
50+
#[error("unknown packet type: 0x{0:02x}")]
51+
UnknownType(u8),
52+
#[error("packet too large: {0}")]
53+
TooLarge(usize),
54+
#[error("connection closed")]
55+
ConnectionClosed,
56+
#[error("unexpected packet type: expected {expected:?}, got {got:?}")]
57+
UnexpectedType {
58+
expected: PacketType,
59+
got: PacketType,
60+
},
61+
}
62+
63+
/// Reads complete TDS messages (potentially spanning multiple packets).
64+
pub struct TdsReader<R> {
65+
reader: R,
66+
}
67+
68+
impl<R: AsyncRead + Unpin> TdsReader<R> {
69+
pub fn new(reader: R) -> Self {
70+
Self { reader }
71+
}
72+
73+
/// Read a complete TDS message. Returns (packet_type, payload).
74+
/// Assembles multiple packets until EOM is set.
75+
pub async fn read_message(&mut self) -> Result<(PacketType, Vec<u8>), PacketError> {
76+
let mut payload = Vec::new();
77+
let mut msg_type = None;
78+
79+
loop {
80+
let mut header = [0u8; HEADER_SIZE];
81+
match self.reader.read_exact(&mut header).await {
82+
Ok(_) => {}
83+
Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
84+
return Err(PacketError::ConnectionClosed);
85+
}
86+
Err(e) => return Err(PacketError::Io(e)),
87+
}
88+
89+
let pkt_type = PacketType::try_from(header[0])?;
90+
let status = header[1];
91+
let length = u16::from_be_bytes([header[2], header[3]]) as usize;
92+
93+
if length < HEADER_SIZE || length > MAX_PACKET_SIZE {
94+
return Err(PacketError::TooLarge(length));
95+
}
96+
97+
// Verify consistent packet type across multi-packet messages
98+
if let Some(expected) = msg_type {
99+
if pkt_type != expected {
100+
return Err(PacketError::UnexpectedType {
101+
expected,
102+
got: pkt_type,
103+
});
104+
}
105+
} else {
106+
msg_type = Some(pkt_type);
107+
}
108+
109+
let body_len = length - HEADER_SIZE;
110+
if body_len > 0 {
111+
let start = payload.len();
112+
payload.resize(start + body_len, 0);
113+
self.reader.read_exact(&mut payload[start..]).await?;
114+
}
115+
116+
if status & STATUS_EOM != 0 {
117+
break;
118+
}
119+
}
120+
121+
Ok((msg_type.unwrap(), payload))
122+
}
123+
}
124+
125+
/// Writes TDS messages, splitting into packets as needed.
126+
pub struct TdsWriter<W> {
127+
writer: W,
128+
packet_number: u8,
129+
}
130+
131+
impl<W: AsyncWrite + Unpin> TdsWriter<W> {
132+
pub fn new(writer: W) -> Self {
133+
Self {
134+
writer,
135+
packet_number: 0,
136+
}
137+
}
138+
139+
/// Write a complete message, splitting into packets if needed.
140+
pub async fn write_message(
141+
&mut self,
142+
pkt_type: PacketType,
143+
data: &[u8],
144+
) -> Result<(), PacketError> {
145+
self.packet_number = 0;
146+
147+
if data.is_empty() {
148+
// Send a single empty EOM packet
149+
self.write_packet(pkt_type, &[], true).await?;
150+
return Ok(());
151+
}
152+
153+
let mut offset = 0;
154+
while offset < data.len() {
155+
let remaining = data.len() - offset;
156+
let chunk_size = remaining.min(MAX_BODY_SIZE);
157+
let is_last = offset + chunk_size >= data.len();
158+
159+
self.write_packet(pkt_type, &data[offset..offset + chunk_size], is_last)
160+
.await?;
161+
offset += chunk_size;
162+
}
163+
164+
Ok(())
165+
}
166+
167+
async fn write_packet(
168+
&mut self,
169+
pkt_type: PacketType,
170+
body: &[u8],
171+
eom: bool,
172+
) -> Result<(), PacketError> {
173+
let length = (HEADER_SIZE + body.len()) as u16;
174+
let status = if eom { STATUS_EOM } else { 0 };
175+
176+
let mut header = [0u8; HEADER_SIZE];
177+
header[0] = pkt_type as u8;
178+
header[1] = status;
179+
header[2..4].copy_from_slice(&length.to_be_bytes());
180+
// SPID = 0
181+
header[6] = self.packet_number;
182+
self.packet_number = self.packet_number.wrapping_add(1);
183+
// Window = 0
184+
185+
self.writer.write_all(&header).await?;
186+
if !body.is_empty() {
187+
self.writer.write_all(body).await?;
188+
}
189+
self.writer.flush().await?;
190+
191+
Ok(())
192+
}
193+
}
194+
195+
#[cfg(test)]
196+
mod tests {
197+
use super::*;
198+
199+
#[tokio::test]
200+
async fn test_roundtrip() {
201+
let data = b"hello world";
202+
let mut buf = Vec::new();
203+
204+
// Write
205+
{
206+
let mut writer = TdsWriter::new(&mut buf);
207+
writer
208+
.write_message(PacketType::TabularResult, data)
209+
.await
210+
.unwrap();
211+
}
212+
213+
// Verify header
214+
assert_eq!(buf[0], PacketType::TabularResult as u8);
215+
assert_eq!(buf[1], STATUS_EOM);
216+
let len = u16::from_be_bytes([buf[2], buf[3]]) as usize;
217+
assert_eq!(len, HEADER_SIZE + data.len());
218+
219+
// Read back
220+
let mut reader = TdsReader::new(&buf[..]);
221+
let (pkt_type, payload) = reader.read_message().await.unwrap();
222+
assert_eq!(pkt_type, PacketType::TabularResult);
223+
assert_eq!(payload, data);
224+
}
225+
}

0 commit comments

Comments
 (0)