Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,29 @@ public static bool IsSupportedDataClassification()
return true;
}

/// <summary>
/// Determines whether the SQL Server supports the 'vector' data type.
/// </summary>
/// <remarks>This method attempts to connect to the SQL Server and check for the existence of the
/// 'vector' data type. If a connection cannot be established or an error occurs during the query, the method
/// returns <see langword="false"/>.</remarks>
/// <returns><see langword="true"/> if the 'vector' data type is supported; otherwise, <see langword="false"/>.</returns>
public static bool IsSupportedSqlVector()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you convert this into a property, and only query database on its initialization?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. I've done this (along with IsSupportedDataClassification as you mentioned in the other comment, although I've changed this method's approach slightly to use OBJECT_ID rather than to wait for an error.)

I've also cleaned up the test conditions slightly, since both properties will return false if the TCP connection string is unspecified.

{
try
{
using var connection = new SqlConnection(TCPConnectionString);
using var command = new SqlCommand("SELECT COUNT(1) FROM SYS.TYPES WHERE [name] = 'vector'", connection);

connection.Open();
return (int)command.ExecuteScalar() > 0;
}
catch (SqlException)
{
return false;
}
}

public static bool IsDNSCachingSetup() => !string.IsNullOrEmpty(DNSCachingConnString);

// Synapse: Always Encrypted is not supported with Azure Synapse.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ namespace Microsoft.Data.SqlClient.ManualTesting.Tests.SQL.VectorTest
public static class VectorFloat32TestData
{
public const int VectorHeaderSize = 8;
public static float[] testData = new float[] { 1.1f, 2.2f, 3.3f };
public static float[] testData = new float[] { 1.1f, 2.2f, 3.3f, 1.01f, float.MinValue, -0.0f };
public static int vectorColumnLength = testData.Length;
// Incorrect size for SqlParameter.Size
public static int IncorrectParamSize = 3234;
Expand Down Expand Up @@ -59,16 +59,17 @@ public sealed class NativeVectorFloat32Tests : IDisposable
private static readonly string s_connectionString = ManualTesting.Tests.DataTestUtility.TCPConnectionString;
private static readonly string s_tableName = DataTestUtility.GetShortName("VectorTestTable");
private static readonly string s_bulkCopySrcTableName = DataTestUtility.GetShortName("VectorBulkCopyTestTable");
private static readonly string s_bulkCopySrcTableDef = $@"(Id INT PRIMARY KEY IDENTITY, VectorData vector(3) NULL)";
private static readonly string s_tableDefinition = $@"(Id INT PRIMARY KEY IDENTITY, VectorData vector(3) NULL)";
private static readonly int s_vectorDimensions = VectorFloat32TestData.vectorColumnLength;
private static readonly string s_bulkCopySrcTableDef = $@"(Id INT PRIMARY KEY IDENTITY, VectorData vector({s_vectorDimensions}) NULL)";
private static readonly string s_tableDefinition = $@"(Id INT PRIMARY KEY IDENTITY, VectorData vector({s_vectorDimensions}) NULL)";
private static readonly string s_selectCmdString = $"SELECT VectorData FROM {s_tableName} ORDER BY Id DESC";
private static readonly string s_insertCmdString = $"INSERT INTO {s_tableName} (VectorData) VALUES (@VectorData)";
private static readonly string s_vectorParamName = $"@VectorData";
private static readonly string s_outputVectorParamName = $"@OutputVectorData";
private static readonly string s_storedProcName = DataTestUtility.GetShortName("VectorsAsVarcharSp");
private static readonly string s_storedProcBody = $@"
{s_vectorParamName} vector(3), -- Input: Serialized float[] as JSON string
{s_outputVectorParamName} vector(3) OUTPUT -- Output: Echoed back from latest inserted row
{s_vectorParamName} vector({s_vectorDimensions}), -- Input: Serialized float[] as JSON string
{s_outputVectorParamName} vector({s_vectorDimensions}) OUTPUT -- Output: Echoed back from latest inserted row
AS
BEGIN
SET NOCOUNT ON;
Expand Down Expand Up @@ -147,7 +148,7 @@ private void ValidateInsertedData(SqlConnection connection, float[] expectedData
}
}

[ConditionalTheory(typeof(DataTestUtility), nameof(DataTestUtility.IsAzureServer), nameof(DataTestUtility.IsNotManagedInstance))]
[ConditionalTheory(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup), nameof(DataTestUtility.IsSupportedSqlVector))]
[MemberData(nameof(VectorFloat32TestData.GetVectorFloat32TestData), MemberType = typeof(VectorFloat32TestData), DisableDiscoveryEnumeration = true)]
public void TestSqlVectorFloat32ParameterInsertionAndReads(
int pattern,
Expand Down Expand Up @@ -213,7 +214,7 @@ private async Task ValidateInsertedDataAsync(SqlConnection connection, float[] e
}
}

[ConditionalTheory(typeof(DataTestUtility), nameof(DataTestUtility.IsAzureServer), nameof(DataTestUtility.IsNotManagedInstance))]
[ConditionalTheory(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup), nameof(DataTestUtility.IsSupportedSqlVector))]
[MemberData(nameof(VectorFloat32TestData.GetVectorFloat32TestData), MemberType = typeof(VectorFloat32TestData), DisableDiscoveryEnumeration = true)]
public async Task TestSqlVectorFloat32ParameterInsertionAndReadsAsync(
int pattern,
Expand Down Expand Up @@ -247,7 +248,7 @@ public async Task TestSqlVectorFloat32ParameterInsertionAndReadsAsync(
await ValidateInsertedDataAsync(conn, expectedValues, expectedLength);
}

[ConditionalTheory(typeof(DataTestUtility), nameof(DataTestUtility.IsAzureServer), nameof(DataTestUtility.IsNotManagedInstance))]
[ConditionalTheory(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup), nameof(DataTestUtility.IsSupportedSqlVector))]
[MemberData(nameof(VectorFloat32TestData.GetVectorFloat32TestData), MemberType = typeof(VectorFloat32TestData), DisableDiscoveryEnumeration = true)]
public void TestStoredProcParamsForVectorFloat32(
int pattern,
Expand Down Expand Up @@ -304,7 +305,7 @@ public void TestStoredProcParamsForVectorFloat32(
Assert.Throws<InvalidOperationException>(() => command.ExecuteNonQuery());
}

[ConditionalTheory(typeof(DataTestUtility), nameof(DataTestUtility.IsAzureServer), nameof(DataTestUtility.IsNotManagedInstance))]
[ConditionalTheory(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup), nameof(DataTestUtility.IsSupportedSqlVector))]
[MemberData(nameof(VectorFloat32TestData.GetVectorFloat32TestData), MemberType = typeof(VectorFloat32TestData), DisableDiscoveryEnumeration = true)]
public async Task TestStoredProcParamsForVectorFloat32Async(
int pattern,
Expand Down Expand Up @@ -361,7 +362,7 @@ public async Task TestStoredProcParamsForVectorFloat32Async(
await Assert.ThrowsAsync<InvalidOperationException>(async () => await command.ExecuteNonQueryAsync());
}

[ConditionalTheory(typeof(DataTestUtility), nameof(DataTestUtility.IsAzureServer), nameof(DataTestUtility.IsNotManagedInstance))]
[ConditionalTheory(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup), nameof(DataTestUtility.IsSupportedSqlVector))]
[InlineData(1)]
[InlineData(2)]
public void TestBulkCopyFromSqlTable(int bulkCopySourceMode)
Expand Down Expand Up @@ -460,7 +461,7 @@ public void TestBulkCopyFromSqlTable(int bulkCopySourceMode)
Assert.Equal(VectorFloat32TestData.testData.Length, ((SqlVector<float>)verifyReader.GetSqlVector<float>(0)).Length);
}

[ConditionalTheory(typeof(DataTestUtility), nameof(DataTestUtility.IsAzureServer), nameof(DataTestUtility.IsNotManagedInstance))]
[ConditionalTheory(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup), nameof(DataTestUtility.IsSupportedSqlVector))]
[InlineData(1)]
[InlineData(2)]
public async Task TestBulkCopyFromSqlTableAsync(int bulkCopySourceMode)
Expand Down Expand Up @@ -560,7 +561,7 @@ public async Task TestBulkCopyFromSqlTableAsync(int bulkCopySourceMode)
Assert.Equal(VectorFloat32TestData.testData.Length, vector.Length);
}

[ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.IsAzureServer), nameof(DataTestUtility.IsNotManagedInstance))]
[ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup), nameof(DataTestUtility.IsSupportedSqlVector))]
public void TestInsertVectorsFloat32WithPrepare()
{
SqlConnection conn = new SqlConnection(s_connectionString);
Expand All @@ -571,15 +572,15 @@ public void TestInsertVectorsFloat32WithPrepare()
command.Prepare();
for (int i = 0; i < 10; i++)
{
vectorParam.Value = new SqlVector<float>(new float[] { i + 0.1f, i + 0.2f, i + 0.3f });
vectorParam.Value = new SqlVector<float>(new float[] { i + 0.1f, i + 0.2f, i + 0.3f, i + 0.4f, i + 0.5f, i + 0.6f });
command.ExecuteNonQuery();
}
SqlCommand validateCommand = new SqlCommand($"SELECT VectorData FROM {s_tableName}", conn);
using SqlDataReader reader = validateCommand.ExecuteReader();
int rowcnt = 0;
while (reader.Read())
{
float[] expectedData = new float[] { rowcnt + 0.1f, rowcnt + 0.2f, rowcnt + 0.3f };
float[] expectedData = new float[] { rowcnt + 0.1f, rowcnt + 0.2f, rowcnt + 0.3f, rowcnt + 0.4f, rowcnt + 0.5f, rowcnt + 0.6f };
float[] dbData = reader.GetSqlVector<float>(0).Memory.ToArray();
Assert.Equal(expectedData, dbData);
rowcnt++;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ private void ValidateInsertedData(SqlConnection connection, float[] expectedData
}
}

[ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup), nameof(DataTestUtility.IsAzureServer), nameof(DataTestUtility.IsNotManagedInstance))]
[ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup), nameof(DataTestUtility.IsSupportedSqlVector))]
public void TestVectorDataInsertionAsVarchar()
{
float[] data = { 1.1f, 2.2f, 3.3f };
Expand Down Expand Up @@ -173,7 +173,7 @@ private async Task ValidateInsertedDataAsync(SqlConnection connection, float[] e
}
}

[ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup), nameof(DataTestUtility.IsAzureServer), nameof(DataTestUtility.IsNotManagedInstance))]
[ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup), nameof(DataTestUtility.IsSupportedSqlVector))]
public async Task TestVectorParameterInitializationAsync()
{
float[] data = { 1.1f, 2.2f, 3.3f };
Expand Down Expand Up @@ -245,7 +245,7 @@ public async Task TestVectorParameterInitializationAsync()
await ValidateInsertedDataAsync(conn, null);
}

[ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup), nameof(DataTestUtility.IsAzureServer), nameof(DataTestUtility.IsNotManagedInstance))]
[ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup), nameof(DataTestUtility.IsSupportedSqlVector))]
public void TestVectorDataReadsAsVarchar()
{
float[] data = { 1.1f, 2.2f, 3.3f };
Expand Down Expand Up @@ -302,7 +302,7 @@ public void TestVectorDataReadsAsVarchar()
Assert.Throws<SqlNullValueException>(() => reader.GetFieldValue<string>(0));
}

[ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup), nameof(DataTestUtility.IsAzureServer), nameof(DataTestUtility.IsNotManagedInstance))]
[ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup), nameof(DataTestUtility.IsSupportedSqlVector))]
public async Task TestVectorDataReadsAsVarcharAsync()
{
float[] data = { 1.1f, 2.2f, 3.3f };
Expand Down Expand Up @@ -359,7 +359,7 @@ public async Task TestVectorDataReadsAsVarcharAsync()
await Assert.ThrowsAsync<SqlNullValueException>(async () => await reader2.GetFieldValueAsync<string>(0));
}

[ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup), nameof(DataTestUtility.IsAzureServer), nameof(DataTestUtility.IsNotManagedInstance))]
[ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup), nameof(DataTestUtility.IsSupportedSqlVector))]
public void TestStoredProcParamsForVectorAsVarchar()
{
// Test data
Expand Down Expand Up @@ -405,7 +405,7 @@ public void TestStoredProcParamsForVectorAsVarchar()
Assert.True(outputParam.Value == DBNull.Value);
}

[ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup), nameof(DataTestUtility.IsAzureServer), nameof(DataTestUtility.IsNotManagedInstance))]
[ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup), nameof(DataTestUtility.IsSupportedSqlVector))]
public async Task TestStoredProcParamsForVectorAsVarcharAsync()
{
// Test data
Expand Down Expand Up @@ -456,7 +456,7 @@ public async Task TestStoredProcParamsForVectorAsVarcharAsync()
Assert.True(outputParam.Value == DBNull.Value);
}

[ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup), nameof(DataTestUtility.IsAzureServer), nameof(DataTestUtility.IsNotManagedInstance))]
[ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup), nameof(DataTestUtility.IsSupportedSqlVector))]
public void TestSqlBulkCopyForVectorAsVarchar()
{
//Setup source with test data and create destination table for bulkcopy.
Expand Down Expand Up @@ -521,7 +521,7 @@ public void TestSqlBulkCopyForVectorAsVarchar()
Assert.True(verifyReader.IsDBNull(0));
}

[ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup), nameof(DataTestUtility.IsAzureServer), nameof(DataTestUtility.IsNotManagedInstance))]
[ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup), nameof(DataTestUtility.IsSupportedSqlVector))]
public async Task TestSqlBulkCopyForVectorAsVarcharAsync()
{
//Setup source with test data and create destination table for bulkcopy.
Expand Down Expand Up @@ -586,7 +586,7 @@ public async Task TestSqlBulkCopyForVectorAsVarcharAsync()
Assert.True(await verifyReader.IsDBNullAsync(0));
}

[ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup), nameof(DataTestUtility.IsAzureServer), nameof(DataTestUtility.IsNotManagedInstance))]
[ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup), nameof(DataTestUtility.IsSupportedSqlVector))]
public void TestInsertVectorsAsVarcharWithPrepare()
{
SqlConnection conn = new SqlConnection(s_connectionString);
Expand Down
Loading