Skip to content

Commit 3e141d3

Browse files
Copilotbgrainger
andcommitted
Add server ID verification for KILL QUERY commands
Co-authored-by: bgrainger <[email protected]>
1 parent a6b86d7 commit 3e141d3

File tree

5 files changed

+290
-8
lines changed

5 files changed

+290
-8
lines changed

src/MySqlConnector/Core/ServerSession.cs

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ public ServerSession(ILogger logger, IConnectionPoolMetadata pool)
4747
public int ActiveCommandId { get; private set; }
4848
public int CancellationTimeout { get; private set; }
4949
public int ConnectionId { get; set; }
50+
public string? ServerUuid { get; set; }
51+
public long? ServerId { get; set; }
5052
public byte[]? AuthPluginData { get; set; }
5153
public long CreatedTimestamp { get; }
5254
public ConnectionPool? Pool { get; }
@@ -117,6 +119,14 @@ public void DoCancel(ICancellableCommand commandToCancel, MySqlCommand killComma
117119
return;
118120
}
119121

122+
// Verify server identity before executing KILL QUERY to prevent cancelling on the wrong server
123+
var killSession = killCommand.Connection!.Session;
124+
if (!VerifyServerIdentity(killSession))
125+
{
126+
Log.IgnoringCancellationForDifferentServer(m_logger, Id, killSession.Id, ServerUuid, killSession.ServerUuid, ServerId, killSession.ServerId);
127+
return;
128+
}
129+
120130
// NOTE: This command is executed while holding the lock to prevent race conditions during asynchronous cancellation.
121131
// For example, if the lock weren't held, the current command could finish and the other thread could set ActiveCommandId
122132
// to zero, then start executing a new command. By the time this "KILL QUERY" command reached the server, the wrong
@@ -137,6 +147,26 @@ public void AbortCancel(ICancellableCommand command)
137147
}
138148
}
139149

150+
private bool VerifyServerIdentity(ServerSession otherSession)
151+
{
152+
// If server UUID is available, use it as the primary identifier (most unique)
153+
if (!string.IsNullOrEmpty(ServerUuid) && !string.IsNullOrEmpty(otherSession.ServerUuid))
154+
{
155+
return string.Equals(ServerUuid, otherSession.ServerUuid, StringComparison.Ordinal);
156+
}
157+
158+
// Fall back to server ID if UUID is not available
159+
if (ServerId.HasValue && otherSession.ServerId.HasValue)
160+
{
161+
return ServerId.Value == otherSession.ServerId.Value;
162+
}
163+
164+
// If no server identification is available, allow the operation to proceed
165+
// This maintains backward compatibility with older MySQL versions
166+
Log.NoServerIdentificationForVerification(m_logger, Id, otherSession.Id);
167+
return true;
168+
}
169+
140170
public bool IsCancelingQuery => m_state == State.CancelingQuery;
141171

142172
public async Task PrepareAsync(IMySqlCommand command, IOBehavior ioBehavior, CancellationToken cancellationToken)
@@ -635,6 +665,9 @@ public async Task DisposeAsync(IOBehavior ioBehavior, CancellationToken cancella
635665
ConnectionId = newConnectionId;
636666
}
637667

668+
// Get server identification for KILL QUERY verification
669+
await GetServerIdentificationAsync(ioBehavior, CancellationToken.None).ConfigureAwait(false);
670+
638671
m_payloadHandler.ByteHandler.RemainingTimeout = Constants.InfiniteTimeout;
639672
return redirectionUrl;
640673
}
@@ -1951,6 +1984,90 @@ private async Task GetRealServerDetailsAsync(IOBehavior ioBehavior, Cancellation
19511984
}
19521985
}
19531986

