11import json
22import os
3- import tempfile
43import uuid
54from contextlib import contextmanager
65from dataclasses import dataclass
76from pathlib import Path
7+ from unittest import mock
88
99import pytest
1010from databricks .sdk import WorkspaceClient
3131
3232
3333@dataclass
34- class EnvData :
34+ class BaseEnvData :
3535 host : str
36+ catalog : str
37+
38+
39+ @dataclass
40+ class BasicAuthEnvData (BaseEnvData ):
3641 client_id : str
3742 client_secret : str
38- catalog : str
3943
4044 def get_connection_config (self ) -> DatabricksNativeVolumesConnectionConfig :
4145 return DatabricksNativeVolumesConnectionConfig (
@@ -47,32 +51,52 @@ def get_connection_config(self) -> DatabricksNativeVolumesConnectionConfig:
4751 )
4852
4953
50- def get_env_data () -> EnvData :
51- return EnvData (
54+ @dataclass
55+ class PATEnvData (BaseEnvData ):
56+ token : str
57+
58+ def get_connection_config (self ) -> DatabricksNativeVolumesConnectionConfig :
59+ return DatabricksNativeVolumesConnectionConfig (
60+ host = self .host ,
61+ access_config = DatabricksNativeVolumesAccessConfig (
62+ token = self .token ,
63+ ),
64+ )
65+
66+
67+ def get_basic_auth_env_data () -> BasicAuthEnvData :
68+ return BasicAuthEnvData (
5269 host = os .environ ["DATABRICKS_HOST" ],
5370 client_id = os .environ ["DATABRICKS_CLIENT_ID" ],
5471 client_secret = os .environ ["DATABRICKS_CLIENT_SECRET" ],
5572 catalog = os .environ ["DATABRICKS_CATALOG" ],
5673 )
5774
5875
76+ def get_pat_env_data () -> PATEnvData :
77+ return PATEnvData (
78+ host = os .environ ["DATABRICKS_HOST" ],
79+ catalog = os .environ ["DATABRICKS_CATALOG" ],
80+ token = os .environ ["DATABRICKS_PAT" ],
81+ )
82+
83+
5984@pytest .mark .asyncio
6085@pytest .mark .tags (CONNECTOR_TYPE , SOURCE_TAG )
6186@requires_env (
6287 "DATABRICKS_HOST" , "DATABRICKS_CLIENT_ID" , "DATABRICKS_CLIENT_SECRET" , "DATABRICKS_CATALOG"
6388)
64- async def test_volumes_native_source ():
65- env_data = get_env_data ()
66- indexer_config = DatabricksNativeVolumesIndexerConfig (
67- recursive = True ,
68- volume = "test-platform" ,
69- volume_path = "databricks-volumes-test-input" ,
70- catalog = env_data .catalog ,
71- )
72- connection_config = env_data .get_connection_config ()
73- with tempfile .TemporaryDirectory () as tempdir :
74- tempdir_path = Path (tempdir )
75- download_config = DatabricksNativeVolumesDownloaderConfig (download_dir = tempdir_path )
89+ async def test_volumes_native_source (tmp_path : Path ):
90+ env_data = get_basic_auth_env_data ()
91+ with mock .patch .dict (os .environ , clear = True ):
92+ indexer_config = DatabricksNativeVolumesIndexerConfig (
93+ recursive = True ,
94+ volume = "test-platform" ,
95+ volume_path = "databricks-volumes-test-input" ,
96+ catalog = env_data .catalog ,
97+ )
98+ connection_config = env_data .get_connection_config ()
99+ download_config = DatabricksNativeVolumesDownloaderConfig (download_dir = tmp_path )
76100 indexer = DatabricksNativeVolumesIndexer (
77101 connection_config = connection_config , index_config = indexer_config
78102 )
@@ -89,12 +113,44 @@ async def test_volumes_native_source():
89113 )
90114
91115
116+ @pytest .mark .asyncio
117+ @pytest .mark .tags (CONNECTOR_TYPE , SOURCE_TAG )
118+ @requires_env ("DATABRICKS_HOST" , "DATABRICKS_PAT" , "DATABRICKS_CATALOG" )
119+ async def test_volumes_native_source_pat (tmp_path : Path ):
120+ env_data = get_pat_env_data ()
121+ with mock .patch .dict (os .environ , clear = True ):
122+ indexer_config = DatabricksNativeVolumesIndexerConfig (
123+ recursive = True ,
124+ volume = "test-platform" ,
125+ volume_path = "databricks-volumes-test-input" ,
126+ catalog = env_data .catalog ,
127+ )
128+ connection_config = env_data .get_connection_config ()
129+ download_config = DatabricksNativeVolumesDownloaderConfig (download_dir = tmp_path )
130+ indexer = DatabricksNativeVolumesIndexer (
131+ connection_config = connection_config , index_config = indexer_config
132+ )
133+ downloader = DatabricksNativeVolumesDownloader (
134+ connection_config = connection_config , download_config = download_config
135+ )
136+ await source_connector_validation (
137+ indexer = indexer ,
138+ downloader = downloader ,
139+ configs = SourceValidationConfigs (
140+ test_id = "databricks_volumes_native_pat" ,
141+ expected_num_files = 1 ,
142+ ),
143+ )
144+
145+
92146def _get_volume_path (catalog : str , volume : str , volume_path : str ):
93147 return f"/Volumes/{ catalog } /default/{ volume } /{ volume_path } "
94148
95149
96150@contextmanager
97- def databricks_destination_context (env_data : EnvData , volume : str , volume_path ) -> WorkspaceClient :
151+ def databricks_destination_context (
152+ env_data : BasicAuthEnvData , volume : str , volume_path
153+ ) -> WorkspaceClient :
98154 client = WorkspaceClient (
99155 host = env_data .host , client_id = env_data .client_id , client_secret = env_data .client_secret
100156 )
@@ -137,7 +193,7 @@ def validate_upload(client: WorkspaceClient, catalog: str, volume: str, volume_p
137193 "DATABRICKS_HOST" , "DATABRICKS_CLIENT_ID" , "DATABRICKS_CLIENT_SECRET" , "DATABRICKS_CATALOG"
138194)
139195async def test_volumes_native_destination (upload_file : Path ):
140- env_data = get_env_data ()
196+ env_data = get_basic_auth_env_data ()
141197 volume_path = f"databricks-volumes-test-output-{ uuid .uuid4 ()} "
142198 file_data = FileData (
143199 source_identifiers = SourceIdentifiers (fullpath = upload_file .name , filename = upload_file .name ),
0 commit comments