Skip to content

Commit 0c7dab6

Browse files
floschaMaxteabag
andauthored
Add AWS Athena adapter (#38)
* Add AWS Athena adapter * Add properties and add ephemeral infra tests --------- Co-authored-by: Peter Adams <18162810+Maxteabag@users.noreply.github.com>
1 parent 4831ded commit 0c7dab6

File tree

14 files changed

+1228
-2
lines changed

14 files changed

+1228
-2
lines changed

CONTRIBUTING.md

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,13 @@ The database tests can be configured with these environment variables:
102102
| `FIREBIRD_PORT` | `3050` | Firebird port |
103103
| `FIREBIRD_USER` | `testuser` | Firebird username |
104104
| `FIREBIRD_PASSWORD` | `TestPassword123!` | Firebird password |
105-
| `FIREBIRD_DATABASE` | `/var/lib/firebird/data/test_sqlit.fdb` | Firebird database path or alias |
105+
| `FIREBIRD_DATABASE` | `/var/lib/firebird/data/test_sqlit.fdb` | Firebird database path or alias |
106+
107+
**AWS Athena:**
108+
| Variable | Default | Description |
109+
|----------|---------|-------------|
110+
| `AWS_PROFILE` | `default` | AWS CLI profile to use (must be configured in `~/.aws/credentials`) |
111+
| `AWS_REGION` | `us-east-1` | AWS Region |
106112

107113
### CockroachDB Quickstart (Docker)
108114

README.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
---
2626

2727
### Connect
28-
Supports all major databases: SQL Server, PostgreSQL, MySQL, SQLite, MariaDB, FirebirdSQL, Oracle, DuckDB, CockroachDB, ClickHouse, Snowflake, Supabase, CloudFlare D1, and Turso.
28+
Supports all major databases: SQL Server, PostgreSQL, MySQL, SQLite, MariaDB, FirebirdSQL, Oracle, DuckDB, CockroachDB, ClickHouse, Snowflake, Supabase, CloudFlare D1, Turso, and Athena.
2929

3030
![Database Providers](demos/demo-providers.gif)
3131

@@ -138,6 +138,8 @@ sqlit connections add cockroachdb --name "MyCockroach" --server "localhost" --po
138138
sqlit connections add sqlite --name "MyLocalDB" --file-path "/path/to/database.db"
139139
sqlit connections add turso --name "MyTurso" --server "libsql://your-db.turso.io" --password "your-auth-token"
140140
sqlit connections add firebird --name "MyFirebird" --server "localhost" --username "user" --password "pass" --database "employee"
141+
sqlit connections add athena --name "MyAthena" --athena-region-name "us-east-1" --athena-s3-staging-dir "s3://my-bucket/results/" --athena-auth-method "profile" --athena-profile-name "default"
142+
sqlit connections add athena --name "MyAthenaKeys" --athena-region-name "us-east-1" --athena-s3-staging-dir "s3://my-bucket/results/" --athena-auth-method "keys" --username "ACCESS_KEY" --password "SECRET_KEY"
141143

142144
# Connect via SSH tunnel
143145
sqlit connections add postgresql --name "RemoteDB" --server "db-host" --username "dbuser" --password "dbpass" \
@@ -249,6 +251,7 @@ Most of the time you can just run `sqlit` and connect. If a Python driver is mis
249251
| Cloudflare D1 | `requests` | `pipx inject sqlit-tui requests` | `python -m pip install requests` |
250252
| Snowflake | `snowflake-connector-python` | `pipx inject sqlit-tui snowflake-connector-python` | `python -m pip install snowflake-connector-python` |
251253
| Firebird | `firebirdsql` | `pipx inject sqlit-tui firebirdsql` | `python -m pip install firebirdsql` |
254+
| Athena | `pyathena` | `pipx inject sqlit-tui pyathena` | `python -m pip install pyathena` |
252255

253256
### SSH Tunnel Support
254257

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ all = [
5050
"sshtunnel>=0.4.0",
5151
"paramiko>=2.0.0,<4.0.0",
5252
"snowflake-connector-python>=3.7.0",
53+
"pyathena>=3.22.0",
5354
]
5455
postgres = ["psycopg2-binary>=2.9.0"]
5556
cockroachdb = ["psycopg2-binary>=2.9.0"]
@@ -63,6 +64,7 @@ d1 = ["requests>=2.32.4"] # min avoids known CVEs
6364
turso = ["libsql>=0.1.0"]
6465
firebird = ["firebirdsql>=1.3.5"]
6566
snowflake = ["snowflake-connector-python>=3.7.0"]
67+
athena = ["pyathena>=3.22.0"]
6668
ssh = [
6769
"sshtunnel>=0.4.0",
6870
"paramiko>=2.0.0,<4.0.0",

sqlit/db/adapters/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
"SQLServerAdapter",
3232
"SupabaseAdapter",
3333
"TursoAdapter",
34+
"AthenaAdapter",
3435
# Factory helpers
3536
"get_adapter",
3637
"get_supported_adapter_db_types",
@@ -49,6 +50,7 @@
4950
from .sqlite import SQLiteAdapter
5051
from .supabase import SupabaseAdapter
5152
from .turso import TursoAdapter
53+
from .athena import AthenaAdapter
5254

5355

5456
def get_adapter(db_type: str) -> DatabaseAdapter:

sqlit/db/adapters/athena.py

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
"""AWS Athena adapter using pyathena."""
2+
3+
from __future__ import annotations
4+
5+
from typing import TYPE_CHECKING, Any
6+
7+
from .base import CursorBasedAdapter, import_driver_module, ColumnInfo, IndexInfo, TriggerInfo, SequenceInfo, TableInfo
8+
9+
if TYPE_CHECKING:
10+
from ...config import ConnectionConfig
11+
12+
13+
class AthenaAdapter(CursorBasedAdapter):
14+
"""Adapter for AWS Athena."""
15+
16+
@classmethod
17+
def badge_label(cls) -> str:
18+
return "Athena"
19+
20+
@property
21+
def name(self) -> str:
22+
return "Athena"
23+
24+
@property
25+
def install_extra(self) -> str:
26+
return "athena"
27+
28+
@property
29+
def install_package(self) -> str:
30+
return "pyathena"
31+
32+
@property
33+
def driver_import_names(self) -> tuple[str, ...]:
34+
return ("pyathena",)
35+
36+
@property
37+
def supports_multiple_databases(self) -> bool:
38+
return True
39+
40+
@property
41+
def supports_stored_procedures(self) -> bool:
42+
return False
43+
44+
@property
45+
def supports_triggers(self) -> bool:
46+
return False
47+
48+
@property
49+
def supports_indexes(self) -> bool:
50+
return False
51+
52+
@property
53+
def supports_cross_database_queries(self) -> bool:
54+
"""Athena supports cross-database queries using database.table syntax."""
55+
return True
56+
57+
@property
58+
def system_databases(self) -> frozenset[str]:
59+
"""Athena system databases to exclude from user listings."""
60+
return frozenset({"information_schema"})
61+
62+
@property
63+
def default_schema(self) -> str:
64+
return "default"
65+
66+
def connect(self, config: ConnectionConfig) -> Any:
67+
"""Connect to AWS Athena."""
68+
pyathena = import_driver_module(
69+
"pyathena",
70+
driver_name=self.name,
71+
extra_name=self.install_extra,
72+
package_name=self.install_package,
73+
)
74+
75+
auth_method = config.options.get("athena_auth_method", "profile")
76+
77+
connect_args = {
78+
"region_name": config.options.get("athena_region_name", "us-east-1"),
79+
"s3_staging_dir": config.options.get("athena_s3_staging_dir"),
80+
"schema_name": config.database or "default",
81+
}
82+
83+
if auth_method == "keys":
84+
connect_args["aws_access_key_id"] = config.username
85+
connect_args["aws_secret_access_key"] = config.password
86+
else:
87+
connect_args["profile_name"] = config.options.get("athena_profile_name", "default")
88+
89+
# Optional WorkGroup
90+
if "athena_work_group" in config.options:
91+
connect_args["work_group"] = config.options["athena_work_group"]
92+
93+
return pyathena.connect(**connect_args)
94+
95+
def get_databases(self, conn: Any) -> list[str]:
96+
"""Get list of databases (schemas in Athena)."""
97+
cursor = conn.cursor()
98+
cursor.execute("SHOW DATABASES")
99+
return [row[0] for row in cursor.fetchall()]
100+
101+
def get_tables(self, conn: Any, database: str | None = None) -> list[TableInfo]:
102+
"""Get list of tables."""
103+
cursor = conn.cursor()
104+
if database:
105+
cursor.execute(f"SHOW TABLES IN {database}")
106+
else:
107+
cursor.execute("SHOW TABLES")
108+
return [(database or self.default_schema, row[0]) for row in cursor.fetchall()]
109+
110+
def get_views(self, conn: Any, database: str | None = None) -> list[TableInfo]:
111+
"""Get list of views."""
112+
cursor = conn.cursor()
113+
if database:
114+
cursor.execute(f"SHOW VIEWS IN {database}")
115+
else:
116+
cursor.execute("SHOW VIEWS")
117+
return [(database or self.default_schema, row[0]) for row in cursor.fetchall()]
118+
119+
def get_columns(
120+
self, conn: Any, table: str, database: str | None = None, schema: str | None = None
121+
) -> list[ColumnInfo]:
122+
"""Get columns for a table or view."""
123+
cursor = conn.cursor()
124+
125+
target_db = database or schema or self.default_schema
126+
full_table = f"{target_db}.{table}"
127+
cursor.execute(f"DESCRIBE {full_table}")
128+
129+
columns = []
130+
rows = cursor.fetchall()
131+
for row in rows:
132+
# A table row element looks like this: ('col_name \tstring \tfrom deserializer ',)
133+
# A view row element looks like this: ('col_name \tstring ',)
134+
col_name, data_type = [e.strip() for e in row[0].split("\t")[:2]]
135+
columns.append(ColumnInfo(name=col_name, data_type=data_type, is_primary_key=False))
136+
137+
return columns
138+
139+
def quote_identifier(self, name: str) -> str:
140+
return name
141+
142+
def build_select_query(self, table: str, limit: int, database: str | None = None, schema: str | None = None) -> str:
143+
"""Build SELECT LIMIT query."""
144+
target_db = database or schema or self.default_schema
145+
return f"SELECT * FROM {target_db}.{table} LIMIT {limit}"
146+
147+
def get_procedures(self, conn: Any, database: str | None = None) -> list[str]:
148+
return []
149+
150+
def get_indexes(self, conn: Any, database: str | None = None) -> list[IndexInfo]:
151+
return []
152+
153+
def get_triggers(self, conn: Any, database: str | None = None) -> list[TriggerInfo]:
154+
return []
155+
156+
def get_sequences(self, conn: Any, database: str | None = None) -> list[SequenceInfo]:
157+
return []

sqlit/db/providers.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
SQLITE_SCHEMA,
3131
SUPABASE_SCHEMA,
3232
TURSO_SCHEMA,
33+
ATHENA_SCHEMA,
3334
ConnectionSchema,
3435
)
3536

@@ -101,6 +102,10 @@ class ProviderSpec:
101102
schema=SNOWFLAKE_SCHEMA,
102103
adapter_path=("sqlit.db.adapters.snowflake", "SnowflakeAdapter"),
103104
),
105+
"athena": ProviderSpec(
106+
schema=ATHENA_SCHEMA,
107+
adapter_path=("sqlit.db.adapters.athena", "AthenaAdapter"),
108+
),
104109
}
105110

