66import os
77import traceback
88from enum import Enum , unique
9- from typing import List , Callable , Any , Tuple
9+ from typing import List
1010
11- from fitframework import fit_logger
12- from fitframework .core .repo .fitable_register import register_fitable
11+ from fitframework import fit_logger , fitable
1312from llama_index .core .base .base_selector import SingleSelection
1413from llama_index .core .postprocessor import SimilarityPostprocessor , SentenceEmbeddingOptimizer , LLMRerank , \
1514 LongContextReorder , FixedRecencyPostprocessor
1615from llama_index .core .postprocessor .types import BaseNodePostprocessor
1716from llama_index .core .prompts import PromptType , PromptTemplate
18- from llama_index .core .prompts .default_prompts import DEFAULT_CHOICE_SELECT_PROMPT_TMPL
1917from llama_index .core .selectors import LLMSingleSelector , LLMMultiSelector
2018from llama_index .core .selectors .prompts import DEFAULT_SINGLE_SELECT_PROMPT_TMPL , DEFAULT_MULTI_SELECT_PROMPT_TMPL
2119from llama_index .embeddings .openai import OpenAIEmbedding
2220from llama_index .llms .openai import OpenAI
2321
24- from .callable_registers import register_callable_tool
25- from .node_utils import document_to_query_node , query_node_to_document
2622from .types .document import Document
23+ from .types .llm_rerank_options import LLMRerankOptions
24+ from .types .embedding_options import EmbeddingOptions
25+ from .types .retriever_options import RetrieverOptions
26+ from .types .llm_choice_selector_options import LLMChoiceSelectorOptions
27+ from .node_utils import document_to_query_node , query_node_to_document
2728
2829os .environ ["no_proxy" ] = "*"
2930
@@ -42,49 +43,50 @@ def __invoke_postprocessor(postprocessor: BaseNodePostprocessor, nodes: List[Doc
4243 return nodes
4344
4445
45- def similarity_filter (nodes : List [Document ], query_str : str , ** kwargs ) -> List [Document ]:
46+ @fitable ("llama.tools.similarity_filter" , "default" )
47+ def similarity_filter (nodes : List [Document ], query_str : str , options : RetrieverOptions ) -> List [Document ]:
4648 """Remove documents that are below a similarity score threshold."""
47- similarity_cutoff = float (kwargs .get ("similarity_cutoff" ) or 0.3 )
48- postprocessor = SimilarityPostprocessor (similarity_cutoff = similarity_cutoff )
49+ if options is None :
50+ options = RetrieverOptions ()
51+ postprocessor = SimilarityPostprocessor (similarity_cutoff = options .similarity_cutoff )
4952 return __invoke_postprocessor (postprocessor , nodes , query_str )
5053
5154
52- def sentence_embedding_optimizer (nodes : List [Document ], query_str : str , ** kwargs ) -> List [Document ]:
55+ @fitable ("llama.tools.sentence_embedding_optimizer" , "default" )
56+ def sentence_embedding_optimizer (nodes : List [Document ], query_str : str , options : EmbeddingOptions ) -> List [Document ]:
5357 """Optimization of a text chunk given the query by shortening the input text."""
54- api_key = kwargs .get ("api_key" ) or "EMPTY"
55- model_name = kwargs .get ("model_name" ) or "bce-embedding-base_v1"
56- api_base = kwargs .get ("api_base" ) or ("http://51.36.139.24:8010/v1" if api_key == "EMPTY" else None )
57- percentile_cutoff = kwargs .get ("percentile_cutoff" )
58- threshold_cutoff = kwargs .get ("threshold_cutoff" )
59- percentile_cutoff = percentile_cutoff if percentile_cutoff is None else float (percentile_cutoff )
60- threshold_cutoff = threshold_cutoff if threshold_cutoff is None else float (threshold_cutoff )
61-
62- embed_model = OpenAIEmbedding (model_name = model_name , api_base = api_base , api_key = api_key )
63- optimizer = SentenceEmbeddingOptimizer (embed_model = embed_model , percentile_cutoff = percentile_cutoff ,
64- threshold_cutoff = threshold_cutoff )
58+ if options is None :
59+ options = EmbeddingOptions ()
60+ api_base = options .api_base
61+ embed_model = OpenAIEmbedding (model_name = options .model_name , api_base = api_base , api_key = options .api_key )
62+ optimizer = SentenceEmbeddingOptimizer (embed_model = embed_model , percentile_cutoff = options .percentile_cutoff ,
63+ threshold_cutoff = options .threshold_cutoff )
6564 return __invoke_postprocessor (optimizer , nodes , query_str )
6665
6766
68- def llm_rerank (nodes : List [Document ], query_str : str , ** kwargs ) -> List [Document ]:
67+ @fitable ("llama.tools.llm_rerank" , "default" )
68+ def llm_rerank (nodes : List [Document ], query_str : str , options : LLMRerankOptions ) -> List [Document ]:
6969 """
7070 Re-order nodes by asking the LLM to return the relevant documents and a score of how relevant they are.
7171 Returns the top N ranked nodes.
7272 """
73- api_key = kwargs . get ( "api_key" ) or "EMPTY"
74- model_name = kwargs . get ( "model_name" ) or "Qwen1.5-14B-Chat"
75- api_base = kwargs . get ( "api_base" ) or ( "http://80.11.128.62:8000/v1" if api_key == "EMPTY" else None )
76- prompt = kwargs . get ( "prompt" ) or DEFAULT_CHOICE_SELECT_PROMPT_TMPL
77- choice_batch_size = int ( kwargs . get ( "choice_batch_size" ) or 10 )
78- top_n = int ( kwargs . get ( "top_n" ) or 10 )
79-
80- llm = OpenAI (model = model_name , api_base = api_base , api_key = api_key , max_tokens = 4096 )
73+ if options is None :
74+ options = LLMRerankOptions ()
75+
76+ api_base = options . api_base
77+
78+ prompt = options . prompt
79+
80+ llm = OpenAI (model = options . model_name , api_base = api_base , api_key = options . api_key )
8181 choice_select_prompt = PromptTemplate (prompt , prompt_type = PromptType .CHOICE_SELECT )
82- llm_rerank_obj = LLMRerank (llm = llm , choice_select_prompt = choice_select_prompt , choice_batch_size = choice_batch_size ,
83- top_n = top_n )
82+ llm_rerank_obj = LLMRerank (llm = llm , choice_select_prompt = choice_select_prompt ,
83+ choice_batch_size = options .choice_batch_size ,
84+ top_n = options .top_n )
8485 return __invoke_postprocessor (llm_rerank_obj , nodes , query_str )
8586
8687
87- def long_context_rerank (nodes : List [Document ], query_str : str , ** kwargs ) -> List [Document ]:
88+ @fitable ("llama.tools.long_context_rerank" , "default" )
89+ def long_context_rerank (nodes : List [Document ], query_str : str ) -> List [Document ]:
8890 """Re-order the retrieved nodes, which can be helpful in cases where a large top-k is needed."""
8991 return __invoke_postprocessor (LongContextReorder (), nodes , query_str )
9092
@@ -95,24 +97,23 @@ class SelectorMode(Enum):
9597 MULTI = "multi"
9698
9799
98- def llm_choice_selector (choice : List [str ], query_str : str , ** kwargs ) -> List [SingleSelection ]:
100+ @fitable ("llama.tools.llm_choice_selector" , "default" )
101+ def llm_choice_selector (choice : List [str ], query_str : str , options : LLMChoiceSelectorOptions ) -> List [SingleSelection ]:
99102 """LLM-based selector that chooses one or multiple out of many options."""
100103 if len (choice ) == 0 :
101104 return []
102- api_key = kwargs .get ("api_key" ) or "EMPTY"
103- model_name = kwargs .get ("model_name" ) or "Qwen1.5-14B-Chat"
104- api_base = kwargs .get ("api_base" ) or ("http://80.11.128.62:8000/v1" if api_key == "EMPTY" else None )
105- prompt = kwargs .get ("prompt" )
106- mode = str (kwargs .get ("mode" ) or SelectorMode .SINGLE .value )
107- if mode .lower () not in [m .value for m in SelectorMode ]:
108- raise ValueError (f"Invalid mode { mode } ." )
109-
110- llm = OpenAI (model = model_name , api_base = api_base , api_key = api_key , max_tokens = 4096 )
111- if mode .lower () == SelectorMode .SINGLE .value :
112- selector_prompt = prompt or DEFAULT_SINGLE_SELECT_PROMPT_TMPL
105+ if options is None :
106+ options = LLMChoiceSelectorOptions ()
107+ api_base = options .api_base
108+ if options .mode .lower () not in [m .value for m in SelectorMode ]:
109+ raise ValueError (f"Invalid mode { options .mode } ." )
110+
111+ llm = OpenAI (model = options .model_name , api_base = api_base , api_key = options .api_key , max_tokens = 4096 )
112+ if options .mode .lower () == SelectorMode .SINGLE .value :
113+ selector_prompt = options .prompt or DEFAULT_SINGLE_SELECT_PROMPT_TMPL
113114 selector = LLMSingleSelector .from_defaults (llm = llm , prompt_template_str = selector_prompt )
114115 else :
115- multi_selector_prompt = prompt or DEFAULT_MULTI_SELECT_PROMPT_TMPL
116+ multi_selector_prompt = options . prompt or DEFAULT_MULTI_SELECT_PROMPT_TMPL
116117 selector = LLMMultiSelector .from_defaults (llm = llm , prompt_template_str = multi_selector_prompt )
117118 try :
118119 return selector .select (choice , query_str ).selections
@@ -122,34 +123,10 @@ def llm_choice_selector(choice: List[str], query_str: str, **kwargs) -> List[Sin
122123 return []
123124
124125
125- def fixed_recency (nodes : List [Document ], tok_k : int , date_key : str , query_str : str , ** kwargs ) -> List [Document ]:
126+ @fitable ("llama.tools.fixed_recency" , "default" )
127+ def fixed_recency (nodes : List [Document ], top_k : int , date_key : str , query_str : str ) -> List [Document ]:
126128 """This postprocessor returns the top K nodes sorted by date"""
127129 postprocessor = FixedRecencyPostprocessor (
128- tok_k = tok_k , date_key = date_key if date_key else "date"
130+ top_k = top_k , date_key = date_key if date_key else "date"
129131 )
130132 return __invoke_postprocessor (postprocessor , nodes , query_str )
131-
132-
133- # Tuple 结构: (tool_func, config_args, return_description)
134- rag_basic_toolkit : List [Tuple [Callable [..., Any ], List [str ], str ]] = [
135- (similarity_filter , ["similarity_cutoff" ], "The filtered documents." ),
136- (sentence_embedding_optimizer , ["model_name" , "api_key" , "api_base" , "percentile_cutoff" , "threshold_cutoff" ],
137- "The optimized documents." ),
138- (llm_rerank , ["model_name" , "api_key" , "api_base" , "prompt" , "choice_batch_size" , "top_n" ],
139- "The re-ordered documents." ),
140- (long_context_rerank , [], "The re-ordered documents." ),
141- (llm_choice_selector , ["model_name" , "api_key" , "api_base" , "prompt" , "mode" ], "The selected choice." ),
142- (fixed_recency , ["nodes" , "tok_k" , "date_key" , "query_str" ], "The fixed recency postprocessor" )
143- ]
144-
145-
146- for tool in rag_basic_toolkit :
147- register_callable_tool (tool , llm_choice_selector .__module__ , "llama_index.rag.toolkit" )
148-
149-
150- if __name__ == '__main__' :
151- import time
152- from .llama_schema_helper import dump_llama_schema
153-
154- current_timestamp = time .strftime ('%Y%m%d%H%M%S' )
155- dump_llama_schema (rag_basic_toolkit , f"./llama_tool_schema-{ str (current_timestamp )} .json" )
0 commit comments