Skip to content

Commit db8c4ad

Browse files
committed
Athena refactor
1 parent 0c7dab6 commit db8c4ad

File tree

1 file changed

+41
-25
lines changed

1 file changed

+41
-25
lines changed

sqlit/db/schema.py

Lines changed: 41 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -439,7 +439,8 @@ def _get_oracle_role_options() -> tuple[SelectOption, ...]:
439439
)
440440

441441

442-
def _get_supabase_region_options() -> tuple[SelectOption, ...]:
442+
def _get_aws_region_options() -> tuple[SelectOption, ...]:
443+
"""AWS regions shared by Supabase, Athena, and other AWS-based services."""
443444
regions = (
444445
"us-east-1",
445446
"us-east-2",
@@ -470,7 +471,7 @@ def _get_supabase_region_options() -> tuple[SelectOption, ...]:
470471
name="supabase_region",
471472
label="Region",
472473
field_type=FieldType.DROPDOWN,
473-
options=_get_supabase_region_options(),
474+
options=_get_aws_region_options(),
474475
required=True,
475476
default="us-east-1",
476477
),
@@ -553,38 +554,38 @@ def _get_supabase_region_options() -> tuple[SelectOption, ...]:
553554
)
554555

555556

557+
def _get_athena_auth_options() -> tuple[SelectOption, ...]:
558+
return (
559+
SelectOption("profile", "AWS Profile"),
560+
SelectOption("keys", "Access Keys"),
561+
)
562+
563+
564+
def _athena_auth_is_profile(v: dict) -> bool:
565+
return v.get("athena_auth_method", "profile") == "profile"
566+
567+
568+
def _athena_auth_is_keys(v: dict) -> bool:
569+
return v.get("athena_auth_method") == "keys"
570+
571+
556572
ATHENA_SCHEMA = ConnectionSchema(
557573
db_type="athena",
558574
display_name="AWS Athena",
559575
fields=(
560576
SchemaField(
561577
name="athena_region_name",
562578
label="Region",
579+
field_type=FieldType.DROPDOWN,
580+
options=_get_aws_region_options(),
563581
required=True,
564582
default="us-east-1",
565583
),
566-
SchemaField(
567-
name="athena_work_group",
568-
label="WorkGroup",
569-
required=True,
570-
default="primary",
571-
description="Athena WorkGroup",
572-
),
573-
SchemaField(
574-
name="athena_s3_staging_dir",
575-
label="S3 Staging Dir",
576-
placeholder="s3://your-bucket/path/",
577-
required=True,
578-
description="S3 location for query results",
579-
),
580584
SchemaField(
581585
name="athena_auth_method",
582-
label="Auth Method",
583-
field_type=FieldType.SELECT,
584-
options=(
585-
SelectOption("profile", "AWS Profile"),
586-
SelectOption("keys", "Access Keys"),
587-
),
586+
label="Authentication",
587+
field_type=FieldType.DROPDOWN,
588+
options=_get_athena_auth_options(),
588589
default="profile",
589590
),
590591
SchemaField(
@@ -594,15 +595,15 @@ def _get_supabase_region_options() -> tuple[SelectOption, ...]:
594595
required=True,
595596
default="default",
596597
description="AWS CLI profile name",
597-
visible_when=lambda v: v.get("athena_auth_method") == "profile",
598+
visible_when=_athena_auth_is_profile,
598599
),
599600
SchemaField(
600601
name="username",
601602
label="Access Key",
602603
placeholder="AWS Access Key ID",
603604
required=True,
604605
group="credentials",
605-
visible_when=lambda v: v.get("athena_auth_method") == "keys",
606+
visible_when=_athena_auth_is_keys,
606607
),
607608
SchemaField(
608609
name="password",
@@ -611,10 +612,25 @@ def _get_supabase_region_options() -> tuple[SelectOption, ...]:
611612
placeholder="AWS Secret Access Key",
612613
required=True,
613614
group="credentials",
614-
visible_when=lambda v: v.get("athena_auth_method") == "keys",
615+
visible_when=_athena_auth_is_keys,
616+
),
617+
SchemaField(
618+
name="athena_work_group",
619+
label="WorkGroup",
620+
required=True,
621+
default="primary",
622+
description="Athena WorkGroup",
623+
),
624+
SchemaField(
625+
name="athena_s3_staging_dir",
626+
label="S3 Staging Dir",
627+
placeholder="s3://your-bucket/path/",
628+
required=True,
629+
description="S3 location for query results",
615630
),
616631
),
617632
supports_ssh=False,
633+
has_advanced_auth=True,
618634
)
619635

620636

0 commit comments

Comments
 (0)