106111

sqlit/db/schema.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -553,6 +553,71 @@ def _get_supabase_region_options() -> tuple[SelectOption, ...]:
553553
)
554554

555555

556+
ATHENA_SCHEMA = ConnectionSchema(
557+
db_type="athena",
558+
display_name="AWS Athena",
559+
fields=(
560+
SchemaField(
561+
name="athena_region_name",
562+
label="Region",
563+
required=True,
564+
default="us-east-1",
565+
),
566+
SchemaField(
567+
name="athena_work_group",
568+
label="WorkGroup",
569+
required=True,
570+
default="primary",
571+
description="Athena WorkGroup",
572+
),
573+
SchemaField(
574+
name="athena_s3_staging_dir",
575+
label="S3 Staging Dir",
576+
placeholder="s3://your-bucket/path/",
577+
required=True,
578+
description="S3 location for query results",
579+
),
580+
SchemaField(
581+
name="athena_auth_method",
582+
label="Auth Method",
583+
field_type=FieldType.SELECT,
584+
options=(
585+
SelectOption("profile", "AWS Profile"),
586+
SelectOption("keys", "Access Keys"),
587+
),
588+
default="profile",
589+
),
590+
SchemaField(
591+
name="athena_profile_name",
592+
label="Profile Name",
593+
placeholder="default",
594+
required=True,
595+
default="default",
596+
description="AWS CLI profile name",
597+
visible_when=lambda v: v.get("athena_auth_method") == "profile",
598+
),
599+
SchemaField(
600+
name="username",
601+
label="Access Key",
602+
placeholder="AWS Access Key ID",
603+
required=True,
604+
group="credentials",
605+
visible_when=lambda v: v.get("athena_auth_method") == "keys",
606+
),
607+
SchemaField(
608+
name="password",
609+
label="Secret Key",
610+
field_type=FieldType.PASSWORD,
611+
placeholder="AWS Secret Access Key",
612+
required=True,
613+
group="credentials",
614+
visible_when=lambda v: v.get("athena_auth_method") == "keys",
615+
),
616+
),
617+
supports_ssh=False,
618+
)
619+
620+
556621
def get_connection_schema(db_type: str) -> ConnectionSchema:
557622
from .providers import get_connection_schema as _get_connection_schema
558623

tests/athena/infra/.gitignore

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Terraform
2+
.terraform/
3+
.terraform.lock.hcl
4+
*.tfstate
5+
*.tfstate.*
6+
*.tfplan
7+
crash.log
8+
override.tf
9+
override.tf.json
10+
*_override.tf
11+
*_override.tf.json
12+
.terraformrc
13+
terraform.rc

tests/athena/infra/create_view.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
#!/usr/bin/env python3
2+
"""Create Athena view using boto3.
3+
4+
Usage:
5+
python create_view.py <database> <workgroup> <s3_staging_dir> <region>
6+
"""
7+
8+
import sys
9+
import time
10+
import boto3
11+
12+
13+
def create_view(database: str, workgroup: str, s3_staging_dir: str, region: str) -> None:
14+
"""Create test_view in Athena."""
15+
athena = boto3.client("athena", region_name=region)
16+
17+
query = f"CREATE OR REPLACE VIEW test_view AS SELECT * FROM {database}.test_hive_table"
18+
19+
# Start query
20+
response = athena.start_query_execution(
21+
QueryString=query,
22+
QueryExecutionContext={"Database": database},
23+
WorkGroup=workgroup,
24+
ResultConfiguration={"OutputLocation": s3_staging_dir},
25+
)
26+
execution_id = response["QueryExecutionId"]
27+
28+
# Wait for completion
29+
while True:
30+
result = athena.get_query_execution(QueryExecutionId=execution_id)
31+
state = result["QueryExecution"]["Status"]["State"]
32+
33+
if state == "SUCCEEDED":
34+
print(f"View created successfully (execution_id={execution_id})")
35+
return
36+
elif state in ("FAILED", "CANCELLED"):
37+
reason = result["QueryExecution"]["Status"].get("StateChangeReason", "Unknown")
38+
print(f"Failed to create view: {reason}", file=sys.stderr)
39+
sys.exit(1)
40+
41+
time.sleep(1)
42+
43+
44+
if __name__ == "__main__":
45+
if len(sys.argv) != 5:
46+
print(f"Usage: {sys.argv[0]} <database> <workgroup> <s3_staging_dir> <region>", file=sys.stderr)
47+
sys.exit(1)
48+
49+
create_view(sys.argv[1], sys.argv[2], sys.argv[3], sys.argv[4])

0 commit comments

Comments
 (0)