Skip to content

Commit ae91dc3

Browse files
authored
Merge pull request #78 from TCeason/debug-bak
feat(mysql): encode OK packet info when session tracking
2 parents b39d56d + 90d29d4 commit ae91dc3

File tree

8 files changed

+261
-2
lines changed

8 files changed

+261
-2
lines changed

mysql/examples/serve_auth.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ impl<W: AsyncWrite + Send + Unpin> AsyncMysqlShim<W> for Backend {
6363
let cols = &[Column {
6464
table: String::new(),
6565
column: "abc".to_string(),
66+
collen: 0,
6667
coltype: myc::constants::ColumnType::MYSQL_TYPE_LONG,
6768
colflags: myc::constants::ColumnFlags::UNSIGNED_FLAG,
6869
}];

mysql/src/lib.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,9 @@ pub struct Column {
6868
///
6969
/// Note that this is *technically* the column's alias.
7070
pub column: String,
71+
/// Column length (in bytes) reported through COLUMN_DEFINITION41.
72+
/// 0 means "use default".
73+
pub collen: u32,
7174
/// This column's type>
7275
pub coltype: ColumnType,
7376
/// Any flags associated with this column.
@@ -310,6 +313,7 @@ where
310313
| CapabilityFlags::CLIENT_PLUGIN_AUTH
311314
| CapabilityFlags::CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA
312315
| CapabilityFlags::CLIENT_CONNECT_WITH_DB
316+
| CapabilityFlags::CLIENT_SESSION_TRACK
313317
| CapabilityFlags::CLIENT_DEPRECATE_EOF;
314318

315319
#[cfg(feature = "tls")]
@@ -591,6 +595,7 @@ where
591595
let cols = &[Column {
592596
table: String::new(),
593597
column: String::from_utf8_lossy(var_with_at).to_string(),
598+
collen: 0,
594599
coltype: myc::constants::ColumnType::MYSQL_TYPE_LONG,
595600
colflags: myc::constants::ColumnFlags::UNSIGNED_FLAG,
596601
}];

mysql/src/tests/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,4 @@
1515
mod commands;
1616
mod packet;
1717
mod value;
18+
mod writers;

mysql/src/tests/value/decode.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ macro_rules! rt {
3030
let mut col = Column {
3131
table: String::new(),
3232
column: String::new(),
33+
collen: 0,
3334
coltype: $ct,
3435
colflags: ColumnFlags::empty(),
3536
};

mysql/src/tests/value/encode.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ mod roundtrip_bin {
105105
let mut col = Column {
106106
table: String::new(),
107107
column: String::new(),
108+
collen: 0,
108109
coltype: $ct,
109110
colflags: ColumnFlags::empty(),
110111
};

mysql/src/tests/writers.rs

Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
// Copyright 2021 Datafuse Labs.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
// Note to developers: you can find decent overviews of the protocol at
16+
//
17+
// https://github.com/cwarden/mysql-proxy/blob/master/doc/protocol.rst
18+
//
19+
// and
20+
//
21+
// https://mariadb.com/kb/en/library/clientserver-protocol/
22+
//
23+
// Wireshark also does a pretty good job at parsing the MySQL protocol.
24+
25+
use tokio::io::{duplex, AsyncReadExt};
26+
27+
use crate::packet_writer::PacketWriter;
28+
use crate::writers::write_ok_packet;
29+
use crate::{CapabilityFlags, OkResponse};
30+
31+
async fn capture_ok_payload(info: &str, capabilities: CapabilityFlags, header: u8) -> Vec<u8> {
32+
let (mut client, server) = duplex(1024);
33+
let mut writer = PacketWriter::new(server);
34+
35+
let ok_packet = OkResponse {
36+
header,
37+
info: info.to_string(),
38+
..Default::default()
39+
};
40+
41+
write_ok_packet(&mut writer, capabilities, ok_packet)
42+
.await
43+
.expect("write_ok_packet succeeds");
44+
45+
let mut header_buf = [0u8; 4];
46+
client
47+
.read_exact(&mut header_buf)
48+
.await
49+
.expect("payload header available");
50+
let payload_len = (header_buf[0] as usize)
51+
| ((header_buf[1] as usize) << 8)
52+
| ((header_buf[2] as usize) << 16);
53+
let mut payload = vec![0u8; payload_len];
54+
client
55+
.read_exact(&mut payload)
56+
.await
57+
.expect("payload body available");
58+
payload
59+
}
60+
61+
fn parse_lenenc_int(data: &[u8]) -> (u64, usize) {
62+
match data[0] {
63+
0xFC => {
64+
let len = u16::from_le_bytes([data[1], data[2]]) as u64;
65+
(len, 3)
66+
}
67+
0xFD => {
68+
let len = (data[1] as u64) | ((data[2] as u64) << 8) | ((data[3] as u64) << 16);
69+
(len, 4)
70+
}
71+
0xFE => {
72+
let mut buf = [0u8; 8];
73+
buf.copy_from_slice(&data[1..9]);
74+
(u64::from_le_bytes(buf), 9)
75+
}
76+
v => (v as u64, 1),
77+
}
78+
}
79+
80+
fn consume_ok_prefix(payload: &[u8]) -> (usize, u8, u16, u16) {
81+
let mut idx = 0;
82+
let header = payload[idx];
83+
idx += 1;
84+
85+
let (affected_rows, consumed) = parse_lenenc_int(&payload[idx..]);
86+
assert_eq!(affected_rows, 0);
87+
idx += consumed;
88+
89+
let (last_insert_id, consumed) = parse_lenenc_int(&payload[idx..]);
90+
assert_eq!(last_insert_id, 0);
91+
idx += consumed;
92+
93+
let status = u16::from_le_bytes([payload[idx], payload[idx + 1]]);
94+
idx += 2;
95+
96+
let warnings = u16::from_le_bytes([payload[idx], payload[idx + 1]]);
97+
idx += 2;
98+
99+
(idx, header, status, warnings)
100+
}
101+
102+
#[tokio::test]
103+
async fn ok_packet_info_lenenc_when_session_track() {
104+
let info = "Read 1 rows, 1.00 B in 0.007 sec.";
105+
let payload = capture_ok_payload(
106+
info,
107+
CapabilityFlags::CLIENT_PROTOCOL_41 | CapabilityFlags::CLIENT_SESSION_TRACK,
108+
0x00,
109+
)
110+
.await;
111+
112+
let (mut idx, header, status, warnings) = consume_ok_prefix(&payload);
113+
assert_eq!(header, 0x00);
114+
assert_eq!(status, 0);
115+
assert_eq!(warnings, 0);
116+
117+
let (info_len, consumed) = parse_lenenc_int(&payload[idx..]);
118+
assert_eq!(info_len as usize, info.len());
119+
idx += consumed;
120+
121+
let encoded = &payload[idx..idx + info.len()];
122+
assert_eq!(encoded, info.as_bytes());
123+
}
124+
125+
#[tokio::test]
126+
async fn ok_packet_info_lenenc_when_deprecate_eof() {
127+
let info = "Read 1 rows, 1.00 B in 0.007 sec.";
128+
let payload = capture_ok_payload(
129+
info,
130+
CapabilityFlags::CLIENT_PROTOCOL_41 | CapabilityFlags::CLIENT_DEPRECATE_EOF,
131+
0x00,
132+
)
133+
.await;
134+
135+
let (mut idx, header, status, warnings) = consume_ok_prefix(&payload);
136+
assert_eq!(header, 0x00);
137+
assert_eq!(status, 0);
138+
assert_eq!(warnings, 0);
139+
140+
let (info_len, consumed) = parse_lenenc_int(&payload[idx..]);
141+
assert_eq!(info_len as usize, info.len());
142+
idx += consumed;
143+
144+
let encoded = &payload[idx..idx + info.len()];
145+
assert_eq!(encoded, info.as_bytes());
146+
}
147+
148+
#[tokio::test]
149+
async fn ok_packet_info_lenenc_when_header_fe() {
150+
let info = "Read 1 rows, 1.00 B in 0.007 sec.";
151+
let payload = capture_ok_payload(info, CapabilityFlags::CLIENT_PROTOCOL_41, 0xfe).await;
152+
153+
let (mut idx, header, status, warnings) = consume_ok_prefix(&payload);
154+
assert_eq!(header, 0xfe);
155+
assert_eq!(status, 0);
156+
assert_eq!(warnings, 0);
157+
158+
let (info_len, consumed) = parse_lenenc_int(&payload[idx..]);
159+
assert_eq!(info_len as usize, info.len());
160+
idx += consumed;
161+
162+
let encoded = &payload[idx..idx + info.len()];
163+
assert_eq!(encoded, info.as_bytes());
164+
}
165+
166+
#[tokio::test]
167+
async fn ok_packet_info_plain_when_no_flags() {
168+
let info = "Read 1 rows, 1.00 B in 0.007 sec.";
169+
let payload = capture_ok_payload(info, CapabilityFlags::CLIENT_PROTOCOL_41, 0x00).await;
170+
171+
let (idx, header, status, warnings) = consume_ok_prefix(&payload);
172+
assert_eq!(header, 0x00);
173+
assert_eq!(status, 0);
174+
assert_eq!(warnings, 0);
175+
176+
let encoded = &payload[idx..];
177+
assert_eq!(encoded, info.as_bytes());
178+
}
179+
180+
#[tokio::test]
181+
async fn ok_packet_info_extended_lenenc_with_flags() {
182+
let info = "x".repeat(300);
183+
let payload = capture_ok_payload(
184+
&info,
185+
CapabilityFlags::CLIENT_PROTOCOL_41 | CapabilityFlags::CLIENT_SESSION_TRACK,
186+
0x00,
187+
)
188+
.await;
189+
190+
let (mut idx, header, status, warnings) = consume_ok_prefix(&payload);
191+
assert_eq!(header, 0x00);
192+
assert_eq!(status, 0);
193+
assert_eq!(warnings, 0);
194+
195+
let (info_len, consumed) = parse_lenenc_int(&payload[idx..]);
196+
assert_eq!(consumed, 3); // expect 0xFC marker with two-byte length
197+
assert_eq!(payload[idx], 0xFC);
198+
assert_eq!(info_len as usize, info.len());
199+
idx += consumed;
200+
201+
let encoded = &payload[idx..idx + info.len()];
202+
assert_eq!(encoded, info.as_bytes());
203+
}

mysql/src/writers.rs

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,9 +82,12 @@ pub(crate) async fn write_ok_packet<W: AsyncWrite + Unpin>(
8282

8383
// Only session-tracking clients expect length-encoded info per protocol; otherwise emit raw text.
8484
let has_session_track = client_capabilities.contains(CapabilityFlags::CLIENT_SESSION_TRACK);
85+
let expect_lenenc_info = has_session_track
86+
|| ok_packet.header == 0xfe
87+
|| client_capabilities.contains(CapabilityFlags::CLIENT_DEPRECATE_EOF);
8588
let send_info = !ok_packet.info.is_empty() || has_session_track;
8689
if send_info {
87-
if has_session_track {
90+
if expect_lenenc_info {
8891
w.write_lenenc_str(ok_packet.info.as_bytes())?;
8992
} else {
9093
w.write_all(ok_packet.info.as_bytes())?;
@@ -171,7 +174,8 @@ where
171174
w.write_lenenc_str(b"")?;
172175
w.write_lenenc_int(0xC)?;
173176
w.write_u16::<LittleEndian>(column_charset(c))?;
174-
w.write_u32::<LittleEndian>(1024)?;
177+
let column_length = if c.collen == 0 { 1024 } else { c.collen };
178+
w.write_u32::<LittleEndian>(column_length)?;
175179
w.write_u8(c.coltype as u8)?;
176180
w.write_u16::<LittleEndian>(c.colflags.bits())?;
177181
w.write_all(&[0x00])?; // decimals
@@ -302,6 +306,7 @@ mod tests {
302306
let column = Column {
303307
table: "t".into(),
304308
column: "c".into(),
309+
collen: 0,
305310
coltype: ColumnType::MYSQL_TYPE_VAR_STRING,
306311
colflags: ColumnFlags::empty(),
307312
};
@@ -314,6 +319,7 @@ mod tests {
314319
let column = Column {
315320
table: "t".into(),
316321
column: "c".into(),
322+
collen: 0,
317323
coltype: ColumnType::MYSQL_TYPE_LONG,
318324
colflags: ColumnFlags::empty(),
319325
};
@@ -326,6 +332,7 @@ mod tests {
326332
let column = Column {
327333
table: "t".into(),
328334
column: "c".into(),
335+
collen: 0,
329336
coltype: ColumnType::MYSQL_TYPE_STRING,
330337
colflags: ColumnFlags::BINARY_FLAG,
331338
};
@@ -338,6 +345,7 @@ mod tests {
338345
let column = Column {
339346
table: "t".into(),
340347
column: "c".into(),
348+
collen: 0,
341349
coltype: ColumnType::MYSQL_TYPE_BLOB,
342350
colflags: ColumnFlags::empty(),
343351
};
@@ -350,6 +358,7 @@ mod tests {
350358
let column = Column {
351359
table: "t".into(),
352360
column: "c".into(),
361+
collen: 0,
353362
coltype: ColumnType::MYSQL_TYPE_VAR_STRING,
354363
colflags: ColumnFlags::BLOB_FLAG,
355364
};

0 commit comments

Comments
 (0)