Skip to content

Commit 6e680b1

Browse files
krassowskidlqqq
andauthored
Separate BaseProvider for faster import (#1338)
* Separate `BaseProvider` for faster import * Use `base_provider` imports across ai_magics package * Vendor the `import_attr` code * Support typing and add test to ensure integrity * Add an import for typing block integrity * Import `import_attr` as private * Rename variables as requested in review Co-authored-by: David L. Qiu <[email protected]> --------- Co-authored-by: David L. Qiu <[email protected]>
1 parent 9e535fd commit 6e680b1

File tree

17 files changed

+705
-519
lines changed

17 files changed

+705
-519
lines changed
Lines changed: 87 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,96 @@
1+
from typing import TYPE_CHECKING
2+
3+
from ._import_utils import import_attr as _import_attr
14
from ._version import __version__
25

3-
# expose embedding model providers on the package root
4-
from .embedding_providers import (
5-
BaseEmbeddingsProvider,
6-
GPT4AllEmbeddingsProvider,
7-
HfHubEmbeddingsProvider,
8-
QianfanEmbeddingsEndpointProvider,
9-
)
10-
from .exception import store_exception
11-
from .magics import AiMagics
12-
13-
# expose JupyternautPersona on the package root
14-
# required by `jupyter-ai`.
15-
from .models.persona import JupyternautPersona, Persona
16-
17-
# expose model providers on the package root
18-
from .providers import (
19-
AI21Provider,
20-
BaseProvider,
21-
GPT4AllProvider,
22-
HfHubProvider,
23-
QianfanProvider,
24-
TogetherAIProvider,
25-
)
6+
if TYPE_CHECKING:
7+
# same as dynamic imports but understood by mypy
8+
from .embedding_providers import (
9+
BaseEmbeddingsProvider,
10+
GPT4AllEmbeddingsProvider,
11+
HfHubEmbeddingsProvider,
12+
QianfanEmbeddingsEndpointProvider,
13+
)
14+
from .exception import store_exception
15+
from .magics import AiMagics
16+
from .models.persona import JupyternautPersona, Persona
17+
from .providers import (
18+
AI21Provider,
19+
BaseProvider,
20+
GPT4AllProvider,
21+
HfHubProvider,
22+
QianfanProvider,
23+
TogetherAIProvider,
24+
)
25+
else:
26+
_exports_by_module = {
27+
# expose embedding model providers on the package root
28+
"embedding_providers": [
29+
"BaseEmbeddingsProvider",
30+
"GPT4AllEmbeddingsProvider",
31+
"HfHubEmbeddingsProvider",
32+
"QianfanEmbeddingsEndpointProvider",
33+
],
34+
"exception": ["store_exception"],
35+
"magics": ["AiMagics"],
36+
# expose JupyternautPersona on the package root
37+
# required by `jupyter-ai`.
38+
"models.persona": ["JupyternautPersona", "Persona"],
39+
# expose model providers on the package root
40+
"providers": [
41+
"AI21Provider",
42+
"BaseProvider",
43+
"GPT4AllProvider",
44+
"HfHubProvider",
45+
"QianfanProvider",
46+
"TogetherAIProvider",
47+
],
48+
}
49+
50+
_modules_by_export = {
51+
import_name: module
52+
for module, imports in _exports_by_module.items()
53+
for import_name in imports
54+
}
55+
56+
def __getattr__(export_name: str) -> object:
57+
module_name = _modules_by_export.get(export_name)
58+
result = _import_attr(export_name, module_name, __spec__.parent)
59+
globals()[export_name] = result
60+
return result
2661

2762

2863
def load_ipython_extension(ipython):
29-
ipython.register_magics(AiMagics)
30-
ipython.set_custom_exc((BaseException,), store_exception)
64+
ipython.register_magics(__getattr__("AiMagics"))
65+
ipython.set_custom_exc((BaseException,), __getattr__("store_exception"))
3166

3267

3368
def unload_ipython_extension(ipython):
3469
ipython.set_custom_exc((BaseException,), ipython.CustomTB)
70+
71+
72+
# required to preserve backward compatibility with `from jupyter_ai_magics import *`
73+
__all__ = [
74+
"__version__",
75+
"load_ipython_extension",
76+
"unload_ipython_extension",
77+
"BaseEmbeddingsProvider",
78+
"GPT4AllEmbeddingsProvider",
79+
"HfHubEmbeddingsProvider",
80+
"QianfanEmbeddingsEndpointProvider",
81+
"store_exception",
82+
"AiMagics",
83+
"JupyternautPersona",
84+
"Persona",
85+
"AI21Provider",
86+
"BaseProvider",
87+
"GPT4AllProvider",
88+
"HfHubProvider",
89+
"QianfanProvider",
90+
"TogetherAIProvider",
91+
]
92+
93+
94+
def __dir__():
95+
# Allows more editors (e.g. IPython) to complete on `jupyter_ai_magics.<tab>`
96+
return list(__all__)
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
"""
2+
MIT License
3+
4+
Copyright (c) LangChain, Inc.
5+
6+
Permission is hereby granted, free of charge, to any person obtaining a copy
7+
of this software and associated documentation files (the "Software"), to deal
8+
in the Software without restriction, including without limitation the rights
9+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10+
copies of the Software, and to permit persons to whom the Software is
11+
furnished to do so, subject to the following conditions:
12+
13+
The above copyright notice and this permission notice shall be included in all
14+
copies or substantial portions of the Software.
15+
16+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22+
SOFTWARE.
23+
"""
24+
25+
from importlib import import_module
26+
from typing import Union
27+
28+
29+
def import_attr(
30+
attr_name: str,
31+
module_name: Union[str, None],
32+
package: Union[str, None],
33+
) -> object:
34+
"""Import an attribute from a module located in a package.
35+
36+
This utility function is used in custom __getattr__ methods within __init__.py
37+
files to dynamically import attributes.
38+
39+
Args:
40+
attr_name: The name of the attribute to import.
41+
module_name: The name of the module to import from. If None, the attribute
42+
is imported from the package itself.
43+
package: The name of the package where the module is located.
44+
"""
45+
if module_name == "__module__" or module_name is None:
46+
try:
47+
result = import_module(f".{attr_name}", package=package)
48+
except ModuleNotFoundError:
49+
msg = f"module '{package!r}' has no attribute {attr_name!r}"
50+
raise AttributeError(msg) from None
51+
else:
52+
try:
53+
module = import_module(f".{module_name}", package=package)
54+
except ModuleNotFoundError:
55+
msg = f"module '{package!r}.{module_name!r}' not found"
56+
raise ImportError(msg) from None
57+
result = getattr(module, attr_name)
58+
return result

0 commit comments

Comments
 (0)