diff --git a/multi_region_serving/.gitignore b/multi_region_serving/.gitignore new file mode 100644 index 0000000..364994d --- /dev/null +++ b/multi_region_serving/.gitignore @@ -0,0 +1,3 @@ +.vscode +.databricks +.scratch diff --git a/multi_region_serving/README.md b/multi_region_serving/README.md new file mode 100644 index 0000000..f6a904e --- /dev/null +++ b/multi_region_serving/README.md @@ -0,0 +1,60 @@ +# Multi-region Serving + +This Databricks Asset Bundle (DAB) is an example tool used to sync resources between main +workspaces and remote workspaces to simplify the workflow for serving models or features +across multiple regions. + +## How to use this example +1. Download this example + +2. Make changes as needed. Some files to highlight: + * databricks.yml - DAB bundle configuration including variable names and default values. + * src/manage_endpoint.ipynb - Notebook for create / update serving endpoints. + * src/manage_share.ipynb - Notebook for syncing dependencies of a shared model. + +## How to trigger the workflows + +1. Install the Databricks CLI from https://docs.databricks.com/dev-tools/cli/databricks-cli.html + +2. Authenticate to your Databricks workspaces, if you have not done so already: + ``` + $ databricks configure + ``` + +3. Validate bundle variables + + If you don't want to set a default value for any variables defined in `databricks.yaml`, you + need to provide the variables when running any commands. You can validate if all variables are + provided + ``` + $ MY_BUNDLE_VARS="share_name=,model_name=,model_version=,endpoint_name=,notification_email=" + $ databricks bundle validate --var=$MY_BUNDLE_VARS + ``` + +4. To deploy a copy to your main workspace: + ``` + $ databricks bundle deploy --target main --var=$MY_BUNDLE_VARS + ``` + (Note that "main" is the target name defined in databricks.yml) + + This deploys everything that's defined for this project. + For example, the default template would deploy a job called + `[dev yourname] manage_serving_job` to your workspace. + You can find that job by opening your workpace and clicking on **Workflows**. + +5. Similarly, to deploy a remote workspace, type: + ``` + $ databricks bundle -p deploy --target remote1 --var=$MY_BUNDLE_VARS + ``` + + Use `-p` to specify the databricks profile used by this command. The profile need to be + configured in `~/.databrickscfg`. + +6. To run the workflow to sync a share, use the "run" command: + ``` + $ databricks bundle -t main -p run manage_share_job --var=$MY_BUNDLE_VARS + ``` + +7. For documentation on the Databricks asset bundles format used + for this project, and for CI/CD configuration, see + https://docs.databricks.com/dev-tools/bundles/index.html. diff --git a/multi_region_serving/databricks.yml b/multi_region_serving/databricks.yml new file mode 100644 index 0000000..0fb576e --- /dev/null +++ b/multi_region_serving/databricks.yml @@ -0,0 +1,40 @@ +# This is a Databricks asset bundle definition for manage_serving. +# See https://docs.databricks.com/dev-tools/bundles/index.html for documentation. +bundle: + name: manage_serving + +variables: + notification_email: + description: Experiment name for the model training. + model_name: + description: Model name for the model training. + remote_model_name: + description: The model name in receipient workspace. This might be similar with the origional model name with a new catalog name in the receipient workspace. + model_version: + description: Model name for the model training. + endpoint_name: + description: Name of the endpoint to deploy. + share_name: + description: Name of the share. + +include: + - resources/*.yml + +targets: + main: + # The default target uses 'mode: development' to create a development copy. + # - Deployed resources get prefixed with '[dev my_user_name]' + # - Any job schedules and triggers are paused by default. + # See also https://docs.databricks.com/dev-tools/bundles/deployment-modes.html. + mode: development + default: true + workspace: + host: https://myworkspace.databricks.com + + remote1: + # The remote workspace that serves the model + mode: development + workspace: + host: https://myworkspace-remote.databricks.com + + diff --git a/multi_region_serving/requirements-dev.txt b/multi_region_serving/requirements-dev.txt new file mode 100644 index 0000000..0ffbf6a --- /dev/null +++ b/multi_region_serving/requirements-dev.txt @@ -0,0 +1,29 @@ +## requirements-dev.txt: dependencies for local development. +## +## For defining dependencies used by jobs in Databricks Workflows, see +## https://docs.databricks.com/dev-tools/bundles/library-dependencies.html + +## Add code completion support for DLT +databricks-dlt + +## pytest is the default package used for testing +pytest + +## Dependencies for building wheel files +setuptools +wheel + +## databricks-connect can be used to run parts of this project locally. +## See https://docs.databricks.com/dev-tools/databricks-connect.html. +## +## databricks-connect is automatically installed if you're using Databricks +## extension for Visual Studio Code +## (https://docs.databricks.com/dev-tools/vscode-ext/dev-tasks/databricks-connect.html). +## +## To manually install databricks-connect, either follow the instructions +## at https://docs.databricks.com/dev-tools/databricks-connect.html +## to install the package system-wide. Or uncomment the line below to install a +## version of db-connect that corresponds to the Databricks Runtime version used +## for this project. +# +# databricks-connect>=15.4,<15.5 diff --git a/multi_region_serving/resources/manage_serving.job.yml b/multi_region_serving/resources/manage_serving.job.yml new file mode 100644 index 0000000..bd7e2d7 --- /dev/null +++ b/multi_region_serving/resources/manage_serving.job.yml @@ -0,0 +1,18 @@ +resources: + jobs: + manage_serving_job: + name: manage_serving_job + email_notifications: + on_failure: + - ${var.notification_email} + tasks: + - task_key: notebook_task + notebook_task: + notebook_path: ../src/manage_endpoint.ipynb + parameters: + - name: endpoint_name + default: ${var.endpoint_name} + - name: model_name + default: ${var.remote_model_name} + - name: model_version + default: ${var.model_version} diff --git a/multi_region_serving/resources/manage_share.job.yml b/multi_region_serving/resources/manage_share.job.yml new file mode 100644 index 0000000..75c15af --- /dev/null +++ b/multi_region_serving/resources/manage_share.job.yml @@ -0,0 +1,18 @@ +resources: + jobs: + manage_share_job: + name: manage_share_job + email_notifications: + on_failure: + - ${var.notification_email} + tasks: + - task_key: notebook_task + notebook_task: + notebook_path: ../src/manage_share.ipynb + parameters: + - name: model_name + default: ${var.model_name} + - name: max_number_of_versions_to_sync + default: '10' + - name: share_name + default: ${var.share_name} diff --git a/multi_region_serving/scratch/README.md b/multi_region_serving/scratch/README.md new file mode 100644 index 0000000..e6cfb81 --- /dev/null +++ b/multi_region_serving/scratch/README.md @@ -0,0 +1,4 @@ +# scratch + +This folder is reserved for personal, exploratory notebooks. +By default these are not committed to Git, as 'scratch' is listed in .gitignore. diff --git a/multi_region_serving/src/lib/rest_client.py b/multi_region_serving/src/lib/rest_client.py new file mode 100644 index 0000000..44beac7 --- /dev/null +++ b/multi_region_serving/src/lib/rest_client.py @@ -0,0 +1,25 @@ +import urllib.request +import json +from databricks.sdk.runtime import spark + + +class RestClient: + def __init__(self, context): + self.base_url = "https://" + spark.conf.get("spark.databricks.workspaceUrl") + self.token = context.apiToken().get() + + def get_share_info(self, share_name: str): + return self._get( + f"api/2.1/unity-catalog/shares/{share_name}?include_shared_data=true" + ) + + def _get(self, uri): + url = f"{self.base_url}/{uri}" + headers = {"Authorization": f"Bearer {self.token}"} + req = urllib.request.Request(url, headers=headers) + try: + response = urllib.request.urlopen(req) + return json.load(response) + except urllib.error.HTTPError as e: + result = e.read().decode() + print((e.code, result)) diff --git a/multi_region_serving/src/manage_endpoint.py b/multi_region_serving/src/manage_endpoint.py new file mode 100644 index 0000000..98ab224 --- /dev/null +++ b/multi_region_serving/src/manage_endpoint.py @@ -0,0 +1,79 @@ +# Databricks notebook source +# MAGIC %md +# MAGIC # Create or Update Model Serving Endpoint +# MAGIC +# MAGIC Create or Update the deployed serving endpoints with a new model version. +# MAGIC +# MAGIC * Make sure you've created online tables for all the required feature tables. +# MAGIC * Run this job on the workspace where you want to serve the model. + +# COMMAND ---------- + +# MAGIC %pip install databricks-sdk>=0.38.0 +# MAGIC %restart_python + +# COMMAND ---------- + +dbutils.widgets.text("endpoint_name", defaultValue="") +dbutils.widgets.text("model_name", defaultValue="") +dbutils.widgets.text("model_version", defaultValue="") + +# COMMAND ---------- + +ARGS = dbutils.widgets.getAll() + +endpoint_name = ARGS["endpoint_name"] +model_name = ARGS["model_name"] +model_version = ARGS["model_version"] + +# COMMAND ---------- + +from databricks.sdk import WorkspaceClient +from databricks.sdk.service.serving import ServedEntityInput, EndpointCoreConfigInput +from databricks.sdk.errors import ResourceDoesNotExist + +workspace = WorkspaceClient() + +# COMMAND ---------- + +try: + endpoint = workspace.serving_endpoints.get(name=endpoint_name) +except ResourceDoesNotExist as e: + endpoint = None + +if endpoint is None: + workspace.serving_endpoints.create( + name=endpoint_name, + config=EndpointCoreConfigInput( + served_entities=[ + ServedEntityInput( + entity_name=model_name, + entity_version=model_version, + scale_to_zero_enabled=True, + workload_size="Small", + ) + ] + ), + ) + print(f"Created endpoint {endpoint_name}") +elif endpoint.pending_config is not None: + print(f"A pending update for endpoint {endpoint_name} is being processed.") +elif ( + endpoint.config.served_entities[0].entity_name != model_name + or endpoint.config.served_entities[0].entity_version != model_version +): + # Update endpoint + workspace.serving_endpoints.update_config( + name=endpoint_name, + served_entities=[ + ServedEntityInput( + entity_name=model_name, + entity_version=model_version, + scale_to_zero_enabled=True, + workload_size="Small", + ) + ], + ) + print(f"Updated endpoint {endpoint_name}") +else: + print("Endpoint already up-to-date") diff --git a/multi_region_serving/src/manage_share.py b/multi_region_serving/src/manage_share.py new file mode 100644 index 0000000..c7fd5cf --- /dev/null +++ b/multi_region_serving/src/manage_share.py @@ -0,0 +1,198 @@ +# Databricks notebook source +# MAGIC %md +# MAGIC # Sync the share and add all required resources +# MAGIC +# MAGIC Add a model and all it's dependencies to the given Share. +# MAGIC +# MAGIC Prerequisit: +# MAGIC +# MAGIC * Create a share in [delta-sharing](https://docs.databricks.com/en/delta-sharing/create-share.html). +# MAGIC * Config the default parameters in resources/manage_share.job.yml + +# COMMAND ---------- + +# MAGIC %pip install databricks-sdk>=0.38.0 +# MAGIC %restart_python + +# COMMAND ---------- + +dbutils.widgets.text("model_name", defaultValue="") +dbutils.widgets.text("share_name", defaultValue="") +dbutils.widgets.text("max_number_of_versions_to_sync", defaultValue="10") + +# COMMAND ---------- + +from databricks.sdk import WorkspaceClient +from databricks.sdk.service.serving import ServedEntityInput +from databricks.sdk.service.sharing import ( + SharedDataObjectUpdate, + SharedDataObjectUpdateAction, + SharedDataObjectDataObjectType, + SharedDataObject, + SharedDataObjectHistoryDataSharingStatus, +) + +workspace = WorkspaceClient() + +# COMMAND ---------- + +model_name = dbutils.widgets.get("model_name") +share_name = dbutils.widgets.get("share_name") +max_number_of_versions_to_sync = int( + dbutils.widgets.get("max_number_of_versions_to_sync") +) + +print("~~~ parameters ~~~") +print(f"Model name: {model_name}") +print(f"Share name: {share_name}") +print(f"Max number of versions to sync: {max_number_of_versions_to_sync}") + +# COMMAND ---------- + + +def getLatestVersions(model_name: str, max_number_of_versions: int): + versions = workspace.model_versions.list( + full_name=model_name, + ) + result = [] + for version in versions: + result.append( + workspace.model_versions.get(full_name=model_name, version=version.version) + ) + return result + + +def getDependencies(model_versions): + tables = set() + functions = set() + for version in model_versions: + for dependency in version.model_version_dependencies.dependencies: + if dependency.table is not None: + tables.add(dependency.table.table_full_name) + elif dependency.function is not None: + functions.add(dependency.function.function_full_name) + return tables, functions + + +# COMMAND ---------- + +versions = getLatestVersions(model_name, max_number_of_versions_to_sync) +tableDependencies, functionDependencies = getDependencies(versions) + +# COMMAND ---------- + +from lib.rest_client import RestClient + +notebook_context = dbutils.notebook.entry_point.getDbutils().notebook().getContext() +rc = RestClient(notebook_context) + +# COMMAND ---------- + +shareInfo = rc.get_share_info(share_name) + +# COMMAND ---------- + +sharedTables = sharedFunctions = sharedSchemas = sharedModels = set() +model_is_shared = False + +if "objects" in shareInfo: + sharedTables = set( + [ + obj["name"] + for obj in filter( + lambda obj: obj["data_object_type"] == "TABLE", shareInfo["objects"] + ) + ] + ) + sharedFunctions = set( + [ + obj["name"] + for obj in filter( + lambda obj: obj["data_object_type"] == "FUNCTION", shareInfo["objects"] + ) + ] + ) + sharedSchemas = set( + [ + obj["name"] + for obj in filter( + lambda obj: obj["data_object_type"] == "SCHEMA", shareInfo["objects"] + ) + ] + ) + sharedModels = set( + [ + obj["name"] + for obj in filter( + lambda obj: obj["data_object_type"] == "MODEL", shareInfo["objects"] + ) + ] + ) + model_is_shared = model_name in sharedModels + +# COMMAND ---------- + + +def getSchema(full_name): + name_sections = full_name.split(".") + return f"{name_sections[0]}.{name_sections[1]}" + + +def getObjectsToAdd(dependencies, sharedObjects, sharedSchemas): + newDependencies = dependencies - sharedObjects + return list(filter(lambda x: getSchema(x) not in sharedSchemas, newDependencies)) + + +# COMMAND ---------- + +tablesToAdd = getObjectsToAdd(tableDependencies, sharedTables, sharedSchemas) +functionsToAdd = getObjectsToAdd(functionDependencies, sharedFunctions, sharedSchemas) + +updates = [] + +for table in tablesToAdd: + updates.append( + SharedDataObjectUpdate( + action=SharedDataObjectUpdateAction.ADD, + data_object=SharedDataObject( + name=table, + data_object_type=SharedDataObjectDataObjectType.TABLE, + history_data_sharing_status=SharedDataObjectHistoryDataSharingStatus.ENABLED, + ), + ) + ) + + +for function in functionsToAdd: + updates.append( + SharedDataObjectUpdate( + action=SharedDataObjectUpdateAction.ADD, + data_object=SharedDataObject( + name=function, data_object_type=SharedDataObjectDataObjectType.FUNCTION + ), + ) + ) + +if not model_is_shared: + updates.append( + SharedDataObjectUpdate( + action=SharedDataObjectUpdateAction.ADD, + data_object=SharedDataObject( + name=model_name, data_object_type=SharedDataObjectDataObjectType.MODEL + ), + ) + ) + + +def print_update_summary(updates): + for update in updates: + print( + f"{update.action.value} {update.data_object.data_object_type} {update.data_object.name}" + ) + + +if updates: + print_update_summary(updates) + workspace.shares.update(name=share_name, updates=updates) +else: + print("The share is already up-to-date.")