|
4 | 4 | """ |
5 | 5 |
|
6 | 6 | import logging |
7 | | -from typing import TYPE_CHECKING, Any, Union |
| 7 | +from collections.abc import Mapping, Sequence |
| 8 | +from typing import TYPE_CHECKING, Any, Union, cast |
8 | 9 |
|
9 | 10 | import sqlglot |
10 | 11 | from sqlglot import exp |
|
46 | 47 | from sqlspec.exceptions import SQLBuilderError |
47 | 48 |
|
48 | 49 | if TYPE_CHECKING: |
| 50 | + from collections.abc import Mapping, Sequence |
| 51 | + |
49 | 52 | from sqlspec.builder._expression_wrappers import ExpressionWrapper |
50 | 53 |
|
51 | 54 |
|
|
73 | 76 | "Truncate", |
74 | 77 | "Update", |
75 | 78 | "WindowFunctionBuilder", |
| 79 | + "build_copy_from_statement", |
| 80 | + "build_copy_statement", |
| 81 | + "build_copy_to_statement", |
76 | 82 | "sql", |
77 | 83 | ) |
78 | 84 |
|
|
108 | 114 | } |
109 | 115 |
|
110 | 116 |
|
| 117 | +def _normalize_copy_dialect(dialect: DialectType | None) -> str: |
| 118 | + if dialect is None: |
| 119 | + return "postgres" |
| 120 | + if isinstance(dialect, str): |
| 121 | + return dialect |
| 122 | + return str(dialect) |
| 123 | + |
| 124 | + |
| 125 | +def _to_copy_schema(table: str, columns: "Sequence[str] | None") -> exp.Expression: |
| 126 | + base = exp.table_(table) |
| 127 | + if not columns: |
| 128 | + return base |
| 129 | + column_nodes = [exp.column(column_name) for column_name in columns] |
| 130 | + return exp.Schema(this=base, expressions=column_nodes) |
| 131 | + |
| 132 | + |
| 133 | +def _build_copy_expression( |
| 134 | + *, direction: str, table: str, location: str, columns: "Sequence[str] | None", options: "Mapping[str, Any] | None" |
| 135 | +) -> exp.Copy: |
| 136 | + copy_args: dict[str, Any] = {"this": _to_copy_schema(table, columns), "files": [exp.Literal.string(location)]} |
| 137 | + |
| 138 | + if direction == "from": |
| 139 | + copy_args["kind"] = True |
| 140 | + elif direction == "to": |
| 141 | + copy_args["kind"] = False |
| 142 | + |
| 143 | + if options: |
| 144 | + params: list[exp.CopyParameter] = [] |
| 145 | + for key, value in options.items(): |
| 146 | + identifier = exp.Var(this=str(key).upper()) |
| 147 | + value_expression: exp.Expression |
| 148 | + if isinstance(value, bool): |
| 149 | + value_expression = exp.Boolean(this=value) |
| 150 | + elif value is None: |
| 151 | + value_expression = exp.null() |
| 152 | + elif isinstance(value, (int, float)): |
| 153 | + value_expression = exp.Literal.number(value) |
| 154 | + elif isinstance(value, (list, tuple)): |
| 155 | + elements = [exp.Literal.string(str(item)) for item in value] |
| 156 | + value_expression = exp.Array(expressions=elements) |
| 157 | + else: |
| 158 | + value_expression = exp.Literal.string(str(value)) |
| 159 | + params.append(exp.CopyParameter(this=identifier, expression=value_expression)) |
| 160 | + copy_args["params"] = params |
| 161 | + |
| 162 | + return exp.Copy(**copy_args) |
| 163 | + |
| 164 | + |
| 165 | +def build_copy_statement( |
| 166 | + *, |
| 167 | + direction: str, |
| 168 | + table: str, |
| 169 | + location: str, |
| 170 | + columns: "Sequence[str] | None" = None, |
| 171 | + options: "Mapping[str, Any] | None" = None, |
| 172 | + dialect: DialectType | None = None, |
| 173 | +) -> SQL: |
| 174 | + expression = _build_copy_expression( |
| 175 | + direction=direction, table=table, location=location, columns=columns, options=options |
| 176 | + ) |
| 177 | + rendered = expression.sql(dialect=_normalize_copy_dialect(dialect)) |
| 178 | + return SQL(rendered) |
| 179 | + |
| 180 | + |
| 181 | +def build_copy_from_statement( |
| 182 | + table: str, |
| 183 | + source: str, |
| 184 | + *, |
| 185 | + columns: "Sequence[str] | None" = None, |
| 186 | + options: "Mapping[str, Any] | None" = None, |
| 187 | + dialect: DialectType | None = None, |
| 188 | +) -> SQL: |
| 189 | + return build_copy_statement( |
| 190 | + direction="from", table=table, location=source, columns=columns, options=options, dialect=dialect |
| 191 | + ) |
| 192 | + |
| 193 | + |
| 194 | +def build_copy_to_statement( |
| 195 | + table: str, |
| 196 | + target: str, |
| 197 | + *, |
| 198 | + columns: "Sequence[str] | None" = None, |
| 199 | + options: "Mapping[str, Any] | None" = None, |
| 200 | + dialect: DialectType | None = None, |
| 201 | +) -> SQL: |
| 202 | + return build_copy_statement( |
| 203 | + direction="to", table=table, location=target, columns=columns, options=options, dialect=dialect |
| 204 | + ) |
| 205 | + |
| 206 | + |
111 | 207 | class SQLFactory: |
112 | 208 | """Factory for creating SQL builders and column expressions.""" |
113 | 209 |
|
@@ -479,6 +575,56 @@ def comment_on(self, dialect: DialectType = None) -> "CommentOn": |
479 | 575 | """ |
480 | 576 | return CommentOn(dialect=dialect or self.dialect) |
481 | 577 |
|
| 578 | + def copy_from( |
| 579 | + self, |
| 580 | + table: str, |
| 581 | + source: str, |
| 582 | + *, |
| 583 | + columns: "Sequence[str] | None" = None, |
| 584 | + options: "Mapping[str, Any] | None" = None, |
| 585 | + dialect: DialectType | None = None, |
| 586 | + ) -> SQL: |
| 587 | + """Build a COPY ... FROM statement.""" |
| 588 | + |
| 589 | + effective_dialect = dialect or self.dialect |
| 590 | + return build_copy_from_statement(table, source, columns=columns, options=options, dialect=effective_dialect) |
| 591 | + |
| 592 | + def copy_to( |
| 593 | + self, |
| 594 | + table: str, |
| 595 | + target: str, |
| 596 | + *, |
| 597 | + columns: "Sequence[str] | None" = None, |
| 598 | + options: "Mapping[str, Any] | None" = None, |
| 599 | + dialect: DialectType | None = None, |
| 600 | + ) -> SQL: |
| 601 | + """Build a COPY ... TO statement.""" |
| 602 | + |
| 603 | + effective_dialect = dialect or self.dialect |
| 604 | + return build_copy_to_statement(table, target, columns=columns, options=options, dialect=effective_dialect) |
| 605 | + |
| 606 | + def copy( |
| 607 | + self, |
| 608 | + table: str, |
| 609 | + *, |
| 610 | + source: str | None = None, |
| 611 | + target: str | None = None, |
| 612 | + columns: "Sequence[str] | None" = None, |
| 613 | + options: "Mapping[str, Any] | None" = None, |
| 614 | + dialect: DialectType | None = None, |
| 615 | + ) -> SQL: |
| 616 | + """Build a COPY statement, inferring direction from provided arguments.""" |
| 617 | + |
| 618 | + if (source is None and target is None) or (source is not None and target is not None): |
| 619 | + msg = "Provide either 'source' or 'target' (but not both) to sql.copy()." |
| 620 | + raise SQLBuilderError(msg) |
| 621 | + |
| 622 | + if source is not None: |
| 623 | + return self.copy_from(table, source, columns=columns, options=options, dialect=dialect) |
| 624 | + |
| 625 | + target_value = cast("str", target) |
| 626 | + return self.copy_to(table, target_value, columns=columns, options=options, dialect=dialect) |
| 627 | + |
482 | 628 | @staticmethod |
483 | 629 | def _looks_like_sql(candidate: str, expected_type: str | None = None) -> bool: |
484 | 630 | """Determine if a string looks like SQL. |
|
0 commit comments