Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 0 additions & 34 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,37 +17,3 @@ jobs:
run: jlpm
- name: Lint TypeScript source
run: jlpm lerna run lint:check

lint_py_imports:
name: Lint Python imports
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Echo environment details
run: |
which python
which pip
python --version
pip --version

# see #546 for context on why this is necessary
- name: Create venv
run: |
python -m venv lint_py_imports

- name: Install job dependencies
run: |
source ./lint_py_imports/bin/activate
pip install jupyterlab~=4.0
pip install import-linter~=1.12.1

- name: Install Jupyter AI packages from source
run: |
source ./lint_py_imports/bin/activate
jlpm install
jlpm install-from-src

- name: Lint Python imports
run: |
source ./lint_py_imports/bin/activate
lint-imports
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,20 @@
Field,
MultiEnvAuthStrategy,
)
from langchain.pydantic_v1 import BaseModel, Extra
from langchain_community.embeddings import (
GPT4AllEmbeddings,
HuggingFaceHubEmbeddings,
QianfanEmbeddingsEndpoint,
)
from pydantic import BaseModel, ConfigDict


class BaseEmbeddingsProvider(BaseModel):
"""Base class for embedding providers"""

class Config:
extra = Extra.allow
# pydantic v2 model config
# upstream docs: https://docs.pydantic.dev/latest/api/config/#pydantic.config.ConfigDict.extra
model_config = ConfigDict(extra="allow")

id: ClassVar[str] = ...
"""ID for this provider class."""
Expand Down
2 changes: 1 addition & 1 deletion packages/jupyter-ai-magics/jupyter_ai_magics/magics.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,7 @@ def handle_error(self, args: ErrorArgs):

prompt = f"Explain the following error:\n\n{last_error}"
# Set CellArgs based on ErrorArgs
values = args.dict()
values = args.model_dump()
values["type"] = "root"
cell_args = CellArgs(**values)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import List, Literal, Optional

from langchain.pydantic_v1 import BaseModel
from pydantic import BaseModel


class InlineCompletionRequest(BaseModel):
Expand All @@ -21,12 +21,12 @@ class InlineCompletionRequest(BaseModel):
# whether to stream the response (if supported by the model)
stream: bool
# path to the notebook of file for which the completions are generated
path: Optional[str]
path: Optional[str] = None
# language inferred from the document mime type (if possible)
language: Optional[str]
language: Optional[str] = None
# identifier of the cell for which the completions are generated if in a notebook
# previous cells and following cells can be used to learn the wider context
cell_id: Optional[str]
cell_id: Optional[str] = None


class InlineCompletionItem(BaseModel):
Expand All @@ -36,9 +36,9 @@ class InlineCompletionItem(BaseModel):
"""

insertText: str
filterText: Optional[str]
isIncomplete: Optional[bool]
token: Optional[str]
filterText: Optional[str] = None
isIncomplete: Optional[bool] = None
token: Optional[str] = None


class CompletionError(BaseModel):
Expand All @@ -59,7 +59,7 @@ class InlineCompletionReply(BaseModel):
list: InlineCompletionList
# number of request for which we are replying
reply_to: int
error: Optional[CompletionError]
error: Optional[CompletionError] = None


class InlineCompletionStreamChunk(BaseModel):
Expand All @@ -69,7 +69,7 @@ class InlineCompletionStreamChunk(BaseModel):
response: InlineCompletionItem
reply_to: int
done: bool
error: Optional[CompletionError]
error: Optional[CompletionError] = None


__all__ = [
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from langchain.pydantic_v1 import BaseModel
from pydantic import BaseModel


class Persona(BaseModel):
Expand Down
20 changes: 10 additions & 10 deletions packages/jupyter-ai-magics/jupyter_ai_magics/parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Literal, Optional, get_args

import click
from langchain.pydantic_v1 import BaseModel
from pydantic import BaseModel

FORMAT_CHOICES_TYPE = Literal[
"code", "html", "image", "json", "markdown", "math", "md", "text"
Expand Down Expand Up @@ -46,23 +46,23 @@ class CellArgs(BaseModel):
type: Literal["root"] = "root"
model_id: str
format: FORMAT_CHOICES_TYPE
model_parameters: Optional[str]
model_parameters: Optional[str] = None
# The following parameters are required only for SageMaker models
region_name: Optional[str]
request_schema: Optional[str]
response_path: Optional[str]
region_name: Optional[str] = None
request_schema: Optional[str] = None
response_path: Optional[str] = None


# Should match CellArgs
class ErrorArgs(BaseModel):
type: Literal["error"] = "error"
model_id: str
format: FORMAT_CHOICES_TYPE
model_parameters: Optional[str]
model_parameters: Optional[str] = None
# The following parameters are required only for SageMaker models
region_name: Optional[str]
request_schema: Optional[str]
response_path: Optional[str]
region_name: Optional[str] = None
request_schema: Optional[str] = None
response_path: Optional[str] = None


class HelpArgs(BaseModel):
Expand All @@ -75,7 +75,7 @@ class VersionArgs(BaseModel):

class ListArgs(BaseModel):
type: Literal["list"] = "list"
provider_id: Optional[str]
provider_id: Optional[str] = None


class RegisterArgs(BaseModel):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@

from jupyter_ai_magics import BaseProvider
from jupyter_ai_magics.providers import EnvAuthStrategy, TextField
from langchain_core.pydantic_v1 import root_validator
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
from langchain_core.utils import get_from_dict_or_env
from langchain_openai import ChatOpenAI


Expand Down Expand Up @@ -31,7 +30,9 @@ class OpenRouterProvider(BaseProvider, ChatOpenRouter):
]

def __init__(self, **kwargs):
openrouter_api_key = kwargs.pop("openrouter_api_key", None)
openrouter_api_key = get_from_dict_or_env(
kwargs, key="openrouter_api_key", env_key="OPENROUTER_API_KEY", default=None
)
openrouter_api_base = kwargs.pop(
"openai_api_base", "https://openrouter.ai/api/v1"
)
Expand All @@ -42,14 +43,6 @@ def __init__(self, **kwargs):
**kwargs,
)

@root_validator(pre=False, skip_on_failure=True, allow_reuse=True)
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
values["openai_api_key"] = convert_to_secret_str(
get_from_dict_or_env(values, "openai_api_key", "OPENROUTER_API_KEY")
)
return super().validate_environment(values)

@classmethod
def is_api_key_exc(cls, e: Exception):
import openai
Expand Down
82 changes: 21 additions & 61 deletions packages/jupyter-ai-magics/jupyter_ai_magics/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,22 +24,14 @@
PromptTemplate,
SystemMessagePromptTemplate,
)
from langchain.pydantic_v1 import BaseModel, Extra
from langchain.schema import LLMResult
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import Runnable
from langchain_community.chat_models import QianfanChatEndpoint
from langchain_community.llms import AI21, GPT4All, HuggingFaceEndpoint, Together
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.language_models.llms import BaseLLM

# this is necessary because `langchain.pydantic_v1.main` does not include
# `ModelMetaclass`, as it is not listed in `__all__` by the `pydantic.main`
# subpackage.
try:
from pydantic.v1.main import ModelMetaclass
except:
from pydantic.main import ModelMetaclass
from pydantic import BaseModel, ConfigDict

from . import completion_utils as completion
from .models.completion import (
Expand Down Expand Up @@ -122,7 +114,7 @@ class EnvAuthStrategy(BaseModel):
name: str
"""The name of the environment variable, e.g. `'ANTHROPIC_API_KEY'`."""

keyword_param: Optional[str]
keyword_param: Optional[str] = None
"""
If unset (default), the authentication token is provided as a keyword
argument with the parameter equal to the environment variable name in
Expand Down Expand Up @@ -177,51 +169,10 @@ class IntegerField(BaseModel):
Field = Union[TextField, MultilineTextField, IntegerField]


class ProviderMetaclass(ModelMetaclass):
"""
A metaclass that ensures all class attributes defined inline within the
class definition are accessible and included in `Class.__dict__`.

This is necessary because Pydantic drops any ClassVars that are defined as
an instance field by a parent class, even if they are defined inline within
the class definition. We encountered this case when `langchain` added a
`name` attribute to a parent class shared by all `Provider`s, which caused
`Provider.name` to be inaccessible. See #558 for more info.
"""

def __new__(mcs, name, bases, namespace, **kwargs):
cls = super().__new__(mcs, name, bases, namespace, **kwargs)
for key in namespace:
# skip private class attributes
if key.startswith("_"):
continue
# skip class attributes already listed in `cls.__dict__`
if key in cls.__dict__:
continue

setattr(cls, key, namespace[key])

return cls

@property
def server_settings(cls):
return cls._server_settings

@server_settings.setter
def server_settings(cls, value):
if cls._server_settings is not None:
raise AttributeError("'server_settings' attribute was already set")
cls._server_settings = value

_server_settings = None


class BaseProvider(BaseModel, metaclass=ProviderMetaclass):
#
# pydantic config
#
class Config:
extra = Extra.allow
class BaseProvider(BaseModel):
# pydantic v2 model config
# upstream docs: https://docs.pydantic.dev/latest/api/config/#pydantic.config.ConfigDict.extra
model_config = ConfigDict(extra="allow")

#
# class attrs
Expand All @@ -236,15 +187,25 @@ class Config:
"""List of supported models by their IDs. For registry providers, this will
be just ["*"]."""

help: ClassVar[str] = None
help: ClassVar[Optional[str]] = None
"""Text to display in lieu of a model list for a registry provider that does
not provide a list of models."""

model_id_key: ClassVar[str] = ...
"""Kwarg expected by the upstream LangChain provider."""
model_id_key: ClassVar[Optional[str]] = None
"""
Optional field which specifies the key under which `model_id` is passed to
the parent LangChain class.

If unset, this defaults to "model_id".
"""

model_id_label: ClassVar[str] = ""
"""Human-readable label of the model ID."""
model_id_label: ClassVar[Optional[str]] = None
"""
Optional field which sets the label shown in the UI allowing users to
select/type a model ID.

If unset, the label shown in the UI defaults to "Model ID".
"""

pypi_package_deps: ClassVar[List[str]] = []
"""List of PyPi package dependencies."""
Expand Down Expand Up @@ -586,7 +547,6 @@ def __init__(self, **kwargs):

id = "gpt4all"
name = "GPT4All"
docs = "https://docs.gpt4all.io/gpt4all_python.html"
models = [
"ggml-gpt4all-j-v1.2-jazzy",
"ggml-gpt4all-j-v1.3-groovy",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from typing import ClassVar, Optional

from pydantic import BaseModel

from ..providers import BaseProvider


def test_provider_classvars():
"""
Asserts that class attributes are not omitted due to parent classes defining
an instance field of the same name. This was a bug present in Pydantic v1,
which led to an issue documented in #558.

This bug is fixed as of `pydantic==2.10.2`, but we will keep this test in
case this behavior changes in future releases.
"""

class Parent(BaseModel):
test: Optional[str] = None

class Base(BaseModel):
test: ClassVar[str]

class Child(Base, Parent):
test: ClassVar[str] = "expected"

assert Child.test == "expected"
Loading
Loading