diff --git a/pipelines/examples/agents/react_example.py b/pipelines/examples/agents/react_example.py index 4496ad66eda6..75279eb5eb5a 100644 --- a/pipelines/examples/agents/react_example.py +++ b/pipelines/examples/agents/react_example.py @@ -82,7 +82,7 @@ # yapf: disable parser = argparse.ArgumentParser() -parser.add_argument("--search_api_key", default=None, type=str, help="The Serper.dev or SerpAPI key.") +parser.add_argument("--search_api_key", default=None, type=str, help="The Serper.dev, SerpAPI or SearchApi.io key.") parser.add_argument('--llm_name', choices=['THUDM/chatglm-6b', "THUDM/chatglm-6b-v1.1", "gpt-3.5-turbo", "gpt-4"], default="THUDM/chatglm-6b-v1.1", help="The chatbot models ") parser.add_argument("--api_key", default=None, type=str, help="The API Key.") args = parser.parse_args() diff --git a/pipelines/examples/agents/react_example_cn.py b/pipelines/examples/agents/react_example_cn.py index 967381e0e104..5d249d010c2a 100644 --- a/pipelines/examples/agents/react_example_cn.py +++ b/pipelines/examples/agents/react_example_cn.py @@ -60,7 +60,7 @@ parser.add_argument('--device', choices=['cpu', 'gpu'], default="gpu", help="Select which device to run dense_qa system, defaults to gpu.") parser.add_argument("--index_name", default='dureader_index', type=str, help="The ann index name of ANN.") parser.add_argument("--search_engine", choices=['faiss', 'milvus'], default="faiss", help="The type of ANN search engine.") -parser.add_argument("--retriever", choices=['dense', 'SerperDev', 'SerpAPI'], default="dense", help="The type of Retriever.") +parser.add_argument("--retriever", choices=['dense', 'SerperDev', 'SerpAPI', 'SearchApi'], default="dense", help="The type of Retriever.") parser.add_argument("--max_seq_len_query", default=64, type=int, help="The maximum total length of query after tokenization.") parser.add_argument("--max_seq_len_passage", default=256, type=int, help="The maximum total length of passage after tokenization.") parser.add_argument("--retriever_batch_size", default=16, type=int, help="The batch size of retriever to extract passage embedding for building ANN index.") @@ -68,7 +68,7 @@ parser.add_argument("--passage_embedding_model", default="rocketqa-zh-base-query-encoder", type=str, help="The passage_embedding_model path") parser.add_argument("--params_path", default="checkpoints/model_40/model_state.pdparams", type=str, help="The checkpoint path") parser.add_argument("--embedding_dim", default=768, type=int, help="The embedding_dim of index") -parser.add_argument("--search_api_key", default=None, type=str, help="The Serper.dev or SerpAPI key.") +parser.add_argument("--search_api_key", default=None, type=str, help="The Serper.dev, SerpAPI or SearchApi.io key.") parser.add_argument('--embed_title', default=False, type=bool, help="The title to be embedded into embedding") parser.add_argument('--model_type', choices=['ernie_search', 'ernie', 'bert', 'neural_search'], default="ernie", help="the ernie model types") parser.add_argument('--llm_name', choices=['ernie-bot', 'THUDM/chatglm-6b', "gpt-3.5-turbo", "gpt-4"], default="THUDM/chatglm-6b", help="The chatbot models ") diff --git a/pipelines/pipelines/nodes/search_engine/providers.py b/pipelines/pipelines/nodes/search_engine/providers.py index 9e8833968bbe..2f2405bc5c8b 100644 --- a/pipelines/pipelines/nodes/search_engine/providers.py +++ b/pipelines/pipelines/nodes/search_engine/providers.py @@ -239,3 +239,110 @@ def search(self, query: str, **kwargs) -> List[Document]: logger.debug("Serper.dev API returned %s documents for the query '%s'", len(documents), query) result_docs = documents[:top_k] return self.score_results(result_docs, len(answer_box) > 0) + + +class SearchApi(SearchEngine): + """ + SearchApi is a real-time search engine that provides an API to access search results from Google, Google Scholar, YouTube, + YouTube transcripts and more. See the [SearchApi website](https://www.searchapi.io/) for more details. + """ + + def __init__( + self, + api_key: str, + top_k: Optional[int] = 10, + engine: Optional[str] = "google", + search_engine_kwargs: Optional[Dict[str, Any]] = None, + ): + """ + :param api_key: API key for SearchApi. + :param top_k: Number of results to return. + :param engine: Search engine to use, for example google, google_scholar, youtube, youtube_transcripts. + See the [SearchApi documentation](https://www.searchapi.io/docs/google) for the full list of supported engines. + :param search_engine_kwargs: Additional parameters passed to the SearchApi. + See the [SearchApi documentation](https://www.searchapi.io/docs/google) for the full list of supported parameters. + """ + super().__init__() + self.params_dict: Dict[str, Union[str, int, float]] = {} + self.api_key = api_key + self.kwargs = search_engine_kwargs if search_engine_kwargs else {} + self.engine = engine + self.top_k = top_k + + def search(self, query: str, **kwargs) -> List[Document]: + """ + :param query: Query string. + :param kwargs: Additional parameters passed to the SearchApi. For example, you can set 'location' to 'New York,United States' + to localize search to the specific location. + :return: List[Document] + """ + kwargs = {**self.kwargs, **kwargs} + top_k = kwargs.pop("top_k", self.top_k) + url = "https://www.searchapi.io/api/v1/search" + + params = {"q": query, **kwargs} + headers = {"Authorization": f"Bearer {self.api_key}", "X-SearchApi-Source": "PaddleNLP"} + + if self.engine: + params["engine"] = self.engine + response = requests.get(url, params=params, headers=headers, timeout=90) + + if response.status_code != 200: + raise Exception(f"Error while querying {self.__class__.__name__}: {response.text}") + + json_content = json.loads(response.text) + documents = [] + has_answer_box = False + + if json_content.get("answer_box"): + if json_content["answer_box"].get("organic_result"): + title = json_content["answer_box"].get("organic_result").get("title", "") + link = json_content["answer_box"].get("organic_result").get("link", "") + if json_content["answer_box"].get("type") == "population_graph": + title = json_content["answer_box"].get("place", "") + link = json_content["answer_box"].get("explore_more_link", "") + + title = json_content["answer_box"].get("title", "") + link = json_content["answer_box"].get("link") + content = json_content["answer_box"].get("answer") or json_content["answer_box"].get("snippet") + + if link and content: + has_answer_box = True + documents.append(Document.from_dict({"title": title, "content": content, "link": link})) + + if json_content.get("knowledge_graph"): + if json_content["knowledge_graph"].get("source"): + link = json_content["knowledge_graph"].get("source").get("link", "") + + link = json_content["knowledge_graph"].get("website", "") + content = json_content["knowledge_graph"].get("description") + + if link and content: + documents.append( + Document.from_dict( + {"title": json_content["knowledge_graph"].get("title", ""), "content": content, "link": link} + ) + ) + + documents += [ + Document.from_dict({"title": c["title"], "content": c.get("snippet", ""), "link": c["link"]}) + for c in json_content["organic_results"] + ] + + if json_content.get("related_questions"): + for question in json_content["related_questions"]: + if question.get("source"): + link = question.get("source").get("link", "") + else: + link = "" + + content = question.get("answer", "") + + if link and content: + documents.append( + Document.from_dict({"title": question.get("question", ""), "content": content, "link": link}) + ) + + logger.debug("SearchApi returned %s documents for the query '%s'", len(documents), query) + result_docs = documents[:top_k] + return self.score_results(result_docs, has_answer_box) diff --git a/pipelines/pipelines/nodes/search_engine/web.py b/pipelines/pipelines/nodes/search_engine/web.py index 573756f58527..b0c62df6fefb 100644 --- a/pipelines/pipelines/nodes/search_engine/web.py +++ b/pipelines/pipelines/nodes/search_engine/web.py @@ -28,6 +28,7 @@ class WebSearch(BaseComponent): WebSerach currently supports the following search engines providers (bridges): - SerperDev (default) + - SearchApi - SerpAPI - BingAPI