Skip to content

Commit a12efe0

Browse files
tmasternakMarcWils
andauthored
Use SQL system catalog views to check for the presence of a Recoverable column. This removes the need for SELECT permissions to send a message to a queue table. (#1451)
Co-authored-by: Marc Wils <[email protected]>
1 parent 95b66cc commit a12efe0

File tree

11 files changed

+50
-40
lines changed

11 files changed

+50
-40
lines changed

src/NServiceBus.Transport.SqlServer.AcceptanceTests/NativeTimeouts/When_configured_to_purge_expired_messages_at_startup.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ async Task SetupInputQueue()
6868
(address, isStreamSupported) =>
6969
{
7070
var canonicalAddress = addressTranslator.Parse(address);
71-
return new SqlTableBasedQueue(sqlConstants, canonicalAddress.QualifiedTableName, canonicalAddress.Address, isStreamSupported);
71+
return new SqlTableBasedQueue(sqlConstants, canonicalAddress, canonicalAddress.Address, isStreamSupported);
7272
},
7373
s => addressTranslator.Parse(s).Address,
7474
true);

src/NServiceBus.Transport.SqlServer.IntegrationTests/When_checking_schema.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ public async Task SetUp()
2424

2525
await ResetQueue(addressParser, dbConnectionFactory);
2626

27-
queue = new SqlTableBasedQueue(sqlConstants, addressParser.Parse(QueueTableName).QualifiedTableName, QueueTableName, false);
27+
queue = new SqlTableBasedQueue(sqlConstants, addressParser.Parse(QueueTableName), QueueTableName, false);
2828
}
2929

3030
[Test]

src/NServiceBus.Transport.SqlServer.IntegrationTests/When_dispatching_messages.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ async Task PrepareAsync(CancellationToken cancellationToken = default)
107107
(address, isStreamSupported) =>
108108
{
109109
var canonicalAddress = addressTranslator.Parse(address);
110-
return new SqlTableBasedQueue(sqlConstants, canonicalAddress.QualifiedTableName, canonicalAddress.Address, isStreamSupported);
110+
return new SqlTableBasedQueue(sqlConstants, canonicalAddress, canonicalAddress.Address, isStreamSupported);
111111
},
112112
s => addressTranslator.Parse(s).Address,
113113
true);
@@ -122,7 +122,7 @@ async Task PrepareAsync(CancellationToken cancellationToken = default)
122122
Task PurgeOutputQueue(QueueAddressTranslator addressTranslator, CancellationToken cancellationToken = default)
123123
{
124124
purger = new QueuePurger(dbConnectionFactory);
125-
var queueAddress = addressTranslator.Parse(ValidAddress).QualifiedTableName;
125+
var queueAddress = addressTranslator.Parse(ValidAddress);
126126
queue = new SqlTableBasedQueue(sqlConstants, queueAddress, ValidAddress, true);
127127

128128
return purger.Purge(queue, cancellationToken);

src/NServiceBus.Transport.SqlServer.IntegrationTests/When_message_receive_takes_long.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ public async Task SetUp()
2929

3030
await CreateQueueIfNotExists(addressParser, dbConnectionFactory);
3131

32-
queue = new SqlTableBasedQueue(sqlConstants, addressParser.Parse(QueueTableName).QualifiedTableName, QueueTableName, true);
32+
queue = new SqlTableBasedQueue(sqlConstants, addressParser.Parse(QueueTableName), QueueTableName, true);
3333
}
3434

3535
[Test]

src/NServiceBus.Transport.SqlServer.IntegrationTests/When_receiving_messages.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ public async Task Should_stop_receiving_messages_after_first_unsuccessful_receiv
4343
transport.Testing.QueueFactoryOverride = qa =>
4444
qa == inputQueueAddress
4545
? inputQueue
46-
: new SqlTableBasedQueue(sqlConstants, parser.Parse(qa).QualifiedTableName, qa, true);
46+
: new SqlTableBasedQueue(sqlConstants, parser.Parse(qa), qa, true);
4747

4848
var receiveSettings = new ReceiveSettings("receiver", new Transport.QueueAddress(inputQueueName), true, false, "error");
4949
var hostSettings = new HostSettings("IntegrationTests", string.Empty, new StartupDiagnosticEntries(),
@@ -95,7 +95,7 @@ class FakeTableBasedQueue : SqlTableBasedQueue
9595
int queueSize;
9696
int successfulReceives;
9797

98-
public FakeTableBasedQueue(SqlServerConstants sqlConstants, string address, int queueSize, int successfulReceives) : base(sqlConstants, address, "", true)
98+
public FakeTableBasedQueue(SqlServerConstants sqlConstants, string address, int queueSize, int successfulReceives) : base(sqlConstants, new QueueAddressTranslator("nservicebus", "dbo", null, new QueueSchemaAndCatalogOptions()).Parse(address), "", true)
9999
{
100100
this.queueSize = queueSize;
101101
this.successfulReceives = successfulReceives;

src/NServiceBus.Transport.SqlServer.IntegrationTests/When_recoverable_column_is_removed.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ public async Task Should_recover(Type contextProviderType, DispatchConsistency d
4343
(address, isStreamSupported) =>
4444
{
4545
var canonicalAddress = addressTranslator.Parse(address);
46-
return new SqlTableBasedQueue(sqlConstants, canonicalAddress.QualifiedTableName, canonicalAddress.Address, isStreamSupported);
46+
return new SqlTableBasedQueue(sqlConstants, canonicalAddress, canonicalAddress.Address, isStreamSupported);
4747
},
4848
s => addressTranslator.Parse(s).Address,
4949
true);

src/NServiceBus.Transport.SqlServer.IntegrationTests/When_using_ttbr.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ async Task PrepareAsync(CancellationToken cancellationToken = default)
127127
(address, isStreamSupported) =>
128128
{
129129
var canonicalAddress = addressTranslator.Parse(address);
130-
return new SqlTableBasedQueue(sqlConstants, canonicalAddress.QualifiedTableName, canonicalAddress.Address, isStreamSupported);
130+
return new SqlTableBasedQueue(sqlConstants, canonicalAddress, canonicalAddress.Address, isStreamSupported);
131131
},
132132
s => addressTranslator.Parse(s).Address,
133133
true);
@@ -147,7 +147,7 @@ Task PurgeOutputQueue(QueueAddressTranslator addressParser, CancellationToken ca
147147
{
148148
purger = new QueuePurger(dbConnectionFactory);
149149
var queueAddress = addressParser.Parse(ValidAddress);
150-
queue = new SqlTableBasedQueue(sqlConstants, queueAddress.QualifiedTableName, queueAddress.Address, true);
150+
queue = new SqlTableBasedQueue(sqlConstants, queueAddress, queueAddress.Address, true);
151151

152152
return purger.Purge(queue, cancellationToken);
153153
}

src/NServiceBus.Transport.SqlServer.TransportTests/When_receive_takes_long_to_complete.cs

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ namespace NServiceBus.TransportTests;
44
using System;
55
using System.Threading;
66
using System.Threading.Tasks;
7+
using Microsoft.Data.SqlClient;
78
using NUnit.Framework;
89
using Transport;
910
using Transport.SqlServer;
@@ -44,15 +45,15 @@ public async Task Peeker_should_provide_accurate_queue_length_estimate(Transport
4445
Assert.That(peekCount, Is.EqualTo(1), "A long running receive transaction should not skew the estimation for number of messages in the queue.");
4546
}
4647

47-
static async Task<SqlTableBasedQueue> CreateATestQueue(SqlServerDbConnectionFactory connectionFactory)
48+
async Task<SqlTableBasedQueue> CreateATestQueue(SqlServerDbConnectionFactory connectionFactory)
4849
{
4950
var queueName = "queue_length_estimation_test";
5051

5152
var sqlConstants = new SqlServerConstants();
5253

53-
var queue = new SqlTableBasedQueue(sqlConstants, queueName, queueName, false);
54+
var queue = new SqlTableBasedQueue(sqlConstants, new CanonicalQueueAddress(queueName, "dbo", catalogName), queueName, false);
5455

55-
var addressTranslator = new QueueAddressTranslator("nservicebus", "dbo", null, null);
56+
var addressTranslator = new QueueAddressTranslator(catalogName, "dbo", null, null);
5657
var queueCreator = new QueueCreator(sqlConstants, connectionFactory, addressTranslator.Parse, false);
5758

5859
await queueCreator.CreateQueueIfNecessary(new[] { queueName }, null);
@@ -98,6 +99,11 @@ await queue.Send(
9899
[SetUp]
99100
public async Task Setup()
100101
{
102+
var connectionString = ConfigureSqlServerTransportInfrastructure.ConnectionString;
103+
var connectionStringBuilder = new SqlConnectionStringBuilder(connectionString);
104+
105+
catalogName = connectionStringBuilder.InitialCatalog;
106+
101107
connectionFactory = new SqlServerDbConnectionFactory(ConfigureSqlServerTransportInfrastructure.ConnectionString);
102108

103109
queue = await CreateATestQueue(connectionFactory);
@@ -119,6 +125,7 @@ public async Task TearDown()
119125
await comm.ExecuteNonQueryAsync(CancellationToken.None);
120126
}
121127

128+
string catalogName;
122129
SqlTableBasedQueue queue;
123130
SqlServerDbConnectionFactory connectionFactory;
124131
}

src/NServiceBus.Transport.SqlServer/Queuing/SqlServerConstants.cs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,11 @@ THEN DATEADD(ms, @TimeToBeReceivedMs, GETUTCDATE()) END,
5050
IF (@NOCOUNT = 'ON') SET NOCOUNT ON;
5151
IF (@NOCOUNT = 'OFF') SET NOCOUNT OFF;";
5252

53-
public string CheckIfTableHasRecoverableText { get; set; } = "SELECT TOP (0) * FROM {0} WITH (NOLOCK);";
53+
public string CheckIfTableHasRecoverableText { get; set; } = @"
54+
SELECT COUNT(*)
55+
FROM {0}.sys.columns c
56+
WHERE c.object_id = OBJECT_ID(N'{1}')
57+
AND c.name = 'Recoverable'";
5458

5559
public string StoreDelayedMessageText { get; set; } =
5660
@"

src/NServiceBus.Transport.SqlServer/Queuing/SqlTableBasedQueue.cs

Lines changed: 23 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,16 @@
1313

1414
class SqlTableBasedQueue : TableBasedQueue
1515
{
16-
public SqlTableBasedQueue(SqlServerConstants sqlConstants, string qualifiedTableName, string queueName, bool isStreamSupported) :
17-
base(sqlConstants, qualifiedTableName, queueName, isStreamSupported)
16+
public SqlTableBasedQueue(SqlServerConstants sqlConstants, CanonicalQueueAddress queueAddress, string queueName, bool isStreamSupported) :
17+
base(sqlConstants, queueAddress.QualifiedTableName, queueName, isStreamSupported)
1818
{
1919
sqlServerConstants = sqlConstants;
2020

21-
purgeExpiredCommand = Format(sqlConstants.PurgeBatchOfExpiredMessagesText, this.qualifiedTableName);
22-
checkExpiresIndexCommand = Format(sqlConstants.CheckIfExpiresIndexIsPresent, this.qualifiedTableName);
23-
checkNonClusteredRowVersionIndexCommand = Format(sqlConstants.CheckIfNonClusteredRowVersionIndexIsPresent, this.qualifiedTableName);
24-
checkHeadersColumnTypeCommand = Format(sqlConstants.CheckHeadersColumnType, this.qualifiedTableName);
21+
purgeExpiredCommand = Format(sqlConstants.PurgeBatchOfExpiredMessagesText, qualifiedTableName);
22+
checkExpiresIndexCommand = Format(sqlConstants.CheckIfExpiresIndexIsPresent, qualifiedTableName);
23+
checkNonClusteredRowVersionIndexCommand = Format(sqlConstants.CheckIfNonClusteredRowVersionIndexIsPresent, qualifiedTableName);
24+
checkHeadersColumnTypeCommand = Format(sqlConstants.CheckHeadersColumnType, qualifiedTableName);
25+
checkRecoverableColumnCommand = Format(sqlConstants.CheckIfTableHasRecoverableText, queueAddress.Catalog, qualifiedTableName);
2526
}
2627

2728
public async Task<int> PurgeBatchOfExpiredMessages(DbConnection connection, int purgeBatchSize, CancellationToken cancellationToken = default)
@@ -85,7 +86,7 @@ protected override async Task SendRawMessage(MessageRow message, DbConnection co
8586

8687
message.PrepareSendCommand(command);
8788

88-
await command.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false);
89+
_ = await command.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false);
8990
}
9091
}
9192
// 207 = Invalid column name
@@ -123,27 +124,24 @@ async Task<string> GetSendCommandText(DbConnection connection, DbTransaction tra
123124
return sendCommand;
124125
}
125126

126-
var commandText = Format(sqlServerConstants.CheckIfTableHasRecoverableText, qualifiedTableName);
127127
using (var command = connection.CreateCommand())
128128
{
129+
command.CommandText = checkRecoverableColumnCommand;
129130
command.CommandType = CommandType.Text;
130-
command.CommandText = commandText;
131131
command.Transaction = transaction;
132132

133-
using (var reader = await command.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false))
133+
var rowsCount = await command.ExecuteScalarAsync<int>(nameof(checkRecoverableColumnCommand), cancellationToken).ConfigureAwait(false);
134+
if (rowsCount > 0)
134135
{
135-
for (int fieldIndex = 0; fieldIndex < reader.FieldCount; fieldIndex++)
136-
{
137-
if (string.Equals("Recoverable", reader.GetName(fieldIndex), StringComparison.OrdinalIgnoreCase))
138-
{
139-
cachedSendCommand = Format(sqlServerConstants.SendTextWithRecoverable, qualifiedTableName);
140-
return cachedSendCommand;
141-
}
142-
}
136+
cachedSendCommand = Format(sqlServerConstants.SendTextWithRecoverable, qualifiedTableName);
137+
return cachedSendCommand;
143138
}
139+
else
140+
{
144141

145-
cachedSendCommand = Format(sqlServerConstants.SendTextWithoutRecoverable, qualifiedTableName);
146-
return cachedSendCommand;
142+
cachedSendCommand = Format(sqlServerConstants.SendTextWithoutRecoverable, qualifiedTableName);
143+
return cachedSendCommand;
144+
}
147145
}
148146
}
149147
finally
@@ -153,10 +151,11 @@ async Task<string> GetSendCommandText(DbConnection connection, DbTransaction tra
153151
}
154152

155153
string cachedSendCommand;
156-
string purgeExpiredCommand;
157-
string checkExpiresIndexCommand;
158-
string checkNonClusteredRowVersionIndexCommand;
159-
string checkHeadersColumnTypeCommand;
154+
readonly string purgeExpiredCommand;
155+
readonly string checkExpiresIndexCommand;
156+
readonly string checkNonClusteredRowVersionIndexCommand;
157+
readonly string checkHeadersColumnTypeCommand;
158+
readonly string checkRecoverableColumnCommand;
160159
readonly SemaphoreSlim sendCommandLock = new SemaphoreSlim(1, 1);
161160
readonly SqlServerConstants sqlServerConstants;
162161
}

0 commit comments

Comments
 (0)