Skip to content

Commit 514d752

Browse files
authored
Prepare PyPI release 0.1.0 (#11)
* Update PyGaggle version and clean up * Update docs
1 parent a6b1180 commit 514d752

File tree

8 files changed

+49
-52
lines changed

8 files changed

+49
-52
lines changed

chatty_goose/pipeline/retrieval_pipeline.py

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -15,63 +15,63 @@ class RetrievalPipeline:
1515
1616
Parameters:
1717
searcher (SimpleSearcher): Pyserini searcher for Lucene index
18-
retrievers (List[CQR]): List of CQR retrievers to use for first-stage retrieval
18+
reformulators (List[CQR]): List of CQR methods to use for first-stage retrieval
1919
searcher_num_hits (int): number of hits returned by searcher - default 10
2020
early_fusion (bool): flag to perform fusion before second-stage retrieval - default True
2121
reranker (Reranker): optional reranker for second-stage retrieval
2222
reranker_query_index (int): retriever index to use for reranking query - defaults to last retriever
23-
reranker_query_retriever (CQR): retriever for generating reranker query,
24-
overrides reranker_query_index if provided
23+
reranker_query_reformulator (CQR): CQR method for generating reranker query,
24+
overrides reranker_query_index if provided
2525
"""
2626

2727
def __init__(
2828
self,
2929
searcher: SimpleSearcher,
30-
retrievers: List[CQR],
30+
reformulators: List[CQR],
3131
searcher_num_hits: int = 10,
3232
early_fusion: bool = True,
3333
reranker: Reranker = None,
3434
reranker_query_index: int = -1,
35-
reranker_query_retriever: CQR = None,
35+
reranker_query_reformulator: CQR = None,
3636
):
3737
self.searcher = searcher
38-
self.retrievers = retrievers
38+
self.reformulators = reformulators
3939
self.searcher_num_hits = int(searcher_num_hits)
4040
self.early_fusion = early_fusion
4141
self.reranker = reranker
4242
self.reranker_query_index = reranker_query_index
43-
self.reranker_query_retriever = reranker_query_retriever
43+
self.reranker_query_reformulator = reranker_query_reformulator
4444

4545
def retrieve(self, query) -> List[JSimpleSearcherResult]:
46-
retriever_hits = []
47-
retriever_queries = []
48-
for retriever in self.retrievers:
49-
new_query = retriever.rewrite(query)
46+
cqr_hits = []
47+
cqr_queries = []
48+
for cqr in self.reformulators:
49+
new_query = cqr.rewrite(query)
5050
hits = self.searcher.search(new_query, k=self.searcher_num_hits)
51-
retriever_hits.append(hits)
52-
retriever_queries.append(new_query)
51+
cqr_hits.append(hits)
52+
cqr_queries.append(new_query)
5353

54-
# Merge results from multiple retrievers if required
54+
# Merge results from multiple CQR methods if required
5555
if self.early_fusion or self.reranker is None:
56-
retriever_hits = reciprocal_rank_fusion(retriever_hits)
56+
cqr_hits = reciprocal_rank_fusion(cqr_hits)
5757

5858
# Return results if no reranker
5959
if self.reranker is None:
60-
return retriever_hits
60+
return cqr_hits
6161

6262
# Get query for reranker
63-
if self.reranker_query_retriever is None:
64-
rerank_query = retriever_queries[self.reranker_query_index]
63+
if self.reranker_query_reformulator is None:
64+
rerank_query = cqr_queries[self.reranker_query_index]
6565
else:
66-
rerank_query = self.reranker_query_retriever.rewrite(query)
66+
rerank_query = self.reranker_query_reformulator.rewrite(query)
6767

6868
# Rerank results
6969
if self.early_fusion:
70-
results = self.rerank(rerank_query, retriever_hits[:self.searcher_num_hits])
70+
results = self.rerank(rerank_query, cqr_hits[:self.searcher_num_hits])
7171
else:
72-
# Rerank all retriever results and fuse together
72+
# Rerank all CQR results and fuse together
7373
results = []
74-
for hits in retriever_hits:
74+
for hits in cqr_hits:
7575
results = self.rerank(rerank_query, hits)
7676
results = reciprocal_rank_fusion(results)
7777
return results
@@ -91,8 +91,8 @@ def rerank(self, query, hits):
9191
return reranked_hits
9292

9393
def reset_history(self):
94-
for retriever in self.retrievers:
95-
retriever.reset_history()
94+
for cqr in self.reformulators:
95+
cqr.reset_history()
9696

97-
if self.reranker_query_retriever:
98-
self.reranker_query_retriever.reset_history()
97+
if self.reranker_query_reformulator:
98+
self.reranker_query_reformulator.reset_history()

chatty_goose/settings.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
class SearcherSettings(BaseSettings):
99
"""Settings for Anserini searcher"""
1010

11-
index_path: str # Lucene index path
11+
index_path: str # Pre-built index name or path to Lucene index
1212
k1: float = 0.82 # BM25 k parameter
1313
b: float = 0.68 # BM25 b parameter
1414
rm3: bool = False # use RM3

chatty_goose/util.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import logging
2+
from os import path
23
from typing import Dict, List, Tuple
34

45
from pygaggle.rerank.transformer import MonoBERT
@@ -49,7 +50,10 @@ def build_bert_reranker(
4950

5051

5152
def build_searcher(settings: SearcherSettings) -> SimpleSearcher:
52-
searcher = SimpleSearcher(settings.index_path)
53+
if path.isdir(settings.index_path):
54+
searcher = SimpleSearcher(settings.index_path)
55+
else:
56+
searcher = SimpleSearcher.from_prebuilt_index(settings.index_path)
5357
searcher.set_bm25(float(settings.k1), float(settings.b))
5458
logging.info(
5559
"Initializing BM25, setting k1={} and b={}".format(settings.k1, settings.b)

docs/cqr_experiments.md

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,31 +2,24 @@
22

33
## Data Preparation
44

5-
1. Download the pre-built CAsT 2019 index using Pyserini. This will download the entire index to `~/.cache/pyserini`.
5+
1. Download either the [training](https://github.com/daltonj/treccastweb/blob/master/2019/data/training/train_topics_v1.0.json) and [evaluation](https://github.com/daltonj/treccastweb/blob/master/2019/data/evaluation/evaluation_topics_v1.0.json) input query JSON files. These files can be found under `data/treccastweb/2019/data` if you cloned the submodules for this repo.
66

7-
```
8-
from pyserini.search import SimpleSearcher
9-
SimpleSearcher.from_prebuilt_index('cast2019')
10-
```
11-
12-
2. Download either the [training](https://github.com/daltonj/treccastweb/blob/master/2019/data/training/train_topics_v1.0.json) and [evaluation](https://github.com/daltonj/treccastweb/blob/master/2019/data/evaluation/evaluation_topics_v1.0.json) input query JSON files. These files can be found under `data/treccastweb/2019/data` if you cloned the submodules for this repo.
13-
14-
3. Download the evaluation answer files for [training](https://github.com/daltonj/treccastweb/blob/master/2019/data/training/train_topics_mod.qrel) or [evaluation](https://trec.nist.gov/data/cast/2019qrels.txt). The training answer file is found under `data/treccastweb/2019/data`.
7+
2. Download the evaluation answer files for [training](https://github.com/daltonj/treccastweb/blob/master/2019/data/training/train_topics_mod.qrel) or [evaluation](https://trec.nist.gov/data/cast/2019qrels.txt). The training answer file is found under `data/treccastweb/2019/data`.
158

169
## Run CQR retrieval
1710

18-
The following command is for HQE, but you can also run other CQR methods using `t5` or `fusion` instead of `hqe` as the input to the `--experiment` flag.
11+
The following command is for HQE, but you can also run other CQR methods using `t5` or `fusion` instead of `hqe` as the input to the `--experiment` flag. Running the command for the first time will download the CAsT 2019 index (or whatever index is specified for the `--index` flag). It is also possible to supply a path to a local directory containing the index.
1912

2013
```shell=bash
2114
python -m experiments.run_retrieval \
2215
--experiment hqe \
2316
--hits 1000 \
24-
--index $anserini_index_path \
17+
--index cast2019 \
2518
--qid_queries $input_query_json \
2619
--output ./output/hqe_bm25 \
2720
```
2821

29-
Running the experiment will output the retrieval results at the specified location in TSV format. By default, this will perform retrieval using only BM25, but you can add the `--rerank` flag to further rerank these results using BERT. For other command line arguments, see [run_retrieval.py](experiments/run_retrieval.py).
22+
The experiment will output the retrieval results at the specified location in TSV format. By default, this will perform retrieval using only BM25, but you can add the `--rerank` flag to further rerank these results using BERT. For other command line arguments, see [run_retrieval.py](experiments/run_retrieval.py).
3023

3124
## Evaluate CQR results
3225

examples/messenger/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ This guide is based on ParlAI's [chat service tutorial](https://parl.ai/docs/tut
99
3. Run the webhook server and Chatty Goose agent using our provided configuration. This assumes you have the ParlAI Python package installed and are inside the `chatty-goose` root repository folder.
1010

1111
```
12-
python3.7 -m parlai.chat_service.services.messenger.run --config-path examples/messenger/config.yml
12+
python -m parlai.chat_service.services.messenger.run --config-path examples/messenger/config.yml
1313
```
1414

1515
4. Add the webhook URL outputted from the above command as a callback URL for the Messenger App settings, and set the verify token to `Messenger4ParlAI`. For Heroku, this URL should look like `https://firstname-parlai-messenger-chatbot.herokuapp.com/webhook`.

experiments/run_retrieval.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -116,9 +116,9 @@ def run_experiment(rp: RetrievalPipeline):
116116
)
117117
searcher = build_searcher(searcher_settings)
118118

119-
# Initialize retrievers and reranker
120-
retrievers = []
121-
reranker_query_retriever = None
119+
# Initialize CQR and reranker
120+
reformulators = []
121+
reranker_query_reformulator = None
122122
reranker = build_bert_reranker(device=args.reranker_device) if args.rerank else None
123123

124124
if experiment == CQRType.HQE or experiment == CQRType.FUSION:
@@ -131,7 +131,7 @@ def run_experiment(rp: RetrievalPipeline):
131131
verbose=args.verbose,
132132
)
133133
hqe_bm25 = HQE(searcher, hqe_bm25_settings)
134-
retrievers.append(hqe_bm25)
134+
reformulators.append(hqe_bm25)
135135

136136
if experiment == CQRType.T5 or experiment == CQRType.FUSION:
137137
# Initialize T5
@@ -143,7 +143,7 @@ def run_experiment(rp: RetrievalPipeline):
143143
verbose=args.verbose,
144144
)
145145
t5 = T5_NTR(t5_settings, device=args.t5_device)
146-
retrievers.append(t5)
146+
reformulators.append(t5)
147147

148148
if experiment == CQRType.HQE:
149149
hqe_bert_settings = HQESettings(
@@ -153,14 +153,14 @@ def run_experiment(rp: RetrievalPipeline):
153153
R_sub=args.R1_sub,
154154
filter=PosFilter(args.filter),
155155
)
156-
reranker_query_retriever = HQE(searcher, hqe_bert_settings)
156+
reranker_query_reformulator = HQE(searcher, hqe_bert_settings)
157157

158158
rp = RetrievalPipeline(
159159
searcher,
160-
retrievers,
160+
reformulators,
161161
searcher_num_hits=args.hits,
162162
early_fusion=not args.late_fusion,
163163
reranker=reranker,
164-
reranker_query_retriever=reranker_query_retriever,
164+
reranker_query_reformulator=reranker_query_reformulator,
165165
)
166166
run_experiment(rp)

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
coloredlogs
22
parlai==1.1.0
33
pydantic>=1.5
4-
pygaggle==0.0.2
4+
pygaggle==0.0.3.1
55
spacy>=2.2.4

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
with open("requirements.txt") as f:
77
requirements = f.read().splitlines()
88

9-
excluded = ["data*", "experiments*"]
9+
excluded = ["data*", "examples*", "experiments*"]
1010

1111

1212
setuptools.setup(
@@ -17,7 +17,7 @@
1717
description="A conversational passage retrieval toolkit",
1818
long_description=long_description,
1919
long_description_content_type="text/markdown",
20-
url="https://github.com/jacklin64/chatty-goose",
20+
url="https://github.com/castorini/chatty-goose",
2121
install_requires=requirements,
2222
packages=setuptools.find_packages(exclude=excluded),
2323
classifiers=[

0 commit comments

Comments
 (0)