Skip to content

Commit f4afa85

Browse files
author
Timon Viola
committed
feat: add poc type casting reflect op
1 parent 6738805 commit f4afa85

File tree

1 file changed

+111
-0
lines changed

1 file changed

+111
-0
lines changed
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING, Any
4+
5+
from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator
6+
from sqlalchemy import ReflectedColumn, cast, inspect, select
7+
8+
from dagcellent.data_utils.sql_reflection import (
9+
reflect_meta_data,
10+
safe_add_database_to_connection,
11+
)
12+
13+
if TYPE_CHECKING:
14+
from sqlalchemy.engine.interfaces import ReflectedColumn
15+
16+
17+
class SQLReflectOperator(SQLExecuteQueryOperator):
18+
"""Operator to perform SQLAlchemy like database reflection.
19+
20+
The target_table is returned as a `SELECT` statement DDL.
21+
22+
Example:
23+
The example below illustrates a PostrgeSQL database and the
24+
returned SELECT query.
25+
26+
```sql
27+
CREATE TABLE IF NOT EXISTS ats
28+
(
29+
departure_id varchar(40) COLLATE pg_catalog."default" NOT NULL,
30+
route_leg_code varchar(40) COLLATE pg_catalog."default" NOT NULL,
31+
planned_departure_date_time timestamp without time zone NOT NULL,
32+
ferry_name varchar(40) COLLATE pg_catalog."default" NOT NULL,
33+
cnv_outlet varchar(40) COLLATE pg_catalog."default" NOT NULL,
34+
store_name varchar(40) COLLATE pg_catalog."default" NOT NULL,
35+
store_item varchar(200) COLLATE pg_catalog."default" NOT NULL,
36+
predicted_sales double precision NOT NULL,
37+
good boolean DEFAULT false,
38+
CONSTRAINT ats_pkey PRIMARY KEY (departure_id, route_leg_code, ferry_name, cnv_outlet, store_name, store_item)
39+
);
40+
```
41+
42+
```python
43+
reflect_table = SQLReflectOperator(
44+
table_name="ats",
45+
task_id="reflect_database",
46+
conn_id=CONN_ID,
47+
)
48+
```
49+
50+
```sql
51+
SELECT
52+
ats.departure_id,
53+
ats.route_leg_code,
54+
ats.planned_departure_date_time,
55+
ats.ferry_name,
56+
ats.cnv_outlet,
57+
ats.store_name,
58+
ats.store_item,
59+
ats.predicted_sales,
60+
ats.good
61+
FROM ats
62+
```
63+
"""
64+
65+
def __init__(
66+
self,
67+
*,
68+
table_name: str,
69+
database: str | None = None,
70+
schema: str | None = None,
71+
**kwargs: Any,
72+
) -> None:
73+
"""Init.
74+
75+
Args:
76+
table: target table name
77+
kwargs: additional arguments to pass to SQLExecuteQueryOperator
78+
"""
79+
# TODO: deprecate this, for now inheritance needs debugging
80+
self.database_name = database
81+
kwargs["database"] = database
82+
self.table_name = table_name
83+
self.schema = schema
84+
super().__init__(sql="", **kwargs) # type: ignore
85+
86+
def execute(self, context: Any):
87+
hook = self.get_db_hook()
88+
engine = hook.get_sqlalchemy_engine() # type: ignore
89+
self.log.debug("%s", f"{self.database_name=}")
90+
if self.database_name:
91+
# inject database name if not defined in connection URI
92+
self.log.debug("Target connection: %s", f"{engine.url.database=}")
93+
engine = safe_add_database_to_connection(engine, self.database_name)
94+
self.log.debug("Target connection: %s", engine.url)
95+
96+
table = reflect_meta_data(engine, schema=self.schema, table=self.table_name)
97+
if table is None: # type: ignore[reportUnnecessaryCondition]
98+
raise ValueError(f"Table {self.table_name} not found in the database.")
99+
100+
self.log.debug("::group::🦆")
101+
self.log.debug("Table: %s", table.__dict__)
102+
self.log.debug("::endgroup::")
103+
104+
reflected_columns: ReflectedColumn = inspect(engine).get_columns(
105+
self.table_name
106+
)
107+
select_ddl = select(
108+
*[cast(col["name"], col["type"]) for col in reflected_columns]
109+
)
110+
111+
return select_ddl

0 commit comments

Comments
 (0)