44from contextlib import contextmanager
55from dataclasses import dataclass
66from types import MappingProxyType
7- from typing import TypeVar
87
98from sqlmesh .core .config import CategorizerConfig
109from sqlmesh .core .console import set_console
2726
2827logger = logging .getLogger (__name__ )
2928
30- T = TypeVar ("T" , bound = "SQLMeshController" )
29+ T = t .TypeVar ("T" , bound = "SQLMeshController" )
30+ ContextCls = t .TypeVar ("ContextCls" , bound = Context )
31+ ContextFactory = t .Callable [..., ContextCls ]
3132
33+ def default_context_factory (** kwargs : t .Any ) -> Context :
34+ return Context (** kwargs )
35+
36+ DEFAULT_CONTEXT_FACTORY : ContextFactory [Context ] = default_context_factory
3237
3338class PlanOptions (t .TypedDict ):
3439 start : t .NotRequired [TimeLike ]
@@ -88,7 +93,7 @@ def parse_fqn(self) -> SQLMeshParsedFQN:
8893 return parse_fqn (self .fqn )
8994
9095
91- class SQLMeshInstance :
96+ class SQLMeshInstance ( t . Generic [ ContextCls ]) :
9297 """
9398 A class that manages sqlmesh operations and context within a specific
9499 environment. This class will run sqlmesh in a separate thread.
@@ -110,15 +115,15 @@ class SQLMeshInstance:
110115 config : SQLMeshContextConfig
111116 console : EventConsole
112117 logger : logging .Logger
113- context : Context
118+ context : ContextCls
114119 environment : str
115120
116121 def __init__ (
117122 self ,
118123 environment : str ,
119124 console : EventConsole ,
120125 config : SQLMeshContextConfig ,
121- context : Context ,
126+ context : ContextCls ,
122127 logger : logging .Logger ,
123128 ):
124129 self .environment = environment
@@ -167,7 +172,7 @@ def plan(
167172 def run_sqlmesh_thread (
168173 logger : logging .Logger ,
169174 context : Context ,
170- controller : SQLMeshController ,
175+ controller : SQLMeshController [ ContextCls ] ,
171176 environment : str ,
172177 plan_options : PlanOptions ,
173178 default_catalog : str ,
@@ -251,7 +256,7 @@ def run(self, **run_options: t.Unpack[RunOptions]) -> t.Iterator[ConsoleEvent]:
251256 def run_sqlmesh_thread (
252257 logger : logging .Logger ,
253258 context : Context ,
254- controller : SQLMeshController ,
259+ controller : SQLMeshController [ ContextCls ] ,
255260 environment : str ,
256261 run_options : RunOptions ,
257262 ) -> None :
@@ -364,8 +369,7 @@ def non_external_models_dag(self) -> t.Iterable[tuple[Model, set[str]]]:
364369 continue
365370 yield (model , deps )
366371
367-
368- class SQLMeshController :
372+ class SQLMeshController (t .Generic [ContextCls ]):
369373 """Allows control of sqlmesh via a python interface. It is not suggested to
370374 use the constructor of this class directly, but instead use the provided
371375 `setup` or `setup_with_config` class methods.
@@ -405,37 +409,45 @@ class SQLMeshController:
405409 def setup (
406410 cls ,
407411 path : str ,
412+ * ,
413+ context_factory : ContextFactory [ContextCls ],
408414 gateway : str = "local" ,
409415 log_override : logging .Logger | None = None ,
410- ) -> "SQLMeshController" :
416+ ) -> t . Self :
411417 return cls .setup_with_config (
412418 config = SQLMeshContextConfig (path = path , gateway = gateway ),
413419 log_override = log_override ,
420+ context_factory = context_factory ,
414421 )
415422
416423 @classmethod
417424 def setup_with_config (
418- cls : type [T ],
425+ cls ,
426+ * ,
419427 config : SQLMeshContextConfig ,
428+ context_factory : ContextFactory [ContextCls ] = DEFAULT_CONTEXT_FACTORY ,
420429 log_override : logging .Logger | None = None ,
421- ) -> T :
430+ ) -> t . Self :
422431 console = EventConsole (log_override = log_override ) # type: ignore
423432 controller = cls (
424433 console = console ,
425434 config = config ,
426435 log_override = log_override ,
436+ context_factory = context_factory ,
427437 )
428438 return controller
429439
430440 def __init__ (
431441 self ,
432442 config : SQLMeshContextConfig ,
433443 console : EventConsole ,
444+ context_factory : ContextFactory [ContextCls ],
434445 log_override : logging .Logger | None = None ,
435446 ) -> None :
436447 self .config = config
437448 self .console = console
438449 self .logger = log_override or logger
450+ self ._context_factory = context_factory
439451 self ._context_open = False
440452
441453 def set_logger (self , logger : logging .Logger ) -> None :
@@ -448,20 +460,20 @@ def add_event_handler(self, handler: ConsoleEventHandler) -> str:
448460 def remove_event_handler (self , handler_id : str ) -> None :
449461 self .console .remove_handler (handler_id )
450462
451- def _create_context (self ) -> Context :
463+ def _create_context (self ) -> ContextCls :
452464 options : dict [str , t .Any ] = dict (
453465 paths = self .config .path ,
454466 gateway = self .config .gateway ,
455467 )
456468 if self .config .sqlmesh_config :
457469 options ["config" ] = self .config .sqlmesh_config
458470 set_console (self .console )
459- return Context (** options )
471+ return self . _context_factory (** options )
460472
461473 @contextmanager
462474 def instance (
463475 self , environment : str , component : str = "unknown"
464- ) -> t .Iterator [SQLMeshInstance ]:
476+ ) -> t .Iterator [SQLMeshInstance [ ContextCls ] ]:
465477 self .logger .info (
466478 f"Opening sqlmesh instance for env={ environment } component={ component } "
467479 )
0 commit comments