1987+
private async Task GetServerIdentificationAsync(IOBehavior ioBehavior, CancellationToken cancellationToken)
1988+
{
1989+
Log.GettingServerIdentification(m_logger, Id);
1990+
try
1991+
{
1992+
PayloadData payload;
1993+
1994+
// Try to get both server_uuid and server_id if server supports server_uuid (MySQL 5.6+)
1995+
if (!ServerVersion.IsMariaDb && ServerVersion.Version >= ServerVersions.SupportsServerUuid)
1996+
{
1997+
payload = SupportsQueryAttributes ? s_selectServerIdWithAttributesPayload : s_selectServerIdNoAttributesPayload;
1998+
await SendAsync(payload, ioBehavior, cancellationToken).ConfigureAwait(false);
1999+
2000+
// column count: 2
2001+
_ = await ReceiveReplyAsync(ioBehavior, cancellationToken).ConfigureAwait(false);
2002+
2003+
// @@server_uuid and @@server_id columns
2004+
_ = await ReceiveReplyAsync(ioBehavior, CancellationToken.None).ConfigureAwait(false);
2005+
_ = await ReceiveReplyAsync(ioBehavior, CancellationToken.None).ConfigureAwait(false);
2006+
2007+
if (!SupportsDeprecateEof)
2008+
{
2009+
payload = await ReceiveReplyAsync(ioBehavior, CancellationToken.None).ConfigureAwait(false);
2010+
_ = EofPayload.Create(payload.Span);
2011+
}
2012+
2013+
// first (and only) row
2014+
payload = await ReceiveReplyAsync(ioBehavior, CancellationToken.None).ConfigureAwait(false);
2015+
2016+
var reader = new ByteArrayReader(payload.Span);
2017+
var length = reader.ReadLengthEncodedIntegerOrNull();
2018+
var serverUuid = length != -1 ? Encoding.UTF8.GetString(reader.ReadByteString(length)) : null;
2019+
length = reader.ReadLengthEncodedIntegerOrNull();
2020+
var serverId = (length != -1 && Utf8Parser.TryParse(reader.ReadByteString(length), out long id, out _)) ? id : default(long?);
2021+
2022+
ServerUuid = serverUuid;
2023+
ServerId = serverId;
2024+
2025+
Log.RetrievedServerIdentification(m_logger, Id, serverUuid, serverId);
2026+
}
2027+
else
2028+
{
2029+
// Fall back to just server_id for older versions or MariaDB
2030+
payload = SupportsQueryAttributes ? s_selectServerIdOnlyWithAttributesPayload : s_selectServerIdOnlyNoAttributesPayload;
2031+
await SendAsync(payload, ioBehavior, cancellationToken).ConfigureAwait(false);
2032+
2033+
// column count: 1
2034+
_ = await ReceiveReplyAsync(ioBehavior, cancellationToken).ConfigureAwait(false);
2035+
2036+
// @@server_id column
2037+
_ = await ReceiveReplyAsync(ioBehavior, CancellationToken.None).ConfigureAwait(false);
2038+
2039+
if (!SupportsDeprecateEof)
2040+
{
2041+
payload = await ReceiveReplyAsync(ioBehavior, CancellationToken.None).ConfigureAwait(false);
2042+
_ = EofPayload.Create(payload.Span);
2043+
}
2044+
2045+
// first (and only) row
2046+
payload = await ReceiveReplyAsync(ioBehavior, CancellationToken.None).ConfigureAwait(false);
2047+
2048+
var reader = new ByteArrayReader(payload.Span);
2049+
var length = reader.ReadLengthEncodedIntegerOrNull();
2050+
var serverId = (length != -1 && Utf8Parser.TryParse(reader.ReadByteString(length), out long id, out _)) ? id : default(long?);
2051+
2052+
ServerUuid = null;
2053+
ServerId = serverId;
2054+
2055+
Log.RetrievedServerIdentification(m_logger, Id, null, serverId);
2056+
}
2057+
2058+
// OK/EOF payload
2059+
payload = await ReceiveReplyAsync(ioBehavior, CancellationToken.None).ConfigureAwait(false);
2060+
if (OkPayload.IsOk(payload.Span, this))
2061+
OkPayload.Verify(payload.Span, this);
2062+
else
2063+
EofPayload.Create(payload.Span);
2064+
}
2065+
catch (MySqlException ex)
2066+
{
2067+
Log.FailedToGetServerIdentification(m_logger, ex, Id);
2068+
}
2069+
}
2070+
19542071
private void ShutdownSocket()
19552072
{
19562073
Log.ClosingStreamSocket(m_logger, Id);
@@ -2182,6 +2299,10 @@ protected override void OnStatementBegin(int index)
21822299
private static readonly PayloadData s_sleepWithAttributesPayload = QueryPayload.Create(true, "SELECT SLEEP(0) INTO @__MySqlConnector__Sleep;"u8);
21832300
private static readonly PayloadData s_selectConnectionIdVersionNoAttributesPayload = QueryPayload.Create(false, "SELECT CONNECTION_ID(), VERSION();"u8);
21842301
private static readonly PayloadData s_selectConnectionIdVersionWithAttributesPayload = QueryPayload.Create(true, "SELECT CONNECTION_ID(), VERSION();"u8);
2302+
private static readonly PayloadData s_selectServerIdNoAttributesPayload = QueryPayload.Create(false, "SELECT @@server_uuid, @@server_id;"u8);
2303+
private static readonly PayloadData s_selectServerIdWithAttributesPayload = QueryPayload.Create(true, "SELECT @@server_uuid, @@server_id;"u8);
2304+
private static readonly PayloadData s_selectServerIdOnlyNoAttributesPayload = QueryPayload.Create(false, "SELECT @@server_id;"u8);
2305+
private static readonly PayloadData s_selectServerIdOnlyWithAttributesPayload = QueryPayload.Create(true, "SELECT @@server_id;"u8);
21852306

21862307
private readonly ILogger m_logger;
21872308
#if NET9_0_OR_GREATER

src/MySqlConnector/Core/ServerVersions.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,7 @@ internal static class ServerVersions
1919

2020
// https://mariadb.com/kb/en/set-statement/
2121
public static readonly Version MariaDbSupportsPerQueryVariables = new(10, 1, 2);
22+
23+
// https://dev.mysql.com/doc/refman/5.6/en/replication-options.html#sysvar_server_uuid
24+
public static readonly Version SupportsServerUuid = new(5, 6, 0);
2225
}

