|
| 1 | +import importlib |
| 2 | +from importlib import import_module |
| 3 | + |
| 4 | +from dbt.adapters import factory |
| 5 | +from dbt.adapters.factory import Adapter |
| 6 | +from dbt.events.functions import fire_event |
| 7 | +from dbt.events.types import AdapterRegistered |
| 8 | +from dbt.semver import VersionSpecifier |
| 9 | + |
| 10 | + |
| 11 | +class OpenDbtAdapterContainer(factory.AdapterContainer): |
| 12 | + DBT_CUSTOM_ADAPTER_VAR = 'dbt_custom_adapter' |
| 13 | + def register_adapter(self, config: 'AdapterRequiredConfig') -> None: |
| 14 | + # ==== CUSTOM CODE ==== |
| 15 | + # ==== END CUSTOM CODE ==== |
| 16 | + adapter_name = config.credentials.type |
| 17 | + adapter_type = self.get_adapter_class_by_name(adapter_name) |
| 18 | + adapter_version = import_module(f".{adapter_name}.__version__", "dbt.adapters").version |
| 19 | + # ==== CUSTOM CODE ==== |
| 20 | + custom_adapter_class_name: str = self.get_custom_adapter_config_value(config) |
| 21 | + if custom_adapter_class_name and custom_adapter_class_name.strip(): |
| 22 | + # OVERRIDE DEFAULT ADAPTER BY USER GIVEN ADAPTER CLASS |
| 23 | + adapter_type = self.get_custom_adapter_class_by_name(custom_adapter_class_name) |
| 24 | + # ==== END CUSTOM CODE ==== |
| 25 | + adapter_version_specifier = VersionSpecifier.from_version_string( |
| 26 | + adapter_version |
| 27 | + ).to_version_string() |
| 28 | + fire_event( |
| 29 | + AdapterRegistered(adapter_name=adapter_name, adapter_version=adapter_version_specifier) |
| 30 | + ) |
| 31 | + with self.lock: |
| 32 | + if adapter_name in self.adapters: |
| 33 | + # this shouldn't really happen... |
| 34 | + return |
| 35 | + |
| 36 | + adapter: Adapter = adapter_type(config) # type: ignore |
| 37 | + self.adapters[adapter_name] = adapter |
| 38 | + |
| 39 | + def get_custom_adapter_config_value(self, config: 'AdapterRequiredConfig') -> str: |
| 40 | + # FIRST: it's set as cli value: dbt run --vars {'dbt_custom_adapter': 'custom_adapters.DuckDBAdapterV1Custom'} |
| 41 | + if hasattr(config, 'cli_vars') and self.DBT_CUSTOM_ADAPTER_VAR in config.cli_vars: |
| 42 | + custom_adapter_class_name: str = config.cli_vars[self.DBT_CUSTOM_ADAPTER_VAR] |
| 43 | + if custom_adapter_class_name and custom_adapter_class_name.strip(): |
| 44 | + return custom_adapter_class_name |
| 45 | + # SECOND: it's set inside dbt_project.yml |
| 46 | + if hasattr(config, 'vars') and self.DBT_CUSTOM_ADAPTER_VAR in config.vars.to_dict(): |
| 47 | + custom_adapter_class_name: str = config.vars.to_dict()[self.DBT_CUSTOM_ADAPTER_VAR] |
| 48 | + if custom_adapter_class_name and custom_adapter_class_name.strip(): |
| 49 | + return custom_adapter_class_name |
| 50 | + |
| 51 | + return None |
| 52 | + |
| 53 | + def get_custom_adapter_class_by_name(self, custom_adapter_class_name: str): |
| 54 | + if "." not in custom_adapter_class_name: |
| 55 | + raise ValueError(f"Unexpected adapter class name: `{custom_adapter_class_name}` ," |
| 56 | + f"Expecting something like:`my.sample.library.MyAdapterClass`") |
| 57 | + |
| 58 | + __module, __class = custom_adapter_class_name.rsplit('.', 1) |
| 59 | + try: |
| 60 | + user_adapter_module = importlib.import_module(__module) |
| 61 | + user_adapter_class = getattr(user_adapter_module, __class) |
| 62 | + return user_adapter_class |
| 63 | + except ModuleNotFoundError as mnfe: |
| 64 | + raise Exception(f"Module of provided adapter not found, provided: {custom_adapter_class_name}") from mnfe |
0 commit comments