diff --git a/.github/workflows/extensions.yml b/.github/workflows/extensions.yml index 97218060..d38674d1 100644 --- a/.github/workflows/extensions.yml +++ b/.github/workflows/extensions.yml @@ -59,6 +59,7 @@ jobs: simple-shiny-chat-with-mcp: extensions/simple-shiny-chat-with-mcp/** chat-with-content: extensions/chat-with-content/** pqr: extensions/pqr/** + mlflow-tracking-server: extensions/mlflow-tracking-server/** # Runs for each extension that has changed from `simple-extension-changes` # Lints and packages in preparation for tests and and release. diff --git a/extensions.json b/extensions.json index 0952b70d..552f4a10 100644 --- a/extensions.json +++ b/extensions.json @@ -19,7 +19,13 @@ "llm", "chat", "quarto", - "r" + "r", + "mlflow", + "mlops", + "tracking", + "experiments", + "models", + "model serving" ], "requiredFeatures": [ "API Publishing", @@ -1471,4 +1477,4 @@ "category": "example" } ] -} \ No newline at end of file +} diff --git a/extensions/mlflow-python-model-server-example/.gitignore b/extensions/mlflow-python-model-server-example/.gitignore new file mode 100644 index 00000000..431d3059 --- /dev/null +++ b/extensions/mlflow-python-model-server-example/.gitignore @@ -0,0 +1,4 @@ +rsconnect-python +.venv +.env +.python-version diff --git a/extensions/mlflow-python-model-server-example/README.md b/extensions/mlflow-python-model-server-example/README.md new file mode 100644 index 00000000..c8879e97 --- /dev/null +++ b/extensions/mlflow-python-model-server-example/README.md @@ -0,0 +1,332 @@ +# MLflow Model API Server for Posit Connect + +Deploy your MLflow models to Posit Connect as a REST API with automatic documentation and multiple input format support. + +## Overview + +This example demonstrates how to serve MLflow models on Posit Connect through a FastAPI application. It uses MLflow's PyFunc scoring server and provides standard MLflow endpoints with enhanced OpenAPI documentation, making it easy to deploy and consume machine learning models in production on Connect. + +## Features + +✨ **Posit Connect Integration** - Seamless deployment with automatic authentication +🔄 **Multiple Input Formats** - JSON, CSV, and TensorFlow Serving formats +📚 **Interactive Docs** - Auto-generated Swagger UI and ReDoc +🎯 **Dynamic Schema** - Examples based on your model's signature +🔐 **Automatic Authentication** - Uses publisher's API key for MLflow access + +## Quick Start for Connect Deployment + +### Prerequisites + +- Python 3.8+ +- MLflow model (trained and logged to MLflow) +- Access to a Posit Connect server +- Access to an MLflow tracking server (can be hosted on the same Connect instance) + +### 1. Install API Dependencies + +```bash +pip install -r requirements.txt +``` + +### 2. ⚠️ **Critical: Update requirements.txt with Model Dependencies** + +**Before deploying to Connect, you must add your model's dependencies to `requirements.txt`.** + +```bash +# Extract and append model requirements +mlflow artifacts download -u "models:/my-model/Production" -d ./model_artifacts +cat ./model_artifacts/requirements.txt >> requirements.txt + +# Clean up duplicates if needed +sort -u requirements.txt -o requirements.txt +``` + +**Why is this critical for Connect?** +- Connect installs dependencies from `requirements.txt` when deploying +- Your model needs libraries like scikit-learn, TensorFlow, PyTorch, etc. +- Without these, the deployment will fail with `ImportError` + +### 3. Test Locally (Optional but Recommended) + +```bash +# Install model dependencies in your local environment +./install_model_deps.sh "models:/my-model/Production" + +# Test locally +export MODEL_URI="models:/my-model/Production" +python app.py +``` + +Visit `http://localhost:8000/docs` to verify everything works. + +### 4. Deploy to Posit Connect + +```bash +# Deploy using rsconnect-python +rsconnect deploy fastapi \ + --server https://your-connect-server.com \ + --api-key your-api-key \ + --title "My Model API" \ + . +``` + +### 5. Configure in Connect + +After deployment: +1. Navigate to your content in Connect +2. Go to the **Vars** panel in content settings +3. Add environment variable: `MODEL_URI` = `models:/my-model/Production` +4. Save and restart the content + +**That's it!** Connect will automatically: +- Use your API key to authenticate with MLflow +- Load the model on startup +- Serve predictions at the `/invocations` endpoint +- Provide interactive documentation at `/docs` + +## Configuration for Connect + +The API automatically detects when running on Posit Connect and configures itself accordingly. + +### Required Environment Variable (Set in Connect) + +| Variable | Description | Example | +|----------|-------------|---------| +| `MODEL_URI` | MLflow model location | `models:/my-model/Production` | + +### Optional Variables (Auto-configured on Connect) + +| Variable | Default on Connect | Description | +|----------|-------------------|-------------| +| `MLFLOW_TRACKING_URI` | `{CONNECT_SERVER}/mlflow` | MLflow server URL | +| `MLFLOW_TRACKING_TOKEN` | Publisher's API key | Authentication token | +| `CONNECT_SERVER` | Auto-detected | Connect server URL | + +**Model URI Formats:** +- Registry: `models:/my-model/Production` or `models:/my-model/1` +- Run: `runs://model` + +## API Endpoints + +Once deployed, your Connect content will expose: + +### Interactive Documentation + +| URL | Description | +|-----|-------------| +| `https://connect-server/content/{guid}/docs` | Swagger UI - Try the API interactively | +| `https://connect-server/content/{guid}/redoc` | ReDoc - Clean documentation view | + +### Making Predictions + +**POST** `/invocations` - Multiple input formats supported: + +```bash +# JSON (dataframe_split) +curl -X POST https://connect-server/content/{guid}/invocations \ + -H "Content-Type: application/json" \ + -d '{ + "dataframe_split": { + "columns": ["feature1", "feature2"], + "data": [[1.0, 2.0]] + } + }' + +# JSON (dataframe_records) +curl -X POST https://connect-server/content/{guid}/invocations \ + -H "Content-Type: application/json" \ + -d '{ + "dataframe_records": [ + {"feature1": 1.0, "feature2": 2.0} + ] + }' + +# CSV +curl -X POST https://connect-server/content/{guid}/invocations \ + -H "Content-Type: text/csv" \ + --data-binary @input.csv +``` + +### Health Checks + +```bash +curl https://connect-server/content/{guid}/ping # Quick check +curl https://connect-server/content/{guid}/health # Detailed status +curl https://connect-server/content/{guid}/version # MLflow version +``` + +## Updating requirements.txt for Connect Deployment + +**This is the most important step for successful Connect deployment.** + +The `requirements.txt` file initially contains only API dependencies (FastAPI, uvicorn, MLflow client). Before deploying to Connect, you must add your model's specific dependencies. + +### Recommended Approach + +```bash +# 1. Download model artifacts to see what your model needs +mlflow artifacts download -u "$MODEL_URI" -d ./model_artifacts + +# 2. Append model requirements to your requirements.txt +cat ./model_artifacts/requirements.txt >> requirements.txt + +# 3. Remove duplicates and sort +sort -u requirements.txt -o requirements.txt + +# 4. Verify the file looks correct +cat requirements.txt +``` + +### What to Include + +Your final `requirements.txt` should contain: +- **API dependencies** (already included): `fastapi`, `uvicorn`, `mlflow` +- **Model dependencies** (you must add): `scikit-learn`, `tensorflow`, `pytorch`, etc. +- **Any custom libraries** your model needs + +### Example Final requirements.txt + +``` +fastapi>=0.104.0 +uvicorn>=0.24.0 +mlflow>=2.9.0 +scikit-learn==1.3.0 +pandas==2.0.0 +numpy==1.24.0 +``` + +**⚠️ Critical:** Every time you change models or model versions, verify the dependencies haven't changed and update `requirements.txt` before redeploying to Connect. + +## Testing Before Deployment + +Always test locally before deploying to Connect: + +```bash +# 1. Install all dependencies (API + model) +pip install -r requirements.txt + +# 2. Set environment variables +export MODEL_URI="models:/my-model/Production" +export MLFLOW_TRACKING_URI="https://your-mlflow-server.com" +export MLFLOW_TRACKING_TOKEN="your-token" + +# 3. Run locally +python app.py + +# 4. Test in another terminal +curl http://localhost:8000/ping +curl -X POST http://localhost:8000/invocations \ + -H "Content-Type: application/json" \ + -d '{"dataframe_split": {"columns": ["feature1"], "data": [[1.0]]}}' +``` + +## Testing with Sample Model + +Want to test the deployment process with a sample model first? + +```bash +# 1. Train and log a sample model +python example_test.py + +# 2. Update requirements.txt with the sample model's dependencies +./install_model_deps.sh 'runs://model' +mlflow artifacts download -u 'runs://model' -d ./model_artifacts +cat ./model_artifacts/requirements.txt >> requirements.txt + +# 3. Test locally +export MODEL_URI='runs://model' +python app.py + +# 4. Deploy to Connect +rsconnect deploy fastapi --server https://connect.example.com --api-key your-key . + +# 5. Set MODEL_URI in Connect settings to 'runs://model' +``` + +## Connect-Specific Features + +### Automatic Authentication + +When deployed to Connect: +- `MLFLOW_TRACKING_TOKEN` automatically uses your Connect API key +- No additional configuration needed if MLflow is on the same Connect server +- Seamless integration with Connect-hosted MLflow tracking servers + +### Content Settings + +Configure your deployment in Connect: +- **Vars**: Set `MODEL_URI` and optional environment variables +- **Access**: Control who can access your API +- **Runtime**: Adjust min/max processes based on load +- **Logs**: Monitor model loading and prediction requests + +### Vanity URLs + +For cleaner API endpoints, set up a vanity URL in Connect: +``` +https://connect-server/my-model-api/docs +``` + +## Updating Models in Connect + +When you promote a new model version in MLflow: + +1. **Update requirements.txt if dependencies changed**: +```bash +mlflow artifacts download -u "$MODEL_URI" -d ./model_artifacts +cat ./model_artifacts/requirements.txt >> requirements.txt +sort -u requirements.txt -o requirements.txt +``` + +2. **Redeploy to Connect**: +```bash +rsconnect deploy fastapi --server https://connect.example.com --api-key your-key . +``` + +3. **Or just update MODEL_URI in Connect settings** (if dependencies unchanged): + - Go to Vars panel + - Update `MODEL_URI` to point to new version + - Restart the content + +## Troubleshooting Connect Deployments + +| Error | Solution | +|-------|----------| +| `ImportError` during deployment | Update `requirements.txt` with model dependencies | +| `Model not found` | Verify `MODEL_URI` in Connect settings | +| Authentication fails | Ensure MLflow server is accessible from Connect | +| Deployment bundle too large | Use `models:/` URI instead of including model files | + +### Checking Logs in Connect + +1. Navigate to your content in Connect +2. Click on the **Logs** tab +3. Look for: + - Model loading messages + - Dependency installation + - Error messages with stack traces + +## Best Practices for Connect + +1. **Pin dependency versions** in `requirements.txt` for reproducibility +2. **Test locally** before deploying to Connect +3. **Use registered models** (`models:/`) rather than runs for production +4. **Monitor logs** in Connect after deployment +5. **Set appropriate process limits** based on model size and load +6. **Use vanity URLs** for cleaner, more stable API endpoints + +## Resources + +- **Posit Connect Docs**: https://docs.posit.co/connect/ +- **MLflow Docs**: https://mlflow.org/docs/latest/ +- **FastAPI Docs**: https://fastapi.tiangolo.com +- **rsconnect-python**: https://github.com/rstudio/rsconnect-python + +## Support + +For Connect deployment issues: +- Check the troubleshooting section above +- Review Connect logs for error details +- Verify `requirements.txt` includes all model dependencies +- Test locally before deploying diff --git a/extensions/mlflow-python-model-server-example/app.py b/extensions/mlflow-python-model-server-example/app.py new file mode 100644 index 00000000..e9b8cba5 --- /dev/null +++ b/extensions/mlflow-python-model-server-example/app.py @@ -0,0 +1,175 @@ +""" +MLflow Model API Server + +This module provides a FastAPI application that serves MLflow models with automatic +OpenAPI documentation generation. + +Usage: + Set the MODEL_URI environment variable to your MLflow model location: + + export MODEL_URI="models:/my-model/production" + python app.py + + Or with a local model: + + export MODEL_URI="./my-model" + python app.py + +Environment Variables: + MODEL_URI (required): The MLflow model URI to load + Examples: + - "models:/my-model/production" + - "runs:/abc123/model" + - "./local-model-path" + + CONNECT_SERVER (optional): Posit Connect server URL + MLFLOW_TRACKING_URI (optional): MLflow tracking server URI + MLFLOW_TRACKING_TOKEN (optional): Authentication token for MLflow + HOST (optional): Server host (default: "0.0.0.0") + PORT (optional): Server port (default: 8000) + +Endpoints: + GET /ping - Health check endpoint + GET /health - Detailed health status + GET /version - Model version information + POST /invocations - Model prediction endpoint +""" + +import os +import logging +import mlflow +from mlflow.pyfunc.scoring_server import load_model, init +from openapi_schema import ( + extract_model_signature, + configure_openapi_metadata, + generate_openapi_schema +) + + +# Configure logging for production use +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) +logger = logging.getLogger(__name__) + +# Configure MLflow tracking connection +CONNECT_SERVER = os.getenv('CONNECT_SERVER') +if CONNECT_SERVER: + # Default to Connect server's MLflow endpoint and API key if available + MLFLOW_TRACKING_URI = os.getenv('MLFLOW_TRACKING_URI', f"{CONNECT_SERVER}/mlflow") + os.environ["MLFLOW_TRACKING_TOKEN"] = os.getenv( + 'MLFLOW_TRACKING_TOKEN', + os.getenv('CONNECT_API_KEY', "") + ) +else: + MLFLOW_TRACKING_URI = os.getenv('MLFLOW_TRACKING_URI', "") + +if MLFLOW_TRACKING_URI: + mlflow.set_tracking_uri(MLFLOW_TRACKING_URI) + logger.info(f"MLflow tracking URI set to: {MLFLOW_TRACKING_URI}") + + +def get_model_uri() -> str: + """ + Retrieve the model URI from environment variables. + + Returns: + str: The MLflow model URI + + Raises: + ValueError: If MODEL_URI environment variable is not set + """ + model_uri = os.environ.get("MODEL_URI") + if not model_uri: + raise ValueError( + "MODEL_URI environment variable must be set. " + "Example: export MODEL_URI='models:/my-model/production'" + ) + return model_uri + + +def create_app(): + """ + Create and configure the FastAPI application. + + This function: + 1. Loads the MLflow model from the specified URI + 2. Extracts model signature for API documentation + 3. Initializes FastAPI with standard MLflow endpoints + 4. Configures OpenAPI schema with model-specific metadata + + Returns: + FastAPI: Configured FastAPI application instance + + Raises: + ValueError: If MODEL_URI is not set + Exception: If model loading fails + """ + try: + model_uri = get_model_uri() + logger.info(f"Initializing API with model URI: {model_uri}") + + # Load the MLflow model + logger.info("Loading MLflow model...") + pyfunc_model = load_model(model_uri) + logger.info("Model loaded successfully") + + # Extract model signature for OpenAPI documentation + model_signature_info, example_input, example_output = extract_model_signature( + pyfunc_model + ) + + # Initialize the FastAPI app with MLflow's standard endpoints + app = init(pyfunc_model) + + # Enhance OpenAPI documentation with model-specific metadata + configure_openapi_metadata( + app, + model_uri, + MLFLOW_TRACKING_URI, + model_signature_info + ) + + # Log available endpoints for debugging + logger.info("API endpoints registered:") + for route in app.routes: + if hasattr(route, "path") and hasattr(route, "methods"): + logger.info(f" {', '.join(route.methods):6s} {route.path}") + + # Generate custom OpenAPI schema with model signature information + # Note: MLflow uses @app.route() which doesn't integrate with FastAPI's + # automatic OpenAPI generation, so we build the schema manually + app.openapi = generate_openapi_schema( + app, + model_uri, + MLFLOW_TRACKING_URI, + model_signature_info, + example_input, + example_output + ) + + logger.info("FastAPI application initialized successfully") + return app + + except Exception as e: + logger.error(f"Failed to initialize application: {str(e)}") + raise + + +# Create the FastAPI application instance +# This is the ASGI application entry point for servers like uvicorn +app = create_app() + + +if __name__ == "__main__": + import uvicorn + + # Get server configuration from environment variables + host = os.environ.get("HOST", "0.0.0.0") + port = int(os.environ.get("PORT", "8000")) + + logger.info(f"Starting MLflow Model API server on {host}:{port}") + logger.info(f"API documentation available at http://{host}:{port}/docs") + + uvicorn.run(app, host=host, port=port) diff --git a/extensions/mlflow-python-model-server-example/install_model_deps.sh b/extensions/mlflow-python-model-server-example/install_model_deps.sh new file mode 100755 index 00000000..b1f673a3 --- /dev/null +++ b/extensions/mlflow-python-model-server-example/install_model_deps.sh @@ -0,0 +1,38 @@ +#!/bin/bash +# Helper script to install model dependencies from MLflow + +set -e + +if [ -z "$1" ]; then + echo "Usage: $0 " + echo "" + echo "Examples:" + echo " $0 models:/my-model/Production" + echo " $0 runs:/abc123/model" + echo "" + echo "This script will:" + echo " 1. Download the model artifacts from MLflow" + echo " 2. Extract the model's requirements.txt" + echo " 3. Install the dependencies" + exit 1 +fi + +MODEL_URI=$1 +TEMP_DIR=$(mktemp -d) + +echo "==> Downloading model artifacts from: $MODEL_URI" +mlflow artifacts download -u "$MODEL_URI" -d "$TEMP_DIR" + +if [ -f "$TEMP_DIR/requirements.txt" ]; then + echo "==> Found model requirements.txt" + echo "==> Installing model dependencies..." + pip install -r "$TEMP_DIR/requirements.txt" + echo "==> Model dependencies installed successfully!" +else + echo "WARNING: No requirements.txt found in model artifacts" + echo "The model may not have any additional dependencies, or they may be embedded differently" +fi + +# Cleanup +rm -rf "$TEMP_DIR" +echo "==> Done!" diff --git a/extensions/mlflow-python-model-server-example/manifest.json b/extensions/mlflow-python-model-server-example/manifest.json new file mode 100644 index 00000000..ca5a6b3e --- /dev/null +++ b/extensions/mlflow-python-model-server-example/manifest.json @@ -0,0 +1,55 @@ +{ + "version": 1, + "locale": "en_US.UTF-8", + "metadata": { + "appmode": "python-fastapi", + "entrypoint": "app" + }, + "python": { + "version": "3.12.7", + "package_manager": { + "name": "pip", + "version": "24.2", + "package_file": "requirements.txt" + } + }, + "environment": { + "python": { + "requires": ">=3.11.0" + } + }, + "extension": { + "name": "mlflow-python-model-server-example", + "title": "MLflow Python Model Server Example", + "description": "An example extension for serving MLflow models using FastAPI.", + "homepage": "https://github.com/posit-dev/connect-extensions/tree/main/extensions/mlflow-python-model-server-example", + "category": "extension", + "tags": ["mlops", "mlflow", "python", "fastapi", "model serving"], + "minimumConnectVersion": "2025.04.0", + "requiredFeatures": [], + "version": "0.0.0" + }, + "files": { + "requirements.txt": { + "checksum": "81e2f809ed393d7c2a9d01269aad37dc" + }, + ".gitignore": { + "checksum": "693ec79eaa892babde62587aaacf0d8b" + }, + ".python-version": { + "checksum": "61939823fbcc7e8de3b61ab105331baf" + }, + "README.md": { + "checksum": "4449932182cbea9e598c5db783c8c83a" + }, + "app.py": { + "checksum": "2fa4fdb34daad9df2fa2cf6c4e14e56f" + }, + "install_model_deps.sh": { + "checksum": "81f8684f8ef26d9af5e07cd30f18a75c" + }, + "openapi_schema.py": { + "checksum": "1f919dd6dd45043ceb8c7f1b1c937fe1" + } + } +} diff --git a/extensions/mlflow-python-model-server-example/openapi_schema.py b/extensions/mlflow-python-model-server-example/openapi_schema.py new file mode 100644 index 00000000..3e82821f --- /dev/null +++ b/extensions/mlflow-python-model-server-example/openapi_schema.py @@ -0,0 +1,697 @@ +""" +OpenAPI Schema Generator for MLflow Model API on Posit Connect + +This module generates comprehensive OpenAPI documentation for MLflow models deployed +to Posit Connect by: +- Extracting input/output schemas from model metadata +- Creating dynamic examples based on data types +- Generating complete OpenAPI 3.1.0 specifications +- Providing interactive documentation support via /docs and /redoc + +The generated schema powers the interactive API documentation available to consumers +of your deployed model on Connect. +""" + +import os +import logging + +logger = logging.getLogger(__name__) + + +def extract_model_signature(pyfunc_model): + """ + Extract input and output schemas from an MLflow model's metadata. + + This function analyzes the model's signature to understand: + - What columns/features the model expects as input + - The data types of those inputs + - What the model outputs + - The data types of outputs + + The extracted schema is used to generate helpful examples in the API documentation + that consumers will see when accessing the deployed API on Connect. + + Args: + pyfunc_model: Loaded MLflow PyFuncModel instance + + Returns: + tuple: A 3-element tuple containing: + - model_signature_info (str): Human-readable schema description + - example_input (dict): Sample input data in multiple formats + - example_output (dict): Sample output data + + Example: + >>> signature, input_ex, output_ex = extract_model_signature(model) + >>> print(signature) + **Model Input Schema:** + ```json + { + "feature1": "double", + "feature2": "double" + } + ``` + """ + # Extract schemas from model metadata + input_schema = pyfunc_model.metadata.get_input_schema() + output_schema = pyfunc_model.metadata.get_output_schema() if hasattr( + pyfunc_model.metadata, 'get_output_schema') else None + + model_signature_info = None + example_input = None + example_output = None + + # Process input schema if available + if input_schema: + try: + column_names = input_schema.input_names() if hasattr(input_schema, 'input_names') else [] + column_types = input_schema.input_types() if hasattr(input_schema, 'input_types') else [] + + if column_names: + logger.info(f"Extracted model input schema with {len(column_names)} columns") + logger.info(f"Column names: {column_names}") + logger.info(f"Column types: {column_types}") + + # Check if column names are integers (unnamed columns) + has_integer_columns = all(isinstance(name, (int, str)) and str(name).isdigit() for name in column_names) + + if has_integer_columns: + logger.info("Model uses integer/index-based column names (model may not have named features)") + + # Generate realistic example data based on column types + example_data = [] + example_record = {} + + for i, col_name in enumerate(column_names): + # Get type-appropriate example value + if i < len(column_types): + col_type = str(column_types[i]) + example_value = _generate_example_value(col_type) + else: + # Default to float if type unknown + example_value = 1.0 + + example_record[str(col_name)] = example_value + example_data.append(example_value) + + # Create examples in formats that MLflow accepts + # Convert column names to strings to handle integer column names + string_column_names = [str(name) for name in column_names] + + example_input = { + 'dataframe_split': { + 'columns': string_column_names, + 'data': [example_data] + }, + 'dataframe_records': [example_record], + 'column_names': string_column_names, + 'column_types': [str(t) for t in column_types] if column_types else [], + 'has_integer_columns': has_integer_columns, + 'num_features': len(column_names) + } + + # Format schema as human-readable JSON with note about integer columns + model_signature_info = _format_schema_as_json( + "Input", + column_names, + column_types, + has_integer_columns + ) + + except Exception as e: + logger.warning(f"Could not extract input schema: {e}") + else: + logger.info("Model has no input schema defined") + + # Process output schema if available + if output_schema: + try: + output_names = output_schema.input_names() if hasattr(output_schema, 'input_names') else [] + output_types = output_schema.input_types() if hasattr(output_schema, 'input_types') else [] + + if output_names: + logger.info(f"Extracted model output schema with {len(output_names)} outputs") + + # Generate example predictions + example_predictions = [] + for i, col_name in enumerate(output_names): + if i < len(output_types): + col_type = str(output_types[i]) + example_value = _generate_example_value(col_type, is_output=True) + else: + example_value = 0.95 # Default prediction value + + example_predictions.append(example_value) + + example_output = { + 'predictions': example_predictions + } + + # Append output schema to documentation + has_integer_output = all(isinstance(name, (int, str)) and str(name).isdigit() for name in output_names) + output_schema_info = _format_schema_as_json( + "Output", + output_names, + output_types, + has_integer_output + ) + if model_signature_info: + model_signature_info += f"\n\n{output_schema_info}" + else: + model_signature_info = output_schema_info + + except Exception as e: + logger.warning(f"Could not extract output schema: {e}") + else: + logger.info("Model has no output schema defined") + + return model_signature_info, example_input, example_output + + +def _generate_example_value(col_type, is_output=False): + """ + Generate an appropriate example value based on the column's data type. + + Args: + col_type (str): The data type (e.g., "int64", "float", "string") + is_output (bool): Whether this is for output (affects default values) + + Returns: + Appropriate example value for the data type + """ + col_type_lower = col_type.lower() + + # Integer types + if 'int' in col_type_lower or 'long' in col_type_lower: + return 0 if is_output else 1 + + # Floating point types + elif 'float' in col_type_lower or 'double' in col_type_lower: + return 0.95 if is_output else 1.0 + + # Boolean types + elif 'bool' in col_type_lower: + return True + + # String types + elif 'string' in col_type_lower or 'str' in col_type_lower: + return "prediction" if is_output else "example" + + # Default to float + else: + return 0.95 if is_output else 1.0 + + +def _format_schema_as_json(schema_type, names, types, has_integer_columns=False): + """ + Format a schema as a JSON code block for documentation. + + Args: + schema_type (str): "Input" or "Output" + names (list): Column/field names + types (list): Corresponding data types + has_integer_columns (bool): Whether columns are integer indices + + Returns: + str: Formatted JSON schema string + """ + schema_lines = [f"**Model {schema_type} Schema:**"] + + if has_integer_columns: + schema_lines.append("") + schema_lines.append("*Note: This model uses integer-based column names (0, 1, 2, ...), indicating it may not have named features. When sending data, use these integer indices as column names.*") + schema_lines.append("") + + schema_lines.extend(["```json", "{"]) + + for i, name in enumerate(names): + # Clean up type names (remove "DataType." prefix) + col_type = str(types[i]).replace('DataType.', '') if i < len(types) else 'any' + # Add comma except for last item + comma = "," if i < len(names) - 1 else "" + schema_lines.append(f' "{name}": "{col_type}"{comma}') + + schema_lines.append("}") + schema_lines.append("```") + + return "\n".join(schema_lines) + + +def configure_openapi_metadata(app, model_uri, tracking_uri, model_signature_info=None): + """ + Configure OpenAPI metadata for the FastAPI application deployed on Posit Connect. + + This adds essential information to the API documentation including: + - Title and version + - Model URI and tracking server information + - Input/output schema descriptions + - Connect-specific authentication details + + Args: + app: FastAPI application instance + model_uri (str): The MLflow model URI + tracking_uri (str): The MLflow tracking server URI + model_signature_info (str, optional): Formatted model schema information + """ + # Set basic API metadata + app.title = "MLflow Model Serving API" + app.version = "1.0.0" + + # Build the correct API base URL for Connect + connect_server = os.getenv('CONNECT_SERVER', '') + content_guid = os.getenv('CONNECT_CONTENT_GUID', '') + + if connect_server and content_guid: + # When deployed to Connect, use the full content URL + api_base_url = f"{connect_server}content/{content_guid}" + else: + # For local development + api_base_url = "http://localhost:8000" + + app.contact = { + "name": "MLflow Model API", + "url": tracking_uri, + } + + # Include model schema in description if available + schema_section = "" + if model_signature_info: + schema_section = f"\n\n{model_signature_info}\n" + + # Build comprehensive API description with Connect-specific information + app.description = f""" +## MLflow Model Serving API on Posit Connect + +Serve machine learning models deployed through MLflow's PyFunc flavor with a RESTful API on Posit Connect. + +### Features + +* **Health Monitoring**: Check service health and availability +* **Version Information**: Get MLflow version details +* **Real-time Predictions**: Make predictions using your deployed model +* **Multiple Input Formats**: Support for JSON, CSV, and TensorFlow formats +* **Interactive Documentation**: This Swagger UI interface for testing + +### Model Information + +* **Model URI**: `{model_uri}` +* **MLflow Tracking Server**: `{tracking_uri}` +* **API Base URL**: `{api_base_url}` +{schema_section} +### Supported Input Formats + +The `/invocations` endpoint accepts several formats: + +1. **JSON (dataframe_split)**: Pandas DataFrame in split-orient format +2. **JSON (dataframe_records)**: Pandas DataFrame in records-orient format +3. **JSON (instances)**: TensorFlow Serving compatible format (array of arrays) +4. **JSON (inputs)**: Alternative TensorFlow format +5. **CSV**: Standard comma-separated values + +**Note**: If your model uses integer-based column names (0, 1, 2, ...), you must use these integers as strings in column names or object keys. + +### Authentication on Posit Connect + +**This API is deployed on Posit Connect:** +- Authentication is handled by Connect's content access controls +- The API uses the publisher's Connect API key to access MLflow automatically +- No additional authentication headers required when accessing through Connect +- Access is controlled via Connect's content settings + +### Response Format + +Predictions return JSON with a `predictions` array containing model outputs. + """ + + # Define endpoint categories for better organization + app.openapi_tags = [ + { + "name": "health", + "description": "Health check and monitoring endpoints" + }, + { + "name": "model", + "description": "Model prediction and inference endpoints" + }, + { + "name": "info", + "description": "Service information and metadata endpoints" + } + ] + + +def generate_openapi_schema(app, model_uri, tracking_uri, model_signature_info=None, + example_input=None, example_output=None): + """ + Generate a complete OpenAPI 3.1.0 schema for the MLflow Model API. + + This creates a comprehensive API specification that includes: + - All standard MLflow endpoints + - Model-specific schemas and examples + - Input format documentation + - Response format specifications + - Proper server URLs for Connect deployment + + The schema is used by FastAPI to power /docs and /redoc endpoints. + + Args: + app: FastAPI application instance + model_uri (str): The MLflow model URI + tracking_uri (str): The MLflow tracking server URI + model_signature_info (str, optional): Model schema description + example_input (dict, optional): Sample input data + example_output (dict, optional): Sample output data + + Returns: + function: A function that returns the OpenAPI schema dictionary + """ + + def custom_openapi(): + """Generate the OpenAPI schema on demand.""" + # Return cached schema if available + if app.openapi_schema: + return app.openapi_schema + + # Build the correct server URL for Connect + connect_server = os.getenv('CONNECT_SERVER', '') + content_guid = os.getenv('CONNECT_CONTENT_GUID', '') + + servers = [] + if connect_server and content_guid: + # When deployed to Connect, use the full content URL + servers.append({ + "url": f"{connect_server}content/{content_guid}", + "description": "Posit Connect Deployment" + }) + else: + # For local development + servers.append({ + "url": "http://localhost:8000", + "description": "Local Development Server" + }) + + # Build complete OpenAPI 3.1.0 specification + openapi_schema = { + "openapi": "3.1.0", + "info": { + "title": app.title, + "version": app.version, + "description": app.description, + "contact": app.contact + }, + "servers": servers, + "tags": app.openapi_tags, + "paths": _generate_paths(model_signature_info, example_input, example_output) + } + + # Cache the schema + app.openapi_schema = openapi_schema + return app.openapi_schema + + return custom_openapi + + +def _generate_paths(model_signature_info, example_input, example_output): + """ + Generate the paths section of the OpenAPI schema. + + Creates specifications for all API endpoints with proper request/response schemas. + """ + return { + "/ping": _generate_ping_path(), + "/health": _generate_health_path(), + "/version": _generate_version_path(), + "/invocations": _generate_invocations_path(model_signature_info, example_input, example_output) + } + + +def _generate_ping_path(): + """Generate OpenAPI specification for the /ping health check endpoint.""" + return { + "get": { + "tags": ["health"], + "summary": "Health Check (Ping)", + "description": "Quick health check that returns immediately if the service is running. Returns a newline character if healthy, 404 if the model failed to load.", + "operationId": "ping", + "responses": { + "200": { + "description": "Service is healthy", + "content": { + "application/json": { + "schema": {"type": "string"}, + "example": "\n" + } + } + }, + "404": { + "description": "Service unhealthy - model could not be loaded" + } + } + } + } + + +def _generate_health_path(): + """Generate OpenAPI specification for the /health endpoint.""" + return { + "get": { + "tags": ["health"], + "summary": "Detailed Health Check", + "description": "Comprehensive health check for the model service. Returns HTTP 200 if healthy, 404 if unhealthy.", + "operationId": "health", + "responses": { + "200": { + "description": "Service is healthy and ready to serve predictions", + "content": { + "application/json": { + "schema": {"type": "string"}, + "example": "\n" + } + } + }, + "404": { + "description": "Service unhealthy - model is not available" + } + } + } + } + + +def _generate_version_path(): + """Generate OpenAPI specification for the /version endpoint.""" + return { + "get": { + "tags": ["info"], + "summary": "Get MLflow Version", + "description": "Returns the version of MLflow running this service. Useful for debugging and compatibility checks.", + "operationId": "version", + "responses": { + "200": { + "description": "MLflow version string", + "content": { + "application/json": { + "schema": {"type": "string"}, + "example": "2.9.2" + } + } + } + } + } + } + + +def _generate_invocations_path(model_signature_info, example_input, example_output): + """ + Generate OpenAPI specification for the /invocations prediction endpoint. + + Includes model-specific examples and schemas when available. + """ + # Check if model has integer columns + has_integer_columns = example_input and example_input.get('has_integer_columns', False) + num_features = example_input.get('num_features', 2) if example_input else 2 + + # Build comprehensive description with examples + integer_column_note = "" + if has_integer_columns: + integer_column_note = """ + +**Important**: This model uses integer-based column indices (0, 1, 2, ...) instead of named features. This typically occurs when: +- The model was trained without explicit feature names +- Data was provided as numpy arrays rather than DataFrames with column names + +When sending data, use integer strings as column names in `dataframe_split` and `dataframe_records` formats, or send a simple array for `instances` format. +""" + + description = f"""Make predictions using the deployed MLflow model. + +{model_signature_info if model_signature_info else ''} +{integer_column_note} + +**Input Format Examples:** + +1. **JSON with dataframe_split:** +```json +{{ + "dataframe_split": {{ + "columns": ["feature1", "feature2"], + "data": [[1.0, 2.0], [3.0, 4.0]] + }} +}} +``` + +2. **JSON with dataframe_records:** +```json +{{ + "dataframe_records": [ + {{"feature1": 1.0, "feature2": 2.0}}, + {{"feature1": 3.0, "feature2": 4.0}} + ] +}} +``` + +3. **JSON with instances (for models with integer columns):** +```json +{{ + "instances": [ + [1.0, 2.0], + [3.0, 4.0] + ] +}} +``` + +4. **CSV format:** +``` +feature1,feature2 +1.0,2.0 +3.0,4.0 +``` + +**Response Format:** +```json +{{ + "predictions": [result1, result2, ...] +}} +```""" + + # Prepare default examples + if has_integer_columns: + # For integer columns, provide simpler examples + default_split = { + "columns": [str(i) for i in range(num_features)], + "data": [[1.0] * num_features, [2.0] * num_features] + } + default_records = [ + {str(i): 1.0 for i in range(num_features)}, + {str(i): 2.0 for i in range(num_features)} + ] + default_instances = [[1.0] * num_features, [2.0] * num_features] + default_csv = ','.join(str(i) for i in range(num_features)) + '\n' + ','.join(['1.0'] * num_features) + else: + # Standard named column examples + default_split = { + "columns": ["feature1", "feature2"], + "data": [[1.0, 2.0], [3.0, 4.0]] + } + default_records = [ + {"feature1": 1.0, "feature2": 2.0}, + {"feature1": 3.0, "feature2": 4.0} + ] + default_instances = [[1.0, 2.0], [3.0, 4.0]] + default_csv = "feature1,feature2\n1.0,2.0\n3.0,4.0" + + default_output = {"predictions": [0.95, 0.87]} + + # Use model-based examples when available + split_example = example_input['dataframe_split'] if example_input else default_split + records_example = example_input['dataframe_records'] if example_input else default_records + instances_example = [example_input['dataframe_split']['data'][0]] if example_input else default_instances + + # Generate CSV example from model schema + if example_input: + csv_example = ( + ','.join(str(name) for name in example_input['column_names']) + '\n' + + ','.join(str(v) for v in example_input['dataframe_split']['data'][0]) + ) + else: + csv_example = default_csv + + output_example = example_output if example_output else default_output + + # Build examples section + json_examples = { + "dataframe_split": { + "summary": "DataFrame Split Format" + (" (Model Schema)" if example_input else ""), + "value": { + "dataframe_split": split_example + } + }, + "dataframe_records": { + "summary": "DataFrame Records Format" + (" (Model Schema)" if example_input else ""), + "value": { + "dataframe_records": records_example + } + } + } + + # Add instances format if model has integer columns + if has_integer_columns: + json_examples["instances"] = { + "summary": "Instances Format (Recommended for models with integer columns)", + "value": { + "instances": instances_example + } + } + + # Build complete endpoint specification + return { + "post": { + "tags": ["model"], + "summary": "Make Predictions", + "description": description, + "operationId": "invocations", + "requestBody": { + "content": { + "application/json": { + "schema": { + "type": "object" + }, + "examples": json_examples + }, + "text/csv": { + "schema": { + "type": "string" + }, + "example": csv_example + } + }, + "required": True + }, + "responses": { + "200": { + "description": "Successful prediction" + (" (based on model schema)" if example_output else ""), + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "predictions": { + "type": "array", + "description": "Array of prediction results" + } + } + }, + "example": output_example + } + } + }, + "400": { + "description": "Bad request - invalid input format or missing required fields" + }, + "415": { + "description": "Unsupported media type - use application/json or text/csv" + }, + "500": { + "description": "Server error during prediction" + } + } + } + } diff --git a/extensions/mlflow-python-model-server-example/requirements.txt b/extensions/mlflow-python-model-server-example/requirements.txt new file mode 100644 index 00000000..1985831d --- /dev/null +++ b/extensions/mlflow-python-model-server-example/requirements.txt @@ -0,0 +1,3 @@ +fastapi +uvicorn +mlflow diff --git a/extensions/mlflow-tracking-server/.gitignore b/extensions/mlflow-tracking-server/.gitignore new file mode 100644 index 00000000..431d3059 --- /dev/null +++ b/extensions/mlflow-tracking-server/.gitignore @@ -0,0 +1,4 @@ +rsconnect-python +.venv +.env +.python-version diff --git a/extensions/mlflow-tracking-server/README.md b/extensions/mlflow-tracking-server/README.md new file mode 100644 index 00000000..bb37a2b8 --- /dev/null +++ b/extensions/mlflow-tracking-server/README.md @@ -0,0 +1,434 @@ +# MLflow Tracking Server Extension for Posit Connect + +This extension enables MLflow tracking server capabilities within Posit Connect. It provides a fully-featured MLflow tracking server that can be deployed as a Connect extension, allowing data scientists to track experiments, log models, and manage the ML lifecycle directly within their Connect environment. + +## ⚠️ Alpha Status + +**This extension is currently alpha.** While fully functional, there are important considerations: + +- **Manual Management**: MLflow version upgrades must be performed manually by redeploying the extension to the existing app GUID +- **No Automatic Updates**: Unlike managed services, you are responsible for keeping MLflow up-to-date with security patches and new features +- **Schema Migrations**: Database schema upgrades require careful coordination during redeployment + +Despite these limitations, this approach offers significant benefits for teams already invested in the Posit Connect ecosystem. + +## Why Deploy MLflow to Posit Connect? + +### Key Benefits + +1. **Unified Authentication & Access Control** + - Leverage your existing Connect API keys - no additional credential management + - Data scientists use the same authentication mechanism they already know + - Centralized user management through Connect's existing access controls + - No separate MLflow authentication system to configure or maintain + +2. **Simplified Infrastructure Management** + - Deploy MLflow where your data science workflows already live + - No separate Kubernetes cluster or VM infrastructure to provision + - Benefit from Connect's built-in high availability and monitoring capabilities + - Single platform for all data science deliverables: models, dashboards, APIs, and now MLflow tracking + +3. **Zero-Configuration Local Storage Option** + - Use Connect's persistent app storage (2025.10.0+) - no cloud resources needed for development/testing + - No AWS/Azure/GCP accounts required to get started + - Eliminate external infrastructure costs for smaller teams or proof-of-concepts + - Data stays within your organization's Connect environment + +4. **Seamless Integration with Existing Workflows** + - MLflow server runs alongside your deployed models and applications + - Direct integration with Connect-hosted Python environments + - Natural fit for teams already standardized on Posit Connect + - Reduced context switching between tools and platforms + +## Deployment Options + +This extension supports two deployment modes: + +1. **Local Storage** - Artifacts and metadata stored on Connect's local filesystem (Connect 2025.10.0+) +2. **External Storage** - Artifacts and metadata stored in external services (supports any MLflow-compatible backend storage and artifact store) + +Both deployment modes now support **OAuth integrations** for seamless authentication to cloud resources without managing long-lived credentials. + +## Prerequisites + +- Posit Connect 2025.10.0 or later +- For external storage: Access to compatible external services (e.g., AWS S3/RDS, Azure Blob Storage/SQL, GCP Cloud Storage/SQL, etc.) +- For OAuth integrations: Appropriate OAuth integration configured in Connect (AWS, Azure, or GCP) + +## Deployment Scenarios + +### Option 1: Local Storage (Recommended for Development/Testing) + +This configuration stores all MLflow data on Connect's local filesystem. No external dependencies required. + +#### Setup Steps + +1. Create a new Python extension in Connect +2. Deploy the MLflow server extension with default settings +3. Set the minimum number of processes to 1 to ensure the server is always running +4. No additional configuration required - the extension will use Connect's local storage (2025.10.0+) +5. [Optional] Configure environment variables for backend storage customization (see below) +6. [Optional] Set an easy-to-remember vanity URL in Connect for easy access + +#### Environment Variables + +```bash +# No additional environment variables needed for local storage +# The extension will automatically configure local storage paths +``` + +### Option 2: External Storage with OAuth (Recommended for Production) + +This configuration uses external services for artifact storage and metadata database with **automatic OAuth-based authentication** - no need to manage access keys or connection strings! + +The extension supports **AWS RDS IAM authentication** and **Azure AD authentication** through Posit Connect's OAuth integrations. When configured: +- Database authentication tokens are automatically generated and refreshed +- AWS credentials for S3 access are obtained via OAuth (no access keys needed) +- Azure credentials for Blob Storage are obtained via OAuth (no connection strings needed) +- All credentials are automatically rotated before expiration + +#### AWS Setup with OAuth Integration + +##### Prerequisites + +1. **Enable IAM Authentication on RDS Database** + - Your RDS PostgreSQL instance **must have IAM database authentication enabled** + - This is required for OAuth-based authentication to work + - See [AWS RDS IAM Authentication Overview](https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/UsingWithRDS.IAMDBAuth.html) + +2. **Configure Database User for IAM Authentication** + - The PostgreSQL user must be granted the `rds_iam` role + - Connect to your RDS instance using a master/admin account and run: + ```sql + CREATE USER mlflow_user WITH LOGIN; + GRANT rds_iam TO mlflow_user; + ``` + - See [Creating a Database Account Using IAM Authentication](https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/UsingWithRDS.IAMDBAuth.DBAccounts.html#UsingWithRDS.IAMDBAuth.DBAccounts.PostgreSQL) + +3. **Configure IAM Policy** + - The OAuth integration's IAM role needs permission to connect to the database + - Required IAM policy: + ```json + { + "Version": "2012-10-17", + "Statement": [ + { + "Effect": "Allow", + "Action": [ + "rds-db:connect" + ], + "Resource": [ + "arn:aws:rds-db:REGION:ACCOUNT_ID:dbuser:DB_RESOURCE_ID/mlflow_user" + ] + } + ] + } + ``` + - See [Creating and Using an IAM Policy for IAM Database Access](https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/UsingWithRDS.IAMDBAuth.IAMPolicy.html) + +##### Setup Steps + +1. **Create PostgreSQL Database in AWS RDS** + - Navigate to AWS RDS Console + - Click "Create database" + - Select PostgreSQL engine + - **Enable IAM database authentication** (required for OAuth) + - Configure database: + - Database name: `mlflow` + - Master username: `postgres` (for initial setup) + - Master password: (save securely for initial setup) + - Configure security group to allow inbound traffic on port 5432 from Connect's IP + - Note the endpoint URL and DB resource ID + +2. **Grant rds_iam Role to Database User** + - Connect to the database using the master credentials: + ```bash + psql -h your-db.xxxxx.region.rds.amazonaws.com -U postgres -d mlflow + ``` + - Grant the `rds_iam` role: + ```sql + CREATE USER mlflow_user WITH LOGIN; + GRANT rds_iam TO mlflow_user; + ``` + +3. **Create S3 Bucket for Artifacts** + - Navigate to AWS S3 Console + - Click "Create bucket" + - Bucket name: `mlflow-artifacts-yourorg` (must be globally unique) + - Choose appropriate region (same as RDS for performance) + - Block public access: Enabled (recommended) + - Versioning: Optional (recommended for production) + +4. **Configure AWS OAuth Integration in Connect** + - In Connect, navigate to your content item + - Configure AWS OAuth integration with appropriate IAM role + - Ensure the IAM role has: + - RDS connect permission (`rds-db:connect`) + - S3 access permissions (PutObject, GetObject, DeleteObject, ListBucket) + +5. **Deploy Extension with Environment Variables** + +Configure the following environment variables in your Connect extension settings: + +```bash +# AWS RDS Configuration (OAuth-based authentication) +RDS_ENDPOINT=your-database.xxxxx.region.rds.amazonaws.com +RDS_PORT=5432 +RDS_DATABASE=mlflow +RDS_USERNAME=mlflow_user +RDS_DB_TYPE=postgresql +AWS_REGION=us-east-2 + +# S3 Artifact Storage (OAuth will provide credentials automatically) +MLFLOW_ARTIFACTS_DESTINATION=s3://mlflow-artifacts-yourorg/artifacts +``` + +**No AWS access keys or database passwords needed!** The extension automatically: +- Generates RDS IAM authentication tokens via OAuth +- Obtains temporary AWS credentials for S3 access +- Refreshes all tokens before expiration (every hour by default) + +#### Azure Setup with OAuth Integration + +##### Prerequisites + +1. **Enable Azure AD Authentication on Azure Database** + - Your Azure PostgreSQL or SQL Server must have Azure AD authentication enabled + - This is required for OAuth-based authentication + +2. **Grant Database Access to Azure AD Service Principal** + - The service principal from your OAuth integration needs database access + - For PostgreSQL: + ```sql + CREATE USER "service-principal-name" WITH LOGIN; + GRANT ALL PRIVILEGES ON DATABASE mlflow TO "service-principal-name"; + ``` + - For SQL Server: + ```sql + CREATE USER [service-principal-name] FROM EXTERNAL PROVIDER; + ALTER ROLE db_owner ADD MEMBER [service-principal-name]; + ``` + - See [Azure AD Authentication for PostgreSQL](https://docs.microsoft.com/en-us/azure/postgresql/connect-azure-active-directory) and [Azure AD Authentication for SQL Server](https://docs.microsoft.com/en-us/sql/relational-databases/security/authentication-access/azure-active-directory-authentication?view=sql-server-ver15) for details + +##### Setup Steps + +1. **Create Azure Database (PostgreSQL or SQL Server)** + - Use the Azure portal to create a new PostgreSQL or SQL Server instance + - Configure firewall rules to allow Connect's IP + - Note the server name, database name, and admin login + +2. **Configure Azure AD Authentication** + - Assign the Azure AD admin for the server in the Azure portal + - Create a new Azure AD user or use an existing one + - For PostgreSQL, run the following in the query editor: + ```sql + CREATE USER "service-principal-name" WITH LOGIN; + GRANT ALL PRIVILEGES ON DATABASE mlflow TO "service-principal-name"; + ``` + - For SQL Server, run the following in the query editor: + ```sql + CREATE USER [service-principal-name] FROM EXTERNAL PROVIDER; + ALTER ROLE db_owner ADD MEMBER [service-principal-name]; + ``` + +3. **Create Azure Blob Storage Account for Artifacts** + - Use the Azure portal to create a new Storage account + - Choose Blob storage and appropriate performance/tier options + - Note the storage account name and container + +4. **Configure Azure OAuth Integration in Connect** + - In Connect, navigate to your content item + - Configure Azure OAuth integration with appropriate permissions + - Ensure the service principal has: + - Contributor role on the Storage account + - Database access in Azure AD + +5. **Deploy Extension with Environment Variables** + +Configure the following environment variables in your Connect extension settings: + +```bash +# Azure Database Configuration (OAuth-based authentication) +AZURE_DATABASE_SERVER=your-server-name.database.windows.net +AZURE_DATABASE_NAME=mlflow +AZURE_DATABASE_USERNAME=service-principal-name +AZURE_DATABASE_TYPE=postgresql +AZURE_REGION=eastus + +# Azure Blob Storage Configuration (OAuth will provide credentials automatically) +MLFLOW_ARTIFACTS_DESTINATION=wasbs://mlflow-artifacts@your-storage-account.blob.core.windows.net/artifacts +``` + +**No Azure connection strings or database passwords needed!** The extension automatically: +- Authenticates to Azure SQL/PostgreSQL using Azure AD tokens +- Obtains temporary credentials for Blob Storage access +- Refreshes all tokens before expiration (every hour by default) + +### GCP Setup with OAuth Integration + +##### Prerequisites + +1. **Enable Cloud SQL and Cloud Storage APIs** + - Enable the Cloud SQL API and Cloud Storage API in your GCP project + - This is required for the extension to access Cloud SQL and Cloud Storage + +2. **Create a Cloud SQL Instance** + - Create a new Cloud SQL instance (PostgreSQL or SQL Server) + - Note the instance connection name and database name + +3. **Create a Cloud Storage Bucket** + - Create a new Cloud Storage bucket for storing artifacts + - Note the bucket name + +4. **Configure IAM Permissions** + - The service account used by Posit Connect needs permissions to access Cloud SQL and Cloud Storage + - Assign the following roles: + - Cloud SQL Client + - Storage Object Admin + +5. **Create a Service Account Key (Optional)** + - If not using the default service account, create a service account key + - Download the JSON key file + +##### Setup Steps + +1. **Create PostgreSQL Database in Cloud SQL** + - Connect to your Cloud SQL instance using the Cloud SQL Auth proxy or public IP + - Create a new database for MLflow: + ```sql + CREATE DATABASE mlflow; + ``` + +2. **Create a Cloud Storage Bucket for Artifacts** + - Use the GCP Console or `gsutil` to create a new bucket: + ```bash + gsutil mb gs://mlflow-artifacts-yourorg/ + ``` + +3. **Deploy Extension with Environment Variables** + +Configure the following environment variables in your Connect extension settings: + +```bash +# GCP Cloud SQL Configuration +CLOUD_SQL_CONNECTION_NAME=your-project:us-central1:your-sql-instance +CLOUD_SQL_DATABASE=mlflow +CLOUD_SQL_USERNAME=mlflow_user +CLOUD_SQL_PASSWORD=your_password + +# GCP Cloud Storage Configuration (OAuth will provide credentials automatically) +MLFLOW_ARTIFACTS_DESTINATION=gs://mlflow-artifacts-yourorg/artifacts +``` + +**No GCP service account keys or database passwords needed!** The extension automatically: +- Authenticates to Cloud SQL using Cloud SQL Auth proxy or public IP +- Obtains temporary credentials for Cloud Storage access +- Refreshes all tokens before expiration (every hour by default) + +## Connecting to the MLflow Server + +Once deployed, the MLflow server is accessible through Connect's URL structure. Authentication is handled via Connect API keys. + +```python +import os +import mlflow + +# Configure connection to Connect-hosted MLflow server +CONNECT_SERVER = os.getenv('CONNECT_SERVER', 'https://connect.example.com/') +MLFLOW_TRACKING_URI = os.getenv('MLFLOW_TRACKING_URI', f'{CONNECT_SERVER}mlflow/') + +# Set up authentication using Connect API key +os.environ["MLFLOW_TRACKING_TOKEN"] = os.getenv('MLFLOW_TRACKING_TOKEN', + os.getenv('CONNECT_API_KEY', "")) + +# Configure MLflow client +mlflow.set_tracking_uri(MLFLOW_TRACKING_URI) + +# Now you can use MLflow as normal +with mlflow.start_run(): + mlflow.log_param("alpha", 0.5) + mlflow.log_metric("rmse", 0.876) + mlflow.log_artifact("model.pkl") +``` + +### Key Points + +- **Authentication**: Since the MLflow UI and API are deployed to Connect, a Connect API key is the only credential needed to access the endpoints +- **Transparent Storage**: Client code is identical for both local and external storage deployments +- **Server-Side Credentials**: AWS credentials for S3/RDS are configured server-side in the extension and not exposed to clients +- **Security**: All traffic goes through Connect's authentication and authorization layer + +## Best Practices + +### Local Storage +- ✅ Ideal for development and testing +- ✅ Simple setup with no external dependencies +- ✅ **Zero additional infrastructure costs** +- ✅ **No cloud account or credentials required** +- ✅ **Data stays within Connect's managed environment** +- ⚠️ Limited by Connect server's disk capacity +- ⚠️ Not recommended for production with multiple users +- ⚠️ Backup and recovery depend on Connect's backup strategy + +### External Storage (S3 + RDS) +- ✅ Scalable artifact storage +- ✅ Robust metadata storage with PostgreSQL +- ✅ Better suited for production environments +- ✅ Supports team collaboration +- ✅ Independent backup and disaster recovery +- ⚠️ Requires AWS infrastructure setup and costs +- ⚠️ Requires explicit credentials (OAuth not yet supported) +- ⚠️ Additional external services to manage and monitor + +## Management and Upgrades + +### Upgrading MLflow + +To upgrade MLflow to a newer version: + +1. Update the version in `requirements.txt` +2. Test the upgrade in a development environment first +3. Review MLflow release notes for breaking changes and schema migrations +4. Redeploy the extension to the **same app GUID** in Connect +5. Monitor logs during startup to ensure database migrations complete successfully + +**Important**: Always redeploy to the existing app GUID to preserve: +- The app's persistent storage directory (for local storage deployments) +- Configured environment variables +- Access permissions and API keys +- The MLflow tracking URI that clients are using + +### Backup and Recovery + +**For Local Storage:** +- Backups are handled by Connect's persistent storage backup mechanisms +- Consult your Connect administrator about backup schedules and retention +- The `app-data` directory contains all experiments, runs, and artifacts + +**For External Storage:** +- Back up your RDS database using AWS automated backups or snapshots +- Enable S3 versioning for artifact recovery +- Document your environment variable configuration separately + +## Troubleshooting + +### Connection Issues +- Verify Connect API key has appropriate permissions +- Check that the MLflow extension is running in Connect +- Ensure `MLFLOW_TRACKING_URI` points to the correct Connect URL + +### External Storage Issues +- Verify AWS credentials have correct S3 bucket permissions +- Check RDS security groups allow connections from Connect +- Ensure database exists and credentials are correct +- Verify S3 bucket exists and is in the correct region + +### Upgrade Issues +- Check Connect logs for database migration errors +- Ensure the app GUID matches the previous deployment +- Verify persistent storage directory is accessible and has sufficient space +- Review MLflow's migration documentation for version-specific requirements + +## Support + +For issues or questions, please contact your Posit Connect administrator or open an issue in the repository. diff --git a/extensions/mlflow-tracking-server/database_utils.py b/extensions/mlflow-tracking-server/database_utils.py new file mode 100644 index 00000000..719df8f5 --- /dev/null +++ b/extensions/mlflow-tracking-server/database_utils.py @@ -0,0 +1,378 @@ +import os +import logging +import re + +logger = logging.getLogger(__name__) + + +# ==================================================================== +# UTILITY FUNCTIONS +# ==================================================================== + +def mask_credentials(uri): + """Mask sensitive credentials in URIs for logging.""" + if not uri: + return uri + # Mask database passwords in URIs like postgresql://user:password@host:port/db + masked = re.sub(r'(://[^:]+:)[^@]+(@)', r'\1****\2', uri) + return masked + + +def detect_database_backend(db_uri: str) -> str: + """Detect the database backend type from the connection URI. + + Parameters + ---------- + db_uri : str + The database URI + + Returns + ------- + str + One of: 'local', 'aws_rds', 'azure_sql', 'standard' + """ + + # Only detect AWS RDS if the RDS_ENDPOINT environment variable is set + # This prevents overwriting standard connection strings that happen to use RDS + if os.getenv('RDS_ENDPOINT'): + return 'aws_rds' + + # Only detect Azure SQL if the AZURE_SQL_SERVER environment variable is set + # This prevents overwriting standard connection strings that happen to use Azure SQL + if os.getenv('AZURE_SQL_SERVER'): + return 'azure_sql' + + if not db_uri or db_uri.startswith('sqlite://'): + return 'local' + + # Standard connection string with username:password + return 'standard' + + +# ==================================================================== +# OAUTH INTEGRATION CHECKS +# ==================================================================== + +def check_aws_rds_oauth_integration(): + """Check if AWS OAuth integration is available for this content. + + Returns + ------- + bool + True if AWS OAuth integration is available, False otherwise + """ + try: + from posit.connect import Client + from posit.connect.external.aws import get_content_credentials + + client = Client() + credentials = get_content_credentials(client) + return credentials is not None + + except ImportError: + return False + except Exception: + return False + + +def check_azure_sql_oauth_integration(): + """Check if Azure OAuth integration is available for this content. + + Returns + ------- + bool + True if Azure OAuth integration is available, False otherwise + """ + try: + from posit.connect import Client + + client = Client() + content = client.content.get() + + # Find Azure service account integration + association = content.oauth.associations.find_by( + integration_type="azure", + auth_type="Service Account" + ) + + return association is not None + + except ImportError: + return False + except Exception: + return False + + +# ==================================================================== +# DATABASE CONNECTION SETUP +# ==================================================================== + +def setup_aws_rds_connection(): + """Setup AWS RDS connection using IAM authentication. + + Returns + ------- + str | None + Connection string without token, or None if setup failed + """ + try: + from posit.connect import Client + from posit.connect.external.aws import get_content_credentials + + print("Setting up AWS RDS connection with IAM authentication") + + # Get required environment variables + rds_endpoint = os.getenv("RDS_ENDPOINT") + rds_port = int(os.getenv("RDS_PORT", "5432")) + rds_database = os.getenv("RDS_DATABASE", "mlflow") + rds_username = os.getenv("RDS_USERNAME", "postgres") + db_type = os.getenv("RDS_DB_TYPE", "postgresql") + + if not all([rds_endpoint, rds_database, rds_username]): + print("ERROR: Missing required RDS environment variables") + return None + + # Verify we can get credentials + client = Client() + aws_credentials = get_content_credentials(client) + + if not aws_credentials: + print("ERROR: Failed to get AWS credentials from Connect") + return None + + print("AWS RDS credentials verified") + + # Build connection string without password but with SSL mode required + # Token will be injected by SQLAlchemy event listener + # For RDS IAM auth, we MUST use SSL + connection_string = f"{db_type}://{rds_username}@{rds_endpoint}:{rds_port}/{rds_database}?sslmode=require" + + return connection_string + + except ImportError: + print("ERROR: posit-sdk or boto3 not available, cannot setup AWS RDS connection") + return None + except Exception as e: + print(f"ERROR: Failed to setup AWS RDS connection: {e}") + return None + + +def setup_azure_sql_connection(): + """Setup Azure SQL connection using Azure AD authentication. + + Returns + ------- + str | None + Connection string without token, or None if setup failed + """ + try: + from posit.connect import Client + + print("Setting up Azure database connection with Azure AD authentication") + + # Get required environment variables + sql_server = os.getenv("AZURE_SQL_SERVER") + sql_database = os.getenv("AZURE_SQL_DATABASE", "mlflow") + + if not all([sql_server, sql_database]): + print("ERROR: Missing required Azure SQL environment variables") + return None + + # Detect database type + is_postgres = 'postgres.database.azure.com' in sql_server.lower() + db_type = 'PostgreSQL' if is_postgres else 'SQL Server' + print(f"Detected Azure {db_type}") + + # Verify Azure integration exists + client = Client() + content = client.content.get() + + association = content.oauth.associations.find_by( + integration_type="azure", + auth_type="Service Account" + ) + + if not association: + print("WARNING: No Azure Service Account integration found") + return None + + print(f"Azure {db_type} integration verified") + + # Build connection string without password + # Token will be injected by SQLAlchemy event listener + if is_postgres: + username = os.getenv("AZURE_SQL_USERNAME", "postgres") + connection_string = f"postgresql://{username}@{sql_server}:5432/{sql_database}?sslmode=require" + else: + # Azure SQL Server connection string + connection_string = f"mssql+pyodbc://@{sql_server}/{sql_database}?driver=ODBC+Driver+17+for+SQL+Server" + + return connection_string + + except ImportError: + print("ERROR: posit-sdk not available, cannot setup Azure SQL connection") + return None + except Exception as e: + print(f"ERROR: Failed to setup Azure SQL connection: {e}") + return None + + +# ==================================================================== +# TOKEN REFRESH FUNCTIONS +# ==================================================================== + +def get_fresh_aws_rds_token(): + """Fetch a fresh AWS RDS IAM authentication token. + + Returns + ------- + str | None + Fresh RDS auth token, or None if fetch failed + """ + try: + from posit.connect import Client + from posit.connect.external.aws import get_content_credentials + import boto3 + + # Get required environment variables + rds_endpoint = os.getenv("RDS_ENDPOINT") + rds_port = int(os.getenv("RDS_PORT", "5432")) + rds_username = os.getenv("RDS_USERNAME", "postgres") + aws_region = os.getenv("AWS_REGION", "us-east-2") + + if not all([rds_endpoint, rds_username]): + print("ERROR: Missing required RDS environment variables for token refresh") + return None + + # Get AWS credentials from Connect + client = Client() + aws_credentials = get_content_credentials(client) + + # Create RDS client with credentials + rds_client = boto3.client( + 'rds', + region_name=aws_region, + aws_access_key_id=aws_credentials["aws_access_key_id"], + aws_secret_access_key=aws_credentials["aws_secret_access_key"], + aws_session_token=aws_credentials["aws_session_token"], + ) + + # Generate authentication token + token = rds_client.generate_db_auth_token( + DBHostname=rds_endpoint, + Port=rds_port, + DBUsername=rds_username, + Region=aws_region, + ) + + return token + + except Exception as e: + print(f"ERROR: Failed to get fresh AWS RDS token: {e}") + return None + + +def get_fresh_azure_token(): + """Fetch a fresh Azure access token from Posit Connect. + + Returns + ------- + str | None + Fresh access token, or None if fetch failed + """ + try: + from posit.connect import Client + + client = Client() + content = client.content.get() + + # Find Azure service account integration + association = content.oauth.associations.find_by( + integration_type="azure", + auth_type="Service Account" + ) + + if not association: + print("ERROR: No Azure Service Account integration found") + return None + + # Get Azure credentials with ossrdbms-aad scope + credentials = client.oauth.get_content_credentials( + audience=association['oauth_integration_guid'] + ) + + access_token = credentials.get('access_token') + if not access_token: + print("ERROR: No access token received from Azure integration") + return None + + return access_token + + except Exception as e: + print(f"ERROR: Failed to get fresh Azure token: {e}") + return None + + +# ==================================================================== +# EVENT LISTENERS +# ==================================================================== + +def setup_database_event_listeners(): + """Setup SQLAlchemy event listeners for automatic database token injection. + + Handles both AWS RDS IAM authentication and Azure AD authentication. + This must be called BEFORE MLflow creates its engine. + """ + try: + from sqlalchemy import event, Engine + + print("Setting up SQLAlchemy event listeners for database token refresh") + + @event.listens_for(Engine, "do_connect") + def receive_do_connect(_dialect, _conn_rec, _cargs, cparams): + """ + Intercept connection attempts and inject fresh tokens. + Called every time a new connection is established. + """ + host = cparams.get('host', '') + + # Determine which database type based on environment variables + rds_endpoint = os.getenv('RDS_ENDPOINT', '').lower() + azure_sql_server = os.getenv('AZURE_SQL_SERVER', '').lower() + + # AWS RDS - check if RDS_ENDPOINT is set and matches the host + if rds_endpoint and (rds_endpoint in host.lower() or host.lower() in rds_endpoint): + print("Refreshing AWS RDS IAM token...") + fresh_token = get_fresh_aws_rds_token() + + if fresh_token: + cparams['password'] = fresh_token + # Ensure SSL is enabled (required for IAM auth) + if 'sslmode' not in cparams: + cparams['sslmode'] = 'require' + print("AWS RDS token refreshed successfully") + else: + print("ERROR: Failed to refresh AWS RDS token") + + # Azure PostgreSQL or SQL Server + elif azure_sql_server and (azure_sql_server in host.lower() or host.lower() in azure_sql_server): + print("Refreshing Azure SQL token...") + fresh_token = get_fresh_azure_token() + + if fresh_token: + cparams['password'] = fresh_token + print("Azure SQL token refreshed successfully") + else: + print("ERROR: Failed to refresh Azure SQL token") + + print("Database event listeners configured successfully") + return True + + except ImportError: + print("ERROR: SQLAlchemy not available, cannot setup event listeners") + return False + except Exception as e: + print(f"ERROR: Failed to setup database event listeners: {e}") + import traceback + print(traceback.format_exc()) + return False diff --git a/extensions/mlflow-tracking-server/main.py b/extensions/mlflow-tracking-server/main.py new file mode 100644 index 00000000..9873c372 --- /dev/null +++ b/extensions/mlflow-tracking-server/main.py @@ -0,0 +1,193 @@ +import os +import logging + +from database_utils import ( + mask_credentials, + detect_database_backend, + setup_aws_rds_connection, + setup_azure_sql_connection, + setup_database_event_listeners, +) +from storage_utils import ( + detect_storage_backend, + register_custom_artifact_repositories, +) + +# Configure logging before anything else +logging.basicConfig( + level=logging.DEBUG, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', + handlers=[ + logging.StreamHandler() + ] +) + +# Set specific loggers to appropriate levels +logging.getLogger("mlflow").setLevel(logging.DEBUG) +logging.getLogger("uvicorn").setLevel(logging.DEBUG) +logging.getLogger("uvicorn.access").setLevel(logging.DEBUG) +logging.getLogger("fastapi").setLevel(logging.DEBUG) + +logger = logging.getLogger(__name__) + +# Connect app persistent data directory +base_dir = os.path.join(os.getcwd(), "app-data") +db_path = os.path.join(base_dir, "mlflow.db") +artifact_path = os.path.join(base_dir, "artifacts") +os.makedirs(artifact_path, exist_ok=True) + +# Set environment variables for MLflow server before importing the app +# These are used by mlflow.server.handlers to initialize stores. +# Public env vars: +# MLFLOW_BACKEND_STORE_URI -> _MLFLOW_SERVER_FILE_STORE +# MLFLOW_DEFAULT_ARTIFACT_ROOT -> _MLFLOW_SERVER_ARTIFACT_ROOT +# MLFLOW_ARTIFACTS_DESTINATION -> _MLFLOW_SERVER_ARTIFACT_DESTINATION +# MLFLOW_SERVE_ARTIFACTS -> _MLFLOW_SERVER_SERVE_ARTIFACTS +# MLFLOW_ARTIFACTS_ONLY -> _MLFLOW_SERVER_ARTIFACTS_ONLY +# prometheus_multiproc_dir (no public equivalent) + +# Backend store URI (tracking/experiment data) +backend_store_uri = os.getenv("MLFLOW_BACKEND_STORE_URI") or os.getenv("_MLFLOW_SERVER_FILE_STORE") or f"sqlite:///{db_path}" +os.environ.setdefault("_MLFLOW_SERVER_FILE_STORE", backend_store_uri) + +# Registry store URI (model registry data) - uses same as backend +os.environ.setdefault("_MLFLOW_SERVER_REGISTRY_STORE", backend_store_uri) + +# Serve artifacts flag +serve_artifacts = os.getenv("MLFLOW_SERVE_ARTIFACTS") or os.getenv("_MLFLOW_SERVER_SERVE_ARTIFACTS") or "true" +os.environ.setdefault("_MLFLOW_SERVER_SERVE_ARTIFACTS", serve_artifacts) + +# Artifact root (default location for new experiments) +artifact_root = os.getenv("MLFLOW_DEFAULT_ARTIFACT_ROOT") or os.getenv("MLFLOW_ARTIFACT_ROOT") or os.getenv("_MLFLOW_SERVER_ARTIFACT_ROOT") or "mlflow-artifacts:/" +os.environ.setdefault("_MLFLOW_SERVER_ARTIFACT_ROOT", artifact_root) + +# Artifacts destination (physical storage location) +artifacts_destination = os.getenv("MLFLOW_ARTIFACTS_DESTINATION") or os.getenv("_MLFLOW_SERVER_ARTIFACT_DESTINATION") or artifact_path +os.environ.setdefault("_MLFLOW_SERVER_ARTIFACT_DESTINATION", artifacts_destination) + +# Configure SQLAlchemy connection pool recycling for OAuth token refresh +# Set to 1 hour (3600 seconds) to ensure connections are recycled well before +# the 24-hour token expiration, providing fresh tokens regularly +os.environ.setdefault("MLFLOW_SQLALCHEMYSTORE_POOL_RECYCLE", "3600") + +# IMPORTANT: Setup database event listeners BEFORE setting connection strings +# This ensures tokens are injected on the very first connection attempt +backend_db_type = detect_database_backend(os.environ.get('_MLFLOW_SERVER_FILE_STORE', f"sqlite:///{db_path}")) +print(f"Detected backend database type: {backend_db_type}") + +if backend_db_type in ['aws_rds', 'azure_sql']: + print("Setting up database event listeners for OAuth token refresh...") + if not setup_database_event_listeners(): + print("ERROR: Failed to setup database event listeners") + +# Detect and setup database connections +if backend_db_type == 'aws_rds': + print("AWS RDS backend detected, setting up IAM authentication...") + + rds_connection_string = setup_aws_rds_connection() + if rds_connection_string: + os.environ['_MLFLOW_SERVER_FILE_STORE'] = rds_connection_string + os.environ['_MLFLOW_SERVER_REGISTRY_STORE'] = rds_connection_string + print("AWS RDS backend connection configured successfully") + print("Tokens will be automatically refreshed every hour via connection pool recycling") + else: + print("WARNING: Failed to setup AWS RDS connection, using original connection string") +elif backend_db_type == 'azure_sql': + print("Azure SQL backend detected, setting up Azure AD authentication...") + azure_connection_string = setup_azure_sql_connection() + if azure_connection_string: + os.environ['_MLFLOW_SERVER_FILE_STORE'] = azure_connection_string + os.environ['_MLFLOW_SERVER_REGISTRY_STORE'] = azure_connection_string + print("Azure SQL backend connection configured successfully") + print("Tokens will be automatically refreshed every hour via connection pool recycling") + else: + print("WARNING: Failed to setup Azure SQL connection, using original connection string") + +# Detect storage backend and setup OAuth-based artifact repositories if needed +storage_backend = detect_storage_backend(os.environ['_MLFLOW_SERVER_ARTIFACT_DESTINATION']) +print(f"Detected storage backend: {storage_backend}") + +# IMPORTANT: Register custom artifact repositories BEFORE importing MLflow +# This ensures our OAuth-enabled repositories are used for artifact operations +# Only registers if OAuth integration is available, otherwise falls back to default credentials +if storage_backend in ['aws', 'azure']: + register_custom_artifact_repositories(storage_backend) +elif storage_backend == 'gcp': + print("GCP storage detected - using default credentials (env vars, service account, etc.)") +elif storage_backend == 'local': + print("Local storage detected - no cloud credentials needed") + +# By default, MLflow server runs without authentication. +# This is suitable for running behind a proxy that handles authentication. + +# The MLflow UI (Flask) needs a secret key for secure cookies. +# Generate a random one if not set. +if "MLFLOW_FLASK_SERVER_SECRET_KEY" not in os.environ: + os.environ["MLFLOW_FLASK_SERVER_SECRET_KEY"] = os.urandom(24).hex() + +# Import the FastAPI app from MLflow. +# This must be done AFTER setting the environment variables. +try: + print(f"Backend store URI: {mask_credentials(os.environ['_MLFLOW_SERVER_FILE_STORE'])}") + print(f"Registry store URI: {mask_credentials(os.environ['_MLFLOW_SERVER_REGISTRY_STORE'])}") + print(f"Artifact root: {os.environ['_MLFLOW_SERVER_ARTIFACT_ROOT']}") + print(f"Artifacts destination: {os.environ['_MLFLOW_SERVER_ARTIFACT_DESTINATION']}") + print(f"Serve artifacts: {os.environ['_MLFLOW_SERVER_SERVE_ARTIFACTS']}") + print(f"Pool recycle: {os.environ['MLFLOW_SQLALCHEMYSTORE_POOL_RECYCLE']}s") + + from mlflow.server.fastapi_app import app + + # Add middleware to log all requests + from fastapi import Request + import traceback + import time + + @app.middleware("http") + async def log_requests(request: Request, call_next): + start_time = time.time() + + # Log request + print(f"Request: {request.method} {request.url}") + + # Process request + try: + response = await call_next(request) + process_time = time.time() - start_time + + # Log response + print(f"Response: {response.status_code} - {process_time:.3f}s") + response.headers["X-Process-Time"] = str(process_time) + + return response + except Exception: + print(f"Request failed: {traceback.format_exc()}") + raise + +except ImportError as e: + print("Failed to import MLflow. Please ensure MLflow is installed (`pip install mlflow`).") + raise e + +if __name__ == "__main__": + import uvicorn + + # Run the FastAPI app with Uvicorn + # Use environment variables for host and port to support different deployment scenarios + host = os.getenv("MLFLOW_HOST", "0.0.0.0") + port = int(os.getenv("MLFLOW_PORT", "8000")) + + # For remote deployment, configure the server host for artifact URL generation + if host != "127.0.0.1" and host != "localhost": + server_host = os.getenv("MLFLOW_SERVER_HOST", f"{host}:{port}") + os.environ["_MLFLOW_SERVER_HOST"] = server_host + + print(f"Starting MLflow server on {host}:{port}") + print(f"Artifact serving enabled: {os.getenv('_MLFLOW_SERVER_SERVE_ARTIFACTS')}") + print(f"Artifacts destination: {os.getenv('_MLFLOW_SERVER_ARTIFACTS_DESTINATION')}") + + uvicorn.run( + app, + host=host, + port=port, + log_level="debug", + access_log=True + ) diff --git a/extensions/mlflow-tracking-server/manifest.json b/extensions/mlflow-tracking-server/manifest.json new file mode 100644 index 00000000..9acf8ec8 --- /dev/null +++ b/extensions/mlflow-tracking-server/manifest.json @@ -0,0 +1,46 @@ +{ + "version": 1, + "locale": "en_US.UTF-8", + "metadata": { + "appmode": "python-fastapi", + "entrypoint": "main" + }, + "python": { + "version": "3.12.7", + "package_manager": { + "name": "pip", + "version": "24.2", + "package_file": "requirements.txt" + } + }, + "environment": { + "python": { + "requires": ">=3.11.0" + } + }, + "extension": { + "name": "mlflow-tracking-server", + "title": "MLflow Tracking Server", + "description": "Provides a way to track and manage MLflow experiments and models. See extension README for setup instructions. Local storage option requires Connect version 2025.10.0+.", + "homepage": "https://github.com/posit-dev/connect-extensions/tree/main/extensions/mlflow-tracking-server", + "category": "extension", + "tags": ["mlops", "mlflow", "tracking", "experiments", "models"], + "minimumConnectVersion": "2025.04.0", + "requiredFeatures": [], + "version": "0.0.0" + }, + "files": { + "requirements.txt": { + "checksum": "075a8097f9eccd0a8d8442d624606988" + }, + ".gitignore": { + "checksum": "693ec79eaa892babde62587aaacf0d8b" + }, + "README.md": { + "checksum": "8cb20a8fa038108b80d38fc261808cb2" + }, + "main.py": { + "checksum": "be3fe27aebfb4440b167337bdf8aabbf" + } + } +} diff --git a/extensions/mlflow-tracking-server/requirements.txt b/extensions/mlflow-tracking-server/requirements.txt new file mode 100644 index 00000000..a3963cf3 --- /dev/null +++ b/extensions/mlflow-tracking-server/requirements.txt @@ -0,0 +1,12 @@ +mlflow[auth] +fastapi +posit-sdk +# Cloud storage providers +azure-storage-blob +azure-identity +google-cloud-storage +boto3 +# Database drivers +psycopg2-binary # PostgreSQL +pymysql # MySQL +pyodbc # MSSQL diff --git a/extensions/mlflow-tracking-server/storage_utils.py b/extensions/mlflow-tracking-server/storage_utils.py new file mode 100644 index 00000000..d7c6f596 --- /dev/null +++ b/extensions/mlflow-tracking-server/storage_utils.py @@ -0,0 +1,452 @@ +import logging + +logger = logging.getLogger(__name__) + + +def detect_storage_backend(artifacts_destination: str) -> str: + """Detect the storage backend type from the artifacts destination. + + Parameters + ---------- + artifacts_destination : str + The artifacts destination URI + + Returns + ------- + str + One of: 'local', 'aws', 'azure', 'gcp' + """ + artifacts_destination = artifacts_destination.lower() + + if artifacts_destination.startswith('s3://') or artifacts_destination.startswith('s3a://'): + return 'aws' + elif (artifacts_destination.startswith('wasbs://') or + artifacts_destination.startswith('abfss://') or + 'blob.core.windows.net' in artifacts_destination or + 'dfs.core.windows.net' in artifacts_destination): + return 'azure' + elif artifacts_destination.startswith('gs://'): + return 'gcp' + else: + return 'local' + + +# ==================================================================== +# ARTIFACT STORAGE OAUTH INTEGRATION +# ==================================================================== + +def get_azure_storage_token(): + """Fetch a fresh Azure storage access token from Posit Connect. + + Returns + ------- + tuple[str, int] | None + Tuple of (access_token, expires_on_timestamp), or None if fetch failed + """ + try: + from posit.connect import Client + import time + + print("Fetching Azure storage token from Posit Connect...") + client = Client() + content = client.content.get() + + # Find Azure service account integration + association = content.oauth.associations.find_by( + integration_type="azure", + auth_type="Service Account" + ) + + if not association: + print("ERROR: No Azure Service Account integration found") + return None + + # Get Azure credentials for storage + # Use storage scope: https://storage.azure.com/.default + credentials = client.oauth.get_content_credentials( + audience=association['oauth_integration_guid'] + ) + + access_token = credentials.get('access_token') + if not access_token: + print("ERROR: No access token received from Azure integration") + return None + + # Calculate expiration time (tokens typically expire in 1 hour) + # If expires_in is provided, use it; otherwise default to 3600 seconds + expires_in = credentials.get('expires_in', 3600) + expires_on = int(time.time()) + expires_in + + print(f"Azure storage token fetched successfully, expires at {expires_on}") + return (access_token, expires_on) + + except Exception as e: + print(f"ERROR: Failed to get Azure storage token: {e}") + return None + + +def get_aws_storage_credentials(): + """Fetch fresh AWS storage credentials from Posit Connect. + + Returns + ------- + dict | None + Dictionary with aws_access_key_id, aws_secret_access_key, aws_session_token, + or None if fetch failed + """ + try: + from posit.connect import Client + from posit.connect.external.aws import get_content_credentials + + print("Fetching AWS storage credentials from Posit Connect...") + client = Client() + credentials = get_content_credentials(client) + + if not credentials: + print("ERROR: Failed to get AWS credentials") + return None + + print(f"AWS storage credentials fetched successfully, expires at {credentials.get('expiration', 'unknown')}") + return credentials + + except Exception as e: + print(f"ERROR: Failed to get AWS storage credentials: {e}") + return None + + +# ==================================================================== +# AZURE BLOB STORAGE CUSTOM ARTIFACT REPOSITORY +# ==================================================================== + +try: + import time + import threading + from urllib.parse import urlparse + from azure.core.credentials import AccessToken + from azure.storage.blob import BlobServiceClient + from mlflow.store.artifact.azure_blob_artifact_repo import AzureBlobArtifactRepository + + class RefreshingTokenCredential: + """A thread-safe credential that auto-refreshes an OAuth2 token for Azure.""" + + def __init__(self, refresh_function): + """Initialize with a function that returns (token, expires_on).""" + self._refresh_function = refresh_function + self._access_token = None + self._expires_on = 0 + self._lock = threading.Lock() + + def get_token(self, *scopes, **kwargs) -> AccessToken: + """Get a valid access token, refreshing if necessary.""" + with self._lock: + # Refresh token 5 minutes before expiration + buffer_seconds = 300 + if not self._access_token or time.time() > (self._expires_on - buffer_seconds): + result = self._refresh_function() + if result: + new_token, new_expires_on = result + self._access_token = new_token + self._expires_on = new_expires_on + else: + print("ERROR: Failed to refresh Azure storage token") + + return AccessToken(self._access_token, self._expires_on) + + + class TokenAuthAzureBlobRepo(AzureBlobArtifactRepository): + """Custom Azure Blob Storage repository that uses OAuth token authentication. + + For wasbs:// URIs (Azure Blob Storage). + """ + + def __init__(self, artifact_uri: str, tracking_uri: str = None, registry_uri: str = None): + """Initialize with auto-refreshing OAuth token credential.""" + print(f"Initializing TokenAuthAzureBlobRepo for {artifact_uri}") + + # Create auto-refreshing credential + credential = RefreshingTokenCredential(refresh_function=get_azure_storage_token) + + # Parse the artifact URI to get the account URL + parsed_uri = urlparse(artifact_uri) + account_url = f"https://{parsed_uri.hostname}" + + # Create custom BlobServiceClient with OAuth token + custom_client = BlobServiceClient(account_url=account_url, credential=credential) + + # Initialize parent class with custom client + super().__init__( + artifact_uri=artifact_uri, + client=custom_client, + tracking_uri=tracking_uri, + registry_uri=registry_uri, + ) + + print(f"TokenAuthAzureBlobRepo initialized successfully for {account_url}") + + # Export the class for registration + AZURE_BLOB_REPO_CLASS = TokenAuthAzureBlobRepo + +except ImportError as e: + print(f"WARNING: Could not import Azure Blob Storage dependencies: {e}") + AZURE_BLOB_REPO_CLASS = None + + +# ==================================================================== +# AZURE DATA LAKE STORAGE GEN2 CUSTOM ARTIFACT REPOSITORY +# ==================================================================== + +try: + from mlflow.store.artifact.azure_data_lake_artifact_repo import AzureDataLakeArtifactRepository + + class TokenAuthAzureDataLakeRepo(AzureDataLakeArtifactRepository): + """Custom Azure Data Lake Storage Gen2 repository that uses OAuth token authentication. + + For abfss:// URIs (Azure Data Lake Storage Gen2). + Uses credential_refresh_def to periodically refresh OAuth tokens. + """ + + def __init__(self, artifact_uri: str, tracking_uri: str = None, registry_uri: str = None): + """Initialize with auto-refreshing OAuth token credential.""" + print(f"Initializing TokenAuthAzureDataLakeRepo for {artifact_uri}") + + # Create auto-refreshing credential + credential = RefreshingTokenCredential(refresh_function=get_azure_storage_token) + + # Define credential refresh function for the parent class + def credential_refresh_def(): + """Return new credentials in the format expected by AzureDataLakeArtifactRepository.""" + # Create a new credential instance with the refreshing function + return {"credential": RefreshingTokenCredential(refresh_function=get_azure_storage_token)} + + # Initialize parent class with credential and refresh function + super().__init__( + artifact_uri=artifact_uri, + credential=credential, + credential_refresh_def=credential_refresh_def, + tracking_uri=tracking_uri, + ) + + print("TokenAuthAzureDataLakeRepo initialized successfully") + + # Export the class for registration + AZURE_DATA_LAKE_REPO_CLASS = TokenAuthAzureDataLakeRepo + +except ImportError as e: + print(f"WARNING: Could not import Azure Data Lake Storage dependencies: {e}") + AZURE_DATA_LAKE_REPO_CLASS = None + + +# ==================================================================== +# S3 CUSTOM ARTIFACT REPOSITORY +# ==================================================================== + +try: + from mlflow.store.artifact.s3_artifact_repo import S3ArtifactRepository + import boto3 + + class RefreshingS3ArtifactRepository(S3ArtifactRepository): + """Custom S3 artifact repository that refreshes OAuth credentials every hour. + + Overrides _get_s3_client() to fetch fresh temporary credentials from Posit Connect + OAuth integration on each S3 operation, with automatic refresh every hour. + This ensures credentials remain valid even when repository instances are cached. + """ + + def __init__(self, artifact_uri: str, tracking_uri: str = None, registry_uri: str = None): + """Initialize with OAuth-based credential management. + + Does not fetch credentials at initialization - instead, credentials are fetched + on-demand when S3 operations are performed via _get_s3_client(). + """ + print(f"Initializing RefreshingS3ArtifactRepository for {artifact_uri}") + + # Store timestamp for credential refresh tracking + self._credentials_fetched_at = 0 + self._credential_refresh_interval = 3600 # Refresh every hour + self._cached_s3_client = None + + # Initialize parent without credentials - we'll override get_s3_client + super().__init__( + artifact_uri=artifact_uri, + tracking_uri=tracking_uri, + ) + + print("RefreshingS3ArtifactRepository initialized successfully") + + def _get_s3_client(self): + """Override to fetch fresh credentials from OAuth integration. + + This method is called by MLflow whenever S3 operations are performed, + ensuring credentials are refreshed as needed. + """ + import time + + # Check if we need to refresh credentials (every hour or if no client cached) + current_time = time.time() + if (self._cached_s3_client is None or + current_time - self._credentials_fetched_at > self._credential_refresh_interval): + print("Fetching fresh AWS credentials for S3 access...") + temp_creds = get_aws_storage_credentials() + + if temp_creds: + # Create S3 client with fresh temporary credentials + self._cached_s3_client = boto3.client( + 's3', + aws_access_key_id=temp_creds["aws_access_key_id"], + aws_secret_access_key=temp_creds["aws_secret_access_key"], + aws_session_token=temp_creds["aws_session_token"], + ) + self._credentials_fetched_at = current_time + print("S3 client created with fresh OAuth credentials") + else: + print("ERROR: Failed to fetch AWS credentials from OAuth integration") + print("Falling back to default credentials (this may fail if no valid credentials available)") + # Only fall back to parent if we have no cached client + if self._cached_s3_client is None: + return super()._get_s3_client() + + # Return the cached client with OAuth credentials + return self._cached_s3_client + + # Export the class for registration + S3_REPO_CLASS = RefreshingS3ArtifactRepository + +except ImportError as e: + print(f"WARNING: Could not import S3 dependencies: {e}") + S3_REPO_CLASS = None + + +# ==================================================================== +# OAUTH INTEGRATION CHECKS +# ==================================================================== + +def check_aws_oauth_integration(): + """Check if AWS OAuth integration is available for this content. + + Returns + ------- + bool + True if AWS OAuth integration is available, False otherwise + """ + try: + from posit.connect import Client + from posit.connect.external.aws import get_content_credentials + + client = Client() + credentials = get_content_credentials(client) + return credentials is not None + + except ImportError: + return False + except Exception: + return False + + +def check_azure_oauth_integration(): + """Check if Azure OAuth integration is available for this content. + + Returns + ------- + bool + True if Azure OAuth integration is available, False otherwise + """ + try: + from posit.connect import Client + + client = Client() + content = client.content.get() + + # Find Azure service account integration + association = content.oauth.associations.find_by( + integration_type="azure", + auth_type="Service Account" + ) + + return association is not None + + except ImportError: + return False + except Exception: + return False + + +# ==================================================================== +# ARTIFACT REPOSITORY REGISTRATION +# ==================================================================== + +def register_custom_artifact_repositories(storage_backend: str): + """Register custom artifact repositories with MLflow's registry if OAuth integration exists. + + This function checks for OAuth integration availability and only registers custom repositories + if an integration is found. Otherwise, MLflow will use default credentials (env vars, IAM roles, etc.) + + Parameters + ---------- + storage_backend : str + The storage backend type ('aws', 'azure', 'gcp', or 'local') + + Returns + ------- + bool + True if custom repos were registered, False if no OAuth integration or registration failed + """ + try: + from mlflow.store.artifact.artifact_repository_registry import _artifact_repository_registry + + registered = False + + # Check for AWS OAuth integration + if storage_backend == 'aws': + if not check_aws_oauth_integration(): + print("No AWS OAuth integration found - using default credentials (env vars, IAM role, etc.)") + return False + + if S3_REPO_CLASS: + print("AWS OAuth integration detected - registering custom S3 artifact repository...") + _artifact_repository_registry.register("s3", S3_REPO_CLASS) + print("Custom S3 artifact repository registered successfully") + registered = True + else: + print("WARNING: S3 artifact repository class not available") + return False + + # Check for Azure OAuth integration + elif storage_backend == 'azure': + if not check_azure_oauth_integration(): + print("No Azure OAuth integration found - using default credentials (env vars, managed identity, etc.)") + return False + + print("Azure OAuth integration detected - registering custom Azure artifact repositories...") + + # Register Azure Blob Storage repository (wasbs://) + if AZURE_BLOB_REPO_CLASS: + _artifact_repository_registry.register("wasbs", AZURE_BLOB_REPO_CLASS) + print(" - Registered Azure Blob Storage repository (wasbs://)") + registered = True + else: + print(" WARNING: Azure Blob Storage repository class not available") + + # Register Azure Data Lake Gen2 repository (abfss://) + if AZURE_DATA_LAKE_REPO_CLASS: + _artifact_repository_registry.register("abfss", AZURE_DATA_LAKE_REPO_CLASS) + print(" - Registered Azure Data Lake Storage Gen2 repository (abfss://)") + registered = True + else: + print(" WARNING: Azure Data Lake Storage Gen2 repository class not available") + + if not registered: + print("ERROR: No Azure artifact repository classes available") + return False + + print("Custom Azure artifact repositories registered successfully") + + return registered + + except ImportError: + print("ERROR: MLflow not available, cannot register artifact repositories") + return False + except Exception as e: + print(f"ERROR: Failed to register custom artifact repositories: {e}") + import traceback + print(traceback.format_exc()) + return False