src/MySqlConnector/Logging/EventIds.cs

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -78,14 +78,17 @@ internal static class EventIds
7878
public const int DetectedProxy = 2150;
7979
public const int ChangingConnectionId = 2151;
8080
public const int FailedToGetConnectionId = 2152;
81-
public const int CreatingConnectionAttributes = 2153;
82-
public const int ObtainingPasswordViaProvidePasswordCallback = 2154;
83-
public const int FailedToObtainPassword = 2155;
84-
public const int ConnectedTlsBasicPreliminary = 2156;
85-
public const int ConnectedTlsDetailedPreliminary = 2157;
86-
public const int CertificateErrorUnixSocket = 2158;
87-
public const int CertificateErrorNoPassword = 2159;
88-
public const int CertificateErrorValidThumbprint = 2160;
81+
public const int GettingServerIdentification = 2153;
82+
public const int RetrievedServerIdentification = 2154;
83+
public const int FailedToGetServerIdentification = 2155;
84+
public const int CreatingConnectionAttributes = 2156;
85+
public const int ObtainingPasswordViaProvidePasswordCallback = 2157;
86+
public const int FailedToObtainPassword = 2158;
87+
public const int ConnectedTlsBasicPreliminary = 2159;
88+
public const int ConnectedTlsDetailedPreliminary = 2160;
89+
public const int CertificateErrorUnixSocket = 2161;
90+
public const int CertificateErrorNoPassword = 2162;
91+
public const int CertificateErrorValidThumbprint = 2163;
8992

9093
// Command execution events, 2200-2299
9194
public const int CannotExecuteNewCommandInState = 2200;
@@ -108,6 +111,8 @@ internal static class EventIds
108111
public const int IgnoringCancellationForInactiveCommand = 2306;
109112
public const int CancelingCommand = 2307;
110113
public const int SendingSleepToClearPendingCancellation = 2308;
114+
public const int IgnoringCancellationForDifferentServer = 2309;
115+
public const int NoServerIdentificationForVerification = 2310;
111116

112117
// Cached procedure events, 2400-2499
113118
public const int GettingCachedProcedure = 2400;

src/MySqlConnector/Logging/Log.cs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,21 @@ internal static partial class Log
189189
[LoggerMessage(EventIds.FailedToGetConnectionId, LogLevel.Information, "Session {SessionId} failed to get CONNECTION_ID(), VERSION()")]
190190
public static partial void FailedToGetConnectionId(ILogger logger, Exception exception, string sessionId);
191191

192+
[LoggerMessage(EventIds.GettingServerIdentification, LogLevel.Debug, "Session {SessionId} getting server identification")]
193+
public static partial void GettingServerIdentification(ILogger logger, string sessionId);
194+
195+
[LoggerMessage(EventIds.RetrievedServerIdentification, LogLevel.Debug, "Session {SessionId} retrieved server identification: UUID={ServerUuid}, ID={ServerId}")]
196+
public static partial void RetrievedServerIdentification(ILogger logger, string sessionId, string? serverUuid, long? serverId);
197+
198+
[LoggerMessage(EventIds.FailedToGetServerIdentification, LogLevel.Information, "Session {SessionId} failed to get server identification")]
199+
public static partial void FailedToGetServerIdentification(ILogger logger, Exception exception, string sessionId);
200+
201+
[LoggerMessage(EventIds.IgnoringCancellationForDifferentServer, LogLevel.Warning, "Session {SessionId} ignoring cancellation from session {KillSessionId}: server identity mismatch (this UUID={ServerUuid}, kill UUID={KillServerUuid}, this ID={ServerId}, kill ID={KillServerId})")]
202+
public static partial void IgnoringCancellationForDifferentServer(ILogger logger, string sessionId, string killSessionId, string? serverUuid, string? killServerUuid, long? serverId, long? killServerId);
203+
204+
[LoggerMessage(EventIds.NoServerIdentificationForVerification, LogLevel.Debug, "Session {SessionId} and kill session {KillSessionId} have no server identification available for verification")]
205+
public static partial void NoServerIdentificationForVerification(ILogger logger, string sessionId, string killSessionId);
206+
192207
[LoggerMessage(EventIds.ClosingStreamSocket, LogLevel.Debug, "Session {SessionId} closing stream/socket")]
193208
public static partial void ClosingStreamSocket(ILogger logger, string sessionId);
194209

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
using MySqlConnector.Core;
2+
using MySqlConnector.Logging;
3+
using Microsoft.Extensions.Logging.Abstractions;
4+
5+
namespace MySqlConnector.Tests;
6+
7+
public class ServerIdentificationTests
8+
{
9+
[Fact]
10+
public void VerifyServerIdentity_WithMatchingUuids_ReturnsTrue()
11+
{
12+
// Arrange
13+
var session1 = CreateServerSession();
14+
var session2 = CreateServerSession();
15+
session1.ServerUuid = "test-uuid-123";
16+
session1.ServerId = 1;
17+
session2.ServerUuid = "test-uuid-123";
18+
session2.ServerId = 2; // Different server ID, but UUIDs match
19+
20+
// Act
21+
bool result = InvokeVerifyServerIdentity(session1, session2);
22+
23+
// Assert
24+
Assert.True(result);
25+
}
26+
27+
[Fact]
28+
public void VerifyServerIdentity_WithDifferentUuids_ReturnsFalse()
29+
{
30+
// Arrange
31+
var session1 = CreateServerSession();
32+
var session2 = CreateServerSession();
33+
session1.ServerUuid = "test-uuid-123";
34+
session1.ServerId = 1;
35+
session2.ServerUuid = "test-uuid-456";
36+
session2.ServerId = 1; // Same server ID, but UUIDs don't match
37+
38+
// Act
39+
bool result = InvokeVerifyServerIdentity(session1, session2);
40+
41+
// Assert
42+
Assert.False(result);
43+
}
44+
45+
[Fact]
46+
public void VerifyServerIdentity_WithMatchingServerIds_ReturnsTrue()
47+
{
48+
// Arrange
49+
var session1 = CreateServerSession();
50+
var session2 = CreateServerSession();
51+
session1.ServerUuid = null; // No UUID available
52+
session1.ServerId = 1;
53+
session2.ServerUuid = null; // No UUID available
54+
session2.ServerId = 1;
55+
56+
// Act
57+
bool result = InvokeVerifyServerIdentity(session1, session2);
58+
59+
// Assert
60+
Assert.True(result);
61+
}
62+
63+
[Fact]
64+
public void VerifyServerIdentity_WithDifferentServerIds_ReturnsFalse()
65+
{
66+
// Arrange
67+
var session1 = CreateServerSession();
68+
var session2 = CreateServerSession();
69+
session1.ServerUuid = null; // No UUID available
70+
session1.ServerId = 1;
71+
session2.ServerUuid = null; // No UUID available
72+
session2.ServerId = 2;
73+
74+
// Act
75+
bool result = InvokeVerifyServerIdentity(session1, session2);
76+
77+
// Assert
78+
Assert.False(result);
79+
}
80+
81+
[Fact]
82+
public void VerifyServerIdentity_WithNoIdentification_ReturnsTrue()
83+
{
84+
// Arrange
85+
var session1 = CreateServerSession();
86+
var session2 = CreateServerSession();
87+
session1.ServerUuid = null;
88+
session1.ServerId = null;
89+
session2.ServerUuid = null;
90+
session2.ServerId = null;
91+
92+
// Act
93+
bool result = InvokeVerifyServerIdentity(session1, session2);
94+
95+
// Assert
96+
Assert.True(result); // Should allow operation for backward compatibility
97+
}
98+
99+
[Fact]
100+
public void VerifyServerIdentity_UuidTakesPrecedenceOverServerId()
101+
{
102+
// Arrange
103+
var session1 = CreateServerSession();
104+
var session2 = CreateServerSession();
105+
session1.ServerUuid = "test-uuid-123";
106+
session1.ServerId = 1;
107+
session2.ServerUuid = "test-uuid-456"; // Different UUID
108+
session2.ServerId = 1; // Same server ID
109+
110+
// Act
111+
bool result = InvokeVerifyServerIdentity(session1, session2);
112+
113+
// Assert
114+
Assert.False(result); // Should use UUID comparison, not server ID
115+
}
116+
117+
private static ServerSession CreateServerSession()
118+
{
119+
var pool = new TestConnectionPool();
120+
return new ServerSession(NullLogger.Instance, pool);
121+
}
122+
123+
private static bool InvokeVerifyServerIdentity(ServerSession session1, ServerSession session2)
124+
{
125+
// Use reflection to call the private VerifyServerIdentity method
126+
var method = typeof(ServerSession).GetMethod("VerifyServerIdentity",
127+
System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance);
128+
return (bool)method!.Invoke(session1, new object[] { session2 })!;
129+
}
130+
131+
private class TestConnectionPool : IConnectionPoolMetadata
132+
{
133+
public int Id => 1;
134+
public int Generation => 1;
135+
public ConnectionPool? ConnectionPool => null;
136+
public int GetNewSessionId() => 1;
137+
}
138+
}

0 commit comments

Comments
 (0)