From c82be08866d5ef13f8fd162668641f0e05d76f90 Mon Sep 17 00:00:00 2001 From: "David L. Qiu" Date: Mon, 13 Jan 2025 13:20:13 -0800 Subject: [PATCH 01/12] remove importliner from project --- .github/workflows/lint.yml | 34 ---------------------------------- pyproject.toml | 26 ++++---------------------- 2 files changed, 4 insertions(+), 56 deletions(-) diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 8c44b6264..09bfff6da 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -17,37 +17,3 @@ jobs: run: jlpm - name: Lint TypeScript source run: jlpm lerna run lint:check - - lint_py_imports: - name: Lint Python imports - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - name: Echo environment details - run: | - which python - which pip - python --version - pip --version - - # see #546 for context on why this is necessary - - name: Create venv - run: | - python -m venv lint_py_imports - - - name: Install job dependencies - run: | - source ./lint_py_imports/bin/activate - pip install jupyterlab~=4.0 - pip install import-linter~=1.12.1 - - - name: Install Jupyter AI packages from source - run: | - source ./lint_py_imports/bin/activate - jlpm install - jlpm install-from-src - - - name: Lint Python imports - run: | - source ./lint_py_imports/bin/activate - lint-imports diff --git a/pyproject.toml b/pyproject.toml index 322177455..4e0a917bf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,8 +7,8 @@ name = "jupyter_ai_monorepo" dynamic = ["version", "description", "authors", "urls", "keywords"] requires-python = ">=3.9" dependencies = [ - "jupyter-ai-magics @ {root:uri}/packages/jupyter-ai-magics", - "jupyter-ai @ {root:uri}/packages/jupyter-ai" + "jupyter-ai-magics @ {root:uri}/packages/jupyter-ai-magics", + "jupyter-ai @ {root:uri}/packages/jupyter-ai", ] [project.optional-dependencies] @@ -26,10 +26,7 @@ source = "nodejs" path = "package.json" [tool.hatch.build] -packages = [ - "packages/jupyter-ai-magics", - "packages/jupyter-ai" -] +packages = ["packages/jupyter-ai-magics", "packages/jupyter-ai"] [tool.hatch.metadata] allow-direct-references = true @@ -40,23 +37,8 @@ ignore = [".*"] [tool.check-wheel-contents] ignore = ["W002"] -[tool.importlinter] -root_packages = ["jupyter_ai", "jupyter_ai_magics"] -include_external_packages = true - -[[tool.importlinter.contracts]] -key = "pydantic" -name = "Forbid `pydantic`. (note: Developers should import Pydantic from `langchain.pydantic_v1` instead for compatibility.)" -type = "forbidden" -source_modules = ["jupyter_ai", "jupyter_ai_magics"] -forbidden_modules = ["pydantic"] -# TODO: get `langchain` to export `ModelMetaclass` to avoid needing this statement -ignore_imports = ["jupyter_ai_magics.providers -> pydantic"] - [tool.pytest.ini_options] addopts = "--ignore packages/jupyter-ai-module-cookiecutter" [tool.mypy] -exclude = [ - "tests" -] +exclude = ["tests"] From b354c2471d7cbb972d038e0670e3e42add14cdfe Mon Sep 17 00:00:00 2001 From: "David L. Qiu" Date: Mon, 13 Jan 2025 14:31:05 -0800 Subject: [PATCH 02/12] initial upgrade to langchain~=0.3, pydantic~=2.0 --- .../jupyter_ai_magics/embedding_providers.py | 2 +- .../jupyter_ai_magics/models/completion.py | 2 +- .../jupyter_ai_magics/models/persona.py | 2 +- .../jupyter_ai_magics/parsers.py | 2 +- .../partner_providers/openrouter.py | 4 +- .../jupyter_ai_magics/providers.py | 81 +++++-------------- packages/jupyter-ai-magics/pyproject.toml | 5 +- .../jupyter_ai/chat_handlers/base.py | 2 +- .../jupyter_ai/chat_handlers/generate.py | 2 +- .../jupyter_ai/completions/handlers/base.py | 2 +- .../jupyter_ai/context_providers/base.py | 2 +- packages/jupyter-ai/jupyter_ai/handlers.py | 2 +- packages/jupyter-ai/jupyter_ai/history.py | 2 +- packages/jupyter-ai/jupyter_ai/models.py | 10 +-- .../jupyter_ai/tests/test_config_manager.py | 2 +- packages/jupyter-ai/pyproject.toml | 1 + 16 files changed, 43 insertions(+), 80 deletions(-) diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/embedding_providers.py b/packages/jupyter-ai-magics/jupyter_ai_magics/embedding_providers.py index 695465488..97ce937a6 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/embedding_providers.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/embedding_providers.py @@ -6,7 +6,7 @@ Field, MultiEnvAuthStrategy, ) -from langchain.pydantic_v1 import BaseModel, Extra +from pydantic import BaseModel, Extra from langchain_community.embeddings import ( GPT4AllEmbeddings, HuggingFaceHubEmbeddings, diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/models/completion.py b/packages/jupyter-ai-magics/jupyter_ai_magics/models/completion.py index f2ee0cd54..17e680264 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/models/completion.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/models/completion.py @@ -1,6 +1,6 @@ from typing import List, Literal, Optional -from langchain.pydantic_v1 import BaseModel +from pydantic import BaseModel class InlineCompletionRequest(BaseModel): diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/models/persona.py b/packages/jupyter-ai-magics/jupyter_ai_magics/models/persona.py index fe25397b0..da2a39204 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/models/persona.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/models/persona.py @@ -1,4 +1,4 @@ -from langchain.pydantic_v1 import BaseModel +from pydantic import BaseModel class Persona(BaseModel): diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/parsers.py b/packages/jupyter-ai-magics/jupyter_ai_magics/parsers.py index 07e26e875..a5bc93439 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/parsers.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/parsers.py @@ -2,7 +2,7 @@ from typing import Literal, Optional, get_args import click -from langchain.pydantic_v1 import BaseModel +from pydantic import BaseModel FORMAT_CHOICES_TYPE = Literal[ "code", "html", "image", "json", "markdown", "math", "md", "text" diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/openrouter.py b/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/openrouter.py index 81c2d7ab1..bbaf9deb4 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/openrouter.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/openrouter.py @@ -2,7 +2,7 @@ from jupyter_ai_magics import BaseProvider from jupyter_ai_magics.providers import EnvAuthStrategy, TextField -from langchain_core.pydantic_v1 import root_validator +from pydantic import model_validator from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env from langchain_openai import ChatOpenAI @@ -42,7 +42,7 @@ def __init__(self, **kwargs): **kwargs, ) - @root_validator(pre=False, skip_on_failure=True, allow_reuse=True) + @model_validator(mode="after") def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" values["openai_api_key"] = convert_to_secret_str( diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py b/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py index fac868229..6ed9082bf 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py @@ -24,7 +24,7 @@ PromptTemplate, SystemMessagePromptTemplate, ) -from langchain.pydantic_v1 import BaseModel, Extra +from pydantic import BaseModel, ConfigDict from langchain.schema import LLMResult from langchain.schema.output_parser import StrOutputParser from langchain.schema.runnable import Runnable @@ -33,13 +33,6 @@ from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.language_models.llms import BaseLLM -# this is necessary because `langchain.pydantic_v1.main` does not include -# `ModelMetaclass`, as it is not listed in `__all__` by the `pydantic.main` -# subpackage. -try: - from pydantic.v1.main import ModelMetaclass -except: - from pydantic.main import ModelMetaclass from . import completion_utils as completion from .models.completion import ( @@ -122,7 +115,7 @@ class EnvAuthStrategy(BaseModel): name: str """The name of the environment variable, e.g. `'ANTHROPIC_API_KEY'`.""" - keyword_param: Optional[str] + keyword_param: Optional[str] = None """ If unset (default), the authentication token is provided as a keyword argument with the parameter equal to the environment variable name in @@ -177,51 +170,10 @@ class IntegerField(BaseModel): Field = Union[TextField, MultilineTextField, IntegerField] -class ProviderMetaclass(ModelMetaclass): - """ - A metaclass that ensures all class attributes defined inline within the - class definition are accessible and included in `Class.__dict__`. - - This is necessary because Pydantic drops any ClassVars that are defined as - an instance field by a parent class, even if they are defined inline within - the class definition. We encountered this case when `langchain` added a - `name` attribute to a parent class shared by all `Provider`s, which caused - `Provider.name` to be inaccessible. See #558 for more info. - """ - - def __new__(mcs, name, bases, namespace, **kwargs): - cls = super().__new__(mcs, name, bases, namespace, **kwargs) - for key in namespace: - # skip private class attributes - if key.startswith("_"): - continue - # skip class attributes already listed in `cls.__dict__` - if key in cls.__dict__: - continue - - setattr(cls, key, namespace[key]) - - return cls - - @property - def server_settings(cls): - return cls._server_settings - - @server_settings.setter - def server_settings(cls, value): - if cls._server_settings is not None: - raise AttributeError("'server_settings' attribute was already set") - cls._server_settings = value - - _server_settings = None - - -class BaseProvider(BaseModel, metaclass=ProviderMetaclass): - # - # pydantic config - # - class Config: - extra = Extra.allow +class BaseProvider(BaseModel): + # pydantic v2 model config + # upstream docs: https://docs.pydantic.dev/latest/api/config/#pydantic.config.ConfigDict.extra + model_config = ConfigDict(extra="allow") # # class attrs @@ -236,15 +188,25 @@ class Config: """List of supported models by their IDs. For registry providers, this will be just ["*"].""" - help: ClassVar[str] = None + help: ClassVar[Optional[str]] = None """Text to display in lieu of a model list for a registry provider that does not provide a list of models.""" - model_id_key: ClassVar[str] = ... - """Kwarg expected by the upstream LangChain provider.""" + model_id_key: ClassVar[Optional[str]] = None + """ + Optional field which specifies the key under which `model_id` is passed to + the parent LangChain class. - model_id_label: ClassVar[str] = "" - """Human-readable label of the model ID.""" + If unset, this defaults to "model_id". + """ + + model_id_label: ClassVar[Optional[str]] = None + """ + Optional field which sets the label shown in the UI allowing users to + select/type a model ID. + + If unset, the label shown in the UI defaults to "Model ID". + """ pypi_package_deps: ClassVar[List[str]] = [] """List of PyPi package dependencies.""" @@ -586,7 +548,6 @@ def __init__(self, **kwargs): id = "gpt4all" name = "GPT4All" - docs = "https://docs.gpt4all.io/gpt4all_python.html" models = [ "ggml-gpt4all-j-v1.2-jazzy", "ggml-gpt4all-j-v1.3-groovy", diff --git a/packages/jupyter-ai-magics/pyproject.toml b/packages/jupyter-ai-magics/pyproject.toml index 6819ad989..66299a7a5 100644 --- a/packages/jupyter-ai-magics/pyproject.toml +++ b/packages/jupyter-ai-magics/pyproject.toml @@ -24,8 +24,9 @@ dynamic = ["version", "description", "authors", "urls", "keywords"] dependencies = [ "ipython", "importlib_metadata>=5.2.0", - "langchain>=0.2.17,<0.3.0", - "langchain_community>=0.2.19,<0.3.0", + "langchain>=0.3.0,<0.4.0", + "langchain_community>=0.3.0,<0.4.0", + "pydantic~=2.0", "typing_extensions>=4.5.0", "click~=8.0", "jsonpath-ng>=1.5.3,<2", diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py index c844650ad..99c02a2c2 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py @@ -36,7 +36,7 @@ ) from jupyter_ai_magics import Persona from jupyter_ai_magics.providers import BaseProvider -from langchain.pydantic_v1 import BaseModel +from pydantic import BaseModel from langchain_core.messages import AIMessageChunk from langchain_core.runnables import Runnable from langchain_core.runnables.config import RunnableConfig diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py index 6318e0979..36222b9fd 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py @@ -13,7 +13,7 @@ from langchain.chains import LLMChain from langchain.llms import BaseLLM from langchain.output_parsers import PydanticOutputParser -from langchain.pydantic_v1 import BaseModel +from pydantic import BaseModel from langchain.schema.output_parser import BaseOutputParser from langchain_core.prompts import PromptTemplate diff --git a/packages/jupyter-ai/jupyter_ai/completions/handlers/base.py b/packages/jupyter-ai/jupyter_ai/completions/handlers/base.py index 32920dc83..bc75f950c 100644 --- a/packages/jupyter-ai/jupyter_ai/completions/handlers/base.py +++ b/packages/jupyter-ai/jupyter_ai/completions/handlers/base.py @@ -14,7 +14,7 @@ InlineCompletionStreamChunk, ) from jupyter_server.base.handlers import JupyterHandler -from langchain.pydantic_v1 import ValidationError +from pydantic import ValidationError class BaseInlineCompletionHandler( diff --git a/packages/jupyter-ai/jupyter_ai/context_providers/base.py b/packages/jupyter-ai/jupyter_ai/context_providers/base.py index 1b0953e84..3468bcd78 100644 --- a/packages/jupyter-ai/jupyter_ai/context_providers/base.py +++ b/packages/jupyter-ai/jupyter_ai/context_providers/base.py @@ -7,7 +7,7 @@ from jupyter_ai.chat_handlers.base import get_preferred_dir from jupyter_ai.config_manager import ConfigManager, Logger from jupyter_ai.models import ChatMessage, HumanChatMessage, ListOptionsEntry -from langchain.pydantic_v1 import BaseModel +from pydantic import BaseModel if TYPE_CHECKING: from jupyter_ai.chat_handlers import BaseChatHandler diff --git a/packages/jupyter-ai/jupyter_ai/handlers.py b/packages/jupyter-ai/jupyter_ai/handlers.py index 28b169c00..869eec944 100644 --- a/packages/jupyter-ai/jupyter_ai/handlers.py +++ b/packages/jupyter-ai/jupyter_ai/handlers.py @@ -12,7 +12,7 @@ from jupyter_ai.context_providers import BaseCommandContextProvider, ContextCommand from jupyter_server.base.handlers import APIHandler as BaseAPIHandler from jupyter_server.base.handlers import JupyterHandler -from langchain.pydantic_v1 import ValidationError +from pydantic import ValidationError from tornado import web, websocket from tornado.web import HTTPError diff --git a/packages/jupyter-ai/jupyter_ai/history.py b/packages/jupyter-ai/jupyter_ai/history.py index 0f1ba7dc0..0197f4f6c 100644 --- a/packages/jupyter-ai/jupyter_ai/history.py +++ b/packages/jupyter-ai/jupyter_ai/history.py @@ -3,7 +3,7 @@ from langchain_core.chat_history import BaseChatMessageHistory from langchain_core.messages import BaseMessage -from langchain_core.pydantic_v1 import BaseModel, PrivateAttr +from pydantic import BaseModel, PrivateAttr from .models import HumanChatMessage diff --git a/packages/jupyter-ai/jupyter_ai/models.py b/packages/jupyter-ai/jupyter_ai/models.py index 6bd7d4e06..2166dee60 100644 --- a/packages/jupyter-ai/jupyter_ai/models.py +++ b/packages/jupyter-ai/jupyter_ai/models.py @@ -3,7 +3,7 @@ from jupyter_ai_magics import Persona from jupyter_ai_magics.providers import AuthStrategy, Field -from langchain.pydantic_v1 import BaseModel, validator +from pydantic import BaseModel, validator DEFAULT_CHUNK_SIZE = 2000 DEFAULT_CHUNK_OVERLAP = 100 @@ -213,14 +213,14 @@ class ListProvidersEntry(BaseModel): id: str name: str - model_id_label: Optional[str] + model_id_label: Optional[str] = None models: List[str] - help: Optional[str] + help: Optional[str] = None auth_strategy: AuthStrategy registry: bool fields: List[Field] - chat_models: Optional[List[str]] - completion_models: Optional[List[str]] + chat_models: Optional[List[str]] = None + completion_models: Optional[List[str]] = None class ListProvidersResponse(BaseModel): diff --git a/packages/jupyter-ai/jupyter_ai/tests/test_config_manager.py b/packages/jupyter-ai/jupyter_ai/tests/test_config_manager.py index 4a739f6e5..38212f995 100644 --- a/packages/jupyter-ai/jupyter_ai/tests/test_config_manager.py +++ b/packages/jupyter-ai/jupyter_ai/tests/test_config_manager.py @@ -12,7 +12,7 @@ ) from jupyter_ai.models import DescribeConfigResponse, GlobalConfig, UpdateConfigRequest from jupyter_ai_magics.utils import get_em_providers, get_lm_providers -from langchain.pydantic_v1 import ValidationError +from pydantic import ValidationError @pytest.fixture diff --git a/packages/jupyter-ai/pyproject.toml b/packages/jupyter-ai/pyproject.toml index c9d1b5d53..a1ab4a46a 100644 --- a/packages/jupyter-ai/pyproject.toml +++ b/packages/jupyter-ai/pyproject.toml @@ -27,6 +27,7 @@ dependencies = [ "jupyterlab~=4.0", "aiosqlite>=0.18", "importlib_metadata>=5.2.0", + "pydantic~=2.0", "jupyter_ai_magics>=2.13.0", "dask[distributed]", # faiss-cpu is not distributed by the official repo. From d5d5d4c36be61a8f5adf7d64d0275a39ce863864 Mon Sep 17 00:00:00 2001 From: "David L. Qiu" Date: Mon, 13 Jan 2025 14:53:13 -0800 Subject: [PATCH 03/12] default to `None` for all `Optional` fields explicitly --- .../jupyter_ai_magics/models/completion.py | 16 ++++---- .../jupyter_ai_magics/parsers.py | 18 ++++----- .../jupyter_ai/chat_handlers/base.py | 2 +- .../jupyter-ai/jupyter_ai/config_manager.py | 6 +-- packages/jupyter-ai/jupyter_ai/models.py | 38 +++++++++---------- .../jupyter_ai/tests/test_handlers.py | 2 +- 6 files changed, 41 insertions(+), 41 deletions(-) diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/models/completion.py b/packages/jupyter-ai-magics/jupyter_ai_magics/models/completion.py index 17e680264..2949c06ec 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/models/completion.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/models/completion.py @@ -21,12 +21,12 @@ class InlineCompletionRequest(BaseModel): # whether to stream the response (if supported by the model) stream: bool # path to the notebook of file for which the completions are generated - path: Optional[str] + path: Optional[str] = None # language inferred from the document mime type (if possible) - language: Optional[str] + language: Optional[str] = None # identifier of the cell for which the completions are generated if in a notebook # previous cells and following cells can be used to learn the wider context - cell_id: Optional[str] + cell_id: Optional[str] = None class InlineCompletionItem(BaseModel): @@ -36,9 +36,9 @@ class InlineCompletionItem(BaseModel): """ insertText: str - filterText: Optional[str] - isIncomplete: Optional[bool] - token: Optional[str] + filterText: Optional[str] = None + isIncomplete: Optional[bool] = None + token: Optional[str] = None class CompletionError(BaseModel): @@ -59,7 +59,7 @@ class InlineCompletionReply(BaseModel): list: InlineCompletionList # number of request for which we are replying reply_to: int - error: Optional[CompletionError] + error: Optional[CompletionError] = None class InlineCompletionStreamChunk(BaseModel): @@ -69,7 +69,7 @@ class InlineCompletionStreamChunk(BaseModel): response: InlineCompletionItem reply_to: int done: bool - error: Optional[CompletionError] + error: Optional[CompletionError] = None __all__ = [ diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/parsers.py b/packages/jupyter-ai-magics/jupyter_ai_magics/parsers.py index a5bc93439..de99fc8bd 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/parsers.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/parsers.py @@ -46,11 +46,11 @@ class CellArgs(BaseModel): type: Literal["root"] = "root" model_id: str format: FORMAT_CHOICES_TYPE - model_parameters: Optional[str] + model_parameters: Optional[str] = None # The following parameters are required only for SageMaker models - region_name: Optional[str] - request_schema: Optional[str] - response_path: Optional[str] + region_name: Optional[str] = None + request_schema: Optional[str] = None + response_path: Optional[str] = None # Should match CellArgs @@ -58,11 +58,11 @@ class ErrorArgs(BaseModel): type: Literal["error"] = "error" model_id: str format: FORMAT_CHOICES_TYPE - model_parameters: Optional[str] + model_parameters: Optional[str] = None # The following parameters are required only for SageMaker models - region_name: Optional[str] - request_schema: Optional[str] - response_path: Optional[str] + region_name: Optional[str] = None + request_schema: Optional[str] = None + response_path: Optional[str] = None class HelpArgs(BaseModel): @@ -75,7 +75,7 @@ class VersionArgs(BaseModel): class ListArgs(BaseModel): type: Literal["list"] = "list" - provider_id: Optional[str] + provider_id: Optional[str] = None class RegisterArgs(BaseModel): diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py index 99c02a2c2..5ba2e4b86 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py @@ -68,7 +68,7 @@ class HandlerRoutingType(BaseModel): class SlashCommandRoutingType(HandlerRoutingType): routing_method = "slash_command" - slash_id: Optional[str] + slash_id: Optional[str] = None """Slash ID for routing a chat command to this handler. Only one handler may declare a particular slash ID. Must contain only alphanumerics and underscores.""" diff --git a/packages/jupyter-ai/jupyter_ai/config_manager.py b/packages/jupyter-ai/jupyter_ai/config_manager.py index 79b710e3a..44879fe32 100644 --- a/packages/jupyter-ai/jupyter_ai/config_manager.py +++ b/packages/jupyter-ai/jupyter_ai/config_manager.py @@ -98,9 +98,9 @@ class ConfigManager(Configurable): config=True, ) - model_provider_id: Optional[str] - embeddings_provider_id: Optional[str] - completions_model_provider_id: Optional[str] + model_provider_id: Optional[str] = None + embeddings_provider_id: Optional[str] = None + completions_model_provider_id: Optional[str] = None def __init__( self, diff --git a/packages/jupyter-ai/jupyter_ai/models.py b/packages/jupyter-ai/jupyter_ai/models.py index 2166dee60..4b011cf90 100644 --- a/packages/jupyter-ai/jupyter_ai/models.py +++ b/packages/jupyter-ai/jupyter_ai/models.py @@ -37,7 +37,7 @@ class CellWithErrorSelection(BaseModel): # the type of message used to chat with the agent class ChatRequest(BaseModel): prompt: str - selection: Optional[Selection] + selection: Optional[Selection] = None class StopRequest(BaseModel): @@ -54,7 +54,7 @@ class StopRequest(BaseModel): class ClearRequest(BaseModel): type: Literal["clear"] = "clear" - target: Optional[str] + target: Optional[str] = None """ Message ID of the HumanChatMessage to delete an exchange at. If not provided, this requests the backend to clear all messages. @@ -67,8 +67,8 @@ class ChatUser(BaseModel): initials: str name: str display_name: str - color: Optional[str] - avatar_url: Optional[str] + color: Optional[str] = None + avatar_url: Optional[str] = None class ChatClient(ChatUser): @@ -148,7 +148,7 @@ class HumanChatMessage(BaseModel): `prompt` and `selection`.""" prompt: str """The prompt typed into the chat input by the user.""" - selection: Optional[Selection] + selection: Optional[Selection] = None """The selection included with the prompt, if any.""" client: ChatClient @@ -238,8 +238,8 @@ class IndexMetadata(BaseModel): class DescribeConfigResponse(BaseModel): - model_provider_id: Optional[str] - embeddings_provider_id: Optional[str] + model_provider_id: Optional[str] = None + embeddings_provider_id: Optional[str] = None send_with_shift_enter: bool fields: Dict[str, Dict[str, Any]] # when sending config over REST API, do not include values of the API keys, @@ -248,7 +248,7 @@ class DescribeConfigResponse(BaseModel): # timestamp indicating when the configuration file was last read. should be # passed to the subsequent UpdateConfig request. last_read: int - completions_model_provider_id: Optional[str] + completions_model_provider_id: Optional[str] = None completions_fields: Dict[str, Dict[str, Any]] @@ -258,16 +258,16 @@ def forbid_none(cls, v): class UpdateConfigRequest(BaseModel): - model_provider_id: Optional[str] - embeddings_provider_id: Optional[str] - send_with_shift_enter: Optional[bool] - api_keys: Optional[Dict[str, str]] - fields: Optional[Dict[str, Dict[str, Any]]] + model_provider_id: Optional[str] = None + embeddings_provider_id: Optional[str] = None + send_with_shift_enter: Optional[bool] = None + api_keys: Optional[Dict[str, str]] = None + fields: Optional[Dict[str, Dict[str, Any]]] = None # if passed, this will raise an Error if the config was written to after the # time specified by `last_read` to prevent write-write conflicts. - last_read: Optional[int] - completions_model_provider_id: Optional[str] - completions_fields: Optional[Dict[str, Dict[str, Any]]] + last_read: Optional[int] = None + completions_model_provider_id: Optional[str] = None + completions_fields: Optional[Dict[str, Dict[str, Any]]] = None _validate_send_wse = validator("send_with_shift_enter", allow_reuse=True)( forbid_none @@ -280,12 +280,12 @@ class GlobalConfig(BaseModel): """Model used to represent the config by ConfigManager. This is exclusive to the backend and should never be sent to the client.""" - model_provider_id: Optional[str] - embeddings_provider_id: Optional[str] + model_provider_id: Optional[str] = None + embeddings_provider_id: Optional[str] = None send_with_shift_enter: bool fields: Dict[str, Dict[str, Any]] api_keys: Dict[str, str] - completions_model_provider_id: Optional[str] + completions_model_provider_id: Optional[str] = None completions_fields: Dict[str, Dict[str, Any]] diff --git a/packages/jupyter-ai/jupyter_ai/tests/test_handlers.py b/packages/jupyter-ai/jupyter_ai/tests/test_handlers.py index 81108bdb7..ec8b43278 100644 --- a/packages/jupyter-ai/jupyter_ai/tests/test_handlers.py +++ b/packages/jupyter-ai/jupyter_ai/tests/test_handlers.py @@ -34,7 +34,7 @@ class MockProvider(BaseProvider, FakeListLLM): name = "My Provider" model_id_key = "model" models = ["model"] - should_raise: Optional[bool] + should_raise: Optional[bool] = None def __init__(self, **kwargs): if "responses" not in kwargs: From b246404ea2de6afca400af3f44eed3e944d4ccaa Mon Sep 17 00:00:00 2001 From: "David L. Qiu" Date: Tue, 14 Jan 2025 09:31:14 -0800 Subject: [PATCH 04/12] fix history impl for Pydantic v2, fixes chat --- packages/jupyter-ai/jupyter_ai/history.py | 33 ++++++++++++++++------- 1 file changed, 23 insertions(+), 10 deletions(-) diff --git a/packages/jupyter-ai/jupyter_ai/history.py b/packages/jupyter-ai/jupyter_ai/history.py index 0197f4f6c..24a08245d 100644 --- a/packages/jupyter-ai/jupyter_ai/history.py +++ b/packages/jupyter-ai/jupyter_ai/history.py @@ -1,16 +1,15 @@ import time -from typing import List, Optional, Sequence, Set, Union +from typing import List, Optional, Sequence, Set from langchain_core.chat_history import BaseChatMessageHistory from langchain_core.messages import BaseMessage -from pydantic import BaseModel, PrivateAttr from .models import HumanChatMessage HUMAN_MSG_ID_KEY = "_jupyter_ai_human_msg_id" -class BoundedChatHistory(BaseChatMessageHistory, BaseModel): +class BoundedChatHistory(BaseChatMessageHistory): """ An in-memory implementation of `BaseChatMessageHistory` that stores up to `k` exchanges between a user and an LLM. @@ -19,10 +18,16 @@ class BoundedChatHistory(BaseChatMessageHistory, BaseModel): messages and 2 AI messages. If `k` is set to `None` all messages are kept. """ - k: Union[int, None] - clear_time: float = 0.0 - cleared_msgs: Set[str] = set() - _all_messages: List[BaseMessage] = PrivateAttr(default_factory=list) + def __init__( + self, + k: Optional[int] = None, + clear_time: float = 0.0, + cleared_msgs: Set[str] = set(), + ): + self.k = k + self.clear_time = clear_time + self.cleared_msgs = cleared_msgs + self._all_messages = [] @property def messages(self) -> List[BaseMessage]: # type:ignore[override] @@ -67,7 +72,7 @@ async def aclear(self) -> None: self.clear() -class WrappedBoundedChatHistory(BaseChatMessageHistory, BaseModel): +class WrappedBoundedChatHistory(BaseChatMessageHistory): """ Wrapper around `BoundedChatHistory` that only appends an `AgentChatMessage` if the `HumanChatMessage` it is replying to was not cleared. If a chat @@ -88,8 +93,16 @@ class WrappedBoundedChatHistory(BaseChatMessageHistory, BaseModel): Reference: https://python.langchain.com/v0.1/docs/expression_language/how_to/message_history/ """ - history: BoundedChatHistory - last_human_msg: HumanChatMessage + def __init__( + self, + history: BoundedChatHistory, + last_human_msg: HumanChatMessage, + *args, + **kwargs, + ): + self.history = history + self.last_human_msg = last_human_msg + super().__init__(*args, **kwargs) @property def messages(self) -> List[BaseMessage]: # type:ignore[override] From 4a0bcb153aa37f690d4b153816497e5b0ce02285 Mon Sep 17 00:00:00 2001 From: "David L. Qiu" Date: Tue, 14 Jan 2025 09:35:19 -0800 Subject: [PATCH 05/12] prefer `.model_dump_json()` over `.json()` Addresses a Pydantic v2 deprecation warning, as `BaseModel.json()` is now deprecated in favor of `BaseModel.model_dump_json()`. --- .../jupyter-ai/jupyter_ai/chat_handlers/learn.py | 2 +- packages/jupyter-ai/jupyter_ai/handlers.py | 16 ++++++++-------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/learn.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/learn.py index e0c6139c0..4968063a4 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/learn.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/learn.py @@ -360,7 +360,7 @@ def save(self): def save_metadata(self): with open(METADATA_SAVE_PATH, "w") as f: - f.write(self.metadata.json()) + f.write(self.metadata.model_dump_json()) def load_metadata(self): if not os.path.exists(METADATA_SAVE_PATH): diff --git a/packages/jupyter-ai/jupyter_ai/handlers.py b/packages/jupyter-ai/jupyter_ai/handlers.py index 869eec944..1c38369e3 100644 --- a/packages/jupyter-ai/jupyter_ai/handlers.py +++ b/packages/jupyter-ai/jupyter_ai/handlers.py @@ -66,7 +66,7 @@ async def get(self): history = ChatHistory( messages=self.chat_history, pending_messages=self.pending_messages ) - self.finish(history.json()) + self.finish(history.model_dump_json()) class RootChatHandler(JupyterHandler, websocket.WebSocketHandler): @@ -494,7 +494,7 @@ def get(self): # Finally, yield response. response = ListProvidersResponse(providers=providers) - self.finish(response.json()) + self.finish(response.model_dump_json()) class EmbeddingsModelProviderHandler(ProviderHandler): @@ -517,7 +517,7 @@ def get(self): providers = sorted(providers, key=lambda p: p.name) response = ListProvidersResponse(providers=providers) - self.finish(response.json()) + self.finish(response.model_dump_json()) class GlobalConfigHandler(BaseAPIHandler): @@ -535,7 +535,7 @@ def get(self): if not config: raise HTTPError(500, "No config found.") - self.finish(config.json()) + self.finish(config.model_dump_json()) @web.authenticated def post(self): @@ -587,7 +587,7 @@ def get(self): # if no selected LLM, return an empty response if not self.config_manager.lm_provider: - self.finish(response.json()) + self.finish(response.model_dump_json()) return for id, chat_handler in self.chat_handlers.items(): @@ -616,7 +616,7 @@ def get(self): # sort slash commands by slash id and deliver the response response.slash_commands.sort(key=lambda sc: sc.slash_id) - self.finish(response.json()) + self.finish(response.model_dump_json()) class AutocompleteOptionsHandler(BaseAPIHandler): @@ -640,7 +640,7 @@ def get(self): # if no selected LLM, return an empty response if not self.config_manager.lm_provider: - self.finish(response.json()) + self.finish(response.model_dump_json()) return partial_cmd = self.get_query_argument("partialCommand", None) @@ -666,7 +666,7 @@ def get(self): response.options = ( self._get_slash_command_options() + self._get_context_provider_options() ) - self.finish(response.json()) + self.finish(response.model_dump_json()) def _get_slash_command_options(self) -> List[ListOptionsEntry]: options = [] From 5716f5e1c95008b69f48d80e399d7f1ad76bb690 Mon Sep 17 00:00:00 2001 From: "David L. Qiu" Date: Tue, 14 Jan 2025 09:42:35 -0800 Subject: [PATCH 06/12] replace `.dict()` with `.model_dump()`. `BaseModel.dict()` is deprecated in favor of `BaseModel.model_dump()` in Pydantic v2. --- .../jupyter-ai-magics/jupyter_ai_magics/magics.py | 2 +- .../jupyter_ai/callback_handlers/metadata.py | 2 +- .../jupyter-ai/jupyter_ai/chat_handlers/generate.py | 4 ++-- .../jupyter_ai/completions/handlers/base.py | 2 +- packages/jupyter-ai/jupyter_ai/config_manager.py | 12 ++++++------ packages/jupyter-ai/jupyter_ai/handlers.py | 6 +++--- 6 files changed, 14 insertions(+), 14 deletions(-) diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py b/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py index 199544f58..ebfdf7b4a 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py @@ -442,7 +442,7 @@ def handle_error(self, args: ErrorArgs): prompt = f"Explain the following error:\n\n{last_error}" # Set CellArgs based on ErrorArgs - values = args.dict() + values = args.model_dump() values["type"] = "root" cell_args = CellArgs(**values) diff --git a/packages/jupyter-ai/jupyter_ai/callback_handlers/metadata.py b/packages/jupyter-ai/jupyter_ai/callback_handlers/metadata.py index c409a9633..819f421a7 100644 --- a/packages/jupyter-ai/jupyter_ai/callback_handlers/metadata.py +++ b/packages/jupyter-ai/jupyter_ai/callback_handlers/metadata.py @@ -20,7 +20,7 @@ def requires_no_arguments(func): def convert_to_serializable(obj): """Convert an object to a JSON serializable format""" if hasattr(obj, "dict") and callable(obj.dict) and requires_no_arguments(obj.dict): - return obj.dict() + return obj.model_dump() if hasattr(obj, "__dict__"): return obj.__dict__ return str(obj) diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py index 36222b9fd..b471a7915 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py @@ -13,9 +13,9 @@ from langchain.chains import LLMChain from langchain.llms import BaseLLM from langchain.output_parsers import PydanticOutputParser -from pydantic import BaseModel from langchain.schema.output_parser import BaseOutputParser from langchain_core.prompts import PromptTemplate +from pydantic import BaseModel class OutlineSection(BaseModel): @@ -55,7 +55,7 @@ async def generate_outline(description, llm=None, verbose=False): chain = NotebookOutlineChain.from_llm(llm=llm, parser=parser, verbose=verbose) outline = await chain.apredict(description=description) outline = parser.parse(outline) - return outline.dict() + return outline.model_dump() class CodeImproverChain(LLMChain): diff --git a/packages/jupyter-ai/jupyter_ai/completions/handlers/base.py b/packages/jupyter-ai/jupyter_ai/completions/handlers/base.py index bc75f950c..5d2e16e1c 100644 --- a/packages/jupyter-ai/jupyter_ai/completions/handlers/base.py +++ b/packages/jupyter-ai/jupyter_ai/completions/handlers/base.py @@ -61,7 +61,7 @@ def loop(self) -> AbstractEventLoop: def reply(self, reply: Union[InlineCompletionReply, InlineCompletionStreamChunk]): """Write a reply object to the WebSocket connection.""" - message = reply.dict() + message = reply.model_dump() super().write_message(message) def initialize(self): diff --git a/packages/jupyter-ai/jupyter_ai/config_manager.py b/packages/jupyter-ai/jupyter_ai/config_manager.py index 44879fe32..7a6d44b2f 100644 --- a/packages/jupyter-ai/jupyter_ai/config_manager.py +++ b/packages/jupyter-ai/jupyter_ai/config_manager.py @@ -263,7 +263,7 @@ def _validate_config(self, config: GlobalConfig): read and before every write to the config file. Guarantees that the config file conforms to the JSON Schema, and that the language and embedding models have authn credentials if specified.""" - self.validator.validate(config.dict()) + self.validator.validate(config.model_dump()) # validate language model config if config.model_provider_id: @@ -352,10 +352,10 @@ def _write_config(self, new_config: GlobalConfig): self._validate_config(new_config) with open(self.config_path, "w") as f: - json.dump(new_config.dict(), f, indent=self.indentation_depth) + json.dump(new_config.model_dump(), f, indent=self.indentation_depth) def delete_api_key(self, key_name: str): - config_dict = self._read_config().dict() + config_dict = self._read_config().model_dump() required_keys = [] for provider in [ self.lm_provider, @@ -389,15 +389,15 @@ def update_config(self, config_update: UpdateConfigRequest): # type:ignore if not api_key_value: raise KeyEmptyError("API key value cannot be empty.") - config_dict = self._read_config().dict() - Merger.merge(config_dict, config_update.dict(exclude_unset=True)) + config_dict = self._read_config().model_dump() + Merger.merge(config_dict, config_update.model_dump(exclude_unset=True)) self._write_config(GlobalConfig(**config_dict)) # this cannot be a property, as the parent Configurable already defines the # self.config attr. def get_config(self): config = self._read_config() - config_dict = config.dict(exclude_unset=True) + config_dict = config.model_dump(exclude_unset=True) api_key_names = list(config_dict.pop("api_keys").keys()) return DescribeConfigResponse( **config_dict, api_keys=api_key_names, last_read=self._last_read diff --git a/packages/jupyter-ai/jupyter_ai/handlers.py b/packages/jupyter-ai/jupyter_ai/handlers.py index 1c38369e3..6d08b6f51 100644 --- a/packages/jupyter-ai/jupyter_ai/handlers.py +++ b/packages/jupyter-ai/jupyter_ai/handlers.py @@ -200,7 +200,7 @@ def open(self): """Handles opening of a WebSocket connection. Client ID can be retrieved from `self.client_id`.""" - current_user = self.get_chat_user().dict() + current_user = self.get_chat_user().model_dump() client_id = self.generate_client_id() self.root_chat_handlers[client_id] = self @@ -212,7 +212,7 @@ def open(self): history=ChatHistory( messages=self.chat_history, pending_messages=self.pending_messages ), - ).dict() + ).model_dump() ) self.log.info(f"Client connected. ID: {client_id}") @@ -238,7 +238,7 @@ def broadcast_message(self, message: Message): for client_id in client_ids: client = self.root_chat_handlers[client_id] if client: - client.write_message(message.dict()) + client.write_message(message.model_dump()) # append all messages of type `ChatMessage` directly to the chat history if isinstance( From 588c34bb9492419511e47c3c6bfef60bffd7fa82 Mon Sep 17 00:00:00 2001 From: "David L. Qiu" Date: Tue, 14 Jan 2025 10:48:50 -0800 Subject: [PATCH 07/12] fix BaseProvider.server_settings --- .../tests/test_base_provider.py | 27 ++++++++++++++ .../tests/test_provider_metaclass.py | 35 ------------------- .../jupyter_ai/tests/test_extension.py | 2 +- 3 files changed, 28 insertions(+), 36 deletions(-) create mode 100644 packages/jupyter-ai-magics/jupyter_ai_magics/tests/test_base_provider.py delete mode 100644 packages/jupyter-ai-magics/jupyter_ai_magics/tests/test_provider_metaclass.py diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/tests/test_base_provider.py b/packages/jupyter-ai-magics/jupyter_ai_magics/tests/test_base_provider.py new file mode 100644 index 000000000..4029cd5bd --- /dev/null +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/tests/test_base_provider.py @@ -0,0 +1,27 @@ +from typing import ClassVar, Optional + +from pydantic import BaseModel + +from ..providers import BaseProvider + + +def test_provider_classvars(): + """ + Asserts that class attributes are not omitted due to parent classes defining + an instance field of the same name. This was a bug present in Pydantic v1, + which led to an issue documented in #558. + + This bug is fixed as of `pydantic==2.10.2`, but we will keep this test in + case this behavior changes in future releases. + """ + + class Parent(BaseModel): + test: Optional[str] = None + + class Base(BaseModel): + test: ClassVar[str] + + class Child(Base, Parent): + test: ClassVar[str] = "expected" + + assert Child.test == "expected" diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/tests/test_provider_metaclass.py b/packages/jupyter-ai-magics/jupyter_ai_magics/tests/test_provider_metaclass.py deleted file mode 100644 index 359fe3774..000000000 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/tests/test_provider_metaclass.py +++ /dev/null @@ -1,35 +0,0 @@ -from types import MappingProxyType -from typing import ClassVar, Optional - -from langchain.pydantic_v1 import BaseModel -from pytest import raises - -from ..providers import BaseProvider, ProviderMetaclass - - -def test_provider_metaclass(): - """ - Asserts that the metaclass prevents class attributes from being omitted due - to parent classes defining an instance field of the same name. - - You can reproduce the original issue by removing the - `metaclass=ProviderMetaclass` argument from the definition of `Child`. - """ - - class Parent(BaseModel): - test: Optional[str] - - class Base(BaseModel): - test: ClassVar[str] - - class Child(Base, Parent, metaclass=ProviderMetaclass): - test: ClassVar[str] = "expected" - - assert Child.test == "expected" - - -def test_base_provider_server_settings_read_only(): - BaseProvider.server_settings = MappingProxyType({}) - - with raises(AttributeError, match="'server_settings' attribute was already set"): - BaseProvider.server_settings = MappingProxyType({}) diff --git a/packages/jupyter-ai/jupyter_ai/tests/test_extension.py b/packages/jupyter-ai/jupyter_ai/tests/test_extension.py index 9ae52d8a0..ff33a6934 100644 --- a/packages/jupyter-ai/jupyter_ai/tests/test_extension.py +++ b/packages/jupyter-ai/jupyter_ai/tests/test_extension.py @@ -63,7 +63,7 @@ def ai_extension(jp_serverapp): # may run in parallel setting it with race condition; because we are not testing # the `BaseProvider.server_settings` here, we can just mock the setter settings_mock = mock.PropertyMock() - with mock.patch.object(BaseProvider.__class__, "server_settings", settings_mock): + with mock.patch.object(BaseProvider, "server_settings", settings_mock): yield ai From 6e1e26282c4083c85c2278d22e78e31aaaba277f Mon Sep 17 00:00:00 2001 From: "David L. Qiu" Date: Tue, 14 Jan 2025 10:49:40 -0800 Subject: [PATCH 08/12] fix OpenRouterProvider --- .../partner_providers/openrouter.py | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/openrouter.py b/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/openrouter.py index bbaf9deb4..83083ec64 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/openrouter.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/openrouter.py @@ -2,8 +2,7 @@ from jupyter_ai_magics import BaseProvider from jupyter_ai_magics.providers import EnvAuthStrategy, TextField -from pydantic import model_validator -from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env +from langchain_core.utils import get_from_dict_or_env from langchain_openai import ChatOpenAI @@ -31,7 +30,9 @@ class OpenRouterProvider(BaseProvider, ChatOpenRouter): ] def __init__(self, **kwargs): - openrouter_api_key = kwargs.pop("openrouter_api_key", None) + openrouter_api_key = get_from_dict_or_env( + kwargs, key="openrouter_api_key", env_key="OPENROUTER_API_KEY", default=None + ) openrouter_api_base = kwargs.pop( "openai_api_base", "https://openrouter.ai/api/v1" ) @@ -42,14 +43,6 @@ def __init__(self, **kwargs): **kwargs, ) - @model_validator(mode="after") - def validate_environment(cls, values: Dict) -> Dict: - """Validate that api key and python package exists in environment.""" - values["openai_api_key"] = convert_to_secret_str( - get_from_dict_or_env(values, "openai_api_key", "OPENROUTER_API_KEY") - ) - return super().validate_environment(values) - @classmethod def is_api_key_exc(cls, e: Exception): import openai From 6ebfc0a263edac714cbcb015d6f2cad2abbaa3ec Mon Sep 17 00:00:00 2001 From: "David L. Qiu" Date: Tue, 14 Jan 2025 10:54:25 -0800 Subject: [PATCH 09/12] fix remaining unit tests --- packages/jupyter-ai/jupyter_ai/tests/test_config_manager.py | 2 +- packages/jupyter-ai/jupyter_ai/tests/test_context_providers.py | 2 +- packages/jupyter-ai/jupyter_ai/tests/test_handlers.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/packages/jupyter-ai/jupyter_ai/tests/test_config_manager.py b/packages/jupyter-ai/jupyter_ai/tests/test_config_manager.py index 38212f995..1f1baab4b 100644 --- a/packages/jupyter-ai/jupyter_ai/tests/test_config_manager.py +++ b/packages/jupyter-ai/jupyter_ai/tests/test_config_manager.py @@ -233,7 +233,7 @@ def configure_with_fields(cm: ConfigManager, completions: bool = False): def test_snapshot_default_config(cm: ConfigManager, snapshot): config_from_cm: DescribeConfigResponse = cm.get_config() - assert config_from_cm == snapshot(exclude=lambda prop, path: prop == "last_read") + assert config_from_cm.model_dump() == snapshot(exclude=lambda prop, path: prop == "last_read") def test_init_with_existing_config(cm: ConfigManager, common_cm_kwargs): diff --git a/packages/jupyter-ai/jupyter_ai/tests/test_context_providers.py b/packages/jupyter-ai/jupyter_ai/tests/test_context_providers.py index 132dcf871..f21086c06 100644 --- a/packages/jupyter-ai/jupyter_ai/tests/test_context_providers.py +++ b/packages/jupyter-ai/jupyter_ai/tests/test_context_providers.py @@ -11,7 +11,7 @@ @pytest.fixture def human_chat_message() -> HumanChatMessage: chat_client = ChatClient( - id=0, username="test", initials="test", name="test", display_name="test" + id="test-client-uuid", username="test", initials="test", name="test", display_name="test" ) prompt = ( "@file:test1.py @file @file:dir/test2.md test test\n" diff --git a/packages/jupyter-ai/jupyter_ai/tests/test_handlers.py b/packages/jupyter-ai/jupyter_ai/tests/test_handlers.py index ec8b43278..8a67d6c3b 100644 --- a/packages/jupyter-ai/jupyter_ai/tests/test_handlers.py +++ b/packages/jupyter-ai/jupyter_ai/tests/test_handlers.py @@ -88,7 +88,7 @@ class TestException(Exception): @pytest.fixture def chat_client(): return ChatClient( - id=0, username="test", initials="test", name="test", display_name="test" + id="test-client-uuid", username="test", initials="test", name="test", display_name="test" ) From af6f707197944b8bd15141423b37c1f6def55983 Mon Sep 17 00:00:00 2001 From: "David L. Qiu" Date: Tue, 14 Jan 2025 13:03:46 -0800 Subject: [PATCH 10/12] address all Pydantic v1 deprecation warnings --- .../jupyter_ai_magics/embedding_providers.py | 7 +++--- .../jupyter-ai/jupyter_ai/config_manager.py | 2 +- packages/jupyter-ai/jupyter_ai/models.py | 23 +++++++++---------- 3 files changed, 16 insertions(+), 16 deletions(-) diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/embedding_providers.py b/packages/jupyter-ai-magics/jupyter_ai_magics/embedding_providers.py index 97ce937a6..b3cbd9af7 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/embedding_providers.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/embedding_providers.py @@ -6,7 +6,7 @@ Field, MultiEnvAuthStrategy, ) -from pydantic import BaseModel, Extra +from pydantic import BaseModel, ConfigDict from langchain_community.embeddings import ( GPT4AllEmbeddings, HuggingFaceHubEmbeddings, @@ -17,8 +17,9 @@ class BaseEmbeddingsProvider(BaseModel): """Base class for embedding providers""" - class Config: - extra = Extra.allow + # pydantic v2 model config + # upstream docs: https://docs.pydantic.dev/latest/api/config/#pydantic.config.ConfigDict.extra + model_config = ConfigDict(extra="allow") id: ClassVar[str] = ... """ID for this provider class.""" diff --git a/packages/jupyter-ai/jupyter_ai/config_manager.py b/packages/jupyter-ai/jupyter_ai/config_manager.py index 7a6d44b2f..7c7c2351b 100644 --- a/packages/jupyter-ai/jupyter_ai/config_manager.py +++ b/packages/jupyter-ai/jupyter_ai/config_manager.py @@ -225,7 +225,7 @@ def _create_default_config(self, default_config): self._write_config(GlobalConfig(**default_config)) def _init_defaults(self): - config_keys = GlobalConfig.__fields__.keys() + config_keys = GlobalConfig.model_fields.keys() schema_properties = self.validator.schema.get("properties", {}) default_config = { field: schema_properties.get(field).get("default") for field in config_keys diff --git a/packages/jupyter-ai/jupyter_ai/models.py b/packages/jupyter-ai/jupyter_ai/models.py index 4b011cf90..9dd3375cf 100644 --- a/packages/jupyter-ai/jupyter_ai/models.py +++ b/packages/jupyter-ai/jupyter_ai/models.py @@ -3,7 +3,7 @@ from jupyter_ai_magics import Persona from jupyter_ai_magics.providers import AuthStrategy, Field -from pydantic import BaseModel, validator +from pydantic import BaseModel, field_validator DEFAULT_CHUNK_SIZE = 2000 DEFAULT_CHUNK_OVERLAP = 100 @@ -129,7 +129,8 @@ class AgentStreamChunkMessage(BaseModel): on `BaseAgentMessage.metadata` for information. """ - @validator("metadata") + @field_validator("metadata", mode="before") + @classmethod def validate_metadata(cls, v): """Ensure metadata values are JSON serializable""" try: @@ -252,11 +253,6 @@ class DescribeConfigResponse(BaseModel): completions_fields: Dict[str, Dict[str, Any]] -def forbid_none(cls, v): - assert v is not None, "size may not be None" - return v - - class UpdateConfigRequest(BaseModel): model_provider_id: Optional[str] = None embeddings_provider_id: Optional[str] = None @@ -269,11 +265,14 @@ class UpdateConfigRequest(BaseModel): completions_model_provider_id: Optional[str] = None completions_fields: Optional[Dict[str, Dict[str, Any]]] = None - _validate_send_wse = validator("send_with_shift_enter", allow_reuse=True)( - forbid_none - ) - _validate_api_keys = validator("api_keys", allow_reuse=True)(forbid_none) - _validate_fields = validator("fields", allow_reuse=True)(forbid_none) + @field_validator("send_with_shift_enter", "api_keys", "fields", mode="before") + @classmethod + def ensure_not_none_if_passed(cls, field_val: Any) -> Any: + """ + Field validator ensuring that certain fields are never `None` if set. + """ + assert field_val is not None, "size may not be None" + return field_val class GlobalConfig(BaseModel): From 1fb1948edfedcfe077ac8847bc2a68a7913630fa Mon Sep 17 00:00:00 2001 From: "David L. Qiu" Date: Tue, 14 Jan 2025 13:19:24 -0800 Subject: [PATCH 11/12] pre-commit --- .../jupyter_ai_magics/embedding_providers.py | 2 +- packages/jupyter-ai-magics/jupyter_ai_magics/providers.py | 3 +-- packages/jupyter-ai/jupyter_ai/chat_handlers/base.py | 2 +- packages/jupyter-ai/jupyter_ai/tests/test_config_manager.py | 4 +++- .../jupyter-ai/jupyter_ai/tests/test_context_providers.py | 6 +++++- packages/jupyter-ai/jupyter_ai/tests/test_handlers.py | 6 +++++- 6 files changed, 16 insertions(+), 7 deletions(-) diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/embedding_providers.py b/packages/jupyter-ai-magics/jupyter_ai_magics/embedding_providers.py index b3cbd9af7..f24a9b917 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/embedding_providers.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/embedding_providers.py @@ -6,12 +6,12 @@ Field, MultiEnvAuthStrategy, ) -from pydantic import BaseModel, ConfigDict from langchain_community.embeddings import ( GPT4AllEmbeddings, HuggingFaceHubEmbeddings, QianfanEmbeddingsEndpoint, ) +from pydantic import BaseModel, ConfigDict class BaseEmbeddingsProvider(BaseModel): diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py b/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py index 6ed9082bf..ac9d937a9 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py @@ -24,7 +24,6 @@ PromptTemplate, SystemMessagePromptTemplate, ) -from pydantic import BaseModel, ConfigDict from langchain.schema import LLMResult from langchain.schema.output_parser import StrOutputParser from langchain.schema.runnable import Runnable @@ -32,7 +31,7 @@ from langchain_community.llms import AI21, GPT4All, HuggingFaceEndpoint, Together from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.language_models.llms import BaseLLM - +from pydantic import BaseModel, ConfigDict from . import completion_utils as completion from .models.completion import ( diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py index 5ba2e4b86..4ddacb2f7 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py @@ -36,12 +36,12 @@ ) from jupyter_ai_magics import Persona from jupyter_ai_magics.providers import BaseProvider -from pydantic import BaseModel from langchain_core.messages import AIMessageChunk from langchain_core.runnables import Runnable from langchain_core.runnables.config import RunnableConfig from langchain_core.runnables.config import merge_configs as merge_runnable_configs from langchain_core.runnables.utils import Input +from pydantic import BaseModel if TYPE_CHECKING: from jupyter_ai.context_providers import BaseCommandContextProvider diff --git a/packages/jupyter-ai/jupyter_ai/tests/test_config_manager.py b/packages/jupyter-ai/jupyter_ai/tests/test_config_manager.py index 1f1baab4b..4c24db001 100644 --- a/packages/jupyter-ai/jupyter_ai/tests/test_config_manager.py +++ b/packages/jupyter-ai/jupyter_ai/tests/test_config_manager.py @@ -233,7 +233,9 @@ def configure_with_fields(cm: ConfigManager, completions: bool = False): def test_snapshot_default_config(cm: ConfigManager, snapshot): config_from_cm: DescribeConfigResponse = cm.get_config() - assert config_from_cm.model_dump() == snapshot(exclude=lambda prop, path: prop == "last_read") + assert config_from_cm.model_dump() == snapshot( + exclude=lambda prop, path: prop == "last_read" + ) def test_init_with_existing_config(cm: ConfigManager, common_cm_kwargs): diff --git a/packages/jupyter-ai/jupyter_ai/tests/test_context_providers.py b/packages/jupyter-ai/jupyter_ai/tests/test_context_providers.py index f21086c06..b9cd11ad8 100644 --- a/packages/jupyter-ai/jupyter_ai/tests/test_context_providers.py +++ b/packages/jupyter-ai/jupyter_ai/tests/test_context_providers.py @@ -11,7 +11,11 @@ @pytest.fixture def human_chat_message() -> HumanChatMessage: chat_client = ChatClient( - id="test-client-uuid", username="test", initials="test", name="test", display_name="test" + id="test-client-uuid", + username="test", + initials="test", + name="test", + display_name="test", ) prompt = ( "@file:test1.py @file @file:dir/test2.md test test\n" diff --git a/packages/jupyter-ai/jupyter_ai/tests/test_handlers.py b/packages/jupyter-ai/jupyter_ai/tests/test_handlers.py index 8a67d6c3b..f641ccec2 100644 --- a/packages/jupyter-ai/jupyter_ai/tests/test_handlers.py +++ b/packages/jupyter-ai/jupyter_ai/tests/test_handlers.py @@ -88,7 +88,11 @@ class TestException(Exception): @pytest.fixture def chat_client(): return ChatClient( - id="test-client-uuid", username="test", initials="test", name="test", display_name="test" + id="test-client-uuid", + username="test", + initials="test", + name="test", + display_name="test", ) From 603df9f5b66543df770dd99e1a7b51b2859b84b8 Mon Sep 17 00:00:00 2001 From: "David L. Qiu" Date: Tue, 14 Jan 2025 13:33:33 -0800 Subject: [PATCH 12/12] fix mypy --- packages/jupyter-ai/jupyter_ai/history.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/jupyter-ai/jupyter_ai/history.py b/packages/jupyter-ai/jupyter_ai/history.py index 24a08245d..f2685cdf5 100644 --- a/packages/jupyter-ai/jupyter_ai/history.py +++ b/packages/jupyter-ai/jupyter_ai/history.py @@ -27,7 +27,7 @@ def __init__( self.k = k self.clear_time = clear_time self.cleared_msgs = cleared_msgs - self._all_messages = [] + self._all_messages: List[BaseMessage] = [] @property def messages(self) -> List[BaseMessage]: # type:ignore[override]