|
3 | 3 | from contextlib import contextmanager, nullcontext |
4 | 4 | from pathlib import Path |
5 | 5 | from tempfile import NamedTemporaryFile |
6 | | -from typing import Any, Dict, List, Literal, Optional, Union, overload |
| 6 | +from typing import Any, Dict, Generator, List, Literal, Optional, Union, overload |
7 | 7 | from uuid import uuid4 |
8 | 8 |
|
9 | 9 | from data_seeder import DbtDataSeeder |
| 10 | +from dbt_utils import get_database_and_schema_properties |
10 | 11 | from elementary.clients.dbt.base_dbt_runner import BaseDbtRunner |
11 | 12 | from elementary.clients.dbt.factory import create_dbt_runner |
12 | 13 | from logger import get_logger |
@@ -42,7 +43,7 @@ def get_dbt_runner(target: str, project_dir: str) -> BaseDbtRunner: |
42 | 43 | class DbtProject: |
43 | 44 | def __init__(self, target: str, project_dir: str): |
44 | 45 | self.dbt_runner = get_dbt_runner(target, project_dir) |
45 | | - |
| 46 | + self.target = target |
46 | 47 | self.project_dir_path = Path(project_dir) |
47 | 48 | self.models_dir_path = self.project_dir_path / "models" |
48 | 49 | self.tmp_models_dir_path = self.models_dir_path / "tmp" |
@@ -189,12 +190,16 @@ def test( |
189 | 190 | test_id, materialization |
190 | 191 | ) |
191 | 192 | else: |
| 193 | + database_property, schema_property = get_database_and_schema_properties( |
| 194 | + self.target |
| 195 | + ) |
192 | 196 | props_yaml = { |
193 | 197 | "version": 2, |
194 | 198 | "sources": [ |
195 | 199 | { |
196 | 200 | "name": "test_data", |
197 | | - "schema": f"{{{{ target.schema }}}}{SCHEMA_NAME_SUFFIX}", |
| 201 | + "schema": f"{{{{ target.{schema_property} }}}}{SCHEMA_NAME_SUFFIX}", |
| 202 | + "database": f"{{{{ target.{database_property} }}}}", |
198 | 203 | "tables": [table_yaml], |
199 | 204 | } |
200 | 205 | ], |
@@ -232,9 +237,19 @@ def test( |
232 | 237 | return [test_result] if multiple_results else test_result |
233 | 238 |
|
234 | 239 | def seed(self, data: List[dict], table_name: str): |
235 | | - return DbtDataSeeder( |
| 240 | + with DbtDataSeeder( |
| 241 | + self.dbt_runner, self.project_dir_path, self.seeds_dir_path |
| 242 | + ).seed(data, table_name): |
| 243 | + return |
| 244 | + |
| 245 | + @contextmanager |
| 246 | + def seed_context( |
| 247 | + self, data: List[dict], table_name: str |
| 248 | + ) -> Generator[None, None, None]: |
| 249 | + with DbtDataSeeder( |
236 | 250 | self.dbt_runner, self.project_dir_path, self.seeds_dir_path |
237 | | - ).seed(data, table_name) |
| 251 | + ).seed(data, table_name): |
| 252 | + yield |
238 | 253 |
|
239 | 254 | @contextmanager |
240 | 255 | def create_temp_model_for_existing_table( |
|
0 commit comments