Skip to content

Commit 77ad191

Browse files
committed
Make 'Context' immutable.
Move mutable properties back to ServerSession. Signed-off-by: Bradley Grainger <[email protected]>
1 parent ef40088 commit 77ad191

File tree

6 files changed

+50
-41
lines changed

6 files changed

+50
-41
lines changed

src/MySqlConnector/Core/ConnectionPool.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ public async ValueTask<ServerSession> GetSessionAsync(MySqlConnection connection
6565
}
6666
else
6767
{
68-
if (ConnectionSettings.ConnectionReset || !session.Context.IsInitialDatabase())
68+
if (ConnectionSettings.ConnectionReset || session.DatabaseOverride is not null)
6969
{
7070
if (timeoutMilliseconds != 0)
7171
session.SetTimeout(Math.Max(1, timeoutMilliseconds - Utility.GetElapsedMilliseconds(startingTimestamp)));

src/MySqlConnector/Core/Context.cs

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,26 +4,16 @@ namespace MySqlConnector.Core;
44

55
internal sealed class Context
66
{
7-
public Context(ProtocolCapabilities protocolCapabilities, string? database, int connectionId)
7+
public Context(ProtocolCapabilities protocolCapabilities)
88
{
99
SupportsDeprecateEof = (protocolCapabilities & ProtocolCapabilities.DeprecateEof) != 0;
1010
SupportsCachedPreparedMetadata = (protocolCapabilities & ProtocolCapabilities.MariaDbCacheMetadata) != 0;
1111
SupportsQueryAttributes = (protocolCapabilities & ProtocolCapabilities.QueryAttributes) != 0;
1212
SupportsSessionTrack = (protocolCapabilities & ProtocolCapabilities.SessionTrack) != 0;
13-
ConnectionId = connectionId;
14-
Database = database;
15-
m_initialDatabase = database;
1613
}
1714

1815
public bool SupportsDeprecateEof { get; }
1916
public bool SupportsQueryAttributes { get; }
2017
public bool SupportsSessionTrack { get; }
2118
public bool SupportsCachedPreparedMetadata { get; }
22-
public string? ClientCharset { get; set; }
23-
24-
public string? Database { get; set; }
25-
private readonly string? m_initialDatabase;
26-
public bool IsInitialDatabase() => string.Equals(m_initialDatabase, Database, StringComparison.Ordinal);
27-
28-
public int ConnectionId { get; set; }
2919
}

src/MySqlConnector/Core/ResultSet.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ public async Task ReadResultSetHeaderAsync(IOBehavior ioBehavior)
4848
if (ok.LastInsertId != 0)
4949
Command?.SetLastInsertedId((long) ok.LastInsertId);
5050
WarningCount = ok.WarningCount;
51+
if (ok.NewSchema is not null)
52+
Connection.Session.DatabaseOverride = ok.NewSchema;
5153
m_columnDefinitions = default;
5254
State = (ok.ServerStatus & ServerStatus.MoreResultsExist) == 0
5355
? ResultSetState.NoMoreData

src/MySqlConnector/Core/ServerSession.cs

Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -44,20 +44,22 @@ public ServerSession(ILogger logger, ConnectionPool? pool, int poolGeneration, i
4444
m_activityTags = [];
4545
DataReader = new();
4646
Log.CreatedNewSession(m_logger, Id);
47-
Context = new Context(0, null, 0);
47+
Context = new Context(default);
4848
}
4949

5050
public string Id { get; }
5151
public ServerVersion ServerVersion { get; set; }
5252
public bool SupportsPerQueryVariables => ServerVersion.IsMariaDb && ServerVersion.Version >= ServerVersions.MariaDbSupportsPerQueryVariables;
5353
public int ActiveCommandId { get; private set; }
5454
public int CancellationTimeout { get; private set; }
55+
public int ConnectionId { get; set; }
5556
public byte[]? AuthPluginData { get; set; }
5657
public long CreatedTimestamp { get; }
5758
public ConnectionPool? Pool { get; }
5859
public int PoolGeneration { get; }
5960
public long LastLeasedTimestamp { get; set; }
6061
public long LastReturnedTimestamp { get; private set; }
62+
public string? DatabaseOverride { get; set; }
6163

6264
public string HostName { get; private set; }
6365
public IPEndPoint? IPEndPoint => m_tcpClient?.Client.RemoteEndPoint as IPEndPoint;
@@ -338,8 +340,8 @@ public void FinishQuerying()
338340
var activity = ActivitySourceHelper.StartActivity(name, m_activityTags);
339341
if (activity is { IsAllDataRequested: true })
340342
{
341-
if (!Context.IsInitialDatabase())
342-
activity.SetTag(ActivitySourceHelper.DatabaseNameTagName, Context.Database);
343+
if (DatabaseOverride is not null)
344+
activity.SetTag(ActivitySourceHelper.DatabaseNameTagName, DatabaseOverride);
343345
if (tagName1 is not null)
344346
activity.SetTag(tagName1, tagValue1);
345347
}
@@ -452,15 +454,16 @@ public async Task DisposeAsync(IOBehavior ioBehavior, CancellationToken cancella
452454
}
453455

454456
ServerVersion = new(initialHandshake.ServerVersion);
455-
Context = new Context(initialHandshake.ProtocolCapabilities, cs.Database, initialHandshake.ConnectionId);
457+
ConnectionId = initialHandshake.ConnectionId;
458+
Context = new Context(initialHandshake.ProtocolCapabilities);
456459
AuthPluginData = initialHandshake.AuthPluginData;
457460
m_useCompression = cs.UseCompression && (initialHandshake.ProtocolCapabilities & ProtocolCapabilities.Compress) != 0;
458461
CancellationTimeout = cs.CancellationTimeout;
459462
UserID = cs.UserID;
460463

461464
// set activity tags
462465
{
463-
var connectionId = Context.ConnectionId.ToString(CultureInfo.InvariantCulture);
466+
var connectionId = ConnectionId.ToString(CultureInfo.InvariantCulture);
464467
m_activityTags[ActivitySourceHelper.DatabaseConnectionIdTagName] = connectionId;
465468
if (activity is { IsAllDataRequested: true })
466469
activity.SetTag(ActivitySourceHelper.DatabaseConnectionIdTagName, connectionId);
@@ -499,7 +502,7 @@ public async Task DisposeAsync(IOBehavior ioBehavior, CancellationToken cancella
499502
}
500503
}
501504

502-
Log.SessionMadeConnection(m_logger, Id, ServerVersion.OriginalString, Context.ConnectionId, m_useCompression, m_supportsConnectionAttributes, Context.SupportsDeprecateEof, Context.SupportsCachedPreparedMetadata, serverSupportsSsl, Context.SupportsSessionTrack, m_supportsPipelining, Context.SupportsQueryAttributes);
505+
Log.SessionMadeConnection(m_logger, Id, ServerVersion.OriginalString, ConnectionId, m_useCompression, m_supportsConnectionAttributes, Context.SupportsDeprecateEof, Context.SupportsCachedPreparedMetadata, serverSupportsSsl, Context.SupportsSessionTrack, m_supportsPipelining, Context.SupportsQueryAttributes);
503506

504507
if (cs.SslMode != MySqlSslMode.None && (cs.SslMode != MySqlSslMode.Preferred || serverSupportsSsl))
505508
{
@@ -532,18 +535,23 @@ public async Task DisposeAsync(IOBehavior ioBehavior, CancellationToken cancella
532535
if (m_useCompression)
533536
m_payloadHandler = new CompressedPayloadHandler(m_payloadHandler.ByteHandler);
534537

535-
// set 'collation_connection' to the server default
536-
if (Context.ClientCharset == null || ServerVersion.Version >= ServerVersions.SupportsUtf8Mb4
537-
? !string.Equals(Context.ClientCharset, "utf8mb4", StringComparison.Ordinal)
538-
: !string.Equals(Context.ClientCharset, "utf8", StringComparison.Ordinal))
538+
if (ok.ClientCharacterSet != (ServerVersion.Version >= ServerVersions.SupportsUtf8Mb4 ? "utf8mb4" : "utf8"))
539539
{
540+
// set 'collation_connection' to the server default
540541
await SendAsync(m_setNamesPayload, ioBehavior, cancellationToken).ConfigureAwait(false);
541542
payload = await ReceiveReplyAsync(ioBehavior, cancellationToken).ConfigureAwait(false);
542543
OkPayload.Verify(payload.Span, Context);
543544
}
544545

545546
if (ShouldGetRealServerDetails(cs))
547+
{
546548
await GetRealServerDetailsAsync(ioBehavior, CancellationToken.None).ConfigureAwait(false);
549+
}
550+
else if (ok.ConnectionId is int newConnectionId && newConnectionId != ConnectionId)
551+
{
552+
Log.ChangingConnectionId(m_logger, Id, ConnectionId, newConnectionId, ServerVersion.OriginalString, ServerVersion.OriginalString);
553+
ConnectionId = newConnectionId;
554+
}
547555

548556
m_payloadHandler.ByteHandler.RemainingTimeout = Constants.InfiniteTimeout;
549557
return statusInfo;
@@ -570,9 +578,9 @@ public async Task<bool> TryResetConnectionAsync(ConnectionSettings cs, MySqlConn
570578
ClearPreparedStatements();
571579

572580
PayloadData payload;
573-
if (Context.IsInitialDatabase() &&
574-
((!ServerVersion.IsMariaDb && ServerVersion.Version.CompareTo(ServerVersions.SupportsResetConnection) >= 0) ||
575-
(ServerVersion.IsMariaDb && ServerVersion.Version.CompareTo(ServerVersions.MariaDbSupportsResetConnection) >= 0)))
581+
if (DatabaseOverride is null &&
582+
((!ServerVersion.IsMariaDb && ServerVersion.Version.CompareTo(ServerVersions.SupportsResetConnection) >= 0) ||
583+
(ServerVersion.IsMariaDb && ServerVersion.Version.CompareTo(ServerVersions.MariaDbSupportsResetConnection) >= 0)))
576584
{
577585
if (m_supportsPipelining)
578586
{
@@ -599,14 +607,14 @@ public async Task<bool> TryResetConnectionAsync(ConnectionSettings cs, MySqlConn
599607
else
600608
{
601609
// optimistically hash the password with the challenge from the initial handshake (supported by MariaDB; doesn't appear to be supported by MySQL)
602-
if (Context.IsInitialDatabase())
610+
if (DatabaseOverride is null)
603611
{
604612
Log.SendingChangeUserRequest(m_logger, Id, ServerVersion.OriginalString);
605613
}
606614
else
607615
{
608-
Log.SendingChangeUserRequestDueToChangedDatabase(m_logger, Id, Context.Database!);
609-
Context.Database = cs.Database;
616+
Log.SendingChangeUserRequestDueToChangedDatabase(m_logger, Id, DatabaseOverride);
617+
DatabaseOverride = null;
610618
}
611619
var password = GetPassword(cs, connection);
612620
var hashedPassword = AuthenticationUtility.CreateAuthenticationResponse(AuthPluginData!, password);
@@ -1668,8 +1676,8 @@ static void ReadRow(ReadOnlySpan<byte> span, out int? connectionId, out ServerVe
16681676

16691677
if (connectionId is int newConnectionId && serverVersion is not null)
16701678
{
1671-
Log.ChangingConnectionId(m_logger, Id, Context.ConnectionId, newConnectionId, ServerVersion.OriginalString, serverVersion.OriginalString);
1672-
Context.ConnectionId = newConnectionId;
1679+
Log.ChangingConnectionId(m_logger, Id, ConnectionId, newConnectionId, ServerVersion.OriginalString, serverVersion.OriginalString);
1680+
ConnectionId = newConnectionId;
16731681
ServerVersion = serverVersion;
16741682
}
16751683
}

src/MySqlConnector/MySqlConnection.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -490,7 +490,7 @@ private async Task ChangeDatabaseAsync(IOBehavior ioBehavior, string databaseNam
490490
OkPayload.Verify(payload.Span, m_session.Context);
491491

492492
// for non session tracking servers
493-
m_session.Context.Database = databaseName;
493+
m_session.DatabaseOverride = databaseName;
494494
}
495495

496496
public new MySqlCommand CreateCommand() => (MySqlCommand) base.CreateCommand();
@@ -628,7 +628,7 @@ public override string ConnectionString
628628
}
629629
}
630630

631-
public override string Database => m_session?.Context.Database ?? GetConnectionSettings().Database;
631+
public override string Database => m_session?.DatabaseOverride ?? GetConnectionSettings().Database;
632632

633633
public override ConnectionState State => m_connectionState;
634634

@@ -639,7 +639,7 @@ public override string ConnectionString
639639
/// <summary>
640640
/// The connection ID from MySQL Server.
641641
/// </summary>
642-
public int ServerThread => Session.Context.ConnectionId;
642+
public int ServerThread => Session.ConnectionId;
643643

644644
/// <summary>
645645
/// Gets or sets the delegate used to provide client certificates for connecting to a server.

src/MySqlConnector/Protocol/Payloads/OkPayload.cs

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@ internal sealed class OkPayload
1313
public ServerStatus ServerStatus { get; }
1414
public int WarningCount { get; }
1515
public string? StatusInfo { get; }
16+
public string? NewSchema { get; }
17+
public string? ClientCharacterSet { get; }
18+
public int? ConnectionId { get; }
1619

1720
public const byte Signature = 0x00;
1821

@@ -57,6 +60,9 @@ public static void Verify(ReadOnlySpan<byte> span, Context context) =>
5760
var lastInsertId = reader.ReadLengthEncodedInteger();
5861
var serverStatus = (ServerStatus) reader.ReadUInt16();
5962
var warningCount = (int) reader.ReadUInt16();
63+
string? newSchema = null;
64+
string? clientCharacterSet = null;
65+
int? connectionId = null;
6066
ReadOnlySpan<byte> statusBytes;
6167

6268
if (context.SupportsSessionTrack)
@@ -75,7 +81,7 @@ public static void Verify(ReadOnlySpan<byte> span, Context context) =>
7581
switch (kind)
7682
{
7783
case SessionTrackKind.Schema:
78-
context.Database = Encoding.UTF8.GetString(reader.ReadLengthEncodedByteString());
84+
newSchema = Encoding.UTF8.GetString(reader.ReadLengthEncodedByteString());
7985
break;
8086

8187
case SessionTrackKind.SystemVariables:
@@ -90,10 +96,10 @@ public static void Verify(ReadOnlySpan<byte> span, Context context) =>
9096
switch (variableSv)
9197
{
9298
case "character_set_client":
93-
context.ClientCharset = valueSv;
99+
clientCharacterSet = valueSv;
94100
break;
95101
case "connection_id":
96-
context.ConnectionId = Convert.ToInt32(valueSv, CultureInfo.InvariantCulture);
102+
connectionId = Convert.ToInt32(valueSv, CultureInfo.InvariantCulture);
97103
break;
98104
}
99105
} while (reader.Offset < systemVariableOffset);
@@ -126,31 +132,34 @@ public static void Verify(ReadOnlySpan<byte> span, Context context) =>
126132
{
127133
var statusInfo = statusBytes.Length == 0 ? null : Encoding.UTF8.GetString(statusBytes);
128134

129-
if (affectedRowCount == 0 && lastInsertId == 0 && warningCount == 0 && statusInfo is null)
135+
if (affectedRowCount == 0 && lastInsertId == 0 && warningCount == 0 && statusInfo is null && newSchema is null && clientCharacterSet is null && connectionId is null)
130136
{
131137
if (serverStatus == ServerStatus.AutoCommit)
132138
return s_autoCommitOk;
133139
if (serverStatus == (ServerStatus.AutoCommit | ServerStatus.SessionStateChanged))
134140
return s_autoCommitSessionStateChangedOk;
135141
}
136142

137-
return new OkPayload(affectedRowCount, lastInsertId, serverStatus, warningCount, statusInfo);
143+
return new OkPayload(affectedRowCount, lastInsertId, serverStatus, warningCount, statusInfo, newSchema, clientCharacterSet, connectionId);
138144
}
139145
else
140146
{
141147
return null;
142148
}
143149
}
144150

145-
private OkPayload(ulong affectedRowCount, ulong lastInsertId, ServerStatus serverStatus, int warningCount, string? statusInfo)
151+
private OkPayload(ulong affectedRowCount, ulong lastInsertId, ServerStatus serverStatus, int warningCount, string? statusInfo, string? newSchema, string? clientCharacterSet, int? connectionId)
146152
{
147153
AffectedRowCount = affectedRowCount;
148154
LastInsertId = lastInsertId;
149155
ServerStatus = serverStatus;
150156
WarningCount = warningCount;
151157
StatusInfo = statusInfo;
158+
NewSchema = newSchema;
159+
ClientCharacterSet = clientCharacterSet;
160+
ConnectionId = connectionId;
152161
}
153162

154-
private static readonly OkPayload s_autoCommitOk = new(0, 0, ServerStatus.AutoCommit, 0, null);
155-
private static readonly OkPayload s_autoCommitSessionStateChangedOk = new(0, 0, ServerStatus.AutoCommit | ServerStatus.SessionStateChanged, 0, null);
163+
private static readonly OkPayload s_autoCommitOk = new(0, 0, ServerStatus.AutoCommit, 0, default, default, default, default);
164+
private static readonly OkPayload s_autoCommitSessionStateChangedOk = new(0, 0, ServerStatus.AutoCommit | ServerStatus.SessionStateChanged, 0, default, default, default, default);
156165
}

0 commit comments

Comments
 (0)