Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
File renamed without changes.
File renamed without changes.
8 changes: 8 additions & 0 deletions fixtures/only_uri/meltano.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
requires_meltano: ">=3.7"
default_environment: dev
project_id: 1ec76bd8-4499-4bad-a974-b27225d75f12
send_anonymous_usage_stats: false
environments:
- name: dev
state_backend:
uri: snowflake://test_user:test_password@my-account/test_database/test_schema?warehouse=test_warehouse&role=test_role
1 change: 0 additions & 1 deletion src/meltano_state_backend_snowflake/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@ class SnowflakeStateBackendError(MeltanoError):
label="Snowflake Schema",
description="Snowflake schema name",
kind=SettingKind.STRING, # ty: ignore[invalid-argument-type]
value="PUBLIC",
env_specific=True,
)

Expand Down
62 changes: 55 additions & 7 deletions tests/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from decimal import Decimal
from typing import TYPE_CHECKING
from unittest import mock
from urllib.parse import urlparse

import pytest
from cryptography.hazmat.primitives import serialization
Expand All @@ -27,11 +28,22 @@
def project(tmp_path: Path) -> Project:
path = tmp_path / "project"
shutil.copytree(
"fixtures/project",
"fixtures/explicit",
path,
ignore=shutil.ignore_patterns(".meltano/**"),
)
return Project.find(path.resolve()) # type: ignore[no-any-return]
return Project(path.resolve())


@pytest.fixture
def project_with_uri(tmp_path: Path) -> Project:
path = tmp_path / "project"
shutil.copytree(
"fixtures/only_uri",
path,
ignore=shutil.ignore_patterns(".meltano/**"),
)
return Project(path.resolve())


@pytest.fixture
Expand All @@ -53,7 +65,18 @@ def test_get_manager(project: Project) -> None:

mock_ensure_tables.assert_called_once()
assert isinstance(manager, SnowflakeStateStoreManager)
assert manager.uri == "snowflake://my-account"

parsed = urlparse(manager.uri)
assert parsed.scheme == "snowflake"
assert parsed.hostname == "my-account"

# Parameters are not included in the URI
assert not parsed.username
assert not parsed.password
assert not parsed.path
assert not parsed.query

# Parameters are passed explicitly to the manager
assert manager.account == "my-account"
assert manager.user == "test_user"
assert manager.password == "test_password" # noqa: S105
Expand All @@ -63,6 +86,30 @@ def test_get_manager(project: Project) -> None:
assert manager.role == "test_role"


def test_get_manager_from_uri(project_with_uri: Project) -> None:
with mock.patch(
"meltano_state_backend_snowflake.backend.SnowflakeStateStoreManager._ensure_tables",
) as mock_ensure_tables:
manager = state_store_manager_from_project_settings(project_with_uri.settings)

mock_ensure_tables.assert_called_once()
assert isinstance(manager, SnowflakeStateStoreManager)

parsed = urlparse(manager.uri)
assert parsed.scheme == "snowflake"
assert parsed.hostname == "my-account"

# Parameters are only included in the URI
assert parsed.username == manager.user == "test_user"
assert parsed.password == manager.password == "test_password" # noqa: S105
assert parsed.path == "/test_database/test_schema"
assert manager.database == "test_database"
assert manager.schema == "test_schema"
assert parsed.query == "warehouse=test_warehouse&role=test_role"
assert manager.warehouse == "test_warehouse"
assert manager.role == "test_role"


@pytest.mark.parametrize(
("setting_name", "env_var_name"),
(
Expand All @@ -86,6 +133,11 @@ def test_get_manager(project: Project) -> None:
"MELTANO_STATE_BACKEND_SNOWFLAKE_ROLE",
id="role",
),
pytest.param(
"state_backend.snowflake.private_key_base64",
"MELTANO_STATE_BACKEND_SNOWFLAKE_PRIVATE_KEY_BASE64",
id="private_key_base64",
),
),
)
def test_settings(project: Project, setting_name: str, env_var_name: str) -> None:
Expand Down Expand Up @@ -393,10 +445,6 @@ def test_acquire_lock_retry(
assert mock_cursor.execute.call_count >= 2


# class TestURIQueryParams:
# """Tests for URI query parameter parsing."""


@pytest.fixture
def base_uri() -> str:
return "snowflake://myuser:mypass@myaccount/mydb/myschema?warehouse=mywh"
Expand Down