22
33from contextlib import contextmanager
44from dataclasses import dataclass
5- from typing import TYPE_CHECKING , Any
5+ from typing import TYPE_CHECKING , Any , cast
66
77from sqlspec .config import GenericDatabaseConfig
88from sqlspec .exceptions import ImproperConfigurationError
9- from sqlspec .utils .dataclass import simple_asdict
10- from sqlspec .utils .empty import Empty , EmptyType
9+ from sqlspec .typing import Empty , EmptyType , dataclass_to_dict
1110
1211if TYPE_CHECKING :
13- from collections .abc import Generator
12+ from collections .abc import Generator , Sequence
1413
1514 from duckdb import DuckDBPyConnection
1615
17- __all__ = ("DuckDBConfig" ,)
16+ __all__ = ("DuckDBConfig" , "ExtensionConfig" )
17+
18+
19+ @dataclass
20+ class ExtensionConfig :
21+ """Configuration for a DuckDB extension.
22+
23+ This class provides configuration options for DuckDB extensions, including installation
24+ and post-install configuration settings.
25+
26+ Args:
27+ name: The name of the extension to install
28+ config: Optional configuration settings to apply after installation
29+ force_install: Whether to force reinstall if already present
30+ repository: Optional repository name to install from
31+ repository_url: Optional repository URL to install from
32+ version: Optional version of the extension to install
33+ """
34+
35+ name : str
36+ config : dict [str , Any ] | None = None
37+ force_install : bool = False
38+ repository : str | None = None
39+ repository_url : str | None = None
40+ version : str | None = None
41+
42+ @classmethod
43+ def from_dict (cls , name : str , config : dict [str , Any ] | bool | None = None ) -> ExtensionConfig :
44+ """Create an ExtensionConfig from a configuration dictionary.
45+
46+ Args:
47+ name: The name of the extension
48+ config: Configuration dictionary that may contain settings
49+
50+ Returns:
51+ A new ExtensionConfig instance
52+ """
53+ if config is None :
54+ return cls (name = name )
55+
56+ if not isinstance (config , dict ):
57+ config = {"force_install" : bool (config )}
58+
59+ install_args = {
60+ key : config .pop (key )
61+ for key in ["force_install" , "repository" , "repository_url" , "version" , "config" , "name" ]
62+ if key in config
63+ }
64+ return cls (name = name , ** install_args )
1865
1966
2067@dataclass
@@ -39,31 +86,100 @@ class DuckDBConfig(GenericDatabaseConfig):
3986 For details see: https://duckdb.org/docs/api/python/overview#connection-options
4087 """
4188
89+ extensions : Sequence [ExtensionConfig ] | EmptyType = Empty
90+ """A sequence of extension configurations to install and configure upon connection creation."""
91+
92+ def __post_init__ (self ) -> None :
93+ """Post-initialization validation and processing.
94+
95+ This method handles merging extension configurations from both the extensions field
96+ and the config dictionary, if present. The config['extensions'] field can be either:
97+ - A dictionary mapping extension names to their configurations
98+ - A list of extension names (which will be installed with force_install=True)
99+
100+ Raises:
101+ ImproperConfigurationError: If there are duplicate extension configurations.
102+ """
103+ if self .config is Empty :
104+ self .config = {}
105+
106+ if self .extensions is Empty :
107+ self .extensions = []
108+ # this is purely for mypy
109+ assert isinstance (self .config , dict ) # noqa: S101
110+ assert isinstance (self .extensions , list ) # noqa: S101
111+
112+ _e = self .config .pop ("extensions" , {})
113+ if not isinstance (_e , (dict , list , tuple )):
114+ msg = "When configuring extensions in the 'config' dictionary, the value must be a dictionary or sequence of extension names"
115+ raise ImproperConfigurationError (msg )
116+ if not isinstance (_e , dict ):
117+ _e = {str (ext ): {"force_install" : False } for ext in _e }
118+
119+ if len (set (_e .keys ()).intersection ({ext .name for ext in self .extensions })) > 0 :
120+ msg = "Configuring the same extension in both 'extensions' and as a key in 'config['extensions']' is not allowed"
121+ raise ImproperConfigurationError (msg )
122+
123+ self .extensions .extend ([ExtensionConfig .from_dict (name , ext_config ) for name , ext_config in _e .items ()])
124+
125+ def _configure_extensions (self , connection : DuckDBPyConnection ) -> None :
126+ """Configure extensions for the connection.
127+
128+ Args:
129+ connection: The DuckDB connection to configure extensions for.
130+
131+ Raises:
132+ ImproperConfigurationError: If extension installation or configuration fails.
133+ """
134+ if self .extensions is Empty :
135+ return
136+
137+ for extension in cast ("list[ExtensionConfig]" , self .extensions ):
138+ try :
139+ if extension .force_install :
140+ connection .install_extension (
141+ extension = extension .name ,
142+ force_install = extension .force_install ,
143+ repository = extension .repository ,
144+ repository_url = extension .repository_url ,
145+ version = extension .version ,
146+ )
147+ connection .load_extension (extension .name )
148+
149+ if extension .config :
150+ for key , value in extension .config .items ():
151+ connection .execute (f"SET { key } ={ value } " )
152+ except Exception as e :
153+ msg = f"Failed to configure extension { extension .name } . Error: { e !s} "
154+ raise ImproperConfigurationError (msg ) from e
155+
42156 @property
43157 def connection_config_dict (self ) -> dict [str , Any ]:
44158 """Return the connection configuration as a dict.
45159
46160 Returns:
47161 A string keyed dict of config kwargs for the duckdb.connect() function.
48162 """
49- config = simple_asdict (self , exclude_empty = True , convert_nested = False )
163+ config = dataclass_to_dict (self , exclude_empty = True , exclude = { "extensions" } , convert_nested = False )
50164 if not config .get ("database" ):
51165 config ["database" ] = ":memory:"
52166 return config
53167
54168 def create_connection (self ) -> DuckDBPyConnection :
55- """Create and return a new database connection.
169+ """Create and return a new database connection with configured extensions .
56170
57171 Returns:
58- A new DuckDB connection instance.
172+ A new DuckDB connection instance with extensions installed and configured .
59173
60174 Raises:
61- ImproperConfigurationError: If the connection could not be established.
175+ ImproperConfigurationError: If the connection could not be established or extensions could not be configured .
62176 """
63177 import duckdb
64178
65179 try :
66- return duckdb .connect (** self .connection_config_dict )
180+ connection = duckdb .connect (** self .connection_config_dict )
181+ self ._configure_extensions (connection )
182+ return connection
67183 except Exception as e :
68184 msg = f"Could not configure the DuckDB connection. Error: { e !s} "
69185 raise ImproperConfigurationError (msg ) from e
0 commit comments