Skip to content

Commit 78aa210

Browse files
committed
add test files
1 parent 8a79ce7 commit 78aa210

File tree

1 file changed

+139
-0
lines changed

1 file changed

+139
-0
lines changed

extensions/positron-python/python_files/posit/positron/tests/test_connections.py

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,13 @@
5151
HAS_GOOGLE_BIGQUERY = False
5252

5353

54+
try:
55+
import redshift_connector
56+
57+
HAS_REDSHIFT = "REDSHIFT_HOST" in os.environ
58+
except ImportError:
59+
HAS_REDSHIFT = False
60+
5461
from positron.access_keys import encode_access_key
5562
from positron.connections import ConnectionsService
5663

@@ -828,3 +835,135 @@ def _view_in_connections_pane(self, variables_comm: DummyComm, path):
828835
assert variables_comm.messages == [json_rpc_response({})]
829836
variables_comm.messages.clear()
830837
return tuple(encoded_paths)
838+
839+
840+
@pytest.mark.skipif(not HAS_REDSHIFT, reason="Redshift not available")
841+
class TestRedshiftConnectionsService:
842+
REDSHIFT_HOST = os.environ.get("REDSHIFT_HOST")
843+
REDSHIFT_PROFILE = os.environ.get("REDSHIFT_PROFILE", "default")
844+
REDSHIFT_DATABASE = "dev"
845+
REDSHIFT_SCHEMA = "public"
846+
REDSHIFT_TABLE = "airlines"
847+
848+
def _connect(self):
849+
return redshift_connector.connect(
850+
iam=True,
851+
host=self.REDSHIFT_HOST,
852+
database=self.REDSHIFT_DATABASE,
853+
profile=self.REDSHIFT_PROFILE,
854+
)
855+
856+
def _open_comm(self, connections_service: ConnectionsService):
857+
con = self._connect()
858+
comm_id = connections_service.register_connection(con)
859+
dummy_comm = DummyComm(TARGET_NAME, comm_id=comm_id)
860+
connections_service.on_comm_open(dummy_comm)
861+
dummy_comm.messages.clear()
862+
return dummy_comm, comm_id
863+
864+
def _database_path(self):
865+
return [{"kind": "database", "name": self.REDSHIFT_DATABASE}]
866+
867+
def _schema_path(self):
868+
return [*self._database_path(), {"kind": "schema", "name": self.REDSHIFT_SCHEMA}]
869+
870+
def _table_path(self):
871+
return [*self._schema_path(), {"kind": "table", "name": self.REDSHIFT_TABLE}]
872+
873+
def _resolve_path(self, kind: str):
874+
if kind == "root":
875+
return []
876+
if kind == "database":
877+
return self._database_path()
878+
if kind == "schema":
879+
return self._schema_path()
880+
if kind == "table":
881+
return self._table_path()
882+
raise ValueError(f"Unknown path kind: {kind}")
883+
884+
def test_register_connection(self, connections_service: ConnectionsService):
885+
con = self._connect()
886+
comm_id = connections_service.register_connection(con)
887+
assert comm_id in connections_service.comms
888+
889+
@pytest.mark.parametrize(
890+
("path_kind"),
891+
[
892+
pytest.param("root", id="root"),
893+
pytest.param("database", id="database"),
894+
pytest.param("schema", id="schema"),
895+
pytest.param("table", id="table"),
896+
],
897+
)
898+
def test_contains_data(self, connections_service: ConnectionsService, path_kind: str):
899+
dummy_comm, comm_id = self._open_comm(connections_service)
900+
path = self._resolve_path(path_kind)
901+
902+
msg = _make_msg(params={"path": path}, method="contains_data", comm_id=comm_id)
903+
dummy_comm.handle_msg(msg)
904+
result = dummy_comm.messages[0]["data"]["result"]
905+
assert result == (path_kind == "table")
906+
907+
@pytest.mark.parametrize(
908+
("path_kind", "expected"),
909+
[
910+
pytest.param("root", "data:image", id="root"),
911+
pytest.param("database", "", id="database"),
912+
pytest.param("schema", "", id="schema"),
913+
pytest.param("table", "", id="table"),
914+
],
915+
)
916+
def test_get_icon(self, connections_service: ConnectionsService, path_kind: str, expected: str):
917+
dummy_comm, comm_id = self._open_comm(connections_service)
918+
path = self._resolve_path(path_kind)
919+
920+
msg = _make_msg(params={"path": path}, method="get_icon", comm_id=comm_id)
921+
dummy_comm.handle_msg(msg)
922+
result = dummy_comm.messages[0]["data"]["result"]
923+
if expected:
924+
assert expected in result
925+
else:
926+
assert result == ""
927+
928+
@pytest.mark.parametrize(
929+
"path_kind",
930+
[
931+
pytest.param("root", id="databases"),
932+
pytest.param("database", id="schemas"),
933+
pytest.param("schema", id="tables"),
934+
],
935+
)
936+
def test_list_objects(self, connections_service: ConnectionsService, path_kind: str):
937+
dummy_comm, comm_id = self._open_comm(connections_service)
938+
path = self._resolve_path(path_kind)
939+
expected = {
940+
"root": self.REDSHIFT_DATABASE,
941+
"database": self.REDSHIFT_SCHEMA,
942+
"schema": self.REDSHIFT_TABLE,
943+
}[path_kind]
944+
945+
msg = _make_msg(params={"path": path}, method="list_objects", comm_id=comm_id)
946+
dummy_comm.handle_msg(msg)
947+
result = dummy_comm.messages[0]["data"]["result"]
948+
names = [item["name"] for item in result]
949+
assert expected in names
950+
951+
def test_list_fields(self, connections_service: ConnectionsService):
952+
dummy_comm, comm_id = self._open_comm(connections_service)
953+
path = self._table_path()
954+
955+
msg = _make_msg(params={"path": path}, method="list_fields", comm_id=comm_id)
956+
dummy_comm.handle_msg(msg)
957+
result = dummy_comm.messages[0]["data"]["result"]
958+
field_names = {field["name"].lower() for field in result}
959+
assert {"carrier", "name"}.issubset(field_names)
960+
961+
def test_preview_object(self, connections_service: ConnectionsService):
962+
dummy_comm, comm_id = self._open_comm(connections_service)
963+
path = self._table_path()
964+
965+
msg = _make_msg(params={"path": path}, method="preview_object", comm_id=comm_id)
966+
dummy_comm.handle_msg(msg)
967+
connections_service._kernel.data_explorer_service.shutdown() # noqa: SLF001
968+
result = dummy_comm.messages[0]["data"]["result"]
969+
assert result is None

0 commit comments

Comments
 (0)