Skip to content

Commit 22a27a1

Browse files
authored
[fel] update llama index tools meta (#244)
* [fit]Update plugins metadata registration * 修改检视意见 * 补充修改检视意见
1 parent 7cd31de commit 22a27a1

File tree

10 files changed

+1365
-799
lines changed

10 files changed

+1365
-799
lines changed

framework/fel/python/plugins/fel_llama_index_tools/callable_registers.py

Lines changed: 0 additions & 29 deletions
This file was deleted.

framework/fel/python/plugins/fel_llama_index_tools/llama_rag_basic_toolkit.py

Lines changed: 50 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -6,24 +6,25 @@
66
import os
77
import traceback
88
from enum import Enum, unique
9-
from typing import List, Callable, Any, Tuple
9+
from typing import List
1010

11-
from fitframework import fit_logger
12-
from fitframework.core.repo.fitable_register import register_fitable
11+
from fitframework import fit_logger, fitable
1312
from llama_index.core.base.base_selector import SingleSelection
1413
from llama_index.core.postprocessor import SimilarityPostprocessor, SentenceEmbeddingOptimizer, LLMRerank, \
1514
LongContextReorder, FixedRecencyPostprocessor
1615
from llama_index.core.postprocessor.types import BaseNodePostprocessor
1716
from llama_index.core.prompts import PromptType, PromptTemplate
18-
from llama_index.core.prompts.default_prompts import DEFAULT_CHOICE_SELECT_PROMPT_TMPL
1917
from llama_index.core.selectors import LLMSingleSelector, LLMMultiSelector
2018
from llama_index.core.selectors.prompts import DEFAULT_SINGLE_SELECT_PROMPT_TMPL, DEFAULT_MULTI_SELECT_PROMPT_TMPL
2119
from llama_index.embeddings.openai import OpenAIEmbedding
2220
from llama_index.llms.openai import OpenAI
2321

24-
from .callable_registers import register_callable_tool
25-
from .node_utils import document_to_query_node, query_node_to_document
2622
from .types.document import Document
23+
from .types.llm_rerank_options import LLMRerankOptions
24+
from .types.embedding_options import EmbeddingOptions
25+
from .types.retriever_options import RetrieverOptions
26+
from .types.llm_choice_selector_options import LLMChoiceSelectorOptions
27+
from .node_utils import document_to_query_node, query_node_to_document
2728

2829
os.environ["no_proxy"] = "*"
2930

@@ -42,49 +43,50 @@ def __invoke_postprocessor(postprocessor: BaseNodePostprocessor, nodes: List[Doc
4243
return nodes
4344

4445

45-
def similarity_filter(nodes: List[Document], query_str: str, **kwargs) -> List[Document]:
46+
@fitable("llama.tools.similarity_filter", "default")
47+
def similarity_filter(nodes: List[Document], query_str: str, options: RetrieverOptions) -> List[Document]:
4648
"""Remove documents that are below a similarity score threshold."""
47-
similarity_cutoff = float(kwargs.get("similarity_cutoff") or 0.3)
48-
postprocessor = SimilarityPostprocessor(similarity_cutoff=similarity_cutoff)
49+
if options is None:
50+
options = RetrieverOptions()
51+
postprocessor = SimilarityPostprocessor(similarity_cutoff=options.similarity_cutoff)
4952
return __invoke_postprocessor(postprocessor, nodes, query_str)
5053

5154

52-
def sentence_embedding_optimizer(nodes: List[Document], query_str: str, **kwargs) -> List[Document]:
55+
@fitable("llama.tools.sentence_embedding_optimizer", "default")
56+
def sentence_embedding_optimizer(nodes: List[Document], query_str: str, options: EmbeddingOptions) -> List[Document]:
5357
"""Optimization of a text chunk given the query by shortening the input text."""
54-
api_key = kwargs.get("api_key") or "EMPTY"
55-
model_name = kwargs.get("model_name") or "bce-embedding-base_v1"
56-
api_base = kwargs.get("api_base") or ("http://51.36.139.24:8010/v1" if api_key == "EMPTY" else None)
57-
percentile_cutoff = kwargs.get("percentile_cutoff")
58-
threshold_cutoff = kwargs.get("threshold_cutoff")
59-
percentile_cutoff = percentile_cutoff if percentile_cutoff is None else float(percentile_cutoff)
60-
threshold_cutoff = threshold_cutoff if threshold_cutoff is None else float(threshold_cutoff)
61-
62-
embed_model = OpenAIEmbedding(model_name=model_name, api_base=api_base, api_key=api_key)
63-
optimizer = SentenceEmbeddingOptimizer(embed_model=embed_model, percentile_cutoff=percentile_cutoff,
64-
threshold_cutoff=threshold_cutoff)
58+
if options is None:
59+
options = EmbeddingOptions()
60+
api_base = options.api_base
61+
embed_model = OpenAIEmbedding(model_name=options.model_name, api_base=api_base, api_key=options.api_key)
62+
optimizer = SentenceEmbeddingOptimizer(embed_model=embed_model, percentile_cutoff=options.percentile_cutoff,
63+
threshold_cutoff=options.threshold_cutoff)
6564
return __invoke_postprocessor(optimizer, nodes, query_str)
6665

6766

68-
def llm_rerank(nodes: List[Document], query_str: str, **kwargs) -> List[Document]:
67+
@fitable("llama.tools.llm_rerank", "default")
68+
def llm_rerank(nodes: List[Document], query_str: str, options: LLMRerankOptions) -> List[Document]:
6969
"""
7070
Re-order nodes by asking the LLM to return the relevant documents and a score of how relevant they are.
7171
Returns the top N ranked nodes.
7272
"""
73-
api_key = kwargs.get("api_key") or "EMPTY"
74-
model_name = kwargs.get("model_name") or "Qwen1.5-14B-Chat"
75-
api_base = kwargs.get("api_base") or ("http://80.11.128.62:8000/v1" if api_key == "EMPTY" else None)
76-
prompt = kwargs.get("prompt") or DEFAULT_CHOICE_SELECT_PROMPT_TMPL
77-
choice_batch_size = int(kwargs.get("choice_batch_size") or 10)
78-
top_n = int(kwargs.get("top_n") or 10)
79-
80-
llm = OpenAI(model=model_name, api_base=api_base, api_key=api_key, max_tokens=4096)
73+
if options is None:
74+
options = LLMRerankOptions()
75+
76+
api_base = options.api_base
77+
78+
prompt = options.prompt
79+
80+
llm = OpenAI(model=options.model_name, api_base=api_base, api_key=options.api_key)
8181
choice_select_prompt = PromptTemplate(prompt, prompt_type=PromptType.CHOICE_SELECT)
82-
llm_rerank_obj = LLMRerank(llm=llm, choice_select_prompt=choice_select_prompt, choice_batch_size=choice_batch_size,
83-
top_n=top_n)
82+
llm_rerank_obj = LLMRerank(llm=llm, choice_select_prompt=choice_select_prompt,
83+
choice_batch_size=options.choice_batch_size,
84+
top_n=options.top_n)
8485
return __invoke_postprocessor(llm_rerank_obj, nodes, query_str)
8586

8687

87-
def long_context_rerank(nodes: List[Document], query_str: str, **kwargs) -> List[Document]:
88+
@fitable("llama.tools.long_context_rerank", "default")
89+
def long_context_rerank(nodes: List[Document], query_str: str) -> List[Document]:
8890
"""Re-order the retrieved nodes, which can be helpful in cases where a large top-k is needed."""
8991
return __invoke_postprocessor(LongContextReorder(), nodes, query_str)
9092

@@ -95,24 +97,23 @@ class SelectorMode(Enum):
9597
MULTI = "multi"
9698

9799

98-
def llm_choice_selector(choice: List[str], query_str: str, **kwargs) -> List[SingleSelection]:
100+
@fitable("llama.tools.llm_choice_selector", "default")
101+
def llm_choice_selector(choice: List[str], query_str: str, options: LLMChoiceSelectorOptions) -> List[SingleSelection]:
99102
"""LLM-based selector that chooses one or multiple out of many options."""
100103
if len(choice) == 0:
101104
return []
102-
api_key = kwargs.get("api_key") or "EMPTY"
103-
model_name = kwargs.get("model_name") or "Qwen1.5-14B-Chat"
104-
api_base = kwargs.get("api_base") or ("http://80.11.128.62:8000/v1" if api_key == "EMPTY" else None)
105-
prompt = kwargs.get("prompt")
106-
mode = str(kwargs.get("mode") or SelectorMode.SINGLE.value)
107-
if mode.lower() not in [m.value for m in SelectorMode]:
108-
raise ValueError(f"Invalid mode {mode}.")
109-
110-
llm = OpenAI(model=model_name, api_base=api_base, api_key=api_key, max_tokens=4096)
111-
if mode.lower() == SelectorMode.SINGLE.value:
112-
selector_prompt = prompt or DEFAULT_SINGLE_SELECT_PROMPT_TMPL
105+
if options is None:
106+
options = LLMChoiceSelectorOptions()
107+
api_base = options.api_base
108+
if options.mode.lower() not in [m.value for m in SelectorMode]:
109+
raise ValueError(f"Invalid mode {options.mode}.")
110+
111+
llm = OpenAI(model=options.model_name, api_base=api_base, api_key=options.api_key, max_tokens=4096)
112+
if options.mode.lower() == SelectorMode.SINGLE.value:
113+
selector_prompt = options.prompt or DEFAULT_SINGLE_SELECT_PROMPT_TMPL
113114
selector = LLMSingleSelector.from_defaults(llm=llm, prompt_template_str=selector_prompt)
114115
else:
115-
multi_selector_prompt = prompt or DEFAULT_MULTI_SELECT_PROMPT_TMPL
116+
multi_selector_prompt = options.prompt or DEFAULT_MULTI_SELECT_PROMPT_TMPL
116117
selector = LLMMultiSelector.from_defaults(llm=llm, prompt_template_str=multi_selector_prompt)
117118
try:
118119
return selector.select(choice, query_str).selections
@@ -122,34 +123,10 @@ def llm_choice_selector(choice: List[str], query_str: str, **kwargs) -> List[Sin
122123
return []
123124

124125

125-
def fixed_recency(nodes: List[Document], tok_k: int, date_key: str, query_str: str, **kwargs) -> List[Document]:
126+
@fitable("llama.tools.fixed_recency", "default")
127+
def fixed_recency(nodes: List[Document], top_k: int, date_key: str, query_str: str) -> List[Document]:
126128
"""This postprocessor returns the top K nodes sorted by date"""
127129
postprocessor = FixedRecencyPostprocessor(
128-
tok_k=tok_k, date_key=date_key if date_key else "date"
130+
top_k=top_k, date_key=date_key if date_key else "date"
129131
)
130132
return __invoke_postprocessor(postprocessor, nodes, query_str)
131-
132-
133-
# Tuple 结构: (tool_func, config_args, return_description)
134-
rag_basic_toolkit: List[Tuple[Callable[..., Any], List[str], str]] = [
135-
(similarity_filter, ["similarity_cutoff"], "The filtered documents."),
136-
(sentence_embedding_optimizer, ["model_name", "api_key", "api_base", "percentile_cutoff", "threshold_cutoff"],
137-
"The optimized documents."),
138-
(llm_rerank, ["model_name", "api_key", "api_base", "prompt", "choice_batch_size", "top_n"],
139-
"The re-ordered documents."),
140-
(long_context_rerank, [], "The re-ordered documents."),
141-
(llm_choice_selector, ["model_name", "api_key", "api_base", "prompt", "mode"], "The selected choice."),
142-
(fixed_recency, ["nodes", "tok_k", "date_key", "query_str"], "The fixed recency postprocessor")
143-
]
144-
145-
146-
for tool in rag_basic_toolkit:
147-
register_callable_tool(tool, llm_choice_selector.__module__, "llama_index.rag.toolkit")
148-
149-
150-
if __name__ == '__main__':
151-
import time
152-
from .llama_schema_helper import dump_llama_schema
153-
154-
current_timestamp = time.strftime('%Y%m%d%H%M%S')
155-
dump_llama_schema(rag_basic_toolkit, f"./llama_tool_schema-{str(current_timestamp)}.json")

0 commit comments

Comments
 (0)