Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,9 @@ Supports:
| `rank_zephyr_7b_v1_full` | 4-bit-quantised GGUF | ~4GB | [Model card](https://huggingface.co/castorini/rank_zephyr_7b_v1_full) |
| `miniReranker_arabic_v1` | `Only dedicated Arabic Reranker` | - | [Model card](https://huggingface.co/prithivida/miniReranker_arabic_v1) |


- HuggingFace models - models that already have onnx generate will also work. Please take into account that all models that are saved as onnx are optimized for size and speed. For instance:
* mixedbread-ai/mxbai-rerank-xsmall-v1
* jinaai/jina-reranker-v1-tiny-en
- Models in roadmap:
* InRanker
- Why sleeker models are preferred ? Reranking is the final leg of larger retrieval pipelines, idea is to avoid any extra overhead especially for user-facing scenarios. To that end models with really small footprint that doesn't need any specialised hardware and yet offer competitive performance are chosen. Feel free to raise issues to add support for a new models as you see fit.
Expand Down Expand Up @@ -131,6 +133,13 @@ ranker = Ranker(model_name="ms-marco-MultiBERT-L-12", cache_dir="/opt")
or

ranker = Ranker(model_name="rank_zephyr_7b_v1_full", max_length=1024) # adjust max_length based on your passage length

or

# Medium (~90MB), fast modern model with competitive performance on diverse datasets
ranker = Ranker(model_name="mixedbread-ai/mxbai-rerank-xsmall-v1", cache_dir="/opt")


```

```python
Expand Down
3 changes: 2 additions & 1 deletion flashrank/Config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@

hf_endpoint = os.environ.get('HF_ENDPOINT', default='https://huggingface.co')
model_url = urljoin(hf_endpoint, 'prithivida/flashrank/resolve/main/{}.zip')
hf_model_url = urljoin(hf_endpoint, '{}/resolve/main/{}')
listwise_rankers = {'rank_zephyr_7b_v1_full'}

required_files = ['config.json', 'tokenizer.json', 'tokenizer_config.json', 'special_tokens_map.json']
default_cache_dir = "/tmp"
default_model = "ms-marco-TinyBERT-L-2-v2"
model_file_map = {
Expand Down
62 changes: 50 additions & 12 deletions flashrank/Ranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,20 @@
import zipfile
import requests
from tqdm import tqdm
from flashrank.Config import default_model, default_cache_dir, model_url, model_file_map, listwise_rankers
from flashrank.Config import default_model, default_cache_dir, model_url, model_file_map, listwise_rankers, hf_model_url, required_files
import collections
from typing import Optional, List, Dict, Any
import logging

def download_file(local_file, download_url):
with requests.get(download_url, stream=True) as r:
r.raise_for_status()
total_size = int(r.headers.get('content-length', 0))
with open(local_file, 'wb') as f, tqdm(desc=local_file.name, total=total_size, unit='iB', unit_scale=True, unit_divisor=1024) as bar:
for chunk in r.iter_content(chunk_size=8192):
size = f.write(chunk)
bar.update(size)

class RerankRequest:
""" Represents a reranking request with a query and a list of passages.

Expand Down Expand Up @@ -49,8 +58,11 @@ def __init__(self, model_name: str = default_model, cache_dir: str = default_cac
self.logger = logging.getLogger(__name__)

self.cache_dir: Path = Path(cache_dir)
self.model_dir: Path = self.cache_dir / model_name
model_file = model_file_map[model_name]
self.model_dir: Path = self.cache_dir / model_name.replace("/", "-")
if model_name in model_file_map:
model_file = model_file_map[model_name]
else:
model_file = f"{model_name.split('/')[-1]}.onnx"
self.model_path = self.model_dir / model_file
self._prepare_model_dir(model_name)

Expand Down Expand Up @@ -81,7 +93,11 @@ def _prepare_model_dir(self, model_name: str):

if not self.model_path.exists():
self.logger.info(f"Downloading {model_name}...")
self._download_model_files(model_name)
if model_name in model_file_map:
self._download_model_files(model_name)
else:
self.logger.info(f"Model {model_name} not found in model map, downloading from custom URL...")
self._download_hf_model_files(model_name)

def _download_model_files(self, model_name: str):
""" Downloads and extracts the model files from a specified URL.
Expand All @@ -92,18 +108,40 @@ def _download_model_files(self, model_name: str):
local_zip_file = self.cache_dir / f"{model_name}.zip"
formatted_model_url = model_url.format(model_name)

with requests.get(formatted_model_url, stream=True) as r:
r.raise_for_status()
total_size = int(r.headers.get('content-length', 0))
with open(local_zip_file, 'wb') as f, tqdm(desc=local_zip_file.name, total=total_size, unit='iB', unit_scale=True, unit_divisor=1024) as bar:
for chunk in r.iter_content(chunk_size=8192):
size = f.write(chunk)
bar.update(size)
download_file(local_zip_file, formatted_model_url)

with zipfile.ZipFile(local_zip_file, 'r') as zip_ref:
zip_ref.extractall(self.cache_dir)
os.remove(local_zip_file)

def _download_hf_model_files(self, model_name: str):
""" Downloads model files from Hugging Face repository.

Args:
model_name (str): The name of the model to download.
"""
has_onnx = False
if not self.model_dir.exists():
self.model_dir.mkdir(parents=True, exist_ok=True)
local_model_path = self.model_dir / f"{model_name.split('/')[-1]}.onnx"
for onnx in ["onnx/model_quantized.onnx", "onnx/model_uint8.onnx", "onnx/model_O4.onnx","onnx/model_fp16.onnx", "onnx/model.onnx"]:
formatted_model_url = hf_model_url.format(model_name, onnx)
try:
download_file(local_model_path, formatted_model_url)
has_onnx = True
break
except requests.exceptions.HTTPError as e:
self.logger.warning(f"Could not download {onnx} for model {model_name}. Error: {e}")
continue
if not has_onnx:
raise FileNotFoundError(f"ONNX model file not found for model {model_name} in the Hugging Face repository.")

for req_file in required_files:
formatted_artifacts_url = hf_model_url.format(model_name, req_file)
local_file_path = self.model_dir / req_file
download_file(local_file_path, formatted_artifacts_url)


def _get_tokenizer(self, max_length: int = 512) -> Tokenizer:
""" Initializes and configures the tokenizer with padding and truncation.

Expand Down Expand Up @@ -227,7 +265,7 @@ def rerank(self, request: RerankRequest) -> List[Dict[str, Any]]:
token_type_ids = np.array([e.type_ids for e in input_text])
attention_mask = np.array([e.attention_mask for e in input_text])

use_token_type_ids = token_type_ids is not None and not np.all(token_type_ids == 0)
use_token_type_ids = token_type_ids is not None and not np.all(token_type_ids == 0) and "token_type_ids" in {inp.name for inp in self.session.get_inputs()}

onnx_input = {"input_ids": input_ids.astype(np.int64), "attention_mask": attention_mask.astype(np.int64)}
if use_token_type_ids:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

setup(
name='FlashRank',
version='0.2.9',
version='0.3',
packages=find_packages(),
install_requires=[
'tokenizers',
Expand Down