Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 82 additions & 0 deletions components/retriever/milvus2/auto_search.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
/*
* Copyright 2025 CloudWeGo Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package milvus2

import (
"context"
"fmt"

"github.com/cloudwego/eino/components/retriever"
"github.com/cloudwego/eino/schema"
"github.com/milvus-io/milvus/client/v2/milvusclient"
)

// autoSearchMode implements automatic search strategy inference.
// It determines the best search mode (Dense, Sparse, or Hybrid) based on the provided configuration.
// For hybrid search, it defaults to RRFReranker for result fusion.
type autoSearchMode struct{}

// Retrieve performs search mapping configuration to the appropriate search mode.
func (a *autoSearchMode) Retrieve(ctx context.Context, client *milvusclient.Client, conf *RetrieverConfig, query string, opts ...retriever.Option) ([]*schema.Document, error) {
hasDense := conf.VectorField != ""
hasSparse := conf.SparseVectorField != ""

// Case 1: Hybrid Search (Both configured)
if hasDense && hasSparse {
// Use RRF for hybrid search result fusion
reranker := milvusclient.NewRRFReranker()

// Prepare SubRequests
// 1. Dense Request
denseReq := &SubRequest{
VectorField: conf.VectorField,
VectorType: DenseVector,
// MetricType: leave empty to let Milvus use the index's default metric type.
TopK: conf.TopK,
SearchParams: ExtractSearchParams(conf, conf.VectorField),
}

// 2. Sparse Request
sparseReq := &SubRequest{
VectorField: conf.SparseVectorField,
VectorType: SparseVector,
MetricType: BM25,
TopK: conf.TopK,
SearchParams: ExtractSearchParams(conf, conf.SparseVectorField),
}

// Delegate to Hybrid implementation (in same package)
hybrid := NewHybrid(reranker, denseReq, sparseReq)
return hybrid.Retrieve(ctx, client, conf, query, opts...)
}

// Case 2: Dense Only
if hasDense {
// Delegate to Approximate implementation
approx := NewApproximate("")
return approx.Retrieve(ctx, client, conf, query, opts...)
}

// Case 3: Sparse Only
if hasSparse {
// Delegate to Sparse implementation
sparse := NewSparse("")
return sparse.Retrieve(ctx, client, conf, query, opts...)
}

return nil, fmt.Errorf("[AutoSearch] no vector fields configured; set VectorField or SparseVectorField")
}
106 changes: 106 additions & 0 deletions components/retriever/milvus2/auto_search_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
/*
* Copyright 2025 CloudWeGo Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package milvus2

import (
"context"
"testing"

. "github.com/bytedance/mockey"
"github.com/cloudwego/eino/schema"
"github.com/milvus-io/milvus/client/v2/milvusclient"
)

func TestAutoSearchMode_Retrieve_Std(t *testing.T) {
ctx := context.Background()
mockClient := &milvusclient.Client{}
mockEmb := &mockEmbedding{dims: 128}
auto := &autoSearchMode{}

mockConverter := func(ctx context.Context, result milvusclient.ResultSet) ([]*schema.Document, error) {
return []*schema.Document{}, nil
}

t.Run("Dense with SearchParams", func(t *testing.T) {
PatchConvey("mock", t, func() {
Mock(GetMethod(mockClient, "Search")).Return([]milvusclient.ResultSet{
{ResultCount: 1},
}, nil).Build()

conf := &RetrieverConfig{
VectorField: "dense_vec",
Embedding: mockEmb,
SearchParams: map[string]map[string]interface{}{
"dense_vec": {"nprobe": 10},
},
DocumentConverter: mockConverter,
}

// We keep recover just in case, but it should pass now
defer func() {
if r := recover(); r != nil {
t.Fatalf("Recovered from panic: %v", r)
}
}()

_, err := auto.Retrieve(ctx, mockClient, conf, "query")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
})
})

t.Run("Hybrid", func(t *testing.T) {
PatchConvey("mock", t, func() {
Mock(GetMethod(mockClient, "HybridSearch")).Return([]milvusclient.ResultSet{
{ResultCount: 1},
}, nil).Build()

conf := &RetrieverConfig{
VectorField: "dense_vec",
SparseVectorField: "sparse_vec",
Embedding: mockEmb,
DocumentConverter: mockConverter,
}

_, err := auto.Retrieve(ctx, mockClient, conf, "query")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
})
})

t.Run("Sparse Only", func(t *testing.T) {
PatchConvey("mock", t, func() {
Mock(GetMethod(mockClient, "Search")).Return([]milvusclient.ResultSet{
{ResultCount: 1},
}, nil).Build()

// Only SparseVectorField configured, no VectorField
conf := &RetrieverConfig{
SparseVectorField: "sparse_vec",
DocumentConverter: mockConverter,
// No Embedding needed for sparse-only search
}

_, err := auto.Retrieve(ctx, mockClient, conf, "query")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
})
})
}
108 changes: 108 additions & 0 deletions components/retriever/milvus2/params.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
/*
* Copyright 2025 CloudWeGo Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package milvus2

import "fmt"

// Common Milvus search parameters.
// See Milvus documentation for full list: https://milvus.io/docs/index_selection.md
const (
// ParamNProbe is for IVF indices. Specifies the number of units to query.
ParamNProbe = "nprobe"
// ParamEF is for HNSW indices. Specifies the search scope.
ParamEF = "ef"
// ParamRadius is for Range Search. Specifies the radius distance.
ParamRadius = "radius"
// ParamRangeFilter is for Range Search. Filters out results within this distance.
ParamRangeFilter = "range_filter"
// ParamLevel is for SCANN indices. Specifies the pruning level.
ParamLevel = "level"
// ParamDropRatioSearch is for Sparse (IVF_FLAT) indices. Ignores small unrelated values.
ParamDropRatioSearch = "drop_ratio_search"
)

// NewSearchParams creates a helper to build the SearchParams map.
func NewSearchParams() *SearchParamsBuilder {
return &SearchParamsBuilder{
m: make(map[string]interface{}),
}
}

// SearchParamsBuilder helps construct the search parameter map in a typed way.
type SearchParamsBuilder struct {
m map[string]interface{}
}

// WithNProbe sets the "nprobe" parameter (for IVF indices).
func (b *SearchParamsBuilder) WithNProbe(nprobe int) *SearchParamsBuilder {
b.m[ParamNProbe] = nprobe
return b
}

// WithEF sets the "ef" parameter (for HNSW indices).
func (b *SearchParamsBuilder) WithEF(ef int) *SearchParamsBuilder {
b.m[ParamEF] = ef
return b
}

// WithRadius sets the "radius" parameter (for Range Search).
func (b *SearchParamsBuilder) WithRadius(radius float64) *SearchParamsBuilder {
b.m[ParamRadius] = radius
return b
}

// WithRangeFilter sets the "range_filter" parameter (for Range Search).
func (b *SearchParamsBuilder) WithRangeFilter(filter float64) *SearchParamsBuilder {
b.m[ParamRangeFilter] = filter
return b
}

// WithDropRatioSearch sets the "drop_ratio_search" parameter (for Sparse indices).
func (b *SearchParamsBuilder) WithDropRatioSearch(ratio float64) *SearchParamsBuilder {
b.m[ParamDropRatioSearch] = ratio
return b
}

// With sets a custom parameter key-value pair.
func (b *SearchParamsBuilder) With(key string, value interface{}) *SearchParamsBuilder {
b.m[key] = value
return b
}

// Build returns the constructed map.
func (b *SearchParamsBuilder) Build() map[string]interface{} {
return b.m
}

// ExtractSearchParams extracts and stringifies search parameters for a specific field from the configuration.
func ExtractSearchParams(conf *RetrieverConfig, field string) map[string]string {
if conf.SearchParams == nil {
return nil
}

// Milvus SDK expects search parameters (like "nprobe", "ef") to be strings.
// We allow users to pass them as appropriate types (int, float) in configuration
// and convert them to strings here.
if params, ok := conf.SearchParams[field]; ok {
out := make(map[string]string, len(params))
for k, v := range params {
out[k] = fmt.Sprintf("%v", v)
}
return out
}
return nil
}
26 changes: 19 additions & 7 deletions components/retriever/milvus2/retriever.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,15 @@ type RetrieverConfig struct {
ConsistencyLevel ConsistencyLevel

// SearchMode defines the search strategy.
// Required.
// Optional. If nil, defaults to Auto Search (infers from VectorField/SparseVectorField).
SearchMode SearchMode

// SearchParams allows passing extra parameters (e.g. "nprobe", "ef") to the search.
// Key is the vector field name, Value is the parameter map.
// Used primarily by Auto Search mode.
// Tip: Use NewSearchParams() builder to construct the value map easily.
SearchParams map[string]map[string]interface{}

// DocumentConverter converts Milvus search results to EINO documents.
// If nil, uses default conversion.
DocumentConverter func(ctx context.Context, result milvusclient.ResultSet) ([]*schema.Document, error)
Expand Down Expand Up @@ -188,18 +194,24 @@ func (c *RetrieverConfig) validate() error {
return fmt.Errorf("[NewRetriever] milvus client or client config not provided")
}
if c.SearchMode == nil {
return fmt.Errorf("[NewRetriever] search mode not provided")
c.SearchMode = &autoSearchMode{}
}
// Embedding validation is delegated to the specific SearchMode implementation.
if c.Collection == "" {
c.Collection = defaultCollection
}
if c.VectorField == "" {
c.VectorField = defaultVectorField
}
if c.SparseVectorField == "" {
c.SparseVectorField = defaultSparseVectorField

// Apply defaults for VectorField and SparseVectorField ONLY if not using AutoSearch.
// AutoSearch relies on empty fields to infer user intent (Dense vs Sparse vs Hybrid).
if _, isAuto := c.SearchMode.(*autoSearchMode); !isAuto {
if c.VectorField == "" {
c.VectorField = defaultVectorField
}
if c.SparseVectorField == "" {
c.SparseVectorField = defaultSparseVectorField
}
}

if len(c.OutputFields) == 0 {
c.OutputFields = []string{"*"}
}
Expand Down
23 changes: 15 additions & 8 deletions components/retriever/milvus2/retriever_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,16 +100,17 @@ func TestRetrieverConfig_validate(t *testing.T) {
convey.So(err, convey.ShouldBeNil)
})

convey.Convey("test missing search mode", func() {
convey.Convey("test missing search mode (defaults to auto)", func() {
config := &RetrieverConfig{
ClientConfig: &milvusclient.ClientConfig{Address: "localhost:19530"},
Collection: "test_collection",
Embedding: mockEmb,
SearchMode: nil,
SearchMode: nil, // Should trigger default
}
err := config.validate()
convey.So(err, convey.ShouldNotBeNil)
convey.So(err.Error(), convey.ShouldContainSubstring, "search mode")
convey.So(err, convey.ShouldBeNil)
convey.So(config.SearchMode, convey.ShouldNotBeNil)
// Optional: check type if we export it or use reflection, but checking not nil is enough for validate logic
})
})
}
Expand All @@ -132,15 +133,21 @@ func TestNewRetriever(t *testing.T) {
convey.So(err.Error(), convey.ShouldContainSubstring, "client")
})

PatchConvey("test missing search mode", func() {
_, err := NewRetriever(ctx, &RetrieverConfig{
PatchConvey("test missing search mode (defaults to auto)", func() {
mockClient := &milvusclient.Client{}
Mock(milvusclient.New).Return(mockClient, nil).Build()
Mock(GetMethod(mockClient, "HasCollection")).Return(true, nil).Build()
Mock(GetMethod(mockClient, "GetLoadState")).Return(entity.LoadState{State: entity.LoadStateLoaded}, nil).Build()

r, err := NewRetriever(ctx, &RetrieverConfig{
ClientConfig: &milvusclient.ClientConfig{Address: "localhost:19530"},
Collection: "test_collection",
Embedding: mockEmb,
SearchMode: nil,
})
convey.So(err, convey.ShouldNotBeNil)
convey.So(err.Error(), convey.ShouldContainSubstring, "search mode")
convey.So(err, convey.ShouldBeNil)
convey.So(r, convey.ShouldNotBeNil)
convey.So(r.config.SearchMode, convey.ShouldNotBeNil)
})
})
}
Expand Down
Loading
Loading