diff --git a/metadata-ingestion/src/datahub/ingestion/source/unity/azure_auth_config.py b/metadata-ingestion/src/datahub/ingestion/source/unity/azure_auth_config.py new file mode 100644 index 0000000000000..ecc476925907a --- /dev/null +++ b/metadata-ingestion/src/datahub/ingestion/source/unity/azure_auth_config.py @@ -0,0 +1,13 @@ +from pydantic import Field + +from datahub.configuration import ConfigModel + + +class AzureAuthConfig(ConfigModel): + client_secret: str = Field(description="Azure client secret") + client_id: str = Field( + description="Azure client (Application) ID", + ) + tenant_id: str = Field( + description="Azure tenant (Directory) ID required when a `client_secret` is used as a credential.", + ) diff --git a/metadata-ingestion/src/datahub/ingestion/source/unity/connection.py b/metadata-ingestion/src/datahub/ingestion/source/unity/connection.py index 3f4c43c5e2cf6..8540332247b77 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/unity/connection.py +++ b/metadata-ingestion/src/datahub/ingestion/source/unity/connection.py @@ -8,6 +8,7 @@ from datahub.configuration.common import ConfigModel from datahub.ingestion.source.sql.sqlalchemy_uri import make_sqlalchemy_uri +from datahub.ingestion.source.unity.azure_auth_config import AzureAuthConfig DATABRICKS = "databricks" @@ -19,7 +20,12 @@ class UnityCatalogConnectionConfig(ConfigModel): """ scheme: str = DATABRICKS - token: str = pydantic.Field(description="Databricks personal access token") + token: Optional[str] = pydantic.Field( + default=None, description="Databricks personal access token" + ) + azure_auth: Optional[AzureAuthConfig] = Field( + default=None, description="Azure configuration" + ) workspace_url: str = pydantic.Field( description="Databricks workspace url. e.g. https://my-workspace.cloud.databricks.com" ) diff --git a/metadata-ingestion/src/datahub/ingestion/source/unity/connection_test.py b/metadata-ingestion/src/datahub/ingestion/source/unity/connection_test.py index 915a2f0601251..21b8a00f5c07b 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/unity/connection_test.py +++ b/metadata-ingestion/src/datahub/ingestion/source/unity/connection_test.py @@ -16,10 +16,10 @@ def __init__(self, config: UnityCatalogSourceConfig): self.report = UnityCatalogReport() self.proxy = UnityCatalogApiProxy( self.config.workspace_url, - self.config.token, self.config.profiling.warehouse_id, report=self.report, databricks_api_page_size=self.config.databricks_api_page_size, + personal_access_token=self.config.token, ) def get_connection_test(self) -> TestConnectionReport: diff --git a/metadata-ingestion/src/datahub/ingestion/source/unity/proxy.py b/metadata-ingestion/src/datahub/ingestion/source/unity/proxy.py index eb8ad5302b746..fe158bde7827f 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/unity/proxy.py +++ b/metadata-ingestion/src/datahub/ingestion/source/unity/proxy.py @@ -38,6 +38,7 @@ from datahub._version import nice_version_name from datahub.api.entities.external.unity_catalog_external_entites import UnityCatalogTag from datahub.emitter.mce_builder import parse_ts_millis +from datahub.ingestion.source.unity.azure_auth_config import AzureAuthConfig from datahub.ingestion.source.unity.config import ( LineageDataSource, UsageDataSource, @@ -159,20 +160,31 @@ class UnityCatalogApiProxy(UnityCatalogProxyProfilingMixin): def __init__( self, workspace_url: str, - personal_access_token: str, warehouse_id: Optional[str], report: UnityCatalogReport, hive_metastore_proxy: Optional[HiveMetastoreProxy] = None, lineage_data_source: LineageDataSource = LineageDataSource.AUTO, usage_data_source: UsageDataSource = UsageDataSource.AUTO, databricks_api_page_size: int = 0, + personal_access_token: Optional[str] = None, + azure_auth: Optional[AzureAuthConfig] = None, ): - self._workspace_client = WorkspaceClient( - host=workspace_url, - token=personal_access_token, - product="datahub", - product_version=nice_version_name(), - ) + if azure_auth: + self._workspace_client = WorkspaceClient( + host=workspace_url, + azure_tenant_id=azure_auth.tenant_id, + azure_client_id=azure_auth.client_id, + azure_client_secret=azure_auth.client_secret, + product="datahub", + product_version=nice_version_name(), + ) + else: + self._workspace_client = WorkspaceClient( + host=workspace_url, + token=personal_access_token, + product="datahub", + product_version=nice_version_name(), + ) self.warehouse_id = warehouse_id or "" self.report = report self.hive_metastore_proxy = hive_metastore_proxy diff --git a/metadata-ingestion/src/datahub/ingestion/source/unity/source.py b/metadata-ingestion/src/datahub/ingestion/source/unity/source.py index 08ddeba4e3769..7f9eaea050d94 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/unity/source.py +++ b/metadata-ingestion/src/datahub/ingestion/source/unity/source.py @@ -205,17 +205,24 @@ def __init__(self, ctx: PipelineContext, config: UnityCatalogSourceConfig): self.config = config self.report: UnityCatalogReport = UnityCatalogReport() + # Validate that either azure_auth or personal_access_token is provided + if not (config.azure_auth or config.token): + raise ValueError( + "Either azure_auth or personal_access_token must be provided in the configuration." + ) + self.init_hive_metastore_proxy() self.unity_catalog_api_proxy = UnityCatalogApiProxy( config.workspace_url, - config.token, config.warehouse_id, report=self.report, hive_metastore_proxy=self.hive_metastore_proxy, lineage_data_source=config.lineage_data_source, usage_data_source=config.usage_data_source, databricks_api_page_size=config.databricks_api_page_size, + personal_access_token=config.token if config.token else None, + azure_auth=config.azure_auth if config.azure_auth else None, ) self.external_url_base = urljoin(self.config.workspace_url, "/explore/data") diff --git a/metadata-ingestion/tests/unit/test_unity_catalog_source.py b/metadata-ingestion/tests/unit/test_unity_catalog_source.py index 2b8de007e9acd..a79ed6305a69d 100644 --- a/metadata-ingestion/tests/unit/test_unity_catalog_source.py +++ b/metadata-ingestion/tests/unit/test_unity_catalog_source.py @@ -49,6 +49,42 @@ def config_with_ml_model_settings(self): } ) + @pytest.fixture + def config_with_azure_auth(self): + """Create a config with Azure authentication.""" + return UnityCatalogSourceConfig.parse_obj( + { + "workspace_url": "https://test.databricks.com", + "warehouse_id": "test_warehouse", + "include_hive_metastore": False, + "databricks_api_page_size": 150, + "azure_auth": { + "client_id": "test-client-id-12345", + "tenant_id": "test-tenant-id-67890", + "client_secret": "test-client-secret", + }, + } + ) + + @pytest.fixture + def config_with_azure_auth_and_ml_models(self): + """Create a config with Azure authentication and ML model settings.""" + return UnityCatalogSourceConfig.parse_obj( + { + "workspace_url": "https://test.databricks.com", + "warehouse_id": "test_warehouse", + "include_hive_metastore": False, + "include_ml_model_aliases": True, + "ml_model_max_results": 1000, + "databricks_api_page_size": 200, + "azure_auth": { + "client_id": "azure-client-id-789", + "tenant_id": "azure-tenant-id-123", + "client_secret": "azure-secret-456", + }, + } + ) + @patch("datahub.ingestion.source.unity.source.UnityCatalogApiProxy") @patch("datahub.ingestion.source.unity.source.HiveMetastoreProxy") def test_source_constructor_passes_default_page_size_to_proxy( @@ -62,13 +98,14 @@ def test_source_constructor_passes_default_page_size_to_proxy( # Verify proxy was created with correct parameters including page size mock_unity_proxy.assert_called_once_with( minimal_config.workspace_url, - minimal_config.token, minimal_config.warehouse_id, report=source.report, hive_metastore_proxy=source.hive_metastore_proxy, lineage_data_source=minimal_config.lineage_data_source, usage_data_source=minimal_config.usage_data_source, databricks_api_page_size=0, # Default value + personal_access_token=minimal_config.token, + azure_auth=None, ) @patch("datahub.ingestion.source.unity.source.UnityCatalogApiProxy") @@ -83,13 +120,14 @@ def test_source_constructor_passes_custom_page_size_to_proxy( # Verify proxy was created with correct parameters including custom page size mock_unity_proxy.assert_called_once_with( config_with_page_size.workspace_url, - config_with_page_size.token, config_with_page_size.warehouse_id, report=source.report, hive_metastore_proxy=source.hive_metastore_proxy, lineage_data_source=config_with_page_size.lineage_data_source, usage_data_source=config_with_page_size.usage_data_source, databricks_api_page_size=75, # Custom value + personal_access_token=config_with_page_size.token, + azure_auth=None, ) @patch("datahub.ingestion.source.unity.source.UnityCatalogApiProxy") @@ -126,13 +164,14 @@ def test_source_with_hive_metastore_disabled( # Verify proxy was created with correct page size even when hive metastore is disabled mock_unity_proxy.assert_called_once_with( config.workspace_url, - config.token, config.warehouse_id, report=source.report, hive_metastore_proxy=None, # Should be None when disabled lineage_data_source=config.lineage_data_source, usage_data_source=config.usage_data_source, databricks_api_page_size=200, + personal_access_token=config.token, + azure_auth=None, ) def test_test_connection_with_page_size_config(self): @@ -225,6 +264,154 @@ def test_test_connection_with_ml_model_configs(self): assert connection_test_config.ml_model_max_results == 750 assert connection_test_config.databricks_api_page_size == 200 + @patch("datahub.ingestion.source.unity.source.UnityCatalogApiProxy") + @patch("datahub.ingestion.source.unity.source.HiveMetastoreProxy") + def test_source_constructor_with_azure_auth( + self, mock_hive_proxy, mock_unity_proxy, config_with_azure_auth + ): + """Test that UnityCatalogSource passes Azure auth config to proxy.""" + ctx = PipelineContext(run_id="test_run") + source = UnityCatalogSource.create(config_with_azure_auth, ctx) + + # Verify proxy was created with Azure auth config + mock_unity_proxy.assert_called_once_with( + config_with_azure_auth.workspace_url, + config_with_azure_auth.warehouse_id, + report=source.report, + hive_metastore_proxy=source.hive_metastore_proxy, + lineage_data_source=config_with_azure_auth.lineage_data_source, + usage_data_source=config_with_azure_auth.usage_data_source, + databricks_api_page_size=150, + personal_access_token=None, # Should be None when using Azure auth + azure_auth=config_with_azure_auth.azure_auth, + ) + + @patch("datahub.ingestion.source.unity.source.UnityCatalogApiProxy") + @patch("datahub.ingestion.source.unity.source.HiveMetastoreProxy") + def test_source_constructor_azure_auth_with_ml_models( + self, mock_hive_proxy, mock_unity_proxy, config_with_azure_auth_and_ml_models + ): + """Test that UnityCatalogSource with Azure auth and ML model settings works correctly.""" + ctx = PipelineContext(run_id="test_run") + source = UnityCatalogSource.create(config_with_azure_auth_and_ml_models, ctx) + + # Verify proxy was created with correct Azure auth and ML model configs + mock_unity_proxy.assert_called_once_with( + config_with_azure_auth_and_ml_models.workspace_url, + config_with_azure_auth_and_ml_models.warehouse_id, + report=source.report, + hive_metastore_proxy=source.hive_metastore_proxy, + lineage_data_source=config_with_azure_auth_and_ml_models.lineage_data_source, + usage_data_source=config_with_azure_auth_and_ml_models.usage_data_source, + databricks_api_page_size=200, + personal_access_token=None, + azure_auth=config_with_azure_auth_and_ml_models.azure_auth, + ) + + # Verify ML model settings are properly configured + assert source.config.include_ml_model_aliases is True + assert source.config.ml_model_max_results == 1000 + + def test_azure_auth_config_validation(self): + """Test that Azure auth config validates required fields.""" + # Test valid Azure auth config + valid_config_dict = { + "workspace_url": "https://test.databricks.com", + "warehouse_id": "test_warehouse", + "azure_auth": { + "client_id": "test-client-id", + "tenant_id": "test-tenant-id", + "client_secret": "test-secret", + }, + } + + config = UnityCatalogSourceConfig.parse_obj(valid_config_dict) + assert config.azure_auth is not None + assert config.azure_auth.client_id == "test-client-id" + assert config.azure_auth.tenant_id == "test-tenant-id" + assert config.azure_auth.client_secret == "test-secret" + + # Test that personal access token is not required when Azure auth is provided + assert config.token is None + + def test_test_connection_with_azure_auth(self): + """Test that test_connection properly handles Azure authentication.""" + config_dict = { + "workspace_url": "https://test.databricks.com", + "warehouse_id": "test_warehouse", + "databricks_api_page_size": 100, + "azure_auth": { + "client_id": "test-client-id", + "tenant_id": "test-tenant-id", + "client_secret": "test-secret", + }, + } + + with patch( + "datahub.ingestion.source.unity.source.UnityCatalogConnectionTest" + ) as mock_connection_test: + mock_connection_test.return_value.get_connection_test.return_value = ( + "azure_test_report" + ) + + result = UnityCatalogSource.test_connection(config_dict) + + # Verify connection test was created with Azure auth config + assert result == "azure_test_report" + mock_connection_test.assert_called_once() + + # Get the config that was passed to UnityCatalogConnectionTest + connection_test_config = mock_connection_test.call_args[0][0] + assert connection_test_config.azure_auth is not None + assert connection_test_config.azure_auth.client_id == "test-client-id" + assert connection_test_config.azure_auth.tenant_id == "test-tenant-id" + assert connection_test_config.azure_auth.client_secret == "test-secret" + assert connection_test_config.databricks_api_page_size == 100 + assert ( + connection_test_config.token is None + ) # Should be None with Azure auth + + def test_source_creation_fails_without_authentication(self): + """Test that UnityCatalogSource creation fails when neither token nor azure_auth are provided.""" + # Test with neither token nor azure_auth provided + config_without_auth = UnityCatalogSourceConfig.parse_obj( + { + "workspace_url": "https://test.databricks.com", + "warehouse_id": "test_warehouse", + "include_hive_metastore": False, + "databricks_api_page_size": 100, + # Neither token nor azure_auth provided + } + ) + + ctx = PipelineContext(run_id="test_run") + + # Should raise ValueError when neither authentication method is provided + with pytest.raises(ValueError) as exc_info: + UnityCatalogSource.create(config_without_auth, ctx) + + assert "Either azure_auth or personal_access_token must be provided" in str( + exc_info.value + ) + + def test_test_connection_fails_without_authentication(self): + """Test that test_connection fails when neither token nor azure_auth are provided.""" + config_dict_without_auth = { + "workspace_url": "https://test.databricks.com", + "warehouse_id": "test_warehouse", + "databricks_api_page_size": 100, + # Neither token nor azure_auth provided + } + + # Should raise ValueError due to Databricks authentication failure + with pytest.raises(ValueError) as exc_info: + UnityCatalogSource.test_connection(config_dict_without_auth) + + # The actual error is from Databricks SDK trying to authenticate without credentials + assert "default auth: cannot configure default credentials" in str( + exc_info.value + ) + @patch("datahub.ingestion.source.unity.source.UnityCatalogApiProxy") @patch("datahub.ingestion.source.unity.source.HiveMetastoreProxy") def test_process_ml_model_generates_workunits(