|
51 | 51 | HAS_GOOGLE_BIGQUERY = False |
52 | 52 |
|
53 | 53 |
|
| 54 | +try: |
| 55 | + import redshift_connector |
| 56 | + |
| 57 | + HAS_REDSHIFT = "REDSHIFT_HOST" in os.environ |
| 58 | +except ImportError: |
| 59 | + HAS_REDSHIFT = False |
| 60 | + |
54 | 61 | from positron.access_keys import encode_access_key |
55 | 62 | from positron.connections import ConnectionsService |
56 | 63 |
|
@@ -828,3 +835,135 @@ def _view_in_connections_pane(self, variables_comm: DummyComm, path): |
828 | 835 | assert variables_comm.messages == [json_rpc_response({})] |
829 | 836 | variables_comm.messages.clear() |
830 | 837 | return tuple(encoded_paths) |
| 838 | + |
| 839 | + |
| 840 | +@pytest.mark.skipif(not HAS_REDSHIFT, reason="Redshift not available") |
| 841 | +class TestRedshiftConnectionsService: |
| 842 | + REDSHIFT_HOST = os.environ.get("REDSHIFT_HOST") |
| 843 | + REDSHIFT_PROFILE = os.environ.get("REDSHIFT_PROFILE", "default") |
| 844 | + REDSHIFT_DATABASE = "dev" |
| 845 | + REDSHIFT_SCHEMA = "public" |
| 846 | + REDSHIFT_TABLE = "airlines" |
| 847 | + |
| 848 | + def _connect(self): |
| 849 | + return redshift_connector.connect( |
| 850 | + iam=True, |
| 851 | + host=self.REDSHIFT_HOST, |
| 852 | + database=self.REDSHIFT_DATABASE, |
| 853 | + profile=self.REDSHIFT_PROFILE, |
| 854 | + ) |
| 855 | + |
| 856 | + def _open_comm(self, connections_service: ConnectionsService): |
| 857 | + con = self._connect() |
| 858 | + comm_id = connections_service.register_connection(con) |
| 859 | + dummy_comm = DummyComm(TARGET_NAME, comm_id=comm_id) |
| 860 | + connections_service.on_comm_open(dummy_comm) |
| 861 | + dummy_comm.messages.clear() |
| 862 | + return dummy_comm, comm_id |
| 863 | + |
| 864 | + def _database_path(self): |
| 865 | + return [{"kind": "database", "name": self.REDSHIFT_DATABASE}] |
| 866 | + |
| 867 | + def _schema_path(self): |
| 868 | + return [*self._database_path(), {"kind": "schema", "name": self.REDSHIFT_SCHEMA}] |
| 869 | + |
| 870 | + def _table_path(self): |
| 871 | + return [*self._schema_path(), {"kind": "table", "name": self.REDSHIFT_TABLE}] |
| 872 | + |
| 873 | + def _resolve_path(self, kind: str): |
| 874 | + if kind == "root": |
| 875 | + return [] |
| 876 | + if kind == "database": |
| 877 | + return self._database_path() |
| 878 | + if kind == "schema": |
| 879 | + return self._schema_path() |
| 880 | + if kind == "table": |
| 881 | + return self._table_path() |
| 882 | + raise ValueError(f"Unknown path kind: {kind}") |
| 883 | + |
| 884 | + def test_register_connection(self, connections_service: ConnectionsService): |
| 885 | + con = self._connect() |
| 886 | + comm_id = connections_service.register_connection(con) |
| 887 | + assert comm_id in connections_service.comms |
| 888 | + |
| 889 | + @pytest.mark.parametrize( |
| 890 | + ("path_kind"), |
| 891 | + [ |
| 892 | + pytest.param("root", id="root"), |
| 893 | + pytest.param("database", id="database"), |
| 894 | + pytest.param("schema", id="schema"), |
| 895 | + pytest.param("table", id="table"), |
| 896 | + ], |
| 897 | + ) |
| 898 | + def test_contains_data(self, connections_service: ConnectionsService, path_kind: str): |
| 899 | + dummy_comm, comm_id = self._open_comm(connections_service) |
| 900 | + path = self._resolve_path(path_kind) |
| 901 | + |
| 902 | + msg = _make_msg(params={"path": path}, method="contains_data", comm_id=comm_id) |
| 903 | + dummy_comm.handle_msg(msg) |
| 904 | + result = dummy_comm.messages[0]["data"]["result"] |
| 905 | + assert result == (path_kind == "table") |
| 906 | + |
| 907 | + @pytest.mark.parametrize( |
| 908 | + ("path_kind", "expected"), |
| 909 | + [ |
| 910 | + pytest.param("root", "data:image", id="root"), |
| 911 | + pytest.param("database", "", id="database"), |
| 912 | + pytest.param("schema", "", id="schema"), |
| 913 | + pytest.param("table", "", id="table"), |
| 914 | + ], |
| 915 | + ) |
| 916 | + def test_get_icon(self, connections_service: ConnectionsService, path_kind: str, expected: str): |
| 917 | + dummy_comm, comm_id = self._open_comm(connections_service) |
| 918 | + path = self._resolve_path(path_kind) |
| 919 | + |
| 920 | + msg = _make_msg(params={"path": path}, method="get_icon", comm_id=comm_id) |
| 921 | + dummy_comm.handle_msg(msg) |
| 922 | + result = dummy_comm.messages[0]["data"]["result"] |
| 923 | + if expected: |
| 924 | + assert expected in result |
| 925 | + else: |
| 926 | + assert result == "" |
| 927 | + |
| 928 | + @pytest.mark.parametrize( |
| 929 | + "path_kind", |
| 930 | + [ |
| 931 | + pytest.param("root", id="databases"), |
| 932 | + pytest.param("database", id="schemas"), |
| 933 | + pytest.param("schema", id="tables"), |
| 934 | + ], |
| 935 | + ) |
| 936 | + def test_list_objects(self, connections_service: ConnectionsService, path_kind: str): |
| 937 | + dummy_comm, comm_id = self._open_comm(connections_service) |
| 938 | + path = self._resolve_path(path_kind) |
| 939 | + expected = { |
| 940 | + "root": self.REDSHIFT_DATABASE, |
| 941 | + "database": self.REDSHIFT_SCHEMA, |
| 942 | + "schema": self.REDSHIFT_TABLE, |
| 943 | + }[path_kind] |
| 944 | + |
| 945 | + msg = _make_msg(params={"path": path}, method="list_objects", comm_id=comm_id) |
| 946 | + dummy_comm.handle_msg(msg) |
| 947 | + result = dummy_comm.messages[0]["data"]["result"] |
| 948 | + names = [item["name"] for item in result] |
| 949 | + assert expected in names |
| 950 | + |
| 951 | + def test_list_fields(self, connections_service: ConnectionsService): |
| 952 | + dummy_comm, comm_id = self._open_comm(connections_service) |
| 953 | + path = self._table_path() |
| 954 | + |
| 955 | + msg = _make_msg(params={"path": path}, method="list_fields", comm_id=comm_id) |
| 956 | + dummy_comm.handle_msg(msg) |
| 957 | + result = dummy_comm.messages[0]["data"]["result"] |
| 958 | + field_names = {field["name"].lower() for field in result} |
| 959 | + assert {"carrier", "name"}.issubset(field_names) |
| 960 | + |
| 961 | + def test_preview_object(self, connections_service: ConnectionsService): |
| 962 | + dummy_comm, comm_id = self._open_comm(connections_service) |
| 963 | + path = self._table_path() |
| 964 | + |
| 965 | + msg = _make_msg(params={"path": path}, method="preview_object", comm_id=comm_id) |
| 966 | + dummy_comm.handle_msg(msg) |
| 967 | + connections_service._kernel.data_explorer_service.shutdown() # noqa: SLF001 |
| 968 | + result = dummy_comm.messages[0]["data"]["result"] |
| 969 | + assert result is None |
0 commit comments