Skip to content

Commit f653903

Browse files
authored
Allowing more general types in Settings (#818)
1 parent 8bc993c commit f653903

File tree

6 files changed

+18
-14
lines changed

6 files changed

+18
-14
lines changed

paperqa/core.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import json
44
import re
5-
from collections.abc import Callable
5+
from collections.abc import Callable, Sequence
66
from typing import Any
77

88
from paperqa.llms import PromptRunner
@@ -41,7 +41,7 @@ async def map_fxn_summary(
4141
prompt_runner: PromptRunner | None,
4242
extra_prompt_data: dict[str, str] | None = None,
4343
parser: Callable[[str], dict[str, Any]] | None = None,
44-
callbacks: list[Callable[[str], None]] | None = None,
44+
callbacks: Sequence[Callable[[str], None]] | None = None,
4545
) -> tuple[Context, LLMResult]:
4646
"""Parses the given text and returns a context object with the parser and prompt runner.
4747

paperqa/docs.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import re
77
import tempfile
88
import urllib.request
9-
from collections.abc import Callable
9+
from collections.abc import Callable, Sequence
1010
from datetime import datetime
1111
from functools import partial
1212
from io import BytesIO
@@ -549,7 +549,7 @@ def get_evidence(
549549
query: PQASession | str,
550550
exclude_text_filter: set[str] | None = None,
551551
settings: MaybeSettings = None,
552-
callbacks: list[Callable] | None = None,
552+
callbacks: Sequence[Callable] | None = None,
553553
embedding_model: EmbeddingModel | None = None,
554554
summary_llm_model: LLMModel | None = None,
555555
partitioning_fn: Callable[[Embeddable], int] | None = None,
@@ -571,7 +571,7 @@ async def aget_evidence(
571571
query: PQASession | str,
572572
exclude_text_filter: set[str] | None = None,
573573
settings: MaybeSettings = None,
574-
callbacks: list[Callable] | None = None,
574+
callbacks: Sequence[Callable] | None = None,
575575
embedding_model: EmbeddingModel | None = None,
576576
summary_llm_model: LLMModel | None = None,
577577
partitioning_fn: Callable[[Embeddable], int] | None = None,
@@ -668,7 +668,7 @@ def query(
668668
self,
669669
query: PQASession | str,
670670
settings: MaybeSettings = None,
671-
callbacks: list[Callable] | None = None,
671+
callbacks: Sequence[Callable] | None = None,
672672
llm_model: LLMModel | None = None,
673673
summary_llm_model: LLMModel | None = None,
674674
embedding_model: EmbeddingModel | None = None,
@@ -690,12 +690,16 @@ async def aquery( # noqa: PLR0912
690690
self,
691691
query: PQASession | str,
692692
settings: MaybeSettings = None,
693-
callbacks: list[Callable] | None = None,
693+
callbacks: Sequence[Callable] | None = None,
694694
llm_model: LLMModel | None = None,
695695
summary_llm_model: LLMModel | None = None,
696696
embedding_model: EmbeddingModel | None = None,
697697
partitioning_fn: Callable[[Embeddable], int] | None = None,
698698
) -> PQASession:
699+
# TODO: remove list cast after release of https://github.com/Future-House/llm-client/pull/36
700+
callbacks = cast(
701+
list[Callable] | None, list(callbacks) if callbacks else callbacks
702+
)
699703

700704
query_settings = get_settings(settings)
701705
answer_config = query_settings.answer

paperqa/llms.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
qdrant_installed = False
4848

4949
PromptRunner = Callable[
50-
[dict, list[Callable[[str], None]] | None, str | None],
50+
[dict, Sequence[Callable[[str], None]] | None, str | None],
5151
Awaitable[LLMResult],
5252
]
5353

paperqa/settings.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import os
44
import pathlib
55
import warnings
6-
from collections.abc import Callable, Mapping
6+
from collections.abc import Callable, Mapping, Sequence
77
from enum import StrEnum
88
from pydoc import locate
99
from typing import Any, ClassVar, Self, TypeAlias, assert_never, cast
@@ -194,7 +194,7 @@ class ParsingSettings(BaseModel):
194194
),
195195
)
196196
chunking_algorithm: ChunkingOptions = ChunkingOptions.SIMPLE_OVERLAP
197-
doc_filters: list[dict] | None = Field(
197+
doc_filters: Sequence[Mapping[str, Any]] | None = Field(
198198
default=None,
199199
description=(
200200
"Optional filters to only allow documents that match this filter. This is a"
@@ -498,7 +498,7 @@ class AgentSettings(BaseModel):
498498
description="If set to true, run the search tool before invoking agent.",
499499
)
500500

501-
tool_names: set[str] | None = Field(
501+
tool_names: set[str] | Sequence[str] | None = Field(
502502
default=None,
503503
description=(
504504
"Optional override on the tools to provide the agent. Leaving as the"
@@ -521,7 +521,7 @@ class AgentSettings(BaseModel):
521521
)
522522
index: IndexSettings = Field(default_factory=IndexSettings)
523523

524-
callbacks: Mapping[str, list[Callable[[_EnvironmentState], Any]]] = Field(
524+
callbacks: Mapping[str, Sequence[Callable[[_EnvironmentState], Any]]] = Field(
525525
default_factory=dict,
526526
description="""
527527
A mapping that associates callback names with lists of corresponding callable functions.

paperqa/types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def __hash__(self) -> int:
6666
def formatted_citation(self) -> str:
6767
return self.citation
6868

69-
def matches_filter_criteria(self, filter_criteria: dict) -> bool:
69+
def matches_filter_criteria(self, filter_criteria: Mapping[str, Any]) -> bool:
7070
"""Returns True if the doc matches the filter criteria, False otherwise."""
7171
data_dict = self.model_dump()
7272
for key, value in filter_criteria.items():

tests/test_agents.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -492,7 +492,7 @@ async def test_agent_sharing_state(
492492
"gather_evidence_completed": [gather_evidence_completed_callback],
493493
}
494494

495-
agent_test_settings.agent.callbacks = callbacks # type: ignore[assignment]
495+
agent_test_settings.agent.callbacks = callbacks
496496

497497
session = PQASession(question="What is is a self-explanatory model?")
498498
env_state = EnvironmentState(docs=Docs(), session=session)

0 commit comments

Comments
 (0)