diff --git a/sdk/data/azcosmos/CHANGELOG.md b/sdk/data/azcosmos/CHANGELOG.md index 4ccf5eabf2fa..dda571b2163d 100644 --- a/sdk/data/azcosmos/CHANGELOG.md +++ b/sdk/data/azcosmos/CHANGELOG.md @@ -4,6 +4,8 @@ ### Features Added +* Adds support for float 16 datatype for vector embedding policy. See [PR 25707](https://github.com/Azure/azure-sdk-for-go/pull/25707) + ### Breaking Changes ### Bugs Fixed diff --git a/sdk/data/azcosmos/emulator_cosmos_container_test.go b/sdk/data/azcosmos/emulator_cosmos_container_test.go index e76480d5e7c6..344db1e60fff 100644 --- a/sdk/data/azcosmos/emulator_cosmos_container_test.go +++ b/sdk/data/azcosmos/emulator_cosmos_container_test.go @@ -354,6 +354,95 @@ func TestContainerVectorSearch(t *testing.T) { } } +func TestCreateValidVectorEmbeddingPolicy(t *testing.T) { + emulatorTests := newEmulatorTests(t) + client := emulatorTests.getClient(t, newSpanValidator(t, &spanMatcher{ + ExpectedSpans: []string{}, + })) + + database := emulatorTests.createDatabase(t, context.TODO(), client, "vectorDataTypes") + defer emulatorTests.deleteDatabase(t, context.TODO(), database) + + // Using valid data types + dataTypes := []struct { + name string + dataType VectorDataType + }{ + {"float32", VectorDataTypeFloat32}, + {"float16", VectorDataTypeFloat16}, + {"int8", VectorDataTypeInt8}, + {"uint8", VectorDataTypeUint8}, + } + + for _, dt := range dataTypes { + t.Run(dt.name, func(t *testing.T) { + containerID := "vector_container_" + dt.name + + properties := ContainerProperties{ + ID: containerID, + PartitionKeyDefinition: PartitionKeyDefinition{ + Paths: []string{"/id"}, + }, + VectorEmbeddingPolicy: &VectorEmbeddingPolicy{ + VectorEmbeddings: []VectorEmbedding{ + { + Path: "/vector1", + DataType: dt.dataType, + Dimensions: 256, + DistanceFunction: VectorDistanceFunctionEuclidean, + }, + }, + }, + } + + createdResp, err := database.CreateContainer(context.TODO(), properties, nil) + if err != nil { + t.Fatalf("Failed to create container with %s data type: %v", dt.name, err) + } + + container, _ := database.NewContainer(containerID) + readResp, err := container.Read(context.TODO(), nil) + if err != nil { + t.Fatalf("Failed to read container: %v", err) + } + + readProperties := readResp.ContainerProperties + if readProperties.VectorEmbeddingPolicy == nil { + t.Fatalf("Expected VectorEmbeddingPolicy to be set") + } + + if len(readProperties.VectorEmbeddingPolicy.VectorEmbeddings) != 1 { + t.Fatalf("Expected 1 vector embedding, got %d", len(readProperties.VectorEmbeddingPolicy.VectorEmbeddings)) + } + + embedding := readProperties.VectorEmbeddingPolicy.VectorEmbeddings[0] + if embedding.DataType != dt.dataType { + t.Errorf("Expected data type %s, got %s", dt.dataType, embedding.DataType) + } + + if embedding.Path != "/vector1" { + t.Errorf("Expected path /vector1, got %s", embedding.Path) + } + + if embedding.Dimensions != 256 { + t.Errorf("Expected dimensions 256, got %d", embedding.Dimensions) + } + + if embedding.DistanceFunction != VectorDistanceFunctionEuclidean { + t.Errorf("Expected distance function euclidean, got %s", embedding.DistanceFunction) + } + + // Clean up + _, err = container.Delete(context.TODO(), nil) + if err != nil { + t.Fatalf("Failed to delete container %s: %v", containerID, err) + } + + _ = createdResp // Avoid unused variable warning + }) + } +} + func TestContainerFullTextSearch(t *testing.T) { emulatorTests := newEmulatorTests(t) client := emulatorTests.getClient(t, newSpanValidator(t, &spanMatcher{ diff --git a/sdk/data/azcosmos/vector_embedding_policy.go b/sdk/data/azcosmos/vector_embedding_policy.go index 3ff109426620..18356f807943 100644 --- a/sdk/data/azcosmos/vector_embedding_policy.go +++ b/sdk/data/azcosmos/vector_embedding_policy.go @@ -38,6 +38,9 @@ const ( // VectorDataTypeFloat32 represents 32-bit floating point numbers (default). VectorDataTypeFloat32 VectorDataType = "float32" + // VectorDataTypeFloat16 represents 16-bit floating point numbers. + VectorDataTypeFloat16 VectorDataType = "float16" + // VectorDataTypeInt8 represents 8-bit signed integers. VectorDataTypeInt8 VectorDataType = "int8"