From a92e65e68878bdf5bfae9d12c72e004aaa34ca50 Mon Sep 17 00:00:00 2001 From: gutt02 Date: Mon, 13 Jan 2025 09:29:55 +0100 Subject: [PATCH] Bulk download of Databricks users and service principals. --- azure_dbr_scim_sync/cli.py | 9 ++++++++- azure_dbr_scim_sync/scim.py | 37 +++++++++++++++++++++++++++++++------ 2 files changed, 39 insertions(+), 7 deletions(-) diff --git a/azure_dbr_scim_sync/cli.py b/azure_dbr_scim_sync/cli.py index 41a23ee..c8a7ea6 100644 --- a/azure_dbr_scim_sync/cli.py +++ b/azure_dbr_scim_sync/cli.py @@ -20,6 +20,12 @@ help="verbose information about changes", show_default=True) @click.option('--debug', default=False, is_flag=True, help="more verbose, shows API calls", show_default=True) +@click.option( + '--bulk-download', + default=False, + is_flag=True, + help="bulk download of users and service principals from the Databricks account, better performance with a large number of users and service principals", + show_default=True) @click.option( '--dry-run-security-principals', default=False, @@ -72,7 +78,7 @@ help="include mail-enabled Entra groups in the sync") def sync_cli(groups_json_file, verbose, debug, dry_run_security_principals, dry_run_members, worker_threads, save_graph_response_json, query_graph_only, group_search_depth, full_sync, - graph_change_feed_grace_time, include_non_security_groups, include_mail_enabled_groups): + graph_change_feed_grace_time, include_non_security_groups, include_mail_enabled_groups, bulk_download): install_logger() logger = logging.getLogger('sync') @@ -128,6 +134,7 @@ def sync_cli(groups_json_file, verbose, debug, dry_run_security_principals, dry_ groups=[x.to_sdk_group() for x in stuff_to_sync.groups.values()], service_principals=[x.to_sdk_service_principal() for x in stuff_to_sync.service_principals.values()], deep_sync_group_names=list(stuff_to_sync.deep_sync_group_names), + bulk_download=bulk_download, dry_run_security_principals=dry_run_security_principals, dry_run_members=dry_run_members, worker_threads=worker_threads) diff --git a/azure_dbr_scim_sync/scim.py b/azure_dbr_scim_sync/scim.py index 1e421ed..1b9ac95 100644 --- a/azure_dbr_scim_sync/scim.py +++ b/azure_dbr_scim_sync/scim.py @@ -126,17 +126,20 @@ def effecitve_change_count(self): 'user': { 'key_obj_field': 'user_name', 'key_api_field': 'userName', - 'cache': user_cache + 'cache': user_cache, + 'objs_by_id': None }, 'group': { 'key_obj_field': 'display_name', 'key_api_field': 'displayName', - 'cache': group_cache + 'cache': group_cache, + 'objs_by_id': None }, 'spn': { 'key_obj_field': 'application_id', 'key_api_field': 'applicationId', - 'cache': spn_cache + 'cache': spn_cache, + 'objs_by_id': None } } @@ -145,6 +148,7 @@ def _generic_get_by_human_name(mapper, sdk_module, search_name): cache = mapper['cache'] key_obj_field = mapper['key_obj_field'] key_api_field = mapper['key_api_field'] + objs_by_id = mapper['objs_by_id'] cached_id = cache[search_name] obj = None @@ -152,8 +156,11 @@ def _generic_get_by_human_name(mapper, sdk_module, search_name): if cached_id: # verify cache try: - obj = sdk_module.get(cached_id) - if obj.__dict__[key_obj_field] == search_name: + if objs_by_id is not None: + obj = objs_by_id[cached_id] + else: + obj = sdk_module.get(cached_id) + if obj.__dict__[key_obj_field].casefold() == search_name.casefold(): # hit! logger.debug(f"Cache hit: {search_name=}, {obj.id=}") return obj @@ -196,7 +203,6 @@ def _generic_create_or_update(mapper, desired: T, actual: T, compare_fields: Lis ResultClass = MergeResult[T] cache = mapper['cache'] key_obj_field = mapper['key_obj_field'] - mapper['key_api_field'] desired = deepcopy(desired) desired_dict = deepcopy(desired.as_dict()) @@ -453,10 +459,29 @@ def sync(*, groups: Iterable[iam.Group], service_principals: Iterable[iam.ServicePrincipal], deep_sync_group_names: Iterable[str], + bulk_download=False, dry_run_security_principals=False, dry_run_members=False, worker_threads: int = 10): + global _generic_type_map + + if not bulk_download: + logger.info("Skipping bulk downloading of Databricks groups, users and service principals") + else: + # API does not include group members! Could only be used if dry_run_members is true. + logger.info("Skipping bulk downloading of Databricks groups...") + # _generic_type_map['group']['objs_by_id'] = {x.id : x for x in account_client.groups.list()} + # logger.info(f"Downloaded: groups={_generic_type_map['group']['objs_by_id']}") + + logger.info("Bulk downloading of Databricks users...") + _generic_type_map['user']['objs_by_id'] = {x.id : x for x in account_client.users.list()} + logger.info(f"Downloaded: users={len(_generic_type_map['user']['objs_by_id'])}") + + logger.info("Bulk downloading of Databricks service principals...") + _generic_type_map['spn']['objs_by_id'] = {x.id : x for x in account_client.service_principals.list()} + logger.info(f"Downloaded: service_principals={len(_generic_type_map['spn']['objs_by_id'])}") + logger.info("Starting creating or updating users, groups and service principals...") result = ScimSyncObject(users=create_or_update_users(account_client, users,