-
Notifications
You must be signed in to change notification settings - Fork 459
[VECTOR_FLOAT16] Implement feature extension and version negotiation for Vector v2 support #2868
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
88084ba
61c6f75
5779d28
f2aa7ce
e45a182
c8a97c2
e389d57
f09d152
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -762,9 +762,6 @@ private ActiveDirectoryAuthentication() { | |
| } | ||
| } | ||
|
|
||
| private static final String VECTOR_SUPPORT_OFF = "off"; | ||
| private static final String VECTOR_SUPPORT_V1 = "v1"; | ||
|
|
||
| final static int TNIR_FIRST_ATTEMPT_TIMEOUT_MS = 500; // fraction of timeout to use for fast failover connections | ||
|
|
||
| /** | ||
|
|
@@ -1105,16 +1102,26 @@ public void setBulkCopyForBatchInsertAllowEncryptedValueModifications( | |
|
|
||
| /** | ||
| * A string that indicates the vector type support during connection initialization. | ||
| * Valid values are "off" (vector types are returned as strings) and "v1" (vectors of type FLOAT32 are returned as vectors). | ||
| * Valid values are : | ||
| * - "off" (vector types are returned as strings) | ||
| * - "v1" (supports float32 vector type) | ||
| * - "v2" (supports float32 and float16 vector types) | ||
| * Default is "v1". | ||
| */ | ||
| private String vectorTypeSupport = VECTOR_SUPPORT_V1; | ||
| private String vectorTypeSupport = SQLServerDriverStringProperty.VECTOR_TYPE_SUPPORT.getDefaultValue(); | ||
|
|
||
| private VectorTypeSupport vectorTypeSupportEnum = VectorTypeSupport.V1; | ||
|
|
||
| /** | ||
| * Negotiated vector version between client and server | ||
| */ | ||
| private byte negotiatedVectorVersion = TDS.VECTORSUPPORT_NOT_SUPPORTED; | ||
|
|
||
| /** | ||
| * Returns the value of the vectorTypeSupport connection property. | ||
| * | ||
| * @return vectorTypeSupport | ||
| * The current vector type support setting ("off" or "v1"). | ||
| * The current vector type support setting ("off"|"v1"|"v2"). | ||
| */ | ||
| @Override | ||
| public String getVectorTypeSupport() { | ||
|
|
@@ -1126,7 +1133,10 @@ public String getVectorTypeSupport() { | |
| * | ||
| * @param vectorTypeSupport | ||
| * A string that indicates the vector type support during connection initialization. | ||
| * Valid values are "off" (vector types are returned as strings) and "v1" (vectors of type FLOAT32 are returned as vectors). | ||
| * Valid values are : | ||
| * - "off" (vector types are returned as strings) | ||
| * - "v1" (supports float32 vector type) | ||
| * - "v2" (supports float32 and float16 vector types) | ||
| * Default is "v1". | ||
| */ | ||
| @Override | ||
|
|
@@ -1136,15 +1146,13 @@ public void setVectorTypeSupport(String vectorTypeSupport) { | |
| Object[] msgArgs = { "null" }; | ||
| throw new IllegalArgumentException(form.format(msgArgs)); | ||
| } | ||
| switch (vectorTypeSupport.trim().toLowerCase()) { | ||
| case VECTOR_SUPPORT_OFF: | ||
| case VECTOR_SUPPORT_V1: | ||
| this.vectorTypeSupport = vectorTypeSupport.toLowerCase(); | ||
| break; | ||
| default: | ||
| MessageFormat form = new MessageFormat(SQLServerException.getErrString("R_invalidVectorTypeSupport")); | ||
| Object[] msgArgs = { vectorTypeSupport }; | ||
| throw new IllegalArgumentException(form.format(msgArgs)); | ||
| try { | ||
| this.vectorTypeSupportEnum = VectorTypeSupport.valueOfString(vectorTypeSupport.trim()); | ||
| this.vectorTypeSupport = vectorTypeSupport.toLowerCase(); | ||
| } catch (SQLServerException e) { | ||
| MessageFormat form = new MessageFormat(SQLServerException.getErrString("R_invalidVectorTypeSupport")); | ||
| Object[] msgArgs = { vectorTypeSupport }; | ||
| throw new IllegalArgumentException(form.format(msgArgs)); | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -5899,14 +5907,28 @@ int writeUserAgentFeatureRequest(boolean write, /* if false just calculates the | |
| */ | ||
| int writeVectorSupportFeatureRequest(boolean write, | ||
| TDSWriter tdsWriter) throws SQLServerException { | ||
| if (VECTOR_SUPPORT_OFF.equalsIgnoreCase(vectorTypeSupport)) { | ||
|
|
||
| // Initialize vectorTypeSupportEnum if not already done. | ||
| if (vectorTypeSupportEnum == null) { | ||
| try { | ||
| vectorTypeSupportEnum = VectorTypeSupport.valueOfString(vectorTypeSupport); | ||
| } catch (SQLServerException e) { | ||
| // Fallback to value OFF if invalid value is provided. | ||
| vectorTypeSupportEnum = VectorTypeSupport.OFF; | ||
| } | ||
| } | ||
|
|
||
| if (vectorTypeSupportEnum == VectorTypeSupport.OFF) { | ||
| return 0; | ||
| } | ||
|
|
||
| int len = 6; // 1byte = featureID, 4bytes = featureData length, 1 bytes = Version | ||
| if (write) { | ||
| tdsWriter.writeByte(TDS.TDS_FEATURE_EXT_VECTORSUPPORT); | ||
| tdsWriter.writeInt(1); | ||
| tdsWriter.writeByte(TDS.MAX_VECTORSUPPORT_VERSION); | ||
|
|
||
| // write the vector type support version | ||
| tdsWriter.writeByte(vectorTypeSupportEnum.getTdsValue()); | ||
| } | ||
| return len; | ||
| } | ||
|
|
@@ -7118,7 +7140,20 @@ private void onFeatureExtAck(byte featureId, byte[] data) throws SQLServerExcept | |
| if (0 == serverSupportedVectorVersion || serverSupportedVectorVersion > TDS.MAX_VECTORSUPPORT_VERSION) { | ||
| throw new SQLServerException(SQLServerException.getErrString("R_InvalidVectorVersionNumber"), null); | ||
| } | ||
| serverSupportsVector = true; | ||
| // Negotiate the vector version between client and server | ||
| negotiatedVectorVersion = negotiateVectorVersion(vectorTypeSupportEnum, serverSupportedVectorVersion); | ||
|
|
||
| if (negotiatedVectorVersion > TDS.VECTORSUPPORT_NOT_SUPPORTED) { | ||
| serverSupportsVector = true; | ||
|
|
||
| if (connectionlogger.isLoggable(Level.FINE)) { | ||
| connectionlogger.fine(toString() + " Vector support negotiated. Client: " + vectorTypeSupport + | ||
| ", Server: " + serverSupportedVectorVersion + | ||
| ", Negotiated: " + negotiatedVectorVersion); | ||
| } | ||
| } else { | ||
| serverSupportsVector = false; | ||
| } | ||
| break; | ||
| } | ||
|
|
||
|
|
@@ -7154,6 +7189,45 @@ private void onFeatureExtAck(byte featureId, byte[] data) throws SQLServerExcept | |
| } | ||
| } | ||
|
|
||
| /** | ||
| * Negotiates the vector version between client and server based on the | ||
| * following rules: | ||
| * - If either client or server is "off", negotiated version is "off" | ||
| * - If both support v2, negotiated version is v2 | ||
| * - If both support v1, negotiated version is v1 | ||
| * - Otherwise, use the minimum supported version | ||
| * | ||
| * @param clientVectorSupportEnum The client's vector type support setting | ||
| * @param serverVersion The server's supported vector version | ||
| * @return The negotiated vector version | ||
| */ | ||
| private byte negotiateVectorVersion(VectorTypeSupport clientVectorSupportEnum, byte serverVersion) { | ||
|
|
||
| // If server doesn't support vectors, negotiation is off | ||
| if (serverVersion == TDS.VECTORSUPPORT_NOT_SUPPORTED) { | ||
| return TDS.VECTORSUPPORT_NOT_SUPPORTED; | ||
| } | ||
|
|
||
| if (clientVectorSupportEnum == VectorTypeSupport.OFF) { | ||
| return TDS.VECTORSUPPORT_NOT_SUPPORTED; | ||
| } | ||
|
|
||
| byte clientMaxVersion = clientVectorSupportEnum.getTdsValue(); | ||
|
|
||
| // Negotiate using the minimum supported version | ||
| return (byte) Math.min(clientMaxVersion, serverVersion); | ||
|
Comment on lines
+7204
to
+7218
|
||
|
|
||
| } | ||
|
|
||
| /** | ||
| * Returns the negotiated vector version between client and server. | ||
| * | ||
| * @return The negotiated vector version (0 = off, 1 = v1, 2 = v2) | ||
| */ | ||
| public byte getNegotiatedVectorVersion() { | ||
muskan124947 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| return negotiatedVectorVersion; | ||
| } | ||
|
|
||
| /* | ||
| * Executes a DTC command | ||
| */ | ||
|
|
@@ -7452,6 +7526,7 @@ final boolean complete(LogonCommand logonCommand, TDSReader tdsReader) throws SQ | |
|
|
||
| // request vector support | ||
| len += writeVectorSupportFeatureRequest(false, tdsWriter); | ||
|
|
||
| // request JSON support | ||
| len += writeJSONSupportFeatureRequest(false, tdsWriter); | ||
|
|
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.