|
| 1 | +""" |
| 2 | +Base clase for SQL adapters for data algebra. |
| 3 | +""" |
| 4 | + |
1 | 5 | import math |
2 | 6 | import re |
3 | 7 | from collections import OrderedDict |
@@ -53,6 +57,7 @@ def __init__( |
53 | 57 | def __str__(self): |
54 | 58 | return self.__repr__() |
55 | 59 |
|
| 60 | + # noinspection PyUnusedLocal |
56 | 61 | def _repr_pretty_(self, p, cycle): |
57 | 62 | """ |
58 | 63 | IPython pretty print, used at implicit print time |
@@ -103,6 +108,7 @@ def _db_mean_expr(dbmodel, expression): |
103 | 108 | ) |
104 | 109 |
|
105 | 110 |
|
| 111 | +# noinspection PyUnusedLocal |
106 | 112 | def _db_size_expr(dbmodel, expression): |
107 | 113 | return "SUM(1)" |
108 | 114 |
|
@@ -306,9 +312,6 @@ def _db_nunique_expr(dbmodel, expression): |
306 | 312 | ) |
307 | 313 |
|
308 | 314 |
|
309 | | -# fns that had been in bigquery_user_fns |
310 | | - |
311 | | - |
312 | 315 | def _as_int64(dbmodel, expression): |
313 | 316 | return ( |
314 | 317 | "CAST(" |
@@ -487,7 +490,7 @@ def _base_Sunday(dbmodel, expression): |
487 | 490 | "**": _db_pow_expr, |
488 | 491 | "nunique": _db_nunique_expr, |
489 | 492 | "mapv": _db_mapv, |
490 | | - # fns that had been in bigquery_user_fns |
| 493 | + # additional fns |
491 | 494 | "as_int64": _as_int64, |
492 | 495 | "as_str": _as_str, |
493 | 496 | "trimstr": _trimstr, |
@@ -588,8 +591,13 @@ def __init__( |
588 | 591 | self.union_all_term_start = union_all_term_start |
589 | 592 | self.union_all_term_end = union_all_term_end |
590 | 593 |
|
591 | | - def db_handle(self, conn): |
592 | | - return DBHandle(db_model=self, conn=conn) |
| 594 | + def db_handle(self, conn, db_engine=None): |
| 595 | + """ |
| 596 | +
|
| 597 | + :param conn: database connection |
| 598 | + :param db_engine: optional sqlalchemy style engine (for closing) |
| 599 | + """ |
| 600 | + return DBHandle(db_model=self, conn=conn, db_engine=db_engine) |
593 | 601 |
|
594 | 602 | def prepare_connection(self, conn): |
595 | 603 | pass |
@@ -630,8 +638,8 @@ def read_query(self, conn, q): |
630 | 638 | def table_exists(self, conn, table_name: str) -> bool: |
631 | 639 | assert isinstance(table_name, str) |
632 | 640 | q_table_name = self.quote_table_name(table_name) |
633 | | - # noinspection PyBroadException |
634 | 641 | table_exists = True |
| 642 | + # noinspection PyBroadException |
635 | 643 | try: |
636 | 644 | self.read_query(conn, "SELECT * FROM " + q_table_name + " LIMIT 1") |
637 | 645 | except Exception: |
@@ -1856,9 +1864,19 @@ def __repr__(self): |
1856 | 1864 |
|
1857 | 1865 |
|
1858 | 1866 | class DBHandle: |
1859 | | - def __init__(self, *, db_model: DBModel, conn): |
| 1867 | + """ |
| 1868 | + Container for database connection handles. |
| 1869 | + """ |
| 1870 | + def __init__(self, *, db_model: DBModel, conn, db_engine=None): |
| 1871 | + """ |
| 1872 | +
|
| 1873 | + :param db_model: associated database model |
| 1874 | + :param conn: database connection |
| 1875 | + :param db_engine: optional sqlalchemy style engine (for closing) |
| 1876 | + """ |
1860 | 1877 | assert isinstance(db_model, DBModel) |
1861 | 1878 | self.db_model = db_model |
| 1879 | + self.db_engine = db_engine |
1862 | 1880 | self.conn = conn |
1863 | 1881 |
|
1864 | 1882 | def __enter__(self): |
@@ -1924,8 +1942,21 @@ def __repr__(self): |
1924 | 1942 |
|
1925 | 1943 | def close(self) -> None: |
1926 | 1944 | if self.conn is not None: |
1927 | | - try: |
1928 | | - self.conn.close() |
1929 | | - except Exception: |
1930 | | - pass |
| 1945 | + caught = None |
| 1946 | + if self.db_engine is not None: |
| 1947 | + # sqlalchemy style handle |
| 1948 | + # noinspection PyBroadException |
| 1949 | + try: |
| 1950 | + self.db_engine.dispose() |
| 1951 | + except Exception as ex: |
| 1952 | + caught = ex |
| 1953 | + else: |
| 1954 | + # noinspection PyBroadException |
| 1955 | + try: |
| 1956 | + self.conn.close() |
| 1957 | + except Exception as ex: |
| 1958 | + caught = ex |
| 1959 | + self.db_engine = None |
1931 | 1960 | self.conn = None |
| 1961 | + if caught is not None: |
| 1962 | + raise ValueError('close caught: ' + str(caught)) |
0 commit comments