Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 69 additions & 14 deletions backend/backend/engine/constants.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TypedDict
from typing import Literal, TypedDict

from backend.config import FainderMode

Expand All @@ -11,25 +11,80 @@ class FilteringStopPointsConfig(TypedDict):
num_hist_ids: int


FILTERING_STOP_POINTS: dict[FainderMode, FilteringStopPointsConfig] = {
# coefficients from linear model for pp result size
INTERCEPT = 5.776e6
COEF_LOG_THRESHOLD = 3.575e5
COEF_PERCENTILE = -9.344e5

FILTERING_STOP_POINTS: dict[FainderMode, dict[int, FilteringStopPointsConfig]] = {
FainderMode.LOW_MEMORY: {
"num_doc_ids": 1000,
"num_col_ids": 10000,
"num_hist_ids": 10000,
0: {
"num_doc_ids": 1000,
"num_col_ids": 10000,
"num_hist_ids": 10000,
},
},
FainderMode.FULL_PRECISION: {
"num_doc_ids": 1000,
"num_col_ids": 10000,
"num_hist_ids": 10000,
0: {
"num_doc_ids": 1000,
"num_col_ids": 10000,
"num_hist_ids": 10000,
},
},
FainderMode.FULL_RECALL: {
"num_doc_ids": 1000,
"num_col_ids": 10000,
"num_hist_ids": 10000,
0: {
"num_doc_ids": 1000,
"num_col_ids": 10000,
"num_hist_ids": 10000,
},
},
FainderMode.EXACT: {
"num_doc_ids": 30000,
"num_col_ids": 3500000,
"num_hist_ids": 3500000,
0: {
"num_doc_ids": 30000,
"num_col_ids": 900000,
"num_hist_ids": 900000,
},
5: {
"num_doc_ids": 75000,
"num_col_ids": 3000000,
"num_hist_ids": 3000000,
},
11: {
"num_doc_ids": 75000,
"num_col_ids": 3000000,
"num_hist_ids": 3000000,
},
27: {
"num_doc_ids": 75000,
"num_col_ids": 3000000,
"num_hist_ids": 3000000,
},
},
}


def get_filtering_stop_point(
mode: FainderMode,
num_workers: int,
filter_type: Literal["num_doc_ids", "num_col_ids", "num_hist_ids"],
) -> int:
"""Get the filtering stop point for a given Fainder mode and number of workers."""
if mode not in FILTERING_STOP_POINTS:
raise ValueError(f"Invalid Fainder mode: {mode}")

if num_workers not in FILTERING_STOP_POINTS[mode]:
# get nearest smaller key
available_keys = sorted(FILTERING_STOP_POINTS[mode].keys())
for key in reversed(available_keys):
if key <= num_workers:
num_workers = key
break
else:
raise ValueError(f"No available stop points for {mode} with {num_workers} workers")

stop_points = FILTERING_STOP_POINTS[mode][num_workers]

if filter_type not in stop_points:
raise ValueError(f"Invalid type: {filter_type}")

return stop_points[filter_type]
5 changes: 3 additions & 2 deletions backend/backend/engine/execution/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from numpy.typing import NDArray

from backend.config import ColumnArray, DocumentArray, DocumentHighlights, FainderMode, Highlights
from backend.engine.constants import FILTERING_STOP_POINTS
from backend.engine.constants import get_filtering_stop_point
from backend.engine.conversion import doc_to_col_ids

DocResult = tuple[DocumentArray, Highlights]
Expand Down Expand Up @@ -197,9 +197,10 @@ def exceeds_filtering_limit(
ids: DocumentArray | ColumnArray,
id_type: Literal["num_hist_ids", "num_col_ids", "num_doc_ids"],
fainder_mode: FainderMode,
num_workers: int,
) -> bool:
"""Check if the number of IDs exceeds the filtering limit for the current mode."""
return len(ids) > FILTERING_STOP_POINTS[fainder_mode][id_type]
return ids.size > get_filtering_stop_point(fainder_mode, num_workers, id_type)


def is_doc_result(val: Sequence[Any]) -> TypeGuard[Sequence[DocResult]]:
Expand Down
37 changes: 24 additions & 13 deletions backend/backend/engine/execution/prefiltering_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from collections.abc import Sequence

import numpy as np
from lark import ParseTree, Token, Transformer
from lark import ParseTree, Token, Transformer_NonRecursive
from loguru import logger
from numpy.typing import NDArray

Expand Down Expand Up @@ -42,8 +42,10 @@ def __init__(
fainder_mode: FainderMode,
doc_ids: DocumentArray | None = None,
col_ids: ColumnArray | None = None,
num_workers: int = 1,
) -> None:
self.fainder_mode = fainder_mode
self.num_workers = num_workers
if doc_ids is None and col_ids is None:
raise ValueError("doc_ids and col_ids cannot both be None")
if doc_ids is not None and col_ids is not None:
Expand All @@ -52,13 +54,13 @@ def __init__(
self._doc_ids: DocumentArray | None = (
None
if doc_ids is not None
and exceeds_filtering_limit(doc_ids, "num_doc_ids", fainder_mode)
and exceeds_filtering_limit(doc_ids, "num_doc_ids", fainder_mode, num_workers)
else doc_ids
)
self._col_ids: ColumnArray | None = (
None
if col_ids is not None
and exceeds_filtering_limit(col_ids, "num_col_ids", fainder_mode)
and exceeds_filtering_limit(col_ids, "num_col_ids", fainder_mode, num_workers)
else col_ids
)

Expand All @@ -83,11 +85,15 @@ def add_doc_ids(self, doc_ids: DocumentArray, col_to_doc: NDArray[np.uint32]) ->
def build_hist_filter(self, metadata: Metadata) -> ColumnArray | None:
"""Build a histogram filter from the intermediate results."""
if self._col_ids is not None:
if exceeds_filtering_limit(self._col_ids, "num_col_ids", self.fainder_mode):
if exceeds_filtering_limit(
self._col_ids, "num_col_ids", self.fainder_mode, self.num_workers
):
return None
return self._col_ids
if self._doc_ids is not None:
if exceeds_filtering_limit(self._doc_ids, "num_doc_ids", self.fainder_mode):
if exceeds_filtering_limit(
self._doc_ids, "num_doc_ids", self.fainder_mode, self.num_workers
):
return None
return doc_to_col_ids(self._doc_ids, metadata.doc_to_cols)
return None
Expand All @@ -104,11 +110,14 @@ def __str__(self) -> str:
class IntermediateResultStore:
"""Store intermediate results for prefiltering per group."""

def __init__(self, fainder_mode: FainderMode, write_groups_used: dict[int, int]) -> None:
def __init__(
self, fainder_mode: FainderMode, write_groups_used: dict[int, int], num_workers: int
) -> None:
self.results: dict[int, IntermediateResult] = {}
self.fainder_mode = fainder_mode
self.write_groups_used = write_groups_used
self.write_groups_actually_used: dict[int, int] = {}
self.num_workers = num_workers

def add_col_id_results(
self, write_group: int, col_ids: ColumnArray, doc_to_cols: list[NDArray[np.uint32]]
Expand All @@ -123,7 +132,7 @@ def add_col_id_results(
logger.trace("Write group {} is not used, skipping adding column IDs", write_group)
return

if exceeds_filtering_limit(col_ids, "num_col_ids", self.fainder_mode):
if exceeds_filtering_limit(col_ids, "num_col_ids", self.fainder_mode, self.num_workers):
logger.trace("Column IDs exceed filtering limit, skipping adding column IDs")
return

Expand All @@ -132,7 +141,7 @@ def add_col_id_results(
self.results[write_group].add_col_ids(col_ids=col_ids, doc_to_cols=doc_to_cols)
else:
self.results[write_group] = IntermediateResult(
col_ids=col_ids, fainder_mode=self.fainder_mode
col_ids=col_ids, fainder_mode=self.fainder_mode, num_workers=self.num_workers
)

def add_doc_id_results(
Expand All @@ -150,7 +159,7 @@ def add_doc_id_results(
logger.trace("Write group {} is not used, skipping adding document IDs", write_group)
return

if exceeds_filtering_limit(doc_ids, "num_doc_ids", self.fainder_mode):
if exceeds_filtering_limit(doc_ids, "num_doc_ids", self.fainder_mode, self.num_workers):
logger.trace("Document IDs exceed filtering limit, skipping adding document IDs")
return

Expand All @@ -159,7 +168,7 @@ def add_doc_id_results(
self.results[write_group].add_doc_ids(doc_ids=doc_ids, col_to_doc=col_to_doc)
else:
self.results[write_group] = IntermediateResult(
doc_ids=doc_ids, fainder_mode=self.fainder_mode
doc_ids=doc_ids, fainder_mode=self.fainder_mode, num_workers=self.num_workers
)

def build_hist_filter(self, read_groups: list[int], metadata: Metadata) -> ColumnArray | None:
Expand Down Expand Up @@ -205,7 +214,7 @@ def build_hist_filter(self, read_groups: list[int], metadata: Metadata) -> Colum
return reduce_arrays(hist_filters, "and")


class PrefilteringExecutor(Transformer[Token, DocResult], Executor):
class PrefilteringExecutor(Transformer_NonRecursive[Token, DocResult], Executor):
"""Uses prefiltering to reduce the number of documents before executing the query."""

def __init__(
Expand Down Expand Up @@ -239,7 +248,9 @@ def reset(
self.scores: dict[int, float] = defaultdict(float)
self.fainder_mode = fainder_mode
self.enable_highlighting = enable_highlighting
self.intermediate_results = IntermediateResultStore(fainder_mode, {})
self.intermediate_results = IntermediateResultStore(
fainder_mode, {}, self.fainder_index.num_workers
)
self.write_groups: dict[int, int] = {}
self.read_groups: dict[int, list[int]] = {}
self.parent_write_group: dict[int, int] = {}
Expand Down Expand Up @@ -287,7 +298,7 @@ def execute(self, tree: ParseTree) -> DocResult:
self.read_groups = {}
logger.trace(tree.pretty())
groups = ResultGroupAnnotator()
groups.apply(tree, parallel=True)
groups.apply(tree)
self.write_groups = groups.write_groups
self.read_groups = groups.read_groups
self.parent_write_group = groups.parent_write_group
Expand Down
4 changes: 2 additions & 2 deletions backend/backend/engine/execution/simple_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from collections.abc import Sequence

import numpy as np
from lark import ParseTree, Token, Transformer
from lark import ParseTree, Token, Transformer_NonRecursive
from loguru import logger

from backend.config import ColumnHighlights, DocumentHighlights, FainderMode, Metadata
Expand All @@ -13,7 +13,7 @@
from .executor import Executor


class SimpleExecutor(Transformer[Token, DocResult], Executor):
class SimpleExecutor(Transformer_NonRecursive[Token, DocResult], Executor):
"""This transformer evaluates a parse tree bottom-up and computes the query result."""

fainder_mode: FainderMode
Expand Down
Loading