diff --git a/components/retriever/qdrant/retriever.go b/components/retriever/qdrant/retriever.go index ca4110091..8f3dcf7ca 100644 --- a/components/retriever/qdrant/retriever.go +++ b/components/retriever/qdrant/retriever.go @@ -40,14 +40,20 @@ type Config struct { ScoreThreshold *float64 // Number of top results to retrieve from Qdrant. TopK int + // ReturnFields limits the attributes returned from the document's payload. num is the number of attributes following the keyword. + // Default []string{"content", "metadata"} + ReturnFields []string + // DocumentConverter converts retrieved raw document to eino Document, default defaultResultParser. + DocumentConverter func(ctx context.Context, point *qdrant.ScoredPoint) (*schema.Document, error) } type Retriever struct { - client *qdrant.Client - collection string - embedding embedding.Embedder - scoreThreshold *float64 - topK int + client *qdrant.Client + collection string + embedding embedding.Embedder + scoreThreshold *float64 + topK int + documentConverter func(ctx context.Context, point *qdrant.ScoredPoint) (*schema.Document, error) } func NewRetriever(ctx context.Context, config *Config) (*Retriever, error) { @@ -69,12 +75,21 @@ func NewRetriever(ctx context.Context, config *Config) (*Retriever, error) { topK = 5 } + if len(config.ReturnFields) == 0 { + config.ReturnFields = []string{defaultMetadataKey, defaultContentKey} + } + + if config.DocumentConverter == nil { + config.DocumentConverter = defaultResultParser(config.ReturnFields) + } + return &Retriever{ - client: config.Client, - collection: config.Collection, - embedding: config.Embedding, - scoreThreshold: config.ScoreThreshold, - topK: topK, + client: config.Client, + collection: config.Collection, + embedding: config.Embedding, + scoreThreshold: config.ScoreThreshold, + topK: topK, + documentConverter: config.DocumentConverter, }, nil } @@ -134,19 +149,13 @@ func (r *Retriever) Retrieve(ctx context.Context, query string, opts ...retrieve } docs = make([]*schema.Document, 0, len(resp)) for _, pt := range resp { - doc := &schema.Document{ - ID: pt.Id.GetUuid(), - MetaData: map[string]any{}, - } - - if val, ok := pt.Payload[defaultContentKey]; ok { - doc.Content = val.GetStringValue() + doc, err := r.documentConverter(ctx, pt) + if err != nil { + return nil, err } - - if val, ok := pt.Payload[defaultMetadataKey]; ok { - doc.MetaData[defaultMetadataKey] = val.GetStructValue().Fields + if doc == nil { + return nil, fmt.Errorf("[qdrant retriever] document converter returned nil document") } - doc.WithScore(float64(pt.Score)) docs = append(docs, doc) @@ -171,6 +180,59 @@ func (r *Retriever) makeEmbeddingCtx(ctx context.Context, emb embedding.Embedder return callbacks.ReuseHandlers(ctx, runInfo) } +func defaultResultParser(returnFields []string) func(ctx context.Context, point *qdrant.ScoredPoint) (*schema.Document, error) { + return func(ctx context.Context, point *qdrant.ScoredPoint) (*schema.Document, error) { + if point == nil { + return nil, fmt.Errorf("[defaultResultParser] point is nil") + } + + resp := &schema.Document{ + Content: "", + MetaData: map[string]any{}, + } + + if point.Id != nil { + if uuid := point.Id.GetUuid(); uuid != "" { + resp.ID = uuid + } else { + resp.ID = fmt.Sprintf("%d", point.Id.GetNum()) + } + } + + for _, field := range returnFields { + val, found := point.Payload[field] + if !found { + return nil, fmt.Errorf("[defaultResultParser] field=%s not found in payload, point=%v", field, point) + } + + if field == defaultContentKey { + resp.Content = val.GetStringValue() + } else if field == defaultMetadataKey { + resp.MetaData[defaultMetadataKey] = val.GetStructValue().Fields + } else { + switch val.GetKind().(type) { + case *qdrant.Value_NullValue: + resp.MetaData[field] = val.GetNullValue() + case *qdrant.Value_DoubleValue: + resp.MetaData[field] = val.GetDoubleValue() + case *qdrant.Value_IntegerValue: + resp.MetaData[field] = val.GetIntegerValue() + case *qdrant.Value_StringValue: + resp.MetaData[field] = val.GetStringValue() + case *qdrant.Value_BoolValue: + resp.MetaData[field] = val.GetBoolValue() + case *qdrant.Value_StructValue: + resp.MetaData[field] = val.GetStructValue().Fields + case *qdrant.Value_ListValue: + resp.MetaData[field] = val.GetListValue() + } + } + } + + return resp, nil + } +} + func tryMarshalJsonString(input any) string { if input == nil { return ""