Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 83 additions & 21 deletions components/retriever/qdrant/retriever.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,20 @@ type Config struct {
ScoreThreshold *float64
// Number of top results to retrieve from Qdrant.
TopK int
// ReturnFields limits the attributes returned from the document's payload. num is the number of attributes following the keyword.
// Default []string{"content", "metadata"}
ReturnFields []string
// DocumentConverter converts retrieved raw document to eino Document, default defaultResultParser.
DocumentConverter func(ctx context.Context, point *qdrant.ScoredPoint) (*schema.Document, error)
}

type Retriever struct {
client *qdrant.Client
collection string
embedding embedding.Embedder
scoreThreshold *float64
topK int
client *qdrant.Client
collection string
embedding embedding.Embedder
scoreThreshold *float64
topK int
documentConverter func(ctx context.Context, point *qdrant.ScoredPoint) (*schema.Document, error)
}

func NewRetriever(ctx context.Context, config *Config) (*Retriever, error) {
Expand All @@ -69,12 +75,21 @@ func NewRetriever(ctx context.Context, config *Config) (*Retriever, error) {
topK = 5
}

if len(config.ReturnFields) == 0 {
config.ReturnFields = []string{defaultMetadataKey, defaultContentKey}
}

if config.DocumentConverter == nil {
config.DocumentConverter = defaultResultParser(config.ReturnFields)
}

return &Retriever{
client: config.Client,
collection: config.Collection,
embedding: config.Embedding,
scoreThreshold: config.ScoreThreshold,
topK: topK,
client: config.Client,
collection: config.Collection,
embedding: config.Embedding,
scoreThreshold: config.ScoreThreshold,
topK: topK,
documentConverter: config.DocumentConverter,
}, nil
}

Expand Down Expand Up @@ -134,19 +149,13 @@ func (r *Retriever) Retrieve(ctx context.Context, query string, opts ...retrieve
}
docs = make([]*schema.Document, 0, len(resp))
for _, pt := range resp {
doc := &schema.Document{
ID: pt.Id.GetUuid(),
MetaData: map[string]any{},
}

if val, ok := pt.Payload[defaultContentKey]; ok {
doc.Content = val.GetStringValue()
doc, err := r.documentConverter(ctx, pt)
if err != nil {
return nil, err
}

if val, ok := pt.Payload[defaultMetadataKey]; ok {
doc.MetaData[defaultMetadataKey] = val.GetStructValue().Fields
if doc == nil {
return nil, fmt.Errorf("[qdrant retriever] document converter returned nil document")
}

doc.WithScore(float64(pt.Score))

docs = append(docs, doc)
Expand All @@ -171,6 +180,59 @@ func (r *Retriever) makeEmbeddingCtx(ctx context.Context, emb embedding.Embedder
return callbacks.ReuseHandlers(ctx, runInfo)
}

func defaultResultParser(returnFields []string) func(ctx context.Context, point *qdrant.ScoredPoint) (*schema.Document, error) {
return func(ctx context.Context, point *qdrant.ScoredPoint) (*schema.Document, error) {
if point == nil {
return nil, fmt.Errorf("[defaultResultParser] point is nil")
}

resp := &schema.Document{
Content: "",
MetaData: map[string]any{},
}

if point.Id != nil {
if uuid := point.Id.GetUuid(); uuid != "" {
resp.ID = uuid
} else {
resp.ID = fmt.Sprintf("%d", point.Id.GetNum())
}
}

for _, field := range returnFields {
val, found := point.Payload[field]
if !found {
return nil, fmt.Errorf("[defaultResultParser] field=%s not found in payload, point=%v", field, point)
}

if field == defaultContentKey {
resp.Content = val.GetStringValue()
} else if field == defaultMetadataKey {
resp.MetaData[defaultMetadataKey] = val.GetStructValue().Fields
} else {
switch val.GetKind().(type) {
case *qdrant.Value_NullValue:
resp.MetaData[field] = val.GetNullValue()
case *qdrant.Value_DoubleValue:
resp.MetaData[field] = val.GetDoubleValue()
case *qdrant.Value_IntegerValue:
resp.MetaData[field] = val.GetIntegerValue()
case *qdrant.Value_StringValue:
resp.MetaData[field] = val.GetStringValue()
case *qdrant.Value_BoolValue:
resp.MetaData[field] = val.GetBoolValue()
case *qdrant.Value_StructValue:
resp.MetaData[field] = val.GetStructValue().Fields
case *qdrant.Value_ListValue:
resp.MetaData[field] = val.GetListValue()
}
}
}

return resp, nil
}
}

func tryMarshalJsonString(input any) string {
if input == nil {
return ""
Expand Down
Loading