Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -92,25 +92,22 @@ public static void SqlLocalDbSharedInstanceConnectionTest()

#region NamedPipeTests

[Fact]
[ActiveIssue("20245")] //pending pipeline configuration
[ConditionalFact(nameof(IsLocalDBEnvironmentSet))]
public static void SqlLocalDbNamedPipeConnectionTest()
{
ConnectionTest(s_localDbNamedPipeConnectionString);
}

[Fact]
[ActiveIssue("20245")] //pending pipeline configuration
public static void LocalDBNamedPipeEncryptionNotSupportedTest()
[ConditionalFact(nameof(IsLocalDBEnvironmentSet))]
public static void LocalDbNamedPipeEncryptionNotSupportedTest()
{
// Encryption is not supported by SQL Local DB.
// But connection should succeed as encryption is disabled by driver.
ConnectionWithEncryptionTest(s_localDbNamedPipeConnectionString);
}

[Fact]
[ActiveIssue("20245")] //pending pipeline configuration
public static void LocalDBNamepipeMarsTest()
[ConditionalFact(nameof(IsLocalDBEnvironmentSet))]
public static void LocalDbNamedPipeMarsTest()
{
ConnectionWithMarsTest(s_localDbNamedPipeConnectionString);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,34 +53,42 @@ public void TestMain()
}

[ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup), nameof(DataTestUtility.IsNotAzureServer))]
[ActiveIssue("5531")]
public void TestPacketNumberWraparound()
{
// this test uses a specifically crafted sql record enumerator and data to put the TdsParserStateObject.WritePacket(byte,bool)
// into a state where it can't differentiate between a packet in the middle of a large packet-set after a byte counter wraparound
// and the first packet of the connection and in doing so trips over a check for packet length from the input which has been
// forced to tell it that there is no output buffer space left, this causes an uncancellable infinite loop

// if the enumerator is completely read to the end then the bug is no longer present and the packet creation task returns,
// if the timeout occurs it is probable (but not absolute) that the write is stuck

public async Task TestPacketNumberWraparound()
{
// This test uses a specifically crafted SQL record enumerator and data to put the
// TdsParserStateObject.WritePacket(byte,bool) into a state where it can't
// differentiate between a packet in the middle of a large packet-set after a byte
// counter wraparound and the first packet of the connection and in doing so trips over
// a check for packet length from the input which has been forced to tell it that there
// is no output buffer space left, this causes an uncancellable infinite loop.
//
// If the enumerator is completely read to the end then the bug is no longer present
// and the packet creation task returns, if the timeout occurs it is probable (but not
// absolute) that the write operation is stuck.

// Arrange
var enumerator = new WraparoundRowEnumerator(1000000);
using var cancellationTokenSource = new CancellationTokenSource();

// Act
Stopwatch stopwatch = new();
stopwatch.Start();
int returned = Task.WaitAny(
Task.Factory.StartNew(
() => RunPacketNumberWraparound(enumerator),
TaskCreationOptions.DenyChildAttach | TaskCreationOptions.LongRunning
),
Task.Delay(TimeSpan.FromSeconds(60))
);

Task actionTask = Task.Factory.StartNew(
async () => await RunPacketNumberWraparound(enumerator, cancellationTokenSource.Token),
TaskCreationOptions.DenyChildAttach | TaskCreationOptions.LongRunning);
Task timeoutTask = Task.Delay(TimeSpan.FromSeconds(60), cancellationTokenSource.Token);
await Task.WhenAny(actionTask, timeoutTask);

stopwatch.Stop();
if (enumerator.MaxCount != enumerator.Count)
{
Console.WriteLine($"enumerator.Count={enumerator.Count}, enumerator.MaxCount={enumerator.MaxCount}, elapsed={stopwatch.Elapsed.TotalSeconds}");
}
Assert.True(enumerator.MaxCount == enumerator.Count);
cancellationTokenSource.Cancel();

// Assert
Assert.True(
enumerator.MaxCount == enumerator.Count,
$"enumerator.Count={enumerator.Count}, " +
$"enumerator.MaxCount={enumerator.MaxCount}, " +
$"elapsed={stopwatch.Elapsed}");
}

[ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup), nameof(DataTestUtility.IsNotAzureServer))]
Expand Down Expand Up @@ -688,23 +696,27 @@ private void QueryHintsTest()
}
}

private static async Task RunPacketNumberWraparound(WraparoundRowEnumerator enumerator)
private static async Task RunPacketNumberWraparound(
WraparoundRowEnumerator enumerator,
CancellationToken cancellationToken)
{
using var connection = new SqlConnection(DataTestUtility.TCPConnectionString);
using var cmd = new SqlCommand("unimportant")
{
CommandType = CommandType.StoredProcedure,
Connection = connection,
};
await cmd.Connection.OpenAsync();
cmd.Parameters.Add(new SqlParameter("@rows", SqlDbType.Structured)
await connection.OpenAsync(cancellationToken);

using var cmd = connection.CreateCommand();
cmd.CommandType = CommandType.StoredProcedure;
cmd.CommandText = "unimportant";

var parameter = new SqlParameter("@rows", SqlDbType.Structured)
{
TypeName = "unimportant",
Value = enumerator,
});
Value = enumerator
};
cmd.Parameters.Add(parameter);

try
{
await cmd.ExecuteNonQueryAsync();
await cmd.ExecuteNonQueryAsync(cancellationToken);
}
catch (Exception)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,41 +7,10 @@
using System.Linq;
using System.Text;
using System.Text.RegularExpressions;
using Xunit;

namespace Microsoft.Data.SqlClient.ManualTesting.Tests
{
/// Define the SQL command type by filtering purpose.
[Flags]
public enum FilterSqlStatements
{
/// Don't filter any SQL commands
None = 0,
/// Filter INSERT or INSERT INTO
Insert = 1,
/// Filter UPDATE
Update = 2,
/// Filter DELETE
Delete = 1 << 2,
/// Filter EXECUTE or EXEC
Execute = 1 << 3,
/// Filter ALTER
Alter = 1 << 4,
/// Filter CREATE
Create = 1 << 5,
/// Filter DROP
Drop = 1 << 6,
/// Filter TRUNCATE
Truncate = 1 << 7,
/// Filter SELECT
Select = 1 << 8,
/// Filter data manipulation commands consist of INSERT, INSERT INTO, UPDATE, and DELETE
DML = Insert | Update | Delete | Truncate,
/// Filter data definition commands consist of ALTER, CREATE, and DROP
DDL = Alter | Create | Drop,
/// Filter any SQL command types
All = DML | DDL | Execute | Select
}

public class RetryLogicTestHelper
{
private static readonly HashSet<int> s_defaultTransientErrors
Expand Down Expand Up @@ -75,152 +44,87 @@ private static readonly HashSet<int> s_defaultTransientErrors
18456 // Using managed identity in Azure Sql Server throws 18456 for non-existent database instead of 4060.
};

public static readonly Regex FilterDmlStatements = new Regex(
@"\b(INSERT( +INTO)|UPDATE|DELETE|TRUNCATE)\b",
RegexOptions.Compiled | RegexOptions.IgnoreCase);

internal static readonly string s_exceedErrMsgPattern = SystemDataResourceManager.Instance.SqlRetryLogic_RetryExceeded;
internal static readonly string s_cancelErrMsgPattern = SystemDataResourceManager.Instance.SqlRetryLogic_RetryCanceled;

public static IEnumerable<object[]> GetConnectionStrings()
public static TheoryData<string, SqlRetryLogicBaseProvider> GetConnectionStringAndRetryProviders(
int numberOfRetries,
TimeSpan maxInterval,
TimeSpan? deltaTime = null,
IEnumerable<int> transientErrorCodes = null,
Regex unauthorizedStatementRegex = null)
{
var builder = new SqlConnectionStringBuilder();

foreach (var cnnString in DataTestUtility.GetConnectionStrings(withEnclave: false))
{
builder.Clear();
builder.ConnectionString = cnnString;
builder.ConnectTimeout = 5;
builder.Pooling = false;
yield return new object[] { builder.ConnectionString };

builder.Pooling = true;
yield return new object[] { builder.ConnectionString };
}
}

public static IEnumerable<object[]> GetConnectionAndRetryStrategy(int numberOfRetries,
TimeSpan maxInterval,
FilterSqlStatements unauthorizedStatemets,
IEnumerable<int> transientErrors,
int deltaTimeMillisecond = 10,
bool custom = true)
{
var option = new SqlRetryLogicOption()
var option = new SqlRetryLogicOption
{
NumberOfTries = numberOfRetries,
DeltaTime = TimeSpan.FromMilliseconds(deltaTimeMillisecond),
DeltaTime = deltaTime ?? TimeSpan.FromMilliseconds(10),
MaxTimeInterval = maxInterval,
TransientErrors = transientErrors ?? (custom ? s_defaultTransientErrors : null),
AuthorizedSqlCondition = custom ? RetryPreConditon(unauthorizedStatemets) : null
TransientErrors = transientErrorCodes ?? s_defaultTransientErrors,
AuthorizedSqlCondition = RetryPreCondition(unauthorizedStatementRegex)
};

foreach (var item in GetRetryStrategies(option))
foreach (var cnn in GetConnectionStrings())
yield return new object[] { cnn[0], item[0] };
}

public static IEnumerable<object[]> GetConnectionAndRetryStrategyInvalidCatalog(int numberOfRetries)
{
return GetConnectionAndRetryStrategy(numberOfRetries, TimeSpan.FromSeconds(1), FilterSqlStatements.None, null, 250, true);
}
var result = new TheoryData<string, SqlRetryLogicBaseProvider>();
foreach (var connectionString in GetConnectionStringsTyped())
{
foreach (var retryProvider in GetRetryStrategiesTyped(option))
{
result.Add(connectionString, retryProvider);
}
}

public static IEnumerable<object[]> GetConnectionAndRetryStrategyInvalidCommand(int numberOfRetries)
{
return GetConnectionAndRetryStrategy(numberOfRetries, TimeSpan.FromMilliseconds(100), FilterSqlStatements.None, null);
return result;
}

public static IEnumerable<object[]> GetConnectionAndRetryStrategyFilterDMLStatements(int numberOfRetries)
{
return GetConnectionAndRetryStrategy(numberOfRetries, TimeSpan.FromMilliseconds(100), FilterSqlStatements.DML, new int[] { 207, 102, 2812 });
}
public static TheoryData<string, SqlRetryLogicBaseProvider> GetNonRetriableCases() =>
new TheoryData<string, SqlRetryLogicBaseProvider>
{
{ DataTestUtility.TCPConnectionString, null },
{ DataTestUtility.TCPConnectionString, SqlConfigurableRetryFactory.CreateNoneRetryProvider() }
};

//40613: Database '%.*ls' on server '%.*ls' is not currently available. Please retry the connection later. If the problem persists, contact customer support, and provide them the session tracing ID of '%.*ls'.
public static IEnumerable<object[]> GetConnectionAndRetryStrategyLongRunner(int numberOfRetries)
private static IEnumerable<string> GetConnectionStringsTyped()
{
return GetConnectionAndRetryStrategy(numberOfRetries, TimeSpan.FromSeconds(120), FilterSqlStatements.None, null, 20 * 1000);
}
var builder = new SqlConnectionStringBuilder();
foreach (var connectionString in DataTestUtility.GetConnectionStrings(withEnclave: false))
{
builder.Clear();
builder.ConnectionString = connectionString;
builder.ConnectTimeout = 5;
builder.Pooling = false;
yield return builder.ConnectionString;

public static IEnumerable<object[]> GetConnectionAndRetryStrategyDropDB(int numberOfRetries)
{
List<int> faults = s_defaultTransientErrors.ToList();
faults.Add(3702); // Cannot drop database because it is currently in use.
return GetConnectionAndRetryStrategy(numberOfRetries, TimeSpan.FromMilliseconds(2000), FilterSqlStatements.None, faults, 500);
builder.Pooling = true;
yield return builder.ConnectionString;
}
}

public static IEnumerable<object[]> GetConnectionAndRetryStrategyLockedTable(int numberOfRetries)
private static IEnumerable<SqlRetryLogicBaseProvider> GetRetryStrategiesTyped(SqlRetryLogicOption option)
{
return GetConnectionAndRetryStrategy(numberOfRetries, TimeSpan.FromMilliseconds(100), FilterSqlStatements.None, null);
yield return SqlConfigurableRetryFactory.CreateExponentialRetryProvider(option);
yield return SqlConfigurableRetryFactory.CreateIncrementalRetryProvider(option);
yield return SqlConfigurableRetryFactory.CreateFixedRetryProvider(option);
}

public static IEnumerable<object[]> GetNoneRetriableCondition()
public static IEnumerable<int> GetDefaultTransientErrorCodes(params int[] additionalCodes)
{
yield return new object[] { DataTestUtility.TCPConnectionString, null };
yield return new object[] { DataTestUtility.TCPConnectionString, SqlConfigurableRetryFactory.CreateNoneRetryProvider() };
}
var transientErrorCodes = new HashSet<int>(s_defaultTransientErrors);
foreach (int additionalCode in additionalCodes)
{
transientErrorCodes.Add(additionalCode);
}

private static IEnumerable<object[]> GetRetryStrategies(SqlRetryLogicOption retryLogicOption)
{
yield return new object[] { SqlConfigurableRetryFactory.CreateExponentialRetryProvider(retryLogicOption) };
yield return new object[] { SqlConfigurableRetryFactory.CreateIncrementalRetryProvider(retryLogicOption) };
yield return new object[] { SqlConfigurableRetryFactory.CreateFixedRetryProvider(retryLogicOption) };
return transientErrorCodes;
}

/// Generate a predicate function to skip unauthorized SQL commands.
private static Predicate<string> RetryPreConditon(FilterSqlStatements unauthorizedSqlStatements)
{
var pattern = GetRegexPattern(unauthorizedSqlStatements);
return (commandText) => string.IsNullOrEmpty(pattern)
|| !Regex.IsMatch(commandText, pattern, RegexOptions.Compiled | RegexOptions.IgnoreCase);
}

/// Provide a regex pattern regarding to the SQL statement.
private static string GetRegexPattern(FilterSqlStatements sqlStatements)
private static Predicate<string> RetryPreCondition(Regex unauthorizedStatementRegex)
{
if (sqlStatements == FilterSqlStatements.None)
{
return string.Empty;
}

var pattern = new StringBuilder();

if (sqlStatements.HasFlag(FilterSqlStatements.Insert))
{
pattern.Append(@"INSERT( +INTO){0,1}|");
}
if (sqlStatements.HasFlag(FilterSqlStatements.Update))
{
pattern.Append(@"UPDATE|");
}
if (sqlStatements.HasFlag(FilterSqlStatements.Delete))
{
pattern.Append(@"DELETE|");
}
if (sqlStatements.HasFlag(FilterSqlStatements.Execute))
{
pattern.Append(@"EXEC(UTE){0,1}|");
}
if (sqlStatements.HasFlag(FilterSqlStatements.Alter))
{
pattern.Append(@"ALTER|");
}
if (sqlStatements.HasFlag(FilterSqlStatements.Create))
{
pattern.Append(@"CREATE|");
}
if (sqlStatements.HasFlag(FilterSqlStatements.Drop))
{
pattern.Append(@"DROP|");
}
if (sqlStatements.HasFlag(FilterSqlStatements.Truncate))
{
pattern.Append(@"TRUNCATE|");
}
if (sqlStatements.HasFlag(FilterSqlStatements.Select))
{
pattern.Append(@"SELECT|");
}
if (pattern.Length > 0)
{
pattern.Remove(pattern.Length - 1, 1);
}
return string.Format(@"\b({0})\b", pattern.ToString());
return commandText => unauthorizedStatementRegex is null ||
!unauthorizedStatementRegex.IsMatch(commandText);
}
}
}
Loading
Loading