diff --git a/src/Azure.DataApiBuilder.Mcp/BuiltInTools/CreateRecordTool.cs b/src/Azure.DataApiBuilder.Mcp/BuiltInTools/CreateRecordTool.cs index 68447f16f4..1a944d115b 100644 --- a/src/Azure.DataApiBuilder.Mcp/BuiltInTools/CreateRecordTool.cs +++ b/src/Azure.DataApiBuilder.Mcp/BuiltInTools/CreateRecordTool.cs @@ -5,14 +5,13 @@ using Azure.DataApiBuilder.Auth; using Azure.DataApiBuilder.Config.DatabasePrimitives; using Azure.DataApiBuilder.Config.ObjectModel; -using Azure.DataApiBuilder.Core.Authorization; using Azure.DataApiBuilder.Core.Configurations; using Azure.DataApiBuilder.Core.Models; using Azure.DataApiBuilder.Core.Resolvers; using Azure.DataApiBuilder.Core.Resolvers.Factories; using Azure.DataApiBuilder.Core.Services; -using Azure.DataApiBuilder.Core.Services.MetadataProviders; using Azure.DataApiBuilder.Mcp.Model; +using Azure.DataApiBuilder.Mcp.Utils; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Mvc; using Microsoft.Extensions.DependencyInjection; @@ -57,23 +56,21 @@ public async Task ExecuteAsync( CancellationToken cancellationToken = default) { ILogger? logger = serviceProvider.GetService>(); + string toolName = GetToolMetadata().Name; if (arguments == null) { - return Utils.McpResponseBuilder.BuildErrorResult("Invalid Arguments", "No arguments provided", logger); + return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", "No arguments provided.", logger); } RuntimeConfigProvider runtimeConfigProvider = serviceProvider.GetRequiredService(); if (!runtimeConfigProvider.TryGetConfig(out RuntimeConfig? runtimeConfig)) { - return Utils.McpResponseBuilder.BuildErrorResult("Invalid Configuration", "Runtime configuration not available", logger); + return McpResponseBuilder.BuildErrorResult(toolName, "InvalidConfiguration", "Runtime configuration not available.", logger); } if (runtimeConfig.McpDmlTools?.CreateRecord != true) { - return Utils.McpResponseBuilder.BuildErrorResult( - "ToolDisabled", - "The create_record tool is disabled in the configuration.", - logger); + return McpErrorHelpers.ToolDisabled(toolName, logger); } try @@ -81,39 +78,21 @@ public async Task ExecuteAsync( cancellationToken.ThrowIfCancellationRequested(); JsonElement root = arguments.RootElement; - if (!root.TryGetProperty("entity", out JsonElement entityElement) || - !root.TryGetProperty("data", out JsonElement dataElement)) + if (!McpArgumentParser.TryParseEntityAndData(root, out string entityName, out JsonElement dataElement, out string parseError)) { - return Utils.McpResponseBuilder.BuildErrorResult("InvalidArguments", "Missing required arguments 'entity' or 'data'", logger); + return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", parseError, logger); } - string entityName = entityElement.GetString() ?? string.Empty; - if (string.IsNullOrWhiteSpace(entityName)) + if (!McpMetadataHelper.TryResolveMetadata( + entityName, + runtimeConfig, + serviceProvider, + out ISqlMetadataProvider sqlMetadataProvider, + out DatabaseObject dbObject, + out string dataSourceName, + out string metadataError)) { - return Utils.McpResponseBuilder.BuildErrorResult("InvalidArguments", "Entity name cannot be empty", logger); - } - - string dataSourceName; - try - { - dataSourceName = runtimeConfig.GetDataSourceNameFromEntityName(entityName); - } - catch (Exception) - { - return Utils.McpResponseBuilder.BuildErrorResult("InvalidConfiguration", $"Entity '{entityName}' not found in configuration", logger); - } - - IMetadataProviderFactory metadataProviderFactory = serviceProvider.GetRequiredService(); - ISqlMetadataProvider sqlMetadataProvider = metadataProviderFactory.GetMetadataProvider(dataSourceName); - - DatabaseObject dbObject; - try - { - dbObject = sqlMetadataProvider.GetDatabaseObjectByKey(entityName); - } - catch (Exception) - { - return Utils.McpResponseBuilder.BuildErrorResult("InvalidConfiguration", $"Database object for entity '{entityName}' not found", logger); + return McpResponseBuilder.BuildErrorResult(toolName, "EntityNotFound", metadataError, logger); } // Create an HTTP context for authorization @@ -121,15 +100,20 @@ public async Task ExecuteAsync( HttpContext httpContext = httpContextAccessor.HttpContext ?? new DefaultHttpContext(); IAuthorizationResolver authorizationResolver = serviceProvider.GetRequiredService(); - if (httpContext is null || !authorizationResolver.IsValidRoleContext(httpContext)) + if (!McpAuthorizationHelper.ValidateRoleContext(httpContext, authorizationResolver, out string roleCtxError)) { - return Utils.McpResponseBuilder.BuildErrorResult("PermissionDenied", "Permission denied: Unable to resolve a valid role context for update operation.", logger); + return McpErrorHelpers.PermissionDenied(toolName, entityName, "create", roleCtxError, logger); } - // Validate that we have at least one role authorized for create - if (!TryResolveAuthorizedRole(httpContext, authorizationResolver, entityName, out string authError)) + if (!McpAuthorizationHelper.TryResolveAuthorizedRole( + httpContext, + authorizationResolver, + entityName, + EntityActionOperation.Create, + out string? effectiveRole, + out string authError)) { - return Utils.McpResponseBuilder.BuildErrorResult("PermissionDenied", authError, logger); + return McpErrorHelpers.PermissionDenied(toolName, entityName, "create", authError, logger); } JsonElement insertPayloadRoot = dataElement.Clone(); @@ -150,12 +134,13 @@ public async Task ExecuteAsync( } catch (Exception ex) { - return Utils.McpResponseBuilder.BuildErrorResult("ValidationFailed", $"Request validation failed: {ex.Message}", logger); + return McpResponseBuilder.BuildErrorResult(toolName, "ValidationFailed", $"Request validation failed: {ex.Message}", logger); } } else { - return Utils.McpResponseBuilder.BuildErrorResult( + return McpResponseBuilder.BuildErrorResult( + toolName, "InvalidCreateTarget", "The create_record tool is only available for tables.", logger); @@ -169,7 +154,7 @@ public async Task ExecuteAsync( if (result is CreatedResult createdResult) { - return Utils.McpResponseBuilder.BuildSuccessResult( + return McpResponseBuilder.BuildSuccessResult( new Dictionary { ["entity"] = entityName, @@ -184,14 +169,15 @@ public async Task ExecuteAsync( bool isError = objectResult.StatusCode.HasValue && objectResult.StatusCode.Value >= 400 && objectResult.StatusCode.Value != 403; if (isError) { - return Utils.McpResponseBuilder.BuildErrorResult( + return McpResponseBuilder.BuildErrorResult( + toolName, "CreateFailed", $"Failed to create record in entity '{entityName}'. Error: {JsonSerializer.Serialize(objectResult.Value)}", logger); } else { - return Utils.McpResponseBuilder.BuildSuccessResult( + return McpResponseBuilder.BuildSuccessResult( new Dictionary { ["entity"] = entityName, @@ -206,14 +192,15 @@ public async Task ExecuteAsync( { if (result is null) { - return Utils.McpResponseBuilder.BuildErrorResult( + return McpResponseBuilder.BuildErrorResult( + toolName, "UnexpectedError", $"Mutation engine returned null result for entity '{entityName}'", logger); } else { - return Utils.McpResponseBuilder.BuildSuccessResult( + return McpResponseBuilder.BuildSuccessResult( new Dictionary { ["entity"] = entityName, @@ -226,50 +213,8 @@ public async Task ExecuteAsync( } catch (Exception ex) { - return Utils.McpResponseBuilder.BuildErrorResult("Error", $"Error: {ex.Message}", logger); + return McpResponseBuilder.BuildErrorResult(toolName, "Error", $"Error: {ex.Message}", logger); } } - - private static bool TryResolveAuthorizedRole( - HttpContext httpContext, - IAuthorizationResolver authorizationResolver, - string entityName, - out string error) - { - error = string.Empty; - - string roleHeader = httpContext.Request.Headers[AuthorizationResolver.CLIENT_ROLE_HEADER].ToString(); - - if (string.IsNullOrWhiteSpace(roleHeader)) - { - error = "Client role header is missing or empty."; - return false; - } - - string[] roles = roleHeader - .Split(',', StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries) - .Distinct(StringComparer.OrdinalIgnoreCase) - .ToArray(); - - if (roles.Length == 0) - { - error = "Client role header is missing or empty."; - return false; - } - - foreach (string role in roles) - { - bool allowed = authorizationResolver.AreRoleAndOperationDefinedForEntity( - entityName, role, EntityActionOperation.Create); - - if (allowed) - { - return true; - } - } - - error = "You do not have permission to create records for this entity."; - return false; - } } } diff --git a/src/Azure.DataApiBuilder.Mcp/BuiltInTools/DeleteRecordTool.cs b/src/Azure.DataApiBuilder.Mcp/BuiltInTools/DeleteRecordTool.cs index 7abac888c5..d7837c0103 100644 --- a/src/Azure.DataApiBuilder.Mcp/BuiltInTools/DeleteRecordTool.cs +++ b/src/Azure.DataApiBuilder.Mcp/BuiltInTools/DeleteRecordTool.cs @@ -73,6 +73,7 @@ public async Task ExecuteAsync( CancellationToken cancellationToken = default) { ILogger? logger = serviceProvider.GetService>(); + string toolName = GetToolMetadata().Name; try { @@ -86,49 +87,37 @@ public async Task ExecuteAsync( // 2) Check if the tool is enabled in configuration before proceeding if (config.McpDmlTools?.DeleteRecord != true) { - return McpResponseBuilder.BuildErrorResult( - "ToolDisabled", - $"The {this.GetToolMetadata().Name} tool is disabled in the configuration.", - logger); + return McpErrorHelpers.ToolDisabled(GetToolMetadata().Name, logger); } // 3) Parsing & basic argument validation if (arguments is null) { - return McpResponseBuilder.BuildErrorResult("InvalidArguments", "No arguments provided.", logger); + return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", "No arguments provided.", logger); } if (!McpArgumentParser.TryParseEntityAndKeys(arguments.RootElement, out string entityName, out Dictionary keys, out string parseError)) { - return McpResponseBuilder.BuildErrorResult("InvalidArguments", parseError, logger); - } - - IMetadataProviderFactory metadataProviderFactory = serviceProvider.GetRequiredService(); - IMutationEngineFactory mutationEngineFactory = serviceProvider.GetRequiredService(); - - // 4) Resolve metadata for entity existence check - string dataSourceName; - ISqlMetadataProvider sqlMetadataProvider; - - try - { - dataSourceName = config.GetDataSourceNameFromEntityName(entityName); - sqlMetadataProvider = metadataProviderFactory.GetMetadataProvider(dataSourceName); - } - catch (Exception) - { - return McpResponseBuilder.BuildErrorResult("EntityNotFound", $"Entity '{entityName}' is not defined in the configuration.", logger); + return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", parseError, logger); } - if (!sqlMetadataProvider.EntityToDatabaseObject.TryGetValue(entityName, out DatabaseObject? dbObject) || dbObject is null) + // 4) Resolve metadata for entity existence + if (!McpMetadataHelper.TryResolveMetadata( + entityName, + config, + serviceProvider, + out ISqlMetadataProvider sqlMetadataProvider, + out DatabaseObject dbObject, + out string dataSourceName, + out string metadataError)) { - return McpResponseBuilder.BuildErrorResult("EntityNotFound", $"Entity '{entityName}' is not defined in the configuration.", logger); + return McpResponseBuilder.BuildErrorResult(toolName, "EntityNotFound", metadataError, logger); } // Validate it's a table or view if (dbObject.SourceType != EntitySourceType.Table && dbObject.SourceType != EntitySourceType.View) { - return McpResponseBuilder.BuildErrorResult("InvalidEntity", $"Entity '{entityName}' is not a table or view. Use 'execute-entity' for stored procedures.", logger); + return McpResponseBuilder.BuildErrorResult(toolName, "InvalidEntity", $"Entity '{entityName}' is not a table or view. Use 'execute-entity' for stored procedures.", logger); } // 5) Authorization @@ -138,7 +127,7 @@ public async Task ExecuteAsync( if (!McpAuthorizationHelper.ValidateRoleContext(httpContext, authResolver, out string roleError)) { - return McpResponseBuilder.BuildErrorResult("PermissionDenied", $"Permission denied: {roleError}", logger); + return McpErrorHelpers.PermissionDenied(toolName, entityName, "delete", roleError, logger); } if (!McpAuthorizationHelper.TryResolveAuthorizedRole( @@ -149,10 +138,11 @@ public async Task ExecuteAsync( out string? effectiveRole, out string authError)) { - return McpResponseBuilder.BuildErrorResult("PermissionDenied", $"Permission denied: {authError}", logger); + return McpErrorHelpers.PermissionDenied(toolName, entityName, "delete", authError, logger); } - // 6) Build and validate Delete context + // Need MetadataProviderFactory for RequestValidator; resolve here. + IMetadataProviderFactory metadataProviderFactory = serviceProvider.GetRequiredService(); RequestValidator requestValidator = new(metadataProviderFactory, runtimeConfigProvider); DeleteRequestContext context = new( @@ -164,7 +154,7 @@ public async Task ExecuteAsync( { if (kvp.Value is null) { - return McpResponseBuilder.BuildErrorResult("InvalidArguments", $"Primary key value for '{kvp.Key}' cannot be null.", logger); + return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", $"Primary key value for '{kvp.Key}' cannot be null.", logger); } context.PrimaryKeyValuePairs[kvp.Key] = kvp.Value; @@ -172,7 +162,7 @@ public async Task ExecuteAsync( requestValidator.ValidatePrimaryKey(context); - // 7) Execute + IMutationEngineFactory mutationEngineFactory = serviceProvider.GetRequiredService(); DatabaseType dbType = config.GetDataSourceFromDataSourceName(dataSourceName).DatabaseType; IMutationEngine mutationEngine = mutationEngineFactory.GetMutationEngine(dbType); @@ -195,6 +185,7 @@ public async Task ExecuteAsync( { string keyDetails = McpJsonHelper.FormatKeyDetails(keys); return McpResponseBuilder.BuildErrorResult( + toolName, "RecordNotFound", $"No record found with the specified primary key: {keyDetails}", logger); @@ -203,6 +194,7 @@ public async Task ExecuteAsync( message.Contains("REFERENCE constraint", StringComparison.OrdinalIgnoreCase)) { return McpResponseBuilder.BuildErrorResult( + toolName, "ConstraintViolation", "Cannot delete record due to foreign key constraint. Other records depend on this record.", logger); @@ -211,6 +203,7 @@ public async Task ExecuteAsync( message.Contains("authorization", StringComparison.OrdinalIgnoreCase)) { return McpResponseBuilder.BuildErrorResult( + toolName, "PermissionDenied", "You do not have permission to delete this record.", logger); @@ -219,6 +212,7 @@ public async Task ExecuteAsync( message.Contains("type", StringComparison.OrdinalIgnoreCase)) { return McpResponseBuilder.BuildErrorResult( + toolName, "InvalidArguments", "Invalid data type for one or more key values.", logger); @@ -226,6 +220,7 @@ public async Task ExecuteAsync( // For any other DAB exceptions, return the message as-is return McpResponseBuilder.BuildErrorResult( + toolName, "DataApiBuilderError", dabEx.Message, logger); @@ -242,7 +237,7 @@ public async Task ExecuteAsync( 208 => $"Table '{dbObject.FullName}' not found in the database.", _ => $"Database error: {sqlEx.Message}" }; - return McpResponseBuilder.BuildErrorResult("DatabaseError", errorMessage, logger); + return McpResponseBuilder.BuildErrorResult(toolName, "DatabaseError", errorMessage, logger); } catch (DbException dbEx) { @@ -254,6 +249,7 @@ public async Task ExecuteAsync( if (errorMsg.Contains("foreign key") || errorMsg.Contains("constraint")) { return McpResponseBuilder.BuildErrorResult( + toolName, "ConstraintViolation", "Cannot delete record due to foreign key constraint. Other records depend on this record.", logger); @@ -261,24 +257,25 @@ public async Task ExecuteAsync( else if (errorMsg.Contains("not found") || errorMsg.Contains("does not exist")) { return McpResponseBuilder.BuildErrorResult( + toolName, "RecordNotFound", "No record found with the specified primary key.", logger); } - return McpResponseBuilder.BuildErrorResult("DatabaseError", $"Database error: {dbEx.Message}", logger); + return McpResponseBuilder.BuildErrorResult(toolName, "DatabaseError", $"Database error: {dbEx.Message}", logger); } catch (InvalidOperationException ioEx) when (ioEx.Message.Contains("connection", StringComparison.OrdinalIgnoreCase)) { // Handle connection-related issues logger?.LogError(ioEx, "Database connection error"); - return McpResponseBuilder.BuildErrorResult("ConnectionError", "Failed to connect to the database.", logger); + return McpResponseBuilder.BuildErrorResult(toolName, "ConnectionError", "Failed to connect to the database.", logger); } catch (TimeoutException timeoutEx) { // Handle query timeout logger?.LogError(timeoutEx, "Delete operation timeout for {Entity}", entityName); - return McpResponseBuilder.BuildErrorResult("TimeoutError", "The delete operation timed out.", logger); + return McpResponseBuilder.BuildErrorResult(toolName, "TimeoutError", "The delete operation timed out.", logger); } catch (Exception ex) { @@ -289,6 +286,7 @@ public async Task ExecuteAsync( { string keyDetails = McpJsonHelper.FormatKeyDetails(keys); return McpResponseBuilder.BuildErrorResult( + toolName, "RecordNotFound", $"No entity found with the given key {keyDetails}.", logger); @@ -325,18 +323,18 @@ public async Task ExecuteAsync( } catch (OperationCanceledException) { - return McpResponseBuilder.BuildErrorResult("OperationCanceled", "The delete operation was canceled.", logger); + return McpResponseBuilder.BuildErrorResult(toolName, "OperationCanceled", "The delete operation was canceled.", logger); } catch (ArgumentException argEx) { - return McpResponseBuilder.BuildErrorResult("InvalidArguments", argEx.Message, logger); + return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", argEx.Message, logger); } catch (Exception ex) { - ILogger? innerLogger = serviceProvider.GetService>(); - innerLogger?.LogError(ex, "Unexpected error in DeleteRecordTool."); + logger?.LogError(ex, "Unexpected error in DeleteRecordTool."); return McpResponseBuilder.BuildErrorResult( + toolName, "UnexpectedError", "An unexpected error occurred during the delete operation.", logger); diff --git a/src/Azure.DataApiBuilder.Mcp/BuiltInTools/DescribeEntitiesTool.cs b/src/Azure.DataApiBuilder.Mcp/BuiltInTools/DescribeEntitiesTool.cs index 154b37ee80..cd2a7cc28b 100644 --- a/src/Azure.DataApiBuilder.Mcp/BuiltInTools/DescribeEntitiesTool.cs +++ b/src/Azure.DataApiBuilder.Mcp/BuiltInTools/DescribeEntitiesTool.cs @@ -67,6 +67,7 @@ public Task ExecuteAsync( CancellationToken cancellationToken = default) { ILogger? logger = serviceProvider.GetService>(); + string toolName = GetToolMetadata().Name; try { @@ -77,10 +78,7 @@ public Task ExecuteAsync( if (!IsToolEnabled(runtimeConfig)) { - return Task.FromResult(McpResponseBuilder.BuildErrorResult( - "ToolDisabled", - $"The {GetToolMetadata().Name} tool is disabled in the configuration.", - logger)); + return Task.FromResult(McpErrorHelpers.ToolDisabled(GetToolMetadata().Name, logger)); } // Get authorization services to determine current user's role @@ -158,6 +156,7 @@ public Task ExecuteAsync( if (entityFilter != null && entityFilter.Count > 0) { return Task.FromResult(McpResponseBuilder.BuildErrorResult( + toolName, "EntitiesNotFound", $"No entities found matching the filter: {string.Join(", ", entityFilter)}", logger)); @@ -165,6 +164,7 @@ public Task ExecuteAsync( else { return Task.FromResult(McpResponseBuilder.BuildErrorResult( + toolName, "NoEntitiesConfigured", "No entities are configured in the runtime configuration.", logger)); @@ -197,6 +197,7 @@ public Task ExecuteAsync( catch (OperationCanceledException) { return Task.FromResult(McpResponseBuilder.BuildErrorResult( + toolName, "OperationCanceled", "The describe operation was canceled.", logger)); @@ -205,6 +206,7 @@ public Task ExecuteAsync( { logger?.LogError(dabEx, "Data API Builder error in DescribeEntitiesTool"); return Task.FromResult(McpResponseBuilder.BuildErrorResult( + toolName, "DataApiBuilderError", dabEx.Message, logger)); @@ -212,6 +214,7 @@ public Task ExecuteAsync( catch (ArgumentException argEx) { return Task.FromResult(McpResponseBuilder.BuildErrorResult( + toolName, "InvalidArguments", argEx.Message, logger)); @@ -220,6 +223,7 @@ public Task ExecuteAsync( { logger?.LogError(ioEx, "Invalid operation in DescribeEntitiesTool"); return Task.FromResult(McpResponseBuilder.BuildErrorResult( + toolName, "InvalidOperation", "Failed to retrieve entity metadata: " + ioEx.Message, logger)); @@ -228,6 +232,7 @@ public Task ExecuteAsync( { logger?.LogError(ex, "Unexpected error in DescribeEntitiesTool"); return Task.FromResult(McpResponseBuilder.BuildErrorResult( + toolName, "UnexpectedError", "An unexpected error occurred while describing entities.", logger)); diff --git a/src/Azure.DataApiBuilder.Mcp/BuiltInTools/ExecuteEntityTool.cs b/src/Azure.DataApiBuilder.Mcp/BuiltInTools/ExecuteEntityTool.cs index be2fa7af36..e780c8ddeb 100644 --- a/src/Azure.DataApiBuilder.Mcp/BuiltInTools/ExecuteEntityTool.cs +++ b/src/Azure.DataApiBuilder.Mcp/BuiltInTools/ExecuteEntityTool.cs @@ -73,6 +73,7 @@ public async Task ExecuteAsync( CancellationToken cancellationToken = default) { ILogger? logger = serviceProvider.GetService>(); + string toolName = GetToolMetadata().Name; try { @@ -86,27 +87,24 @@ public async Task ExecuteAsync( // 2) Check if the tool is enabled in configuration before proceeding if (config.McpDmlTools?.ExecuteEntity != true) { - return McpResponseBuilder.BuildErrorResult( - "ToolDisabled", - $"The {this.GetToolMetadata().Name} tool is disabled in the configuration.", - logger); + return McpErrorHelpers.ToolDisabled(this.GetToolMetadata().Name, logger); } // 3) Parsing & basic argument validation if (arguments is null) { - return McpResponseBuilder.BuildErrorResult("InvalidArguments", "No arguments provided.", logger); + return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", "No arguments provided.", logger); } - if (!TryParseExecuteArguments(arguments.RootElement, out string entity, out Dictionary parameters, out string parseError)) + if (!McpArgumentParser.TryParseExecuteArguments(arguments.RootElement, out string entity, out Dictionary parameters, out string parseError)) { - return McpResponseBuilder.BuildErrorResult("InvalidArguments", parseError, logger); + return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", parseError, logger); } // Entity is required if (string.IsNullOrWhiteSpace(entity)) { - return McpResponseBuilder.BuildErrorResult("InvalidArguments", "Entity is required", logger); + return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", "Entity is required", logger); } IMetadataProviderFactory metadataProviderFactory = serviceProvider.GetRequiredService(); @@ -115,31 +113,25 @@ public async Task ExecuteAsync( // 4) Validate entity exists and is a stored procedure if (!config.Entities.TryGetValue(entity, out Entity? entityConfig)) { - return McpResponseBuilder.BuildErrorResult("EntityNotFound", $"Entity '{entity}' not found in configuration.", logger); + return McpResponseBuilder.BuildErrorResult(toolName, "EntityNotFound", $"Entity '{entity}' not found in configuration.", logger); } if (entityConfig.Source.Type != EntitySourceType.StoredProcedure) { - return McpResponseBuilder.BuildErrorResult("InvalidEntity", $"Entity {entity} cannot be executed.", logger); + return McpResponseBuilder.BuildErrorResult(toolName, "InvalidEntity", $"Entity {entity} cannot be executed.", logger); } - // 5) Resolve metadata - string dataSourceName; - ISqlMetadataProvider sqlMetadataProvider; - - try + // Use shared metadata helper. + if (!McpMetadataHelper.TryResolveMetadata( + entity, + config, + serviceProvider, + out ISqlMetadataProvider sqlMetadataProvider, + out DatabaseObject dbObject, + out string dataSourceName, + out string metadataError)) { - dataSourceName = config.GetDataSourceNameFromEntityName(entity); - sqlMetadataProvider = metadataProviderFactory.GetMetadataProvider(dataSourceName); - } - catch (Exception) - { - return McpResponseBuilder.BuildErrorResult("EntityNotFound", $"Failed to resolve entity metadata for '{entity}'.", logger); - } - - if (!sqlMetadataProvider.EntityToDatabaseObject.TryGetValue(entity, out DatabaseObject? dbObject) || dbObject is null) - { - return McpResponseBuilder.BuildErrorResult("EntityNotFound", $"Failed to resolve database object for entity '{entity}'.", logger); + return McpResponseBuilder.BuildErrorResult(toolName, "EntityNotFound", metadataError, logger); } // 6) Authorization - Never bypass permissions @@ -149,7 +141,7 @@ public async Task ExecuteAsync( if (!McpAuthorizationHelper.ValidateRoleContext(httpContext, authResolver, out string roleError)) { - return McpResponseBuilder.BuildErrorResult("PermissionDenied", roleError, logger); + return McpErrorHelpers.PermissionDenied(toolName, entity, "execute", roleError, logger); } if (!McpAuthorizationHelper.TryResolveAuthorizedRole( @@ -160,7 +152,7 @@ public async Task ExecuteAsync( out string? effectiveRole, out string authError)) { - return McpResponseBuilder.BuildErrorResult("PermissionDenied", authError, logger); + return McpErrorHelpers.PermissionDenied(toolName, entity, "execute", authError, logger); } // 7) Validate parameters against metadata @@ -171,7 +163,7 @@ public async Task ExecuteAsync( { if (!entityConfig.Source.Parameters.Any(p => p.Name == param.Key)) { - return McpResponseBuilder.BuildErrorResult("InvalidArguments", $"Invalid parameter: {param.Key}", logger); + return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", $"Invalid parameter: {param.Key}", logger); } } } @@ -241,6 +233,7 @@ public async Task ExecuteAsync( message.Contains("authorization", StringComparison.OrdinalIgnoreCase)) { return McpResponseBuilder.BuildErrorResult( + toolName, "PermissionDenied", "You do not have permission to execute this stored procedure.", logger); @@ -249,6 +242,7 @@ public async Task ExecuteAsync( message.Contains("type", StringComparison.OrdinalIgnoreCase)) { return McpResponseBuilder.BuildErrorResult( + toolName, "InvalidArguments", "Invalid data type for one or more parameters.", logger); @@ -256,6 +250,7 @@ public async Task ExecuteAsync( // For any other DAB exceptions, return the message as-is return McpResponseBuilder.BuildErrorResult( + toolName, "DataApiBuilderError", dabEx.Message, logger); @@ -273,96 +268,55 @@ public async Task ExecuteAsync( 229 or 262 => $"Permission denied to execute stored procedure '{entityConfig.Source.Object}'.", _ => $"Database error: {sqlEx.Message}" }; - return McpResponseBuilder.BuildErrorResult("DatabaseError", errorMessage, logger); + return McpResponseBuilder.BuildErrorResult(toolName, "DatabaseError", errorMessage, logger); } catch (DbException dbEx) { // Handle generic database exceptions (works for PostgreSQL, MySQL, etc.) logger?.LogError(dbEx, "Database error executing stored procedure {StoredProcedure}", entity); - return McpResponseBuilder.BuildErrorResult("DatabaseError", $"Database error: {dbEx.Message}", logger); + return McpResponseBuilder.BuildErrorResult(toolName, "DatabaseError", $"Database error: {dbEx.Message}", logger); } catch (InvalidOperationException ioEx) when (ioEx.Message.Contains("connection", StringComparison.OrdinalIgnoreCase)) { // Handle connection-related issues logger?.LogError(ioEx, "Database connection error"); - return McpResponseBuilder.BuildErrorResult("ConnectionError", "Failed to connect to the database.", logger); + return McpResponseBuilder.BuildErrorResult(toolName, "ConnectionError", "Failed to connect to the database.", logger); } catch (TimeoutException timeoutEx) { // Handle query timeout logger?.LogError(timeoutEx, "Stored procedure execution timeout for {StoredProcedure}", entity); - return McpResponseBuilder.BuildErrorResult("TimeoutError", "The stored procedure execution timed out.", logger); + return McpResponseBuilder.BuildErrorResult(toolName, "TimeoutError", "The stored procedure execution timed out.", logger); } catch (Exception ex) { // Generic database/execution errors logger?.LogError(ex, "Unexpected error executing stored procedure {StoredProcedure}", entity); - return McpResponseBuilder.BuildErrorResult("DatabaseError", "An error occurred while executing the stored procedure.", logger); + return McpResponseBuilder.BuildErrorResult(toolName, "DatabaseError", "An error occurred while executing the stored procedure.", logger); } // 11) Build response with execution result - return BuildExecuteSuccessResponse(entity, parameters, queryResult, logger); + return BuildExecuteSuccessResponse(toolName, entity, parameters, queryResult, logger); } catch (OperationCanceledException) { - return McpResponseBuilder.BuildErrorResult("OperationCanceled", "The execute operation was canceled.", logger); + return McpResponseBuilder.BuildErrorResult(toolName, "OperationCanceled", "The execute operation was canceled.", logger); } catch (ArgumentException argEx) { - return McpResponseBuilder.BuildErrorResult("InvalidArguments", argEx.Message, logger); + return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", argEx.Message, logger); } catch (Exception ex) { logger?.LogError(ex, "Unexpected error in ExecuteEntityTool."); return McpResponseBuilder.BuildErrorResult( + toolName, "UnexpectedError", "An unexpected error occurred during the execute operation.", logger); } } - /// - /// Parses the execute arguments from the JSON input. - /// - private static bool TryParseExecuteArguments( - JsonElement rootElement, - out string entity, - out Dictionary parameters, - out string parseError) - { - entity = string.Empty; - parameters = new Dictionary(); - parseError = string.Empty; - - if (rootElement.ValueKind != JsonValueKind.Object) - { - parseError = "Arguments must be an object"; - return false; - } - - // Extract entity name (required) - if (!rootElement.TryGetProperty("entity", out JsonElement entityElement) || - entityElement.ValueKind != JsonValueKind.String) - { - parseError = "Missing or invalid 'entity' parameter"; - return false; - } - - entity = entityElement.GetString() ?? string.Empty; - - // Extract parameters if provided (optional) - if (rootElement.TryGetProperty("parameters", out JsonElement parametersElement) && - parametersElement.ValueKind == JsonValueKind.Object) - { - foreach (JsonProperty property in parametersElement.EnumerateObject()) - { - parameters[property.Name] = GetParameterValue(property.Value); - } - } - - return true; - } - /// /// Converts a JSON element to its appropriate CLR type matching GraphQL data types. /// @@ -386,6 +340,7 @@ private static bool TryParseExecuteArguments( /// Builds a successful response for the execute operation. /// private static CallToolResult BuildExecuteSuccessResponse( + string toolName, string entityName, Dictionary? parameters, IActionResult? queryResult, @@ -426,16 +381,14 @@ private static CallToolResult BuildExecuteSuccessResponse( else if (queryResult is BadRequestObjectResult badRequest) { return McpResponseBuilder.BuildErrorResult( + toolName, "BadRequest", badRequest.Value?.ToString() ?? "Bad request", logger); } else if (queryResult is UnauthorizedObjectResult) { - return McpResponseBuilder.BuildErrorResult( - "PermissionDenied", - "You do not have permission to execute this entity", - logger); + return McpErrorHelpers.PermissionDenied(toolName, entityName, "execute", "You do not have permission to execute this entity", logger); } else { diff --git a/src/Azure.DataApiBuilder.Mcp/BuiltInTools/ReadRecordsTool.cs b/src/Azure.DataApiBuilder.Mcp/BuiltInTools/ReadRecordsTool.cs index 42b1f41ea0..1ed91c30a8 100644 --- a/src/Azure.DataApiBuilder.Mcp/BuiltInTools/ReadRecordsTool.cs +++ b/src/Azure.DataApiBuilder.Mcp/BuiltInTools/ReadRecordsTool.cs @@ -15,6 +15,7 @@ using Azure.DataApiBuilder.Core.Services; using Azure.DataApiBuilder.Core.Services.MetadataProviders; using Azure.DataApiBuilder.Mcp.Model; +using Azure.DataApiBuilder.Mcp.Utils; using Azure.DataApiBuilder.Service.Exceptions; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Http; @@ -78,6 +79,7 @@ public async Task ExecuteAsync( CancellationToken cancellationToken = default) { ILogger? logger = serviceProvider.GetService>(); + string toolName = GetToolMetadata().Name; // Get runtime config RuntimeConfigProvider runtimeConfigProvider = serviceProvider.GetRequiredService(); @@ -85,10 +87,7 @@ public async Task ExecuteAsync( if (runtimeConfig.McpDmlTools?.ReadRecords is not true) { - return BuildErrorResult( - "ToolDisabled", - "The read_records tool is disabled in the configuration.", - logger); + return McpErrorHelpers.ToolDisabled(toolName, logger); } try @@ -105,18 +104,16 @@ public async Task ExecuteAsync( // Extract arguments if (arguments == null) { - return BuildErrorResult("InvalidArguments", "No arguments provided.", logger); + return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", "No arguments provided.", logger); } JsonElement root = arguments.RootElement; - if (!root.TryGetProperty("entity", out JsonElement entityElement) || string.IsNullOrWhiteSpace(entityElement.GetString())) + if (!McpArgumentParser.TryParseEntity(root, out entityName, out string parseError)) { - return BuildErrorResult("InvalidArguments", "Missing required argument 'entity'.", logger); + return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", parseError, logger); } - entityName = entityElement.GetString()!; - if (root.TryGetProperty("select", out JsonElement selectElement)) { select = selectElement.GetString(); @@ -142,27 +139,16 @@ public async Task ExecuteAsync( after = afterElement.GetString(); } - // Get required services & configuration - IQueryEngineFactory queryEngineFactory = serviceProvider.GetRequiredService(); - IMetadataProviderFactory metadataProviderFactory = serviceProvider.GetRequiredService(); - - // Check metadata for entity exists - string dataSourceName; - ISqlMetadataProvider sqlMetadataProvider; - - try + if (!McpMetadataHelper.TryResolveMetadata( + entityName, + runtimeConfig, + serviceProvider, + out ISqlMetadataProvider sqlMetadataProvider, + out DatabaseObject dbObject, + out string dataSourceName, + out string metadataError)) { - dataSourceName = runtimeConfig.GetDataSourceNameFromEntityName(entityName); - sqlMetadataProvider = metadataProviderFactory.GetMetadataProvider(dataSourceName); - } - catch (Exception) - { - return BuildErrorResult("EntityNotFound", $"Entity '{entityName}' is not defined in the configuration.", logger); - } - - if (!sqlMetadataProvider.EntityToDatabaseObject.TryGetValue(entityName, out DatabaseObject? dbObject) || dbObject is null) - { - return BuildErrorResult("EntityNotFound", $"Entity '{entityName}' is not defined in the configuration.", logger); + return McpResponseBuilder.BuildErrorResult(toolName, "EntityNotFound", metadataError, logger); } // Authorization check in the existing entity @@ -171,20 +157,29 @@ public async Task ExecuteAsync( IHttpContextAccessor httpContextAccessor = serviceProvider.GetRequiredService(); HttpContext? httpContext = httpContextAccessor.HttpContext; - if (httpContext is null || !authResolver.IsValidRoleContext(httpContext)) + if (!McpAuthorizationHelper.ValidateRoleContext(httpContext, authResolver, out string roleCtxError)) { - return BuildErrorResult("PermissionDenied", $"You do not have permission to read records for entity '{entityName}'.", logger); + return McpErrorHelpers.PermissionDenied(toolName, entityName, "read", roleCtxError, logger); } - if (!TryResolveAuthorizedRole(httpContext, authResolver, entityName, out string? effectiveRole, out string authError)) + if (!McpAuthorizationHelper.TryResolveAuthorizedRole( + httpContext!, + authResolver, + entityName, + EntityActionOperation.Read, + out string? effectiveRole, + out string readAuthError)) { - return BuildErrorResult("PermissionDenied", authError, logger); + string finalError = readAuthError.StartsWith("You do not have permission", StringComparison.OrdinalIgnoreCase) + ? $"You do not have permission to read records for entity '{entityName}'." + : readAuthError; + return McpErrorHelpers.PermissionDenied(toolName, entityName, "read", finalError, logger); } // Build and validate Find context - RequestValidator requestValidator = new(metadataProviderFactory, runtimeConfigProvider); + RequestValidator requestValidator = new(serviceProvider.GetRequiredService(), runtimeConfigProvider); FindRequestContext context = new(entityName, dbObject, true); - httpContext.Request.Method = "GET"; + httpContext!.Request.Method = "GET"; requestValidator.ValidateEntity(entityName); @@ -208,7 +203,7 @@ public async Task ExecuteAsync( { if (string.IsNullOrWhiteSpace(param)) { - return BuildErrorResult("InvalidArguments", "Parameters inside 'orderby' argument cannot be empty or null.", logger); + return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", "Parameters inside 'orderby' argument cannot be empty or null.", logger); } sortQueryString += $"{param}, "; @@ -230,193 +225,53 @@ public async Task ExecuteAsync( requirements: new[] { new ColumnsPermissionsRequirement() }); if (!authorizationResult.Succeeded) { - return BuildErrorResult("PermissionDenied", DataApiBuilderException.AUTHORIZATION_FAILURE, logger); + return McpErrorHelpers.PermissionDenied(toolName, entityName, "read", DataApiBuilderException.AUTHORIZATION_FAILURE, logger); } // Execute + IQueryEngineFactory queryEngineFactory = serviceProvider.GetRequiredService(); IQueryEngine queryEngine = queryEngineFactory.GetQueryEngine(sqlMetadataProvider.GetDatabaseType()); JsonDocument? queryResult = await queryEngine.ExecuteAsync(context); - IActionResult actionResult = queryResult is null ? SqlResponseHelpers.FormatFindResult(JsonDocument.Parse("[]").RootElement.Clone(), context, metadataProviderFactory.GetMetadataProvider(dataSourceName), runtimeConfigProvider.GetConfig(), httpContext, true) - : SqlResponseHelpers.FormatFindResult(queryResult.RootElement.Clone(), context, metadataProviderFactory.GetMetadataProvider(dataSourceName), runtimeConfigProvider.GetConfig(), httpContext, true); + IMetadataProviderFactory metadataProviderFactory = serviceProvider.GetRequiredService(); + IActionResult actionResult = queryResult is null + ? SqlResponseHelpers.FormatFindResult(JsonDocument.Parse("[]").RootElement.Clone(), context, sqlMetadataProvider, runtimeConfig, httpContext, true) + : SqlResponseHelpers.FormatFindResult(queryResult.RootElement.Clone(), context, sqlMetadataProvider, runtimeConfig, httpContext, true); // Normalize response - string rawPayloadJson = ExtractResultJson(actionResult); - JsonDocument result = JsonDocument.Parse(rawPayloadJson); + string rawPayloadJson = McpResponseBuilder.ExtractResultJson(actionResult); + using JsonDocument result = JsonDocument.Parse(rawPayloadJson); JsonElement queryRoot = result.RootElement; - return BuildSuccessResult( - entityName, - queryRoot.Clone(), - logger); + return McpResponseBuilder.BuildSuccessResult( + new Dictionary + { + ["entity"] = entityName, + ["result"] = queryRoot.Clone(), + ["message"] = $"Successfully read records for entity '{entityName}'" + }, + logger, + $"ReadRecordsTool success for entity {entityName}."); } catch (OperationCanceledException) { - return BuildErrorResult("OperationCanceled", "The read operation was canceled.", logger); + return McpResponseBuilder.BuildErrorResult(toolName, "OperationCanceled", "The read operation was canceled.", logger); } catch (DbException argEx) { - return BuildErrorResult("DatabaseOperationFailed", argEx.Message, logger); + return McpResponseBuilder.BuildErrorResult(toolName, "DatabaseOperationFailed", argEx.Message, logger); } catch (ArgumentException argEx) { - return BuildErrorResult("InvalidArguments", argEx.Message, logger); + return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", argEx.Message, logger); } catch (DataApiBuilderException argEx) { - return BuildErrorResult(argEx.StatusCode.ToString(), argEx.Message, logger); - } - catch (Exception) - { - return BuildErrorResult("UnexpectedError", "Unexpected error occurred in ReadRecordsTool.", logger); - } - } - - /// - /// Ensures that the role used on the request has the necessary authorizations. - /// - /// Contains request headers and metadata of the user. - /// Resolver used to check if role has necessary authorizations. - /// Name of the entity used in the request. - /// Role defined in client role header. - /// Error message given to the user. - /// True if the user role is authorized, along with the role. - private static bool TryResolveAuthorizedRole( - HttpContext httpContext, - IAuthorizationResolver authorizationResolver, - string entityName, - out string? effectiveRole, - out string error) - { - effectiveRole = null; - error = string.Empty; - - string roleHeader = httpContext.Request.Headers[AuthorizationResolver.CLIENT_ROLE_HEADER].ToString(); - - if (string.IsNullOrWhiteSpace(roleHeader)) - { - error = $"Client role header '{AuthorizationResolver.CLIENT_ROLE_HEADER}' is missing or empty."; - return false; - } - - string[] roles = roleHeader - .Split(',', StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries) - .Distinct(StringComparer.OrdinalIgnoreCase) - .ToArray(); - - if (roles.Length == 0) - { - error = $"Client role header '{AuthorizationResolver.CLIENT_ROLE_HEADER}' is missing or empty."; - return false; + return McpResponseBuilder.BuildErrorResult(toolName, argEx.StatusCode.ToString(), argEx.Message, logger); } - - foreach (string role in roles) - { - bool allowed = authorizationResolver.AreRoleAndOperationDefinedForEntity( - entityName, role, EntityActionOperation.Read); - - if (allowed) - { - effectiveRole = role; - return true; - } - } - - error = $"You do not have permission to read records for entity '{entityName}'."; - return false; - } - - /// - /// Returns a result from the query in the case that it was successfully ran. - /// - /// Name of the entity used in the request. - /// Query result from engine. - /// MCP logger that returns all logged events. - private static CallToolResult BuildSuccessResult( - string entityName, - JsonElement engineRootElement, - ILogger? logger) - { - // Build normalized response - Dictionary normalized = new() - { - ["status"] = "success", - ["result"] = engineRootElement // only requested values - }; - - string output = JsonSerializer.Serialize(normalized, new JsonSerializerOptions { WriteIndented = true }); - - logger?.LogInformation("ReadRecordsTool success for entity {Entity}.", entityName); - - return new CallToolResult - { - Content = new List - { - new TextContentBlock { Type = "text", Text = output } - } - }; - } - - /// - /// Returns an error if the query failed to run at any point. - /// - /// Type of error that is encountered. - /// Error message given to the user. - /// MCP logger that returns all logged events. - private static CallToolResult BuildErrorResult( - string errorType, - string message, - ILogger? logger) - { - Dictionary errorObj = new() - { - ["status"] = "error", - ["error"] = new Dictionary - { - ["type"] = errorType, - ["message"] = message - } - }; - - string output = JsonSerializer.Serialize(errorObj); - - logger?.LogError("ReadRecordsTool error {ErrorType}: {Message}", errorType, message); - - return new CallToolResult - { - Content = - [ - new TextContentBlock { Type = "text", Text = output } - ], - IsError = true - }; - } - - /// - /// Extracts a JSON string from a typical IActionResult. - /// Falls back to "{}" for unsupported/empty cases to avoid leaking internals. - /// - private static string ExtractResultJson(IActionResult? result) - { - switch (result) + catch (Exception ex) { - case ObjectResult obj: - if (obj.Value is JsonElement je) - { - return je.GetRawText(); - } - - if (obj.Value is JsonDocument jd) - { - return jd.RootElement.GetRawText(); - } - - return JsonSerializer.Serialize(obj.Value ?? new object()); - - case ContentResult content: - return string.IsNullOrWhiteSpace(content.Content) ? "{}" : content.Content; - - default: - return "{}"; + logger?.LogError(ex, "Unexpected error in ReadRecordsTool."); + return McpResponseBuilder.BuildErrorResult(toolName, "UnexpectedError", "Unexpected error occurred in ReadRecordsTool.", logger); } } } diff --git a/src/Azure.DataApiBuilder.Mcp/BuiltInTools/UpdateRecordTool.cs b/src/Azure.DataApiBuilder.Mcp/BuiltInTools/UpdateRecordTool.cs index 9e7d101fe6..195e27a0cd 100644 --- a/src/Azure.DataApiBuilder.Mcp/BuiltInTools/UpdateRecordTool.cs +++ b/src/Azure.DataApiBuilder.Mcp/BuiltInTools/UpdateRecordTool.cs @@ -5,7 +5,6 @@ using Azure.DataApiBuilder.Auth; using Azure.DataApiBuilder.Config.DatabasePrimitives; using Azure.DataApiBuilder.Config.ObjectModel; -using Azure.DataApiBuilder.Core.Authorization; using Azure.DataApiBuilder.Core.Configurations; using Azure.DataApiBuilder.Core.Models; using Azure.DataApiBuilder.Core.Resolvers; @@ -13,6 +12,7 @@ using Azure.DataApiBuilder.Core.Services; using Azure.DataApiBuilder.Core.Services.MetadataProviders; using Azure.DataApiBuilder.Mcp.Model; +using Azure.DataApiBuilder.Mcp.Utils; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Mvc; using Microsoft.Extensions.DependencyInjection; @@ -83,8 +83,7 @@ public async Task ExecuteAsync( CancellationToken cancellationToken = default) { ILogger? logger = serviceProvider.GetService>(); - - // 1) Resolve required services & configuration + string toolName = GetToolMetadata().Name; RuntimeConfigProvider runtimeConfigProvider = serviceProvider.GetRequiredService(); RuntimeConfig config = runtimeConfigProvider.GetConfig(); @@ -92,10 +91,7 @@ public async Task ExecuteAsync( // 2)Check if the tool is enabled in configuration before proceeding. if (config.McpDmlTools?.UpdateRecord != true) { - return BuildErrorResult( - "ToolDisabled", - "The update_record tool is disabled in the configuration.", - logger); + return McpErrorHelpers.ToolDisabled(GetToolMetadata().Name, logger); } try @@ -106,34 +102,32 @@ public async Task ExecuteAsync( // 3) Parsing & basic argument validation (entity, keys, fields) if (arguments is null) { - return BuildErrorResult("InvalidArguments", "No arguments provided.", logger); + return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", "No arguments provided.", logger); } - if (!TryParseArguments(arguments.RootElement, out string entityName, out Dictionary keys, out Dictionary fields, out string parseError)) + if (!McpArgumentParser.TryParseEntityKeysAndFields( + arguments.RootElement, + out string entityName, + out Dictionary keys, + out Dictionary fields, + out string parseError)) { - return BuildErrorResult("InvalidArguments", parseError, logger); + return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", parseError, logger); } IMetadataProviderFactory metadataProviderFactory = serviceProvider.GetRequiredService(); IMutationEngineFactory mutationEngineFactory = serviceProvider.GetRequiredService(); - // 4) Resolve metadata for entity existence check - string dataSourceName; - ISqlMetadataProvider sqlMetadataProvider; - - try - { - dataSourceName = config.GetDataSourceNameFromEntityName(entityName); - sqlMetadataProvider = metadataProviderFactory.GetMetadataProvider(dataSourceName); - } - catch (Exception) - { - return BuildErrorResult("EntityNotFound", $"Entity '{entityName}' is not defined in the configuration.", logger); - } - - if (!sqlMetadataProvider.EntityToDatabaseObject.TryGetValue(entityName, out DatabaseObject? dbObject) || dbObject is null) + if (!McpMetadataHelper.TryResolveMetadata( + entityName, + config, + serviceProvider, + out ISqlMetadataProvider sqlMetadataProvider, + out DatabaseObject dbObject, + out string dataSourceName, + out string metadataError)) { - return BuildErrorResult("EntityNotFound", $"Entity '{entityName}' is not defined in the configuration.", logger); + return McpResponseBuilder.BuildErrorResult(toolName, "EntityNotFound", metadataError, logger); } // 5) Authorization after we have a known entity @@ -143,12 +137,18 @@ public async Task ExecuteAsync( if (httpContext is null || !authResolver.IsValidRoleContext(httpContext)) { - return BuildErrorResult("PermissionDenied", "Permission denied: unable to resolve a valid role context for update operation.", logger); + return McpErrorHelpers.PermissionDenied(toolName, entityName, "update", "unable to resolve a valid role context for update operation.", logger); } - if (!TryResolveAuthorizedRoleHasPermission(httpContext, authResolver, entityName, out string? effectiveRole, out string authError)) + if (!McpAuthorizationHelper.TryResolveAuthorizedRole( + httpContext!, + authResolver, + entityName, + EntityActionOperation.Update, + out string? effectiveRole, + out string authError)) { - return BuildErrorResult("PermissionDenied", $"Permission denied: {authError}", logger); + return McpErrorHelpers.PermissionDenied(toolName, entityName, "update", authError, logger); } // 6) Build and validate Upsert (UpdateIncremental) context @@ -165,7 +165,7 @@ public async Task ExecuteAsync( { if (kvp.Value is null) { - return BuildErrorResult("InvalidArguments", $"Primary key value for '{kvp.Key}' cannot be null.", logger); + return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", $"Primary key value for '{kvp.Key}' cannot be null.", logger); } context.PrimaryKeyValuePairs[kvp.Key] = kvp.Value; @@ -193,10 +193,7 @@ public async Task ExecuteAsync( if (errorMsg.Contains("No Update could be performed, record not found", StringComparison.OrdinalIgnoreCase)) { - return BuildErrorResult( - "InvalidArguments", - "No record found with the given key.", - logger); + return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", "No record found with the given key.", logger); } else { @@ -208,265 +205,51 @@ public async Task ExecuteAsync( cancellationToken.ThrowIfCancellationRequested(); // 8) Normalize response (success or engine error payload) - string rawPayloadJson = ExtractResultJson(mutationResult); + string rawPayloadJson = McpResponseBuilder.ExtractResultJson(mutationResult); using JsonDocument resultDoc = JsonDocument.Parse(rawPayloadJson); JsonElement root = resultDoc.RootElement; - return BuildSuccessResult( - entityName: entityName, - engineRootElement: root.Clone(), - logger: logger); + // Extract first item of value[] array (updated record) + Dictionary filteredResult = new(); + if (root.TryGetProperty("value", out JsonElement valueArray) && + valueArray.ValueKind == JsonValueKind.Array && + valueArray.GetArrayLength() > 0) + { + JsonElement firstItem = valueArray[0]; + foreach (JsonProperty prop in firstItem.EnumerateObject()) + { + filteredResult[prop.Name] = McpResponseBuilder.GetJsonValue(prop.Value); + } + } + + return McpResponseBuilder.BuildSuccessResult( + new Dictionary + { + ["entity"] = entityName, + ["result"] = filteredResult, + ["message"] = $"Successfully updated record in entity '{entityName}'" + }, + logger, + $"UpdateRecordTool success for entity {entityName}."); } catch (OperationCanceledException) { - return BuildErrorResult("OperationCanceled", "The update operation was canceled.", logger); + return McpResponseBuilder.BuildErrorResult(toolName, "OperationCanceled", "The update operation was canceled.", logger); } catch (ArgumentException argEx) { - return BuildErrorResult("InvalidArguments", argEx.Message, logger); + return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", argEx.Message, logger); } catch (Exception ex) { - ILogger? innerLogger = serviceProvider.GetService>(); - innerLogger?.LogError(ex, "Unexpected error in UpdateRecordTool."); + logger?.LogError(ex, "Unexpected error in UpdateRecordTool."); - return BuildErrorResult( + return McpResponseBuilder.BuildErrorResult( + toolName, "UnexpectedError", ex.Message ?? "An unexpected error occurred during the update operation.", logger); } } - - #region Parsing & Authorization - - private static bool TryParseArguments( - JsonElement root, - out string entityName, - out Dictionary keys, - out Dictionary fields, - out string error) - { - entityName = string.Empty; - keys = new Dictionary(); - fields = new Dictionary(); - error = string.Empty; - - if (!root.TryGetProperty("entity", out JsonElement entityEl) || - !root.TryGetProperty("keys", out JsonElement keysEl) || - !root.TryGetProperty("fields", out JsonElement fieldsEl)) - { - error = "Missing required arguments 'entity', 'keys', or 'fields'."; - return false; - } - - // Parse and validate required arguments: entity, keys, fields - entityName = entityEl.GetString() ?? string.Empty; - if (string.IsNullOrWhiteSpace(entityName)) - { - throw new ArgumentException("Entity is required", nameof(entityName)); - } - - if (keysEl.ValueKind != JsonValueKind.Object || fieldsEl.ValueKind != JsonValueKind.Object) - { - throw new ArgumentException("'keys' and 'fields' must be JSON objects."); - } - - try - { - keys = JsonSerializer.Deserialize>(keysEl.GetRawText()) ?? new Dictionary(); - fields = JsonSerializer.Deserialize>(fieldsEl.GetRawText()) ?? new Dictionary(); - } - catch (Exception ex) - { - throw new ArgumentException("Failed to parse 'keys' or 'fields'", ex); - } - - if (keys.Count == 0) - { - throw new ArgumentException("Keys are required to update an entity"); - } - - if (fields.Count == 0) - { - throw new ArgumentException("At least one field must be provided to update an entity", nameof(fields)); - } - - foreach (KeyValuePair kv in keys) - { - if (kv.Value is null || (kv.Value is string str && string.IsNullOrWhiteSpace(str))) - { - throw new ArgumentException($"Key value for '{kv.Key}' cannot be null or empty."); - } - } - - return true; - } - - private static bool TryResolveAuthorizedRoleHasPermission( - HttpContext httpContext, - IAuthorizationResolver authorizationResolver, - string entityName, - out string? effectiveRole, - out string error) - { - effectiveRole = null; - error = string.Empty; - - string roleHeader = httpContext.Request.Headers[AuthorizationResolver.CLIENT_ROLE_HEADER].ToString(); - - if (string.IsNullOrWhiteSpace(roleHeader)) - { - error = "Client role header is missing or empty."; - return false; - } - - string[] roles = roleHeader - .Split(',', StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries) - .Distinct(StringComparer.OrdinalIgnoreCase) - .ToArray(); - - if (roles.Length == 0) - { - error = "Client role header is missing or empty."; - return false; - } - - foreach (string role in roles) - { - bool allowed = authorizationResolver.AreRoleAndOperationDefinedForEntity( - entityName, role, EntityActionOperation.Update); - - if (allowed) - { - effectiveRole = role; - return true; - } - } - - error = "You do not have permission to update records for this entity."; - return false; - } - - #endregion - - #region Response Builders & Utilities - - private static CallToolResult BuildSuccessResult( - string entityName, - JsonElement engineRootElement, - ILogger? logger) - { - // Extract only requested keys and updated fields from engineRootElement - Dictionary filteredResult = new(); - - // Navigate to "value" array in the engine result - if (engineRootElement.TryGetProperty("value", out JsonElement valueArray) && - valueArray.ValueKind == JsonValueKind.Array && - valueArray.GetArrayLength() > 0) - { - JsonElement firstItem = valueArray[0]; - - // Include all properties from the result - foreach (JsonProperty prop in firstItem.EnumerateObject()) - { - filteredResult[prop.Name] = GetJsonValue(prop.Value); - } - } - - // Build normalized response - Dictionary normalized = new() - { - ["status"] = "success", - ["result"] = filteredResult - }; - - string output = JsonSerializer.Serialize(normalized, new JsonSerializerOptions { WriteIndented = true }); - - logger?.LogInformation("UpdateRecordTool success for entity {Entity}.", entityName); - - return new CallToolResult - { - Content = new List - { - new TextContentBlock { Type = "text", Text = output } - } - }; - } - - /// - /// Converts JsonElement to .NET object dynamically. - /// - private static object? GetJsonValue(JsonElement element) - { - return element.ValueKind switch - { - JsonValueKind.String => element.GetString(), - JsonValueKind.Number => element.TryGetInt64(out long l) ? l : element.GetDouble(), - JsonValueKind.True => true, - JsonValueKind.False => false, - JsonValueKind.Null => null, - _ => element.GetRawText() // fallback for arrays/objects - }; - } - - private static CallToolResult BuildErrorResult( - string errorType, - string message, - ILogger? logger) - { - Dictionary errorObj = new() - { - ["status"] = "error", - ["error"] = new Dictionary - { - ["type"] = errorType, - ["message"] = message - } - }; - - string output = JsonSerializer.Serialize(errorObj); - - logger?.LogWarning("UpdateRecordTool error {ErrorType}: {Message}", errorType, message); - - return new CallToolResult - { - Content = - [ - new TextContentBlock { Type = "text", Text = output } - ], - IsError = true - }; - } - - /// - /// Extracts a JSON string from a typical IActionResult. - /// Falls back to "{}" for unsupported/empty cases to avoid leaking internals. - /// - private static string ExtractResultJson(IActionResult? result) - { - switch (result) - { - case ObjectResult obj: - if (obj.Value is JsonElement je) - { - return je.GetRawText(); - } - - if (obj.Value is JsonDocument jd) - { - return jd.RootElement.GetRawText(); - } - - return JsonSerializer.Serialize(obj.Value ?? new object()); - - case ContentResult content: - return string.IsNullOrWhiteSpace(content.Content) ? "{}" : content.Content; - - default: - return "{}"; - } - } - - #endregion } } diff --git a/src/Azure.DataApiBuilder.Mcp/Model/McpErrorCode.cs b/src/Azure.DataApiBuilder.Mcp/Model/McpErrorCode.cs new file mode 100644 index 0000000000..ed13f62783 --- /dev/null +++ b/src/Azure.DataApiBuilder.Mcp/Model/McpErrorCode.cs @@ -0,0 +1,14 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +namespace Azure.DataApiBuilder.Mcp.Model +{ + /// + /// MCP error codes standardized for built-in tools. + /// + public enum McpErrorCode + { + ToolDisabled, + PermissionDenied + } +} diff --git a/src/Azure.DataApiBuilder.Mcp/Utils/McpArgumentParser.cs b/src/Azure.DataApiBuilder.Mcp/Utils/McpArgumentParser.cs index 04d14eb5d6..02344c2956 100644 --- a/src/Azure.DataApiBuilder.Mcp/Utils/McpArgumentParser.cs +++ b/src/Azure.DataApiBuilder.Mcp/Utils/McpArgumentParser.cs @@ -11,26 +11,24 @@ namespace Azure.DataApiBuilder.Mcp.Utils public static class McpArgumentParser { /// - /// Parses entity and keys arguments for delete/update operations. + /// Parses only the entity name from arguments. /// - public static bool TryParseEntityAndKeys( + public static bool TryParseEntity( JsonElement root, out string entityName, - out Dictionary keys, - out string error) + out string error, + CancellationToken cancellationToken = default) { + cancellationToken.ThrowIfCancellationRequested(); entityName = string.Empty; - keys = new Dictionary(); error = string.Empty; - if (!root.TryGetProperty("entity", out JsonElement entityEl) || - !root.TryGetProperty("keys", out JsonElement keysEl)) + if (!root.TryGetProperty("entity", out JsonElement entityEl)) { - error = "Missing required arguments 'entity' or 'keys'."; + error = "Missing required argument 'entity'."; return false; } - // Parse and validate entity name entityName = entityEl.GetString() ?? string.Empty; if (string.IsNullOrWhiteSpace(entityName)) { @@ -38,7 +36,65 @@ public static bool TryParseEntityAndKeys( return false; } - // Parse and validate keys + return true; + } + + /// + /// Parses entity and data arguments for create operations. + /// + public static bool TryParseEntityAndData( + JsonElement root, + out string entityName, + out JsonElement dataElement, + out string error, + CancellationToken cancellationToken = default) + { + cancellationToken.ThrowIfCancellationRequested(); + dataElement = default; + + if (!TryParseEntity(root, out entityName, out error, cancellationToken)) + { + return false; + } + + if (!root.TryGetProperty("data", out dataElement)) + { + error = "Missing required argument 'data'."; + return false; + } + + if (dataElement.ValueKind != JsonValueKind.Object) + { + error = "'data' must be a JSON object."; + return false; + } + + return true; + } + + /// + /// Parses entity and keys arguments for delete/update operations. + /// + public static bool TryParseEntityAndKeys( + JsonElement root, + out string entityName, + out Dictionary keys, + out string error, + CancellationToken cancellationToken = default) + { + cancellationToken.ThrowIfCancellationRequested(); + keys = new Dictionary(); + if (!TryParseEntity(root, out entityName, out error, cancellationToken)) + { + return false; + } + + if (!root.TryGetProperty("keys", out JsonElement keysEl)) + { + error = "Missing required argument 'keys'."; + return false; + } + if (keysEl.ValueKind != JsonValueKind.Object) { error = "'keys' must be a JSON object."; @@ -64,6 +120,8 @@ public static bool TryParseEntityAndKeys( // Validate key values foreach (KeyValuePair kv in keys) { + cancellationToken.ThrowIfCancellationRequested(); + if (kv.Value is null || (kv.Value is string str && string.IsNullOrWhiteSpace(str))) { error = $"Primary key value for '{kv.Key}' cannot be null or empty"; @@ -82,12 +140,14 @@ public static bool TryParseEntityKeysAndFields( out string entityName, out Dictionary keys, out Dictionary fields, - out string error) + out string error, + CancellationToken cancellationToken = default) { + cancellationToken.ThrowIfCancellationRequested(); fields = new Dictionary(); // First parse entity and keys - if (!TryParseEntityAndKeys(root, out entityName, out keys, out error)) + if (!TryParseEntityAndKeys(root, out entityName, out keys, out error, cancellationToken)) { return false; } @@ -123,5 +183,61 @@ public static bool TryParseEntityKeysAndFields( return true; } + + /// + /// Parses the execute arguments from the JSON input. + /// + public static bool TryParseExecuteArguments( + JsonElement rootElement, + out string entity, + out Dictionary parameters, + out string parseError, + CancellationToken cancellationToken = default) + { + cancellationToken.ThrowIfCancellationRequested(); + entity = string.Empty; + parameters = new Dictionary(); + + if (rootElement.ValueKind != JsonValueKind.Object) + { + parseError = "Arguments must be an object"; + return false; + } + + if (!TryParseEntity(rootElement, out entity, out parseError, cancellationToken)) + { + return false; + } + + // Extract parameters if provided (optional) + if (rootElement.TryGetProperty("parameters", out JsonElement parametersElement) && + parametersElement.ValueKind == JsonValueKind.Object) + { + foreach (JsonProperty property in parametersElement.EnumerateObject()) + { + cancellationToken.ThrowIfCancellationRequested(); + parameters[property.Name] = GetExecuteParameterValue(property.Value); + } + } + + return true; + } + + // Local helper replicating ExecuteEntityTool.GetParameterValue without refactoring other tools. + private static object? GetExecuteParameterValue(JsonElement element) + { + return element.ValueKind switch + { + JsonValueKind.String => element.GetString(), + JsonValueKind.Number => + element.TryGetInt64(out long longValue) ? longValue : + element.TryGetDecimal(out decimal decimalValue) ? decimalValue : + element.GetDouble(), + JsonValueKind.True => true, + JsonValueKind.False => false, + JsonValueKind.Null => null, + _ => element.ToString() + }; + } } } diff --git a/src/Azure.DataApiBuilder.Mcp/Utils/McpErrorHelpers.cs b/src/Azure.DataApiBuilder.Mcp/Utils/McpErrorHelpers.cs new file mode 100644 index 0000000000..75335b2db1 --- /dev/null +++ b/src/Azure.DataApiBuilder.Mcp/Utils/McpErrorHelpers.cs @@ -0,0 +1,28 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using Microsoft.Extensions.Logging; +using ModelContextProtocol.Protocol; + +namespace Azure.DataApiBuilder.Mcp.Utils +{ + /// + /// Helper utilities for creating standardized MCP error responses. + /// Only includes helpers currently being centralized. + /// + public static class McpErrorHelpers + { + public static CallToolResult PermissionDenied(string toolName, string entityName, string operation, string detail, ILogger? logger) + { + string message = $"Permission denied for {operation} on entity '{entityName}'. {detail}"; + return McpResponseBuilder.BuildErrorResult(toolName, Model.McpErrorCode.PermissionDenied.ToString(), message, logger); + } + + // Centralized language for 'tool disabled' errors. Pass the tool name, e.g. "read_records". + public static CallToolResult ToolDisabled(string toolName, ILogger? logger) + { + string message = $"The {toolName} tool is disabled in the configuration."; + return McpResponseBuilder.BuildErrorResult(toolName, Model.McpErrorCode.ToolDisabled.ToString(), message, logger); + } + } +} diff --git a/src/Azure.DataApiBuilder.Mcp/Utils/McpMetadataHelper.cs b/src/Azure.DataApiBuilder.Mcp/Utils/McpMetadataHelper.cs new file mode 100644 index 0000000000..d92117dba1 --- /dev/null +++ b/src/Azure.DataApiBuilder.Mcp/Utils/McpMetadataHelper.cs @@ -0,0 +1,90 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using Azure.DataApiBuilder.Config.DatabasePrimitives; +using Azure.DataApiBuilder.Config.ObjectModel; +using Azure.DataApiBuilder.Core.Services.MetadataProviders; +using Azure.DataApiBuilder.Service.Exceptions; // Added for DataApiBuilderException +using Microsoft.Extensions.DependencyInjection; + +namespace Azure.DataApiBuilder.Mcp.Utils +{ + /// + /// Utility class for resolving metadata and datasource information for MCP tools. + /// + public static class McpMetadataHelper + { + public static bool TryResolveMetadata( + string entityName, + RuntimeConfig config, + IServiceProvider serviceProvider, + out Azure.DataApiBuilder.Core.Services.ISqlMetadataProvider sqlMetadataProvider, + out DatabaseObject dbObject, + out string dataSourceName, + out string error, + CancellationToken cancellationToken = default) + { + cancellationToken.ThrowIfCancellationRequested(); + sqlMetadataProvider = default!; + dbObject = default!; + dataSourceName = string.Empty; + error = string.Empty; + + if (string.IsNullOrWhiteSpace(entityName)) + { + error = "Entity name cannot be null or empty."; + return false; + } + + IMetadataProviderFactory metadataProviderFactory = serviceProvider.GetRequiredService(); + + // Resolve datasource name for the entity. + try + { + cancellationToken.ThrowIfCancellationRequested(); + dataSourceName = config.GetDataSourceNameFromEntityName(entityName); + } + catch (DataApiBuilderException dabEx) when (dabEx.SubStatusCode == DataApiBuilderException.SubStatusCodes.EntityNotFound) + { + error = $"Entity '{entityName}' is not defined in the configuration."; + return false; + } + catch (DataApiBuilderException dabEx) + { + // Other DAB exceptions during entity->datasource resolution. + error = dabEx.Message; + return false; + } + + // Resolve metadata provider for the datasource. + try + { + cancellationToken.ThrowIfCancellationRequested(); + sqlMetadataProvider = metadataProviderFactory.GetMetadataProvider(dataSourceName); + } + catch (DataApiBuilderException dabEx) when (dabEx.SubStatusCode == DataApiBuilderException.SubStatusCodes.DataSourceNotFound) + { + error = $"Data source '{dataSourceName}' for entity '{entityName}' is not defined in the configuration."; + return false; + } + catch (DataApiBuilderException dabEx) + { + // Other DAB exceptions during metadata provider resolution. + error = dabEx.Message; + return false; + } + + cancellationToken.ThrowIfCancellationRequested(); + + // Validate entity exists in metadata mapping. + if (!sqlMetadataProvider.EntityToDatabaseObject.TryGetValue(entityName, out DatabaseObject? temp) || temp is null) + { + error = $"Entity '{entityName}' is not defined in the configuration."; + return false; + } + + dbObject = temp; + return true; + } + } +} diff --git a/src/Azure.DataApiBuilder.Mcp/Utils/McpResponseBuilder.cs b/src/Azure.DataApiBuilder.Mcp/Utils/McpResponseBuilder.cs index afbccbda38..49cacef2c3 100644 --- a/src/Azure.DataApiBuilder.Mcp/Utils/McpResponseBuilder.cs +++ b/src/Azure.DataApiBuilder.Mcp/Utils/McpResponseBuilder.cs @@ -43,12 +43,14 @@ public static CallToolResult BuildSuccessResult( /// Builds an error response for MCP tools. /// public static CallToolResult BuildErrorResult( + string toolName, string errorType, string message, ILogger? logger = null) { Dictionary errorObj = new() { + ["toolName"] = toolName, ["status"] = "error", ["error"] = new Dictionary { @@ -99,5 +101,21 @@ public static string ExtractResultJson(IActionResult? result) return "{}"; } } + + /// + /// Extracts value from a JsonElement. + /// + public static object? GetJsonValue(JsonElement element) + { + return element.ValueKind switch + { + JsonValueKind.String => element.GetString(), + JsonValueKind.Number => element.TryGetInt64(out long l) ? l : element.GetDouble(), + JsonValueKind.True => true, + JsonValueKind.False => false, + JsonValueKind.Null => null, + _ => element.GetRawText() + }; + } } } diff --git a/src/Cli.Tests/EndToEndTests.cs b/src/Cli.Tests/EndToEndTests.cs index 7fe017501f..5dbf97ca5e 100644 --- a/src/Cli.Tests/EndToEndTests.cs +++ b/src/Cli.Tests/EndToEndTests.cs @@ -116,10 +116,11 @@ public void TestInitializingRestAndGraphQLGlobalSettings() string[] args = { "init", "-c", TEST_RUNTIME_CONFIG_FILE, "--connection-string", SAMPLE_TEST_CONN_STRING, "--database-type", "mssql", "--rest.path", "/rest-api", "--rest.enabled", "false", "--graphql.path", "/graphql-api" }; Program.Execute(args, _cliLogger!, _fileSystem!, _runtimeConfigLoader!); + DeserializationVariableReplacementSettings replacementSettings = new(azureKeyVaultOptions: null, doReplaceEnvVar: true, doReplaceAkvVar: true); Assert.IsTrue(_runtimeConfigLoader!.TryLoadConfig( TEST_RUNTIME_CONFIG_FILE, out RuntimeConfig? runtimeConfig, - replaceEnvVar: true)); + replacementSettings: replacementSettings)); SqlConnectionStringBuilder builder = new(runtimeConfig.DataSource.ConnectionString); Assert.AreEqual(ProductInfo.GetDataApiBuilderUserAgent(), builder.ApplicationName); @@ -195,10 +196,11 @@ public void TestEnablingMultipleCreateOperation(CliBool isMultipleCreateEnabled, Program.Execute(args.ToArray(), _cliLogger!, _fileSystem!, _runtimeConfigLoader!); + DeserializationVariableReplacementSettings replacementSettings = new(azureKeyVaultOptions: null, doReplaceEnvVar: true, doReplaceAkvVar: true); Assert.IsTrue(_runtimeConfigLoader!.TryLoadConfig( TEST_RUNTIME_CONFIG_FILE, out RuntimeConfig? runtimeConfig, - replaceEnvVar: true)); + replacementSettings: replacementSettings)); Assert.IsNotNull(runtimeConfig); Assert.AreEqual(expectedDbType, runtimeConfig.DataSource.DatabaseType); diff --git a/src/Cli.Tests/EnvironmentTests.cs b/src/Cli.Tests/EnvironmentTests.cs index 151d5babb2..2d6378cf74 100644 --- a/src/Cli.Tests/EnvironmentTests.cs +++ b/src/Cli.Tests/EnvironmentTests.cs @@ -19,7 +19,13 @@ public class EnvironmentTests [TestInitialize] public void TestInitialize() { - StringJsonConverterFactory converterFactory = new(EnvironmentVariableReplacementFailureMode.Throw); + DeserializationVariableReplacementSettings replacementSettings = new( + azureKeyVaultOptions: null, + doReplaceEnvVar: true, + doReplaceAkvVar: false, + envFailureMode: EnvironmentVariableReplacementFailureMode.Throw); + + StringJsonConverterFactory converterFactory = new(replacementSettings); _options = new() { PropertyNameCaseInsensitive = true diff --git a/src/Cli/ConfigGenerator.cs b/src/Cli/ConfigGenerator.cs index 9a56f83c4a..7c35335089 100644 --- a/src/Cli/ConfigGenerator.cs +++ b/src/Cli/ConfigGenerator.cs @@ -2700,9 +2700,10 @@ private static bool TryUpdateConfiguredAzureKeyVaultOptions( // Azure Key Vault Endpoint if (options.AzureKeyVaultEndpoint is not null) { + // Ensure endpoint flag is marked user provided so converter writes it. updatedAzureKeyVaultOptions = updatedAzureKeyVaultOptions is not null - ? updatedAzureKeyVaultOptions with { Endpoint = options.AzureKeyVaultEndpoint } - : new AzureKeyVaultOptions { Endpoint = options.AzureKeyVaultEndpoint }; + ? updatedAzureKeyVaultOptions with { Endpoint = options.AzureKeyVaultEndpoint, UserProvidedEndpoint = true } + : new AzureKeyVaultOptions(endpoint: options.AzureKeyVaultEndpoint); _logger.LogInformation("Updated RuntimeConfig with azure-key-vault.endpoint as '{endpoint}'", options.AzureKeyVaultEndpoint); } @@ -2711,7 +2712,7 @@ private static bool TryUpdateConfiguredAzureKeyVaultOptions( { updatedRetryPolicyOptions = updatedRetryPolicyOptions is not null ? updatedRetryPolicyOptions with { Mode = options.AzureKeyVaultRetryPolicyMode.Value, UserProvidedMode = true } - : new AKVRetryPolicyOptions { Mode = options.AzureKeyVaultRetryPolicyMode.Value, UserProvidedMode = true }; + : new AKVRetryPolicyOptions(mode: options.AzureKeyVaultRetryPolicyMode.Value); _logger.LogInformation("Updated RuntimeConfig with azure-key-vault.retry-policy.mode as '{mode}'", options.AzureKeyVaultRetryPolicyMode.Value); } @@ -2726,7 +2727,7 @@ private static bool TryUpdateConfiguredAzureKeyVaultOptions( updatedRetryPolicyOptions = updatedRetryPolicyOptions is not null ? updatedRetryPolicyOptions with { MaxCount = options.AzureKeyVaultRetryPolicyMaxCount.Value, UserProvidedMaxCount = true } - : new AKVRetryPolicyOptions { MaxCount = options.AzureKeyVaultRetryPolicyMaxCount.Value, UserProvidedMaxCount = true }; + : new AKVRetryPolicyOptions(maxCount: options.AzureKeyVaultRetryPolicyMaxCount.Value); _logger.LogInformation("Updated RuntimeConfig with azure-key-vault.retry-policy.max-count as '{maxCount}'", options.AzureKeyVaultRetryPolicyMaxCount.Value); } @@ -2741,7 +2742,7 @@ private static bool TryUpdateConfiguredAzureKeyVaultOptions( updatedRetryPolicyOptions = updatedRetryPolicyOptions is not null ? updatedRetryPolicyOptions with { DelaySeconds = options.AzureKeyVaultRetryPolicyDelaySeconds.Value, UserProvidedDelaySeconds = true } - : new AKVRetryPolicyOptions { DelaySeconds = options.AzureKeyVaultRetryPolicyDelaySeconds.Value, UserProvidedDelaySeconds = true }; + : new AKVRetryPolicyOptions(delaySeconds: options.AzureKeyVaultRetryPolicyDelaySeconds.Value); _logger.LogInformation("Updated RuntimeConfig with azure-key-vault.retry-policy.delay-seconds as '{delaySeconds}'", options.AzureKeyVaultRetryPolicyDelaySeconds.Value); } @@ -2756,7 +2757,7 @@ private static bool TryUpdateConfiguredAzureKeyVaultOptions( updatedRetryPolicyOptions = updatedRetryPolicyOptions is not null ? updatedRetryPolicyOptions with { MaxDelaySeconds = options.AzureKeyVaultRetryPolicyMaxDelaySeconds.Value, UserProvidedMaxDelaySeconds = true } - : new AKVRetryPolicyOptions { MaxDelaySeconds = options.AzureKeyVaultRetryPolicyMaxDelaySeconds.Value, UserProvidedMaxDelaySeconds = true }; + : new AKVRetryPolicyOptions(maxDelaySeconds: options.AzureKeyVaultRetryPolicyMaxDelaySeconds.Value); _logger.LogInformation("Updated RuntimeConfig with azure-key-vault.retry-policy.max-delay-seconds as '{maxDelaySeconds}'", options.AzureKeyVaultRetryPolicyMaxDelaySeconds.Value); } @@ -2771,16 +2772,17 @@ private static bool TryUpdateConfiguredAzureKeyVaultOptions( updatedRetryPolicyOptions = updatedRetryPolicyOptions is not null ? updatedRetryPolicyOptions with { NetworkTimeoutSeconds = options.AzureKeyVaultRetryPolicyNetworkTimeoutSeconds.Value, UserProvidedNetworkTimeoutSeconds = true } - : new AKVRetryPolicyOptions { NetworkTimeoutSeconds = options.AzureKeyVaultRetryPolicyNetworkTimeoutSeconds.Value, UserProvidedNetworkTimeoutSeconds = true }; + : new AKVRetryPolicyOptions(networkTimeoutSeconds: options.AzureKeyVaultRetryPolicyNetworkTimeoutSeconds.Value); _logger.LogInformation("Updated RuntimeConfig with azure-key-vault.retry-policy.network-timeout-seconds as '{networkTimeoutSeconds}'", options.AzureKeyVaultRetryPolicyNetworkTimeoutSeconds.Value); } - // Update Azure Key Vault options with retry policy if retry policy was modified + // Update Azure Key Vault options with retry policy if modified if (updatedRetryPolicyOptions is not null) { + // Ensure outer AKV object marks retry policy as user provided so it serializes. updatedAzureKeyVaultOptions = updatedAzureKeyVaultOptions is not null - ? updatedAzureKeyVaultOptions with { RetryPolicy = updatedRetryPolicyOptions } - : new AzureKeyVaultOptions { RetryPolicy = updatedRetryPolicyOptions }; + ? updatedAzureKeyVaultOptions with { RetryPolicy = updatedRetryPolicyOptions, UserProvidedRetryPolicy = true } + : new AzureKeyVaultOptions(retryPolicy: updatedRetryPolicyOptions); } // Update runtime config if Azure Key Vault options were modified diff --git a/src/Cli/Exporter.cs b/src/Cli/Exporter.cs index d4f103e868..896b485692 100644 --- a/src/Cli/Exporter.cs +++ b/src/Cli/Exporter.cs @@ -44,7 +44,8 @@ public static bool Export(ExportOptions options, ILogger logger, FileSystemRunti } // Load the runtime configuration from the file - if (!loader.TryLoadConfig(runtimeConfigFile, out RuntimeConfig? runtimeConfig, replaceEnvVar: true)) + DeserializationVariableReplacementSettings replacementSettings = new(azureKeyVaultOptions: null, doReplaceEnvVar: true, doReplaceAkvVar: true); + if (!loader.TryLoadConfig(runtimeConfigFile, out RuntimeConfig? runtimeConfig, replacementSettings: replacementSettings)) { logger.LogError("Failed to read the config file: {0}.", runtimeConfigFile); return false; diff --git a/src/Config/Azure.DataApiBuilder.Config.csproj b/src/Config/Azure.DataApiBuilder.Config.csproj index a494bc38ae..6b5bdf0955 100644 --- a/src/Config/Azure.DataApiBuilder.Config.csproj +++ b/src/Config/Azure.DataApiBuilder.Config.csproj @@ -15,6 +15,7 @@ + diff --git a/src/Config/Converters/AKVRetryPolicyOptionsConverterFactory.cs b/src/Config/Converters/AKVRetryPolicyOptionsConverterFactory.cs index 06d00b64d3..553e43db53 100644 --- a/src/Config/Converters/AKVRetryPolicyOptionsConverterFactory.cs +++ b/src/Config/Converters/AKVRetryPolicyOptionsConverterFactory.cs @@ -12,9 +12,9 @@ namespace Azure.DataApiBuilder.Config.Converters; /// internal class AKVRetryPolicyOptionsConverterFactory : JsonConverterFactory { - // Determines whether to replace environment variable with its - // value or not while deserializing. - private bool _replaceEnvVar; + // Settings for variable replacement during deserialization. + // Currently allows for Azure Key Vault (via @akv('secret-name')) and Environment Variable replacement. + private readonly DeserializationVariableReplacementSettings? _replacementSettings; /// public override bool CanConvert(Type typeToConvert) @@ -25,34 +25,34 @@ public override bool CanConvert(Type typeToConvert) /// public override JsonConverter? CreateConverter(Type typeToConvert, JsonSerializerOptions options) { - return new AKVRetryPolicyOptionsConverter(_replaceEnvVar); + return new AKVRetryPolicyOptionsConverter(_replacementSettings); } - /// Whether to replace environment variable with its - /// value or not while deserializing. - internal AKVRetryPolicyOptionsConverterFactory(bool replaceEnvVar) + /// Settings for variable replacement during deserialization. + /// If null, no variable replacement will be performed. + internal AKVRetryPolicyOptionsConverterFactory(DeserializationVariableReplacementSettings? replacementSettings = null) { - _replaceEnvVar = replaceEnvVar; + _replacementSettings = replacementSettings; } private class AKVRetryPolicyOptionsConverter : JsonConverter { - // Determines whether to replace environment variable with its - // value or not while deserializing. - private bool _replaceEnvVar; + // Settings for variable replacement during deserialization. + // Currently allows for Azure Key Vault (via @akv('')) and Environment Variable replacement. + private readonly DeserializationVariableReplacementSettings? _replacementSettings; - /// Whether to replace environment variable with its - /// value or not while deserializing. - public AKVRetryPolicyOptionsConverter(bool replaceEnvVar) + /// Settings for variable replacement during deserialization. + /// If null, no variable replacement will be performed. + public AKVRetryPolicyOptionsConverter(DeserializationVariableReplacementSettings? replacementSettings) { - _replaceEnvVar = replaceEnvVar; + _replacementSettings = replacementSettings; } /// /// Defines how DAB reads AKV Retry Policy options and defines which values are /// used to instantiate those options. /// - /// Thrown when improperly formatted cache options are provided. + /// Thrown when improperly formatted retry policy options are provided. public override AKVRetryPolicyOptions? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) { if (reader.TokenType is JsonTokenType.StartObject) @@ -82,7 +82,7 @@ public AKVRetryPolicyOptionsConverter(bool replaceEnvVar) } else { - mode = EnumExtensions.Deserialize(reader.DeserializeString(_replaceEnvVar)!); + mode = EnumExtensions.Deserialize(reader.DeserializeString(_replacementSettings)!); } break; diff --git a/src/Config/Converters/AzureKeyVaultOptionsConverterFactory.cs b/src/Config/Converters/AzureKeyVaultOptionsConverterFactory.cs new file mode 100644 index 0000000000..92ed0c1a85 --- /dev/null +++ b/src/Config/Converters/AzureKeyVaultOptionsConverterFactory.cs @@ -0,0 +1,128 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Text.Json; +using System.Text.Json.Serialization; +using Azure.DataApiBuilder.Config.ObjectModel; + +namespace Azure.DataApiBuilder.Config.Converters; + +/// +/// Converter factory for AzureKeyVaultOptions that can optionally perform variable replacement. +/// +internal class AzureKeyVaultOptionsConverterFactory : JsonConverterFactory +{ + // Determines whether to replace environment variable with its + // value or not while deserializing. + private readonly DeserializationVariableReplacementSettings? _replacementSettings; + + /// How to handle variable replacement during deserialization. + internal AzureKeyVaultOptionsConverterFactory(DeserializationVariableReplacementSettings? replacementSettings = null) + { + _replacementSettings = replacementSettings; + } + + /// + public override bool CanConvert(Type typeToConvert) + { + return typeToConvert.IsAssignableTo(typeof(AzureKeyVaultOptions)); + } + + /// + public override JsonConverter? CreateConverter(Type typeToConvert, JsonSerializerOptions options) + { + return new AzureKeyVaultOptionsConverter(_replacementSettings); + } + + private class AzureKeyVaultOptionsConverter : JsonConverter + { + // Determines whether to replace environment variable with its + // value or not while deserializing. + private readonly DeserializationVariableReplacementSettings? _replacementSettings; + + /// Whether to replace environment variable with its + /// value or not while deserializing. + public AzureKeyVaultOptionsConverter(DeserializationVariableReplacementSettings? replacementSettings) + { + _replacementSettings = replacementSettings; + } + + /// + /// Reads AzureKeyVaultOptions with optional variable replacement. + /// + public override AzureKeyVaultOptions? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + { + if (reader.TokenType is JsonTokenType.Null) + { + return null; + } + + if (reader.TokenType is JsonTokenType.StartObject) + { + string? endpoint = null; + AKVRetryPolicyOptions? retryPolicy = null; + + while (reader.Read()) + { + if (reader.TokenType is JsonTokenType.EndObject) + { + return new AzureKeyVaultOptions(endpoint, retryPolicy); + } + + string? property = reader.GetString(); + reader.Read(); + + switch (property) + { + case "endpoint": + if (reader.TokenType is JsonTokenType.String) + { + endpoint = reader.DeserializeString(_replacementSettings); + } + + break; + + case "retry-policy": + if (reader.TokenType is JsonTokenType.StartObject) + { + // Uses the AKVRetryPolicyOptionsConverter to read the retry-policy object. + retryPolicy = JsonSerializer.Deserialize(ref reader, options); + } + + break; + + default: + throw new JsonException($"Unexpected property {property}"); + } + } + } + + throw new JsonException("Invalid AzureKeyVaultOptions format"); + } + + /// + /// When writing the AzureKeyVaultOptions back to a JSON file, only write the properties + /// if they are user provided. This avoids polluting the written JSON file with properties + /// the user most likely omitted when writing the original DAB runtime config file. + /// This Write operation is only used when a RuntimeConfig object is serialized to JSON. + /// + public override void Write(Utf8JsonWriter writer, AzureKeyVaultOptions value, JsonSerializerOptions options) + { + writer.WriteStartObject(); + + if (value?.UserProvidedEndpoint is true) + { + writer.WritePropertyName("endpoint"); + JsonSerializer.Serialize(writer, value.Endpoint, options); + } + + if (value?.UserProvidedRetryPolicy is true) + { + writer.WritePropertyName("retry-policy"); + JsonSerializer.Serialize(writer, value.RetryPolicy, options); + } + + writer.WriteEndObject(); + } + } +} diff --git a/src/Config/Converters/AzureLogAnalyticsAuthOptionsConverter.cs b/src/Config/Converters/AzureLogAnalyticsAuthOptionsConverter.cs index 1428c0d75f..d4b7623aa2 100644 --- a/src/Config/Converters/AzureLogAnalyticsAuthOptionsConverter.cs +++ b/src/Config/Converters/AzureLogAnalyticsAuthOptionsConverter.cs @@ -9,15 +9,14 @@ namespace Azure.DataApiBuilder.Config.Converters; internal class AzureLogAnalyticsAuthOptionsConverter : JsonConverter { - // Determines whether to replace environment variable with its - // value or not while deserializing. - private bool _replaceEnvVar; + // Settings for variable replacement during deserialization. + private readonly DeserializationVariableReplacementSettings? _replacementSettings; - /// Whether to replace environment variable with its - /// value or not while deserializing. - public AzureLogAnalyticsAuthOptionsConverter(bool replaceEnvVar) + /// Settings for variable replacement during deserialization. + /// If null, no variable replacement will be performed. + public AzureLogAnalyticsAuthOptionsConverter(DeserializationVariableReplacementSettings? replacementSettings = null) { - _replaceEnvVar = replaceEnvVar; + _replacementSettings = replacementSettings; } /// @@ -48,7 +47,7 @@ public AzureLogAnalyticsAuthOptionsConverter(bool replaceEnvVar) case "custom-table-name": if (reader.TokenType is not JsonTokenType.Null) { - customTableName = reader.DeserializeString(_replaceEnvVar); + customTableName = reader.DeserializeString(_replacementSettings); } break; @@ -56,7 +55,7 @@ public AzureLogAnalyticsAuthOptionsConverter(bool replaceEnvVar) case "dcr-immutable-id": if (reader.TokenType is not JsonTokenType.Null) { - dcrImmutableId = reader.DeserializeString(_replaceEnvVar); + dcrImmutableId = reader.DeserializeString(_replacementSettings); } break; @@ -64,7 +63,7 @@ public AzureLogAnalyticsAuthOptionsConverter(bool replaceEnvVar) case "dce-endpoint": if (reader.TokenType is not JsonTokenType.Null) { - dceEndpoint = reader.DeserializeString(_replaceEnvVar); + dceEndpoint = reader.DeserializeString(_replacementSettings); } break; diff --git a/src/Config/Converters/AzureLogAnalyticsOptionsConverterFactory.cs b/src/Config/Converters/AzureLogAnalyticsOptionsConverterFactory.cs index 3fcbe8c7bd..fc7c72d655 100644 --- a/src/Config/Converters/AzureLogAnalyticsOptionsConverterFactory.cs +++ b/src/Config/Converters/AzureLogAnalyticsOptionsConverterFactory.cs @@ -12,9 +12,8 @@ namespace Azure.DataApiBuilder.Config.Converters; /// internal class AzureLogAnalyticsOptionsConverterFactory : JsonConverterFactory { - // Determines whether to replace environment variable with its - // value or not while deserializing. - private bool _replaceEnvVar; + // Settings for variable replacement during deserialization. + private readonly DeserializationVariableReplacementSettings? _replacementSettings; /// public override bool CanConvert(Type typeToConvert) @@ -25,27 +24,26 @@ public override bool CanConvert(Type typeToConvert) /// public override JsonConverter? CreateConverter(Type typeToConvert, JsonSerializerOptions options) { - return new AzureLogAnalyticsOptionsConverter(_replaceEnvVar); + return new AzureLogAnalyticsOptionsConverter(_replacementSettings); } - /// Whether to replace environment variable with its - /// value or not while deserializing. - internal AzureLogAnalyticsOptionsConverterFactory(bool replaceEnvVar) + /// Settings for variable replacement during deserialization. + /// If null, no variable replacement will be performed. + internal AzureLogAnalyticsOptionsConverterFactory(DeserializationVariableReplacementSettings? replacementSettings = null) { - _replaceEnvVar = replaceEnvVar; + _replacementSettings = replacementSettings; } private class AzureLogAnalyticsOptionsConverter : JsonConverter { - // Determines whether to replace environment variable with its - // value or not while deserializing. - private bool _replaceEnvVar; + // Settings for variable replacement during deserialization. + private readonly DeserializationVariableReplacementSettings? _replacementSettings; - /// Whether to replace environment variable with its - /// value or not while deserializing. - internal AzureLogAnalyticsOptionsConverter(bool replaceEnvVar) + /// Settings for variable replacement during deserialization. + /// If null, no variable replacement will be performed. + internal AzureLogAnalyticsOptionsConverter(DeserializationVariableReplacementSettings? replacementSettings) { - _replaceEnvVar = replaceEnvVar; + _replacementSettings = replacementSettings; } /// @@ -57,7 +55,7 @@ internal AzureLogAnalyticsOptionsConverter(bool replaceEnvVar) { if (reader.TokenType is JsonTokenType.StartObject) { - AzureLogAnalyticsAuthOptionsConverter authOptionsConverter = new(_replaceEnvVar); + AzureLogAnalyticsAuthOptionsConverter authOptionsConverter = new(_replacementSettings); bool? enabled = null; AzureLogAnalyticsAuthOptions? auth = null; @@ -91,7 +89,7 @@ internal AzureLogAnalyticsOptionsConverter(bool replaceEnvVar) case "dab-identifier": if (reader.TokenType is not JsonTokenType.Null) { - logType = reader.DeserializeString(_replaceEnvVar); + logType = reader.DeserializeString(_replacementSettings); } break; diff --git a/src/Config/Converters/DataSourceConverterFactory.cs b/src/Config/Converters/DataSourceConverterFactory.cs index dabbee405e..1788ebf2b4 100644 --- a/src/Config/Converters/DataSourceConverterFactory.cs +++ b/src/Config/Converters/DataSourceConverterFactory.cs @@ -9,9 +9,8 @@ namespace Azure.DataApiBuilder.Config.Converters; internal class DataSourceConverterFactory : JsonConverterFactory { - // Determines whether to replace environment variable with its - // value or not while deserializing. - private bool _replaceEnvVar; + // Settings for variable replacement during deserialization. + private readonly DeserializationVariableReplacementSettings? _replacementSettings; /// public override bool CanConvert(Type typeToConvert) @@ -22,27 +21,26 @@ public override bool CanConvert(Type typeToConvert) /// public override JsonConverter? CreateConverter(Type typeToConvert, JsonSerializerOptions options) { - return new DataSourceConverter(_replaceEnvVar); + return new DataSourceConverter(_replacementSettings); } - /// Whether to replace environment variable with its - /// value or not while deserializing. - internal DataSourceConverterFactory(bool replaceEnvVar) + /// Settings for variable replacement during deserialization. + /// If null, no variable replacement will be performed. + internal DataSourceConverterFactory(DeserializationVariableReplacementSettings? replacementSettings = null) { - _replaceEnvVar = replaceEnvVar; + _replacementSettings = replacementSettings; } private class DataSourceConverter : JsonConverter { - // Determines whether to replace environment variable with its - // value or not while deserializing. - private bool _replaceEnvVar; + // Settings for variable replacement during deserialization. + private readonly DeserializationVariableReplacementSettings? _replacementSettings; - /// Whether to replace environment variable with its - /// value or not while deserializing. - public DataSourceConverter(bool replaceEnvVar) + /// Settings for variable replacement during deserialization. + /// If null, no variable replacement will be performed. + public DataSourceConverter(DeserializationVariableReplacementSettings? replacementSettings) { - _replaceEnvVar = replaceEnvVar; + _replacementSettings = replacementSettings; } public override DataSource? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) @@ -69,11 +67,11 @@ public DataSourceConverter(bool replaceEnvVar) switch (propertyName) { case "database-type": - databaseType = EnumExtensions.Deserialize(reader.DeserializeString(_replaceEnvVar)!); + databaseType = EnumExtensions.Deserialize(reader.DeserializeString(_replacementSettings)!); break; case "connection-string": - connectionString = reader.DeserializeString(replaceEnvVar: _replaceEnvVar)!; + connectionString = reader.DeserializeString(_replacementSettings)!; break; case "health": @@ -106,7 +104,7 @@ public DataSourceConverter(bool replaceEnvVar) if (reader.TokenType is JsonTokenType.String) { // Determine whether to resolve the environment variable or keep as-is. - string stringValue = reader.DeserializeString(replaceEnvVar: _replaceEnvVar)!; + string stringValue = reader.DeserializeString(_replacementSettings)!; if (bool.TryParse(stringValue, out bool boolValue)) { diff --git a/src/Config/Converters/DatasourceHealthOptionsConvertorFactory.cs b/src/Config/Converters/DatasourceHealthOptionsConvertorFactory.cs index 52272c57a7..d8286ff7a0 100644 --- a/src/Config/Converters/DatasourceHealthOptionsConvertorFactory.cs +++ b/src/Config/Converters/DatasourceHealthOptionsConvertorFactory.cs @@ -11,7 +11,7 @@ internal class DataSourceHealthOptionsConvertorFactory : JsonConverterFactory { // Determines whether to replace environment variable with its // value or not while deserializing. - private bool _replaceEnvVar; + private readonly DeserializationVariableReplacementSettings? _replacementSettings; /// public override bool CanConvert(Type typeToConvert) @@ -22,27 +22,27 @@ public override bool CanConvert(Type typeToConvert) /// public override JsonConverter? CreateConverter(Type typeToConvert, JsonSerializerOptions options) { - return new HealthCheckOptionsConverter(_replaceEnvVar); + return new HealthCheckOptionsConverter(_replacementSettings); } /// Whether to replace environment variable with its /// value or not while deserializing. - internal DataSourceHealthOptionsConvertorFactory(bool replaceEnvVar) + internal DataSourceHealthOptionsConvertorFactory(DeserializationVariableReplacementSettings? replacementSettings) { - _replaceEnvVar = replaceEnvVar; + _replacementSettings = replacementSettings; } private class HealthCheckOptionsConverter : JsonConverter { // Determines whether to replace environment variable with its // value or not while deserializing. - private bool _replaceEnvVar; + private readonly DeserializationVariableReplacementSettings? _replacementSettings; /// Whether to replace environment variable with its /// value or not while deserializing. - public HealthCheckOptionsConverter(bool replaceEnvVar) + public HealthCheckOptionsConverter(DeserializationVariableReplacementSettings? replacementSettings) { - _replaceEnvVar = replaceEnvVar; + _replacementSettings = replacementSettings; } /// @@ -85,7 +85,7 @@ public HealthCheckOptionsConverter(bool replaceEnvVar) case "name": if (reader.TokenType is not JsonTokenType.Null) { - name = reader.DeserializeString(_replaceEnvVar); + name = reader.DeserializeString(_replacementSettings); } break; diff --git a/src/Config/Converters/EntityCacheOptionsConverterFactory.cs b/src/Config/Converters/EntityCacheOptionsConverterFactory.cs index 32a616ab81..641efd062f 100644 --- a/src/Config/Converters/EntityCacheOptionsConverterFactory.cs +++ b/src/Config/Converters/EntityCacheOptionsConverterFactory.cs @@ -14,7 +14,7 @@ internal class EntityCacheOptionsConverterFactory : JsonConverterFactory { // Determines whether to replace environment variable with its // value or not while deserializing. - private bool _replaceEnvVar; + private readonly DeserializationVariableReplacementSettings? _replacementSettings; /// public override bool CanConvert(Type typeToConvert) @@ -25,27 +25,25 @@ public override bool CanConvert(Type typeToConvert) /// public override JsonConverter? CreateConverter(Type typeToConvert, JsonSerializerOptions options) { - return new EntityCacheOptionsConverter(_replaceEnvVar); + return new EntityCacheOptionsConverter(_replacementSettings); } - /// Whether to replace environment variable with its - /// value or not while deserializing. - internal EntityCacheOptionsConverterFactory(bool replaceEnvVar) + /// The replacement settings to use while deserializing. + internal EntityCacheOptionsConverterFactory(DeserializationVariableReplacementSettings? replacementSettings) { - _replaceEnvVar = replaceEnvVar; + _replacementSettings = replacementSettings; } private class EntityCacheOptionsConverter : JsonConverter { // Determines whether to replace environment variable with its // value or not while deserializing. - private bool _replaceEnvVar; + private readonly DeserializationVariableReplacementSettings? _replacementSettings; - /// Whether to replace environment variable with its - /// value or not while deserializing. - public EntityCacheOptionsConverter(bool replaceEnvVar) + /// The replacement settings to use while deserializing. + public EntityCacheOptionsConverter(DeserializationVariableReplacementSettings? replacementSettings) { - _replaceEnvVar = replaceEnvVar; + _replacementSettings = replacementSettings; } /// @@ -110,7 +108,7 @@ public EntityCacheOptionsConverter(bool replaceEnvVar) throw new JsonException("level property cannot be null."); } - level = EnumExtensions.Deserialize(reader.DeserializeString(_replaceEnvVar)!); + level = EnumExtensions.Deserialize(reader.DeserializeString(_replacementSettings)!); break; } diff --git a/src/Config/Converters/EntityGraphQLOptionsConverterFactory.cs b/src/Config/Converters/EntityGraphQLOptionsConverterFactory.cs index 576850b1cb..abe094e970 100644 --- a/src/Config/Converters/EntityGraphQLOptionsConverterFactory.cs +++ b/src/Config/Converters/EntityGraphQLOptionsConverterFactory.cs @@ -9,9 +9,8 @@ namespace Azure.DataApiBuilder.Config.Converters; internal class EntityGraphQLOptionsConverterFactory : JsonConverterFactory { - /// Determines whether to replace environment variable with its - /// value or not while deserializing. - private bool _replaceEnvVar; + /// Settings for variable replacement during deserialization. + private readonly DeserializationVariableReplacementSettings? _replacementSettings; /// public override bool CanConvert(Type typeToConvert) @@ -22,27 +21,26 @@ public override bool CanConvert(Type typeToConvert) /// public override JsonConverter? CreateConverter(Type typeToConvert, JsonSerializerOptions options) { - return new EntityGraphQLOptionsConverter(_replaceEnvVar); + return new EntityGraphQLOptionsConverter(_replacementSettings); } - /// Whether to replace environment variable with its - /// value or not while deserializing. - internal EntityGraphQLOptionsConverterFactory(bool replaceEnvVar) + /// Settings for variable replacement during deserialization. + /// If null, no variable replacement will be performed. + internal EntityGraphQLOptionsConverterFactory(DeserializationVariableReplacementSettings? replacementSettings = null) { - _replaceEnvVar = replaceEnvVar; + _replacementSettings = replacementSettings; } private class EntityGraphQLOptionsConverter : JsonConverter { - // Determines whether to replace environment variable with its - // value or not while deserializing. - private bool _replaceEnvVar; + // Settings for variable replacement during deserialization. + private readonly DeserializationVariableReplacementSettings? _replacementSettings; - /// Whether to replace environment variable with its - /// value or not while deserializing. - public EntityGraphQLOptionsConverter(bool replaceEnvVar) + /// Settings for variable replacement during deserialization. + /// If null, no variable replacement will be performed. + public EntityGraphQLOptionsConverter(DeserializationVariableReplacementSettings? replacementSettings) { - _replaceEnvVar = replaceEnvVar; + _replacementSettings = replacementSettings; } /// @@ -73,7 +71,7 @@ public EntityGraphQLOptionsConverter(bool replaceEnvVar) case "type": if (reader.TokenType is JsonTokenType.String) { - singular = reader.DeserializeString(_replaceEnvVar) ?? string.Empty; + singular = reader.DeserializeString(_replacementSettings) ?? string.Empty; } else if (reader.TokenType is JsonTokenType.StartObject) { @@ -95,10 +93,10 @@ public EntityGraphQLOptionsConverter(bool replaceEnvVar) switch (property2) { case "singular": - singular = reader.DeserializeString(_replaceEnvVar) ?? string.Empty; + singular = reader.DeserializeString(_replacementSettings) ?? string.Empty; break; case "plural": - plural = reader.DeserializeString(_replaceEnvVar) ?? string.Empty; + plural = reader.DeserializeString(_replacementSettings) ?? string.Empty; break; } } @@ -112,7 +110,7 @@ public EntityGraphQLOptionsConverter(bool replaceEnvVar) break; case "operation": - string? op = reader.DeserializeString(_replaceEnvVar); + string? op = reader.DeserializeString(_replacementSettings); if (op is not null) { @@ -136,7 +134,7 @@ public EntityGraphQLOptionsConverter(bool replaceEnvVar) if (reader.TokenType is JsonTokenType.String) { - string? singular = reader.DeserializeString(_replaceEnvVar); + string? singular = reader.DeserializeString(_replacementSettings); return new EntityGraphQLOptions(singular ?? string.Empty, string.Empty); } diff --git a/src/Config/Converters/EntityRestOptionsConverterFactory.cs b/src/Config/Converters/EntityRestOptionsConverterFactory.cs index cc33943caa..f8c9096673 100644 --- a/src/Config/Converters/EntityRestOptionsConverterFactory.cs +++ b/src/Config/Converters/EntityRestOptionsConverterFactory.cs @@ -9,9 +9,8 @@ namespace Azure.DataApiBuilder.Config.Converters; internal class EntityRestOptionsConverterFactory : JsonConverterFactory { - /// Determines whether to replace environment variable with its - /// value or not while deserializing. - private bool _replaceEnvVar; + /// Settings for variable replacement during deserialization. + private readonly DeserializationVariableReplacementSettings? _replacementSettings; /// public override bool CanConvert(Type typeToConvert) @@ -22,27 +21,26 @@ public override bool CanConvert(Type typeToConvert) /// public override JsonConverter? CreateConverter(Type typeToConvert, JsonSerializerOptions options) { - return new EntityRestOptionsConverter(_replaceEnvVar); + return new EntityRestOptionsConverter(_replacementSettings); } - /// Whether to replace environment variable with its - /// value or not while deserializing. - internal EntityRestOptionsConverterFactory(bool replaceEnvVar) + /// Settings for variable replacement during deserialization. + /// If null, no variable replacement will be performed. + internal EntityRestOptionsConverterFactory(DeserializationVariableReplacementSettings? replacementSettings = null) { - _replaceEnvVar = replaceEnvVar; + _replacementSettings = replacementSettings; } internal class EntityRestOptionsConverter : JsonConverter { - // Determines whether to replace environment variable with its - // value or not while deserializing. - private bool _replaceEnvVar; + // Settings for variable replacement during deserialization. + private readonly DeserializationVariableReplacementSettings? _replacementSettings; - /// Whether to replace environment variable with its - /// value or not while deserializing. - public EntityRestOptionsConverter(bool replaceEnvVar) + /// Settings for variable replacement during deserialization. + /// If null, no variable replacement will be performed. + public EntityRestOptionsConverter(DeserializationVariableReplacementSettings? replacementSettings) { - _replaceEnvVar = replaceEnvVar; + _replacementSettings = replacementSettings; } /// @@ -67,7 +65,7 @@ public EntityRestOptionsConverter(bool replaceEnvVar) if (reader.TokenType is JsonTokenType.String || reader.TokenType is JsonTokenType.Null) { - restOptions = restOptions with { Path = reader.DeserializeString(_replaceEnvVar) }; + restOptions = restOptions with { Path = reader.DeserializeString(_replacementSettings) }; break; } @@ -87,7 +85,7 @@ public EntityRestOptionsConverter(bool replaceEnvVar) break; } - methods.Add(EnumExtensions.Deserialize(reader.DeserializeString(replaceEnvVar: true)!)); + methods.Add(EnumExtensions.Deserialize(reader.DeserializeString(new DeserializationVariableReplacementSettings())!)); } restOptions = restOptions with { Methods = methods.ToArray() }; @@ -107,7 +105,7 @@ public EntityRestOptionsConverter(bool replaceEnvVar) if (reader.TokenType is JsonTokenType.String) { - return new EntityRestOptions(Array.Empty(), reader.DeserializeString(_replaceEnvVar), true); + return new EntityRestOptions(Array.Empty(), reader.DeserializeString(_replacementSettings), true); } if (reader.TokenType is JsonTokenType.True || reader.TokenType is JsonTokenType.False) diff --git a/src/Config/Converters/EntitySourceConverterFactory.cs b/src/Config/Converters/EntitySourceConverterFactory.cs index a748382e01..2edafe31e1 100644 --- a/src/Config/Converters/EntitySourceConverterFactory.cs +++ b/src/Config/Converters/EntitySourceConverterFactory.cs @@ -9,9 +9,8 @@ namespace Azure.DataApiBuilder.Config.Converters; internal class EntitySourceConverterFactory : JsonConverterFactory { - // Determines whether to replace environment variable with its - // value or not while deserializing. - private bool _replaceEnvVar; + // Settings for variable replacement during deserialization. + private readonly DeserializationVariableReplacementSettings? _replacementSettings; /// public override bool CanConvert(Type typeToConvert) @@ -22,34 +21,33 @@ public override bool CanConvert(Type typeToConvert) /// public override JsonConverter? CreateConverter(Type typeToConvert, JsonSerializerOptions options) { - return new EntitySourceConverter(_replaceEnvVar); + return new EntitySourceConverter(_replacementSettings); } - /// Whether to replace environment variable with its - /// value or not while deserializing. - internal EntitySourceConverterFactory(bool replaceEnvVar) + /// Settings for variable replacement during deserialization. + /// If null, no variable replacement will be performed. + internal EntitySourceConverterFactory(DeserializationVariableReplacementSettings? replacementSettings = null) { - _replaceEnvVar = replaceEnvVar; + _replacementSettings = replacementSettings; } private class EntitySourceConverter : JsonConverter { - // Determines whether to replace environment variable with its - // value or not while deserializing. - private bool _replaceEnvVar; + // Settings for variable replacement during deserialization. + private readonly DeserializationVariableReplacementSettings? _replacementSettings; - /// Whether to replace environment variable with its - /// value or not while deserializing. - public EntitySourceConverter(bool replaceEnvVar) + /// Settings for variable replacement during deserialization. + /// If null, no variable replacement will be performed. + public EntitySourceConverter(DeserializationVariableReplacementSettings? replacementSettings) { - _replaceEnvVar = replaceEnvVar; + _replacementSettings = replacementSettings; } public override EntitySource? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) { if (reader.TokenType == JsonTokenType.String) { - string? obj = reader.DeserializeString(_replaceEnvVar); + string? obj = reader.DeserializeString(_replacementSettings); return new EntitySource(obj ?? string.Empty, EntitySourceType.Table, new(), Array.Empty()); } diff --git a/src/Config/Converters/EnumMemberJsonEnumConverterFactory.cs b/src/Config/Converters/EnumMemberJsonEnumConverterFactory.cs index 1d6dd9f7c4..4455a474e1 100644 --- a/src/Config/Converters/EnumMemberJsonEnumConverterFactory.cs +++ b/src/Config/Converters/EnumMemberJsonEnumConverterFactory.cs @@ -114,7 +114,7 @@ public JsonStringEnumConverterEx() public override TEnum Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) { // Always replace env variable in case of Enum otherwise string to enum conversion will fail. - string? stringValue = reader.DeserializeString(replaceEnvVar: true); + string? stringValue = reader.DeserializeString(new(doReplaceEnvVar: true)); if (stringValue == null) { diff --git a/src/Config/Converters/FileSinkConverter.cs b/src/Config/Converters/FileSinkConverter.cs index cc7d138a1b..4299fb913b 100644 --- a/src/Config/Converters/FileSinkConverter.cs +++ b/src/Config/Converters/FileSinkConverter.cs @@ -7,18 +7,17 @@ using Serilog; namespace Azure.DataApiBuilder.Config.Converters; + class FileSinkConverter : JsonConverter { - // Determines whether to replace environment variable with its - // value or not while deserializing. - private bool _replaceEnvVar; - - /// - /// Whether to replace environment variable with its value or not while deserializing. - /// - public FileSinkConverter(bool replaceEnvVar) + // Settings for variable replacement during deserialization. + private readonly DeserializationVariableReplacementSettings? _replacementSettings; + + /// Settings for variable replacement during deserialization. + /// If null, no variable replacement will be performed. + public FileSinkConverter(DeserializationVariableReplacementSettings? replacementSettings = null) { - _replaceEnvVar = replaceEnvVar; + _replacementSettings = replacementSettings; } /// @@ -59,7 +58,7 @@ public FileSinkConverter(bool replaceEnvVar) case "path": if (reader.TokenType is not JsonTokenType.Null) { - path = reader.DeserializeString(_replaceEnvVar); + path = reader.DeserializeString(_replacementSettings); } break; @@ -67,7 +66,7 @@ public FileSinkConverter(bool replaceEnvVar) case "rolling-interval": if (reader.TokenType is not JsonTokenType.Null) { - rollingInterval = EnumExtensions.Deserialize(reader.DeserializeString(_replaceEnvVar)!); + rollingInterval = EnumExtensions.Deserialize(reader.DeserializeString(_replacementSettings)!); } break; diff --git a/src/Config/Converters/GraphQLRuntimeOptionsConverterFactory.cs b/src/Config/Converters/GraphQLRuntimeOptionsConverterFactory.cs index 082c982e7e..109caef0d5 100644 --- a/src/Config/Converters/GraphQLRuntimeOptionsConverterFactory.cs +++ b/src/Config/Converters/GraphQLRuntimeOptionsConverterFactory.cs @@ -9,9 +9,8 @@ namespace Azure.DataApiBuilder.Config.Converters; internal class GraphQLRuntimeOptionsConverterFactory : JsonConverterFactory { - // Determines whether to replace environment variable with its - // value or not while deserializing. - private bool _replaceEnvVar; + // Settings for variable replacement during deserialization. + private readonly DeserializationVariableReplacementSettings? _replacementSettings; /// public override bool CanConvert(Type typeToConvert) @@ -22,25 +21,26 @@ public override bool CanConvert(Type typeToConvert) /// public override JsonConverter? CreateConverter(Type typeToConvert, JsonSerializerOptions options) { - return new GraphQLRuntimeOptionsConverter(_replaceEnvVar); + return new GraphQLRuntimeOptionsConverter(_replacementSettings); } - internal GraphQLRuntimeOptionsConverterFactory(bool replaceEnvVar) + /// Settings for variable replacement during deserialization. + /// If null, no variable replacement will be performed. + internal GraphQLRuntimeOptionsConverterFactory(DeserializationVariableReplacementSettings? replacementSettings = null) { - _replaceEnvVar = replaceEnvVar; + _replacementSettings = replacementSettings; } private class GraphQLRuntimeOptionsConverter : JsonConverter { - // Determines whether to replace environment variable with its - // value or not while deserializing. - private bool _replaceEnvVar; + // Settings for variable replacement during deserialization. + private readonly DeserializationVariableReplacementSettings? _replacementSettings; - /// Whether to replace environment variable with its - /// value or not while deserializing. - internal GraphQLRuntimeOptionsConverter(bool replaceEnvVar) + /// Settings for variable replacement during deserialization. + /// If null, no variable replacement will be performed. + internal GraphQLRuntimeOptionsConverter(DeserializationVariableReplacementSettings? replacementSettings) { - _replaceEnvVar = replaceEnvVar; + _replacementSettings = replacementSettings; } public override GraphQLRuntimeOptions? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) @@ -117,7 +117,7 @@ internal GraphQLRuntimeOptionsConverter(bool replaceEnvVar) case "path": if (reader.TokenType is JsonTokenType.String) { - string? path = reader.DeserializeString(_replaceEnvVar); + string? path = reader.DeserializeString(_replacementSettings); if (path is null) { path = "/graphql"; diff --git a/src/Config/Converters/McpRuntimeOptionsConverterFactory.cs b/src/Config/Converters/McpRuntimeOptionsConverterFactory.cs index db9acfa603..d75cbbef5a 100644 --- a/src/Config/Converters/McpRuntimeOptionsConverterFactory.cs +++ b/src/Config/Converters/McpRuntimeOptionsConverterFactory.cs @@ -14,7 +14,7 @@ internal class McpRuntimeOptionsConverterFactory : JsonConverterFactory { // Determines whether to replace environment variable with its // value or not while deserializing. - private bool _replaceEnvVar; + private DeserializationVariableReplacementSettings? _replacementSettings; /// public override bool CanConvert(Type typeToConvert) @@ -25,25 +25,25 @@ public override bool CanConvert(Type typeToConvert) /// public override JsonConverter? CreateConverter(Type typeToConvert, JsonSerializerOptions options) { - return new McpRuntimeOptionsConverter(_replaceEnvVar); + return new McpRuntimeOptionsConverter(_replacementSettings); } - internal McpRuntimeOptionsConverterFactory(bool replaceEnvVar) + internal McpRuntimeOptionsConverterFactory(DeserializationVariableReplacementSettings? replacementSettings) { - _replaceEnvVar = replaceEnvVar; + _replacementSettings = replacementSettings; } private class McpRuntimeOptionsConverter : JsonConverter { // Determines whether to replace environment variable with its // value or not while deserializing. - private bool _replaceEnvVar; + private readonly DeserializationVariableReplacementSettings? _replacementSettings; /// Whether to replace environment variable with its /// value or not while deserializing. - internal McpRuntimeOptionsConverter(bool replaceEnvVar) + internal McpRuntimeOptionsConverter(DeserializationVariableReplacementSettings? replacementSettings) { - _replaceEnvVar = replaceEnvVar; + _replacementSettings = replacementSettings; } /// @@ -89,7 +89,7 @@ internal McpRuntimeOptionsConverter(bool replaceEnvVar) case "path": if (reader.TokenType is not JsonTokenType.Null) { - path = reader.DeserializeString(_replaceEnvVar); + path = reader.DeserializeString(_replacementSettings); } break; diff --git a/src/Config/Converters/RuntimeHealthOptionsConvertorFactory.cs b/src/Config/Converters/RuntimeHealthOptionsConvertorFactory.cs index d49cc264e7..9c5f46dce2 100644 --- a/src/Config/Converters/RuntimeHealthOptionsConvertorFactory.cs +++ b/src/Config/Converters/RuntimeHealthOptionsConvertorFactory.cs @@ -11,7 +11,7 @@ internal class RuntimeHealthOptionsConvertorFactory : JsonConverterFactory { // Determines whether to replace environment variable with its // value or not while deserializing. - private bool _replaceEnvVar; + private readonly DeserializationVariableReplacementSettings? _replacementSettings; /// public override bool CanConvert(Type typeToConvert) @@ -22,25 +22,25 @@ public override bool CanConvert(Type typeToConvert) /// public override JsonConverter? CreateConverter(Type typeToConvert, JsonSerializerOptions options) { - return new HealthCheckOptionsConverter(_replaceEnvVar); + return new HealthCheckOptionsConverter(_replacementSettings); } - internal RuntimeHealthOptionsConvertorFactory(bool replaceEnvVar) + internal RuntimeHealthOptionsConvertorFactory(DeserializationVariableReplacementSettings? replacementSettings) { - _replaceEnvVar = replaceEnvVar; + _replacementSettings = replacementSettings; } private class HealthCheckOptionsConverter : JsonConverter { // Determines whether to replace environment variable with its // value or not while deserializing. - private bool _replaceEnvVar; + private readonly DeserializationVariableReplacementSettings? _replacementSettings; /// Whether to replace environment variable with its /// value or not while deserializing. - internal HealthCheckOptionsConverter(bool replaceEnvVar) + internal HealthCheckOptionsConverter(DeserializationVariableReplacementSettings? replacementSettings) { - _replaceEnvVar = replaceEnvVar; + _replacementSettings = replacementSettings; } /// @@ -102,7 +102,7 @@ internal HealthCheckOptionsConverter(bool replaceEnvVar) { if (reader.TokenType == JsonTokenType.String) { - string? currentRole = reader.DeserializeString(_replaceEnvVar); + string? currentRole = reader.DeserializeString(_replacementSettings); if (!string.IsNullOrEmpty(currentRole)) { stringList.Add(currentRole); diff --git a/src/Config/Converters/StringJsonConverterFactory.cs b/src/Config/Converters/StringJsonConverterFactory.cs index 078b611789..c3f5333237 100644 --- a/src/Config/Converters/StringJsonConverterFactory.cs +++ b/src/Config/Converters/StringJsonConverterFactory.cs @@ -4,21 +4,20 @@ using System.Text.Json; using System.Text.Json.Serialization; using System.Text.RegularExpressions; -using Azure.DataApiBuilder.Service.Exceptions; namespace Azure.DataApiBuilder.Config.Converters; /// -/// Custom string json converter factory to replace environment variables of the pattern -/// @env('ENV_NAME') with their value during deserialization. +/// Custom string json converter factory to replace environment variables and other variable patterns +/// during deserialization using the DeserializationVariableReplacementSettings. /// public class StringJsonConverterFactory : JsonConverterFactory { - private EnvironmentVariableReplacementFailureMode _replacementFailureMode; + private readonly DeserializationVariableReplacementSettings _replacementSettings; - public StringJsonConverterFactory(EnvironmentVariableReplacementFailureMode replacementFailureMode) + public StringJsonConverterFactory(DeserializationVariableReplacementSettings replacementSettings) { - _replacementFailureMode = replacementFailureMode; + _replacementSettings = replacementSettings; } public override bool CanConvert(Type typeToConvert) @@ -28,32 +27,16 @@ public override bool CanConvert(Type typeToConvert) public override JsonConverter? CreateConverter(Type typeToConvert, JsonSerializerOptions options) { - return new StringJsonConverter(_replacementFailureMode); + return new StringJsonConverter(_replacementSettings); } class StringJsonConverter : JsonConverter { - // @env\(' : match @env(' - // .*? : lazy match any character except newline 0 or more times - // (?='\)) : look ahead for ') which will combine with our lazy match - // ie: in @env('hello')goodbye') we match @env('hello') - // '\) : consume the ') into the match (look ahead doesn't capture) - // This pattern lazy matches any string that starts with @env(' and ends with ') - // ie: fooBAR@env('hello-world')bash)FOO') match: @env('hello-world') - // This matching pattern allows for the @env('') to be safely nested - // within strings that contain ') after our match. - // ie: if the environment variable "Baz" has the value of "Bar" - // fooBarBaz: "('foo@env('Baz')Baz')" would parse into - // fooBarBaz: "('fooBarBaz')" - // Note that there is no escape character currently for ') to exist - // within the name of the environment variable, but that ') is not - // a valid environment variable name in certain shells. - const string ENV_PATTERN = @"@env\('.*?(?='\))'\)"; - private EnvironmentVariableReplacementFailureMode _replacementFailureMode; + private DeserializationVariableReplacementSettings _replacementSettings; - public StringJsonConverter(EnvironmentVariableReplacementFailureMode replacementFailureMode) + public StringJsonConverter(DeserializationVariableReplacementSettings replacementSettings) { - _replacementFailureMode = replacementFailureMode; + _replacementSettings = replacementSettings; } public override string? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) @@ -61,7 +44,18 @@ public StringJsonConverter(EnvironmentVariableReplacementFailureMode replacement if (reader.TokenType == JsonTokenType.String) { string? value = reader.GetString(); - return Regex.Replace(value!, ENV_PATTERN, new MatchEvaluator(ReplaceMatchWithEnvVariable)); + if (string.IsNullOrWhiteSpace(value)) + { + return value; + } + + // Apply all replacement strategies configured in the settings + foreach (KeyValuePair> strategy in _replacementSettings.ReplacementStrategies) + { + value = strategy.Key.Replace(value, new MatchEvaluator(strategy.Value)); + } + + return value; } if (reader.TokenType == JsonTokenType.Null) @@ -76,30 +70,5 @@ public override void Write(Utf8JsonWriter writer, string value, JsonSerializerOp { writer.WriteStringValue(value); } - - private string ReplaceMatchWithEnvVariable(Match match) - { - // [^@env\(] : any substring that is not @env( - // .* : any char except newline any number of times - // (?=\)) : look ahead for end char of ) - // This pattern greedy matches all characters that are not a part of @env() - // ie: @env('hello@env('goodbye')world') match: 'hello@env('goodbye')world' - string innerPattern = @"[^@env\(].*(?=\))"; - - // strips first and last characters, ie: '''hello'' --> ''hello' - string envName = Regex.Match(match.Value, innerPattern).Value[1..^1]; - string? envValue = Environment.GetEnvironmentVariable(envName); - if (_replacementFailureMode == EnvironmentVariableReplacementFailureMode.Throw) - { - return envValue is not null ? envValue : - throw new DataApiBuilderException(message: $"Environmental Variable, {envName}, not found.", - statusCode: System.Net.HttpStatusCode.ServiceUnavailable, - subStatusCode: DataApiBuilderException.SubStatusCodes.ErrorInInitialization); - } - else - { - return envValue ?? match.Value; - } - } } } diff --git a/src/Config/Converters/Utf8JsonReaderExtensions.cs b/src/Config/Converters/Utf8JsonReaderExtensions.cs index 20c6821d02..5e16357227 100644 --- a/src/Config/Converters/Utf8JsonReaderExtensions.cs +++ b/src/Config/Converters/Utf8JsonReaderExtensions.cs @@ -13,14 +13,12 @@ static internal class Utf8JsonReaderExtensions /// substitution is applied. /// /// The reader that we want to pull the string from. - /// Whether to replace environment variable with its - /// value or not while deserializing. + /// The replacement settings to use while deserializing. /// The failure mode to use when replacing environment variables. /// The result of deserialization. /// Thrown if the is not String. public static string? DeserializeString(this Utf8JsonReader reader, - bool replaceEnvVar, - EnvironmentVariableReplacementFailureMode replacementFailureMode = EnvironmentVariableReplacementFailureMode.Throw) + DeserializationVariableReplacementSettings? replacementSettings) { if (reader.TokenType is JsonTokenType.Null) { @@ -34,9 +32,9 @@ static internal class Utf8JsonReaderExtensions // Add the StringConverterFactory so that we can do environment variable substitution. JsonSerializerOptions options = new(); - if (replaceEnvVar) + if (replacementSettings is not null) { - options.Converters.Add(new StringJsonConverterFactory(replacementFailureMode)); + options.Converters.Add(new StringJsonConverterFactory(replacementSettings)); } return JsonSerializer.Deserialize(ref reader, options); diff --git a/src/Config/DeserializationVariableReplacementSettings.cs b/src/Config/DeserializationVariableReplacementSettings.cs new file mode 100644 index 0000000000..350824409b --- /dev/null +++ b/src/Config/DeserializationVariableReplacementSettings.cs @@ -0,0 +1,288 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Text.RegularExpressions; +using Azure.Core; +using Azure.DataApiBuilder.Config.Converters; +using Azure.DataApiBuilder.Config.ObjectModel; +using Azure.DataApiBuilder.Service.Exceptions; +using Azure.Identity; +using Azure.Security.KeyVault.Secrets; +using Microsoft.Extensions.Logging; + +namespace Azure.DataApiBuilder.Config +{ + public class DeserializationVariableReplacementSettings + { + public bool DoReplaceEnvVar { get; set; } + public bool DoReplaceAkvVar { get; set; } + public EnvironmentVariableReplacementFailureMode EnvFailureMode { get; set; } = EnvironmentVariableReplacementFailureMode.Throw; + + // @env\(' : match @env(' + // @akv\(' : match @akv(' + // .*? : lazy match any character except newline 0 or more times + // (?='\)) : look ahead for ')' which will combine with our lazy match + // ie: in @env('hello')goodbye') we match @env('hello') + // '\) : consume the ') into the match (look ahead doesn't capture) + // This pattern lazy matches any string that starts with @env(' and ends with ') OR @akv(' and ends with ') + // Example: fooBAR@env('hello-world')bash)FOO') match: @env('hello-world') + // Example: fooBAR@akv('secret-name')bash)FOO') match: @akv('secret-name') + // This matching pattern allows for the @env('') / @akv('') to be safely nested + // within strings that contain ')' after our match. + // Note that there is no escape character currently for ')' to exist within the name of the variable. + public const string OUTER_ENV_PATTERN = @"@env\('.*?(?='\))'\)"; + public const string OUTER_AKV_PATTERN = @"@akv\('.*?(?='\))'\)"; + + // [^@env\(] : any substring that is not @env( + // [^@akv\(] : any substring that is not @akv( + // .* : any char except newline any number of times + // (?=\)) : look ahead for end char of ) + // This pattern greedy matches all characters that are not a part of @env() / @akv() + // ie: @env('hello@env('goodbye')world') match: 'hello@env('goodbye')world' + public const string INNER_ENV_PATTERN = @"[^@env\(].*(?=\))"; + public const string INNER_AKV_PATTERN = @"[^@akv\(].*(?=\))"; + + private readonly AzureKeyVaultOptions? _azureKeyVaultOptions; + private readonly SecretClient? _akvClient; + private readonly Dictionary? _akvFileSecrets; + private readonly ILogger? _logger; + + public Dictionary> ReplacementStrategies { get; private set; } = new(); + + public DeserializationVariableReplacementSettings( + AzureKeyVaultOptions? azureKeyVaultOptions = null, + bool doReplaceEnvVar = false, + bool doReplaceAkvVar = false, + EnvironmentVariableReplacementFailureMode envFailureMode = EnvironmentVariableReplacementFailureMode.Throw, + ILogger? logger = null) + { + _azureKeyVaultOptions = azureKeyVaultOptions; + DoReplaceEnvVar = doReplaceEnvVar; + DoReplaceAkvVar = doReplaceAkvVar; + EnvFailureMode = envFailureMode; + _logger = logger; + + if (DoReplaceEnvVar) + { + ReplacementStrategies.Add( + new Regex(OUTER_ENV_PATTERN, RegexOptions.Compiled), + ReplaceEnvVariable); + } + + if (DoReplaceAkvVar && _azureKeyVaultOptions is not null) + { + // Determine if endpoint points to a local .akv file. If so, load secrets from file; otherwise, use remote AKV. + if (IsLocalAkvFileEndpoint(_azureKeyVaultOptions.Endpoint)) + { + _akvFileSecrets = LoadAkvFileSecrets(_azureKeyVaultOptions.Endpoint!, _logger); + } + else + { + _akvClient = CreateSecretClient(_azureKeyVaultOptions); + } + + ReplacementStrategies.Add( + new Regex(OUTER_AKV_PATTERN, RegexOptions.Compiled), + ReplaceAkvVariable); + } + } + + // Checks if the endpoint is a path to a local .akv file. + private static bool IsLocalAkvFileEndpoint(string? endpoint) + => !string.IsNullOrWhiteSpace(endpoint) + && endpoint.EndsWith(".akv", StringComparison.OrdinalIgnoreCase) + && File.Exists(endpoint); + + // Loads key=value pairs from a .akv file, similar to .env style. Lines starting with '#' are comments. + private static Dictionary LoadAkvFileSecrets(string filePath, ILogger? logger = null) + { + Dictionary secrets = new(StringComparer.OrdinalIgnoreCase); + foreach (string rawLine in File.ReadAllLines(filePath)) + { + string line = rawLine.Trim(); + if (string.IsNullOrEmpty(line) || line.StartsWith('#')) + { + continue; + } + + int eqIndex = line.IndexOf('='); + if (eqIndex <= 0) + { + logger?.LogDebug("Ignoring malformed line in AKV secrets file {FilePath}: {Line}", filePath, rawLine); + continue; + } + + string key = line.Substring(0, eqIndex).Trim(); + string value = line[(eqIndex + 1)..].Trim(); + + // Remove optional surrounding quotes + if (value.Length >= 2 && ((value.StartsWith('"') && value.EndsWith('"')) || (value.StartsWith('\'') && value.EndsWith('\'')))) + { + value = value[1..^1]; + } + + if (!string.IsNullOrEmpty(key)) + { + if (!secrets.TryAdd(key, value)) + { + logger?.LogDebug("Duplicate key '{Key}' encountered in AKV secrets file {FilePath}. Skipping later value.", key, filePath); + } + } + } + + return secrets; + } + + private string ReplaceEnvVariable(Match match) + { + // strips first and last characters, ie: '''hello'' --> ''hello' + string name = Regex.Match(match.Value, INNER_ENV_PATTERN).Value[1..^1]; + string? value = Environment.GetEnvironmentVariable(name); + if (EnvFailureMode is EnvironmentVariableReplacementFailureMode.Throw) + { + return value is not null ? value : + throw new DataApiBuilderException( + message: $"Environmental Variable, {name}, not found.", + statusCode: System.Net.HttpStatusCode.ServiceUnavailable, + subStatusCode: DataApiBuilderException.SubStatusCodes.ErrorInInitialization); + } + else + { + return value ?? match.Value; + } + } + + private string ReplaceAkvVariable(Match match) + { + // strips first and last characters, ie: '''hello'' --> ''hello' + string name = Regex.Match(match.Value, INNER_AKV_PATTERN).Value[1..^1]; + + // Validate AKV secret name per rules: + // Allowed: alphanumeric and hyphen (-) + // Disallowed: spaces or any other symbols + // Must start and end with alphanumeric + // Length: 1 to 127 chars + if (!IsValidAkvSecretName(name, out string validationError)) + { + throw new DataApiBuilderException( + message: $"Azure Key Vault secret name '{name}' is invalid. {validationError} Requirements: allowed characters are alphanumeric and hyphen (-); must start and end with an alphanumeric character; length 1-127 characters; case-insensitive.", + statusCode: System.Net.HttpStatusCode.ServiceUnavailable, + subStatusCode: DataApiBuilderException.SubStatusCodes.ErrorInInitialization); + } + + string? value = GetAkvVariable(name); + if (EnvFailureMode == EnvironmentVariableReplacementFailureMode.Throw) + { + return value is not null ? value : + throw new DataApiBuilderException(message: $"Azure Key Vault Variable, '{name}', not found.", + statusCode: System.Net.HttpStatusCode.ServiceUnavailable, + subStatusCode: DataApiBuilderException.SubStatusCodes.ErrorInInitialization); + } + else + { + return value ?? match.Value; + } + } + + private static bool IsValidAkvSecretName(string name, out string error) + { + error = string.Empty; + if (string.IsNullOrEmpty(name)) + { + error = "Name cannot be null or empty."; + return false; + } + + if (name.Length < 1 || name.Length > 127) + { + error = $"Length {name.Length} is outside allowed range (1-127)."; + return false; + } + + // Must start and end with alphanumeric + if (!char.IsLetterOrDigit(name[0]) || !char.IsLetterOrDigit(name[^1])) + { + error = "Must start and end with an alphanumeric character."; + return false; + } + + // Allowed characters: letters, digits, hyphen. + for (int i = 0; i < name.Length; i++) + { + char c = name[i]; + if (!(char.IsLetterOrDigit(c) || c == '-')) + { + error = $"Invalid character '{c}' at position {i}."; + return false; + } + } + + return true; + } + + private static SecretClient CreateSecretClient(AzureKeyVaultOptions options) + { + if (string.IsNullOrWhiteSpace(options.Endpoint)) + { + throw new DataApiBuilderException( + "Missing 'endpoint' property is required to connect to Azure Key Vault.", + System.Net.HttpStatusCode.InternalServerError, + DataApiBuilderException.SubStatusCodes.ErrorInInitialization); + } + + // If endpoint is a local .akv file, we should not create a SecretClient. + if (IsLocalAkvFileEndpoint(options.Endpoint)) + { + throw new DataApiBuilderException( + "Attempted to create Azure Key Vault client for local .akv file endpoint.", + System.Net.HttpStatusCode.InternalServerError, + DataApiBuilderException.SubStatusCodes.ErrorInInitialization); + } + + SecretClientOptions clientOptions = new(); + + if (options.RetryPolicy is not null) + { + // Convert AKVRetryPolicyMode to RetryMode + RetryMode retryMode = options.RetryPolicy.Mode switch + { + AKVRetryPolicyMode.Fixed => RetryMode.Fixed, + AKVRetryPolicyMode.Exponential => RetryMode.Exponential, + null => RetryMode.Exponential, + _ => RetryMode.Exponential + }; + + clientOptions.Retry.Mode = retryMode; + clientOptions.Retry.MaxRetries = options.RetryPolicy.MaxCount ?? AKVRetryPolicyOptions.DEFAULT_MAX_COUNT; + clientOptions.Retry.Delay = TimeSpan.FromSeconds(options.RetryPolicy.DelaySeconds ?? AKVRetryPolicyOptions.DEFAULT_DELAY_SECONDS); + clientOptions.Retry.MaxDelay = TimeSpan.FromSeconds(options.RetryPolicy.MaxDelaySeconds ?? AKVRetryPolicyOptions.DEFAULT_MAX_DELAY_SECONDS); + clientOptions.Retry.NetworkTimeout = TimeSpan.FromSeconds(options.RetryPolicy.NetworkTimeoutSeconds ?? AKVRetryPolicyOptions.DEFAULT_NETWORK_TIMEOUT_SECONDS); + } + + return new SecretClient(new Uri(options.Endpoint), new DefaultAzureCredential(), clientOptions); // CodeQL [SM05137] DefaultAzureCredential will use Managed Identity if available or fallback to default. + } + + private string? GetAkvVariable(string name) + { + // If using local .akv file secrets, return from dictionary. + if (_akvFileSecrets is not null) + { + return _akvFileSecrets.TryGetValue(name, out string? value) ? value : null; + } + + if (_akvClient is null) + { + throw new InvalidOperationException("Azure Key Vault client is not initialized."); + } + + try + { + return _akvClient.GetSecret(name).Value.Value; + } + catch (Azure.RequestFailedException ex) when (ex.Status == 404) + { + return null; + } + } + } +} diff --git a/src/Config/FileSystemRuntimeConfigLoader.cs b/src/Config/FileSystemRuntimeConfigLoader.cs index 9c2a8e50b5..614cfbd11c 100644 --- a/src/Config/FileSystemRuntimeConfigLoader.cs +++ b/src/Config/FileSystemRuntimeConfigLoader.cs @@ -182,17 +182,16 @@ private void OnNewFileContentsDetected(object? sender, EventArgs e) /// /// The path to the dab-config.json file. /// The loaded RuntimeConfig, or null if none was loaded. - /// Whether to replace environment variable with its - /// value or not while deserializing. /// ILogger for logging errors. /// When not null indicates we need to overwrite mode and how to do so. + /// Settings for variable replacement during deserialization. If null, uses default settings with environment variable replacement disabled. /// True if the config was loaded, otherwise false. public bool TryLoadConfig( string path, [NotNullWhen(true)] out RuntimeConfig? config, - bool replaceEnvVar = false, ILogger? logger = null, - bool? isDevMode = null) + bool? isDevMode = null, + DeserializationVariableReplacementSettings? replacementSettings = null) { if (_fileSystem.File.Exists(path)) { @@ -226,7 +225,15 @@ public bool TryLoadConfig( } } - if (!string.IsNullOrEmpty(json) && TryParseConfig(json, out RuntimeConfig, connectionString: _connectionString, replaceEnvVar: replaceEnvVar)) + // Use default replacement settings if none provided + replacementSettings ??= new DeserializationVariableReplacementSettings(); + + if (!string.IsNullOrEmpty(json) && TryParseConfig( + json, + out RuntimeConfig, + replacementSettings, + logger: null, + connectionString: _connectionString)) { if (TrySetupConfigFileWatcher()) { @@ -292,12 +299,13 @@ public bool TryLoadConfig( /// Tries to load the config file using the filename known to the RuntimeConfigLoader and for the default environment. /// /// The loaded RuntimeConfig, or null if none was loaded. - /// Whether to replace environment variable with its - /// value or not while deserializing. + /// Settings for variable replacement during deserialization. If null, uses default settings with environment variable replacement disabled. /// True if the config was loaded, otherwise false. public override bool TryLoadKnownConfig([NotNullWhen(true)] out RuntimeConfig? config, bool replaceEnvVar = false) { - return TryLoadConfig(ConfigFilePath, out config, replaceEnvVar); + // Convert legacy replaceEnvVar parameter to replacement settings for backward compatibility + DeserializationVariableReplacementSettings? replacementSettings = new(azureKeyVaultOptions: null, doReplaceEnvVar: replaceEnvVar, doReplaceAkvVar: replaceEnvVar); + return TryLoadConfig(ConfigFilePath, out config, replacementSettings: replacementSettings); } /// @@ -307,7 +315,11 @@ public override bool TryLoadKnownConfig([NotNullWhen(true)] out RuntimeConfig? c private void HotReloadConfig(bool isDevMode, ILogger? logger = null) { logger?.LogInformation(message: "Starting hot-reload process for config: {ConfigFilePath}", ConfigFilePath); - if (!TryLoadConfig(ConfigFilePath, out _, replaceEnvVar: true, isDevMode: isDevMode)) + + // Use default replacement settings for hot reload + DeserializationVariableReplacementSettings replacementSettings = new(azureKeyVaultOptions: null, doReplaceEnvVar: true, doReplaceAkvVar: true); + + if (!TryLoadConfig(ConfigFilePath, out _, logger: logger, isDevMode: isDevMode, replacementSettings: replacementSettings)) { throw new DataApiBuilderException( message: "Deserialization of the configuration file failed.", @@ -467,7 +479,7 @@ public override string GetPublishedDraftSchemaLink() string? schemaPath = _fileSystem.Path.Combine(assemblyDirectory, "dab.draft.schema.json"); string schemaFileContent = _fileSystem.File.ReadAllText(schemaPath); - Dictionary? jsonDictionary = JsonSerializer.Deserialize>(schemaFileContent, GetSerializationOptions()); + Dictionary? jsonDictionary = JsonSerializer.Deserialize>(schemaFileContent, GetSerializationOptions(replacementSettings: null)); if (jsonDictionary is null) { diff --git a/src/Config/ObjectModel/AzureKeyVaultOptions.cs b/src/Config/ObjectModel/AzureKeyVaultOptions.cs index 27094cd16f..ebd1e909c1 100644 --- a/src/Config/ObjectModel/AzureKeyVaultOptions.cs +++ b/src/Config/ObjectModel/AzureKeyVaultOptions.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. +using System.Diagnostics.CodeAnalysis; using System.Text.Json.Serialization; namespace Azure.DataApiBuilder.Config.ObjectModel; @@ -12,4 +13,40 @@ public record AzureKeyVaultOptions [JsonPropertyName("retry-policy")] public AKVRetryPolicyOptions? RetryPolicy { get; init; } + + /// + /// Flag which informs CLI and JSON serializer whether to write endpoint + /// property and value to the runtime config file. + /// When user doesn't provide the endpoint property/value, which signals DAB to use the default, + /// the DAB CLI should not write the default value to a serialized config. + /// + [JsonIgnore(Condition = JsonIgnoreCondition.Always)] + [MemberNotNullWhen(true, nameof(Endpoint))] + public bool UserProvidedEndpoint { get; init; } = false; + + /// + /// Flag which informs CLI and JSON serializer whether to write retry-policy + /// property and value to the runtime config file. + /// When user doesn't provide the retry-policy property/value, which signals DAB to use the default, + /// the DAB CLI should not write the default value to a serialized config. + /// + [JsonIgnore(Condition = JsonIgnoreCondition.Always)] + [MemberNotNullWhen(true, nameof(RetryPolicy))] + public bool UserProvidedRetryPolicy { get; init; } = false; + + [JsonConstructor] + public AzureKeyVaultOptions(string? endpoint = null, AKVRetryPolicyOptions? retryPolicy = null) + { + if (endpoint is not null) + { + Endpoint = endpoint; + UserProvidedEndpoint = true; + } + + if (retryPolicy is not null) + { + RetryPolicy = retryPolicy; + UserProvidedRetryPolicy = true; + } + } } diff --git a/src/Config/ObjectModel/RuntimeConfig.cs b/src/Config/ObjectModel/RuntimeConfig.cs index a450e1265c..6896d82161 100644 --- a/src/Config/ObjectModel/RuntimeConfig.cs +++ b/src/Config/ObjectModel/RuntimeConfig.cs @@ -298,7 +298,10 @@ public RuntimeConfig( foreach (string dataSourceFile in DataSourceFiles.SourceFiles) { - if (loader.TryLoadConfig(dataSourceFile, out RuntimeConfig? config, replaceEnvVar: true)) + // Use default replacement settings for environment variable replacement + DeserializationVariableReplacementSettings replacementSettings = new(azureKeyVaultOptions: null, doReplaceEnvVar: true, doReplaceAkvVar: true); + + if (loader.TryLoadConfig(dataSourceFile, out RuntimeConfig? config, replacementSettings: replacementSettings)) { try { @@ -448,7 +451,7 @@ public bool CheckDataSourceExists(string dataSourceName) public string ToJson(JsonSerializerOptions? jsonSerializerOptions = null) { // get default serializer options if none provided. - jsonSerializerOptions = jsonSerializerOptions ?? RuntimeConfigLoader.GetSerializationOptions(); + jsonSerializerOptions = jsonSerializerOptions ?? RuntimeConfigLoader.GetSerializationOptions(replacementSettings: null); return JsonSerializer.Serialize(this, jsonSerializerOptions); } diff --git a/src/Config/RuntimeConfigLoader.cs b/src/Config/RuntimeConfigLoader.cs index f78c32ebc1..bad5aa8680 100644 --- a/src/Config/RuntimeConfigLoader.cs +++ b/src/Config/RuntimeConfigLoader.cs @@ -129,25 +129,86 @@ protected void SignalConfigChanged(string message = "") /// public abstract string GetPublishedDraftSchemaLink(); + /// + /// Extracts AzureKeyVaultOptions from JSON string with configurable variable replacement. + /// + /// JSON that represents the config file. + /// Whether to enable environment variable replacement during extraction. + /// Failure mode for environment variable replacement if enabled. + /// AzureKeyVaultOptions if present, null otherwise. + private static AzureKeyVaultOptions? ExtractAzureKeyVaultOptions( + string json, + bool enableEnvReplacement, + EnvironmentVariableReplacementFailureMode replacementFailureMode = EnvironmentVariableReplacementFailureMode.Throw) + { + JsonSerializerOptions options = new() + { + PropertyNameCaseInsensitive = false, + PropertyNamingPolicy = new HyphenatedNamingPolicy(), + ReadCommentHandling = JsonCommentHandling.Skip + }; + DeserializationVariableReplacementSettings envOnlySettings = new( + azureKeyVaultOptions: null, + doReplaceEnvVar: enableEnvReplacement, + doReplaceAkvVar: false, + envFailureMode: replacementFailureMode); + options.Converters.Add(new StringJsonConverterFactory(envOnlySettings)); + options.Converters.Add(new EnumMemberJsonEnumConverterFactory()); + options.Converters.Add(new AzureKeyVaultOptionsConverterFactory(replacementSettings: envOnlySettings)); + options.Converters.Add(new AKVRetryPolicyOptionsConverterFactory(replacementSettings: envOnlySettings)); + + try + { + using JsonDocument doc = JsonDocument.Parse(json); + if (doc.RootElement.TryGetProperty("azure-key-vault", out JsonElement akvElement)) + { + return JsonSerializer.Deserialize(akvElement.GetRawText(), options); + } + } + catch + { + // If we can't extract AKV options, return null and proceed without AKV variable replacement + return null; + } + + return null; + } + /// /// Parses a JSON string into a RuntimeConfig object for single database scenario. /// /// JSON that represents the config file. /// The parsed config, or null if it parsed unsuccessfully. - /// True if the config was parsed, otherwise false. + /// Settings for variable replacement during deserialization. If null, no variable replacement will be performed. /// logger to log messages /// connectionString to add to config if specified - /// Whether to replace environment variable with its - /// value or not while deserializing. By default, no replacement happens. - /// Determines failure mode for env variable replacement. + /// True if the config was parsed, otherwise false. public static bool TryParseConfig(string json, [NotNullWhen(true)] out RuntimeConfig? config, + DeserializationVariableReplacementSettings? replacementSettings = null, ILogger? logger = null, - string? connectionString = null, - bool replaceEnvVar = false, - EnvironmentVariableReplacementFailureMode replacementFailureMode = EnvironmentVariableReplacementFailureMode.Throw) + string? connectionString = null) { - JsonSerializerOptions options = GetSerializationOptions(replaceEnvVar, replacementFailureMode); + // First pass: extract AzureKeyVault options if AKV replacement is requested + if (replacementSettings?.DoReplaceAkvVar is true) + { + AzureKeyVaultOptions? azureKeyVaultOptions = ExtractAzureKeyVaultOptions( + json: json, + enableEnvReplacement: replacementSettings.DoReplaceEnvVar, + replacementFailureMode: replacementSettings.EnvFailureMode); + + // Update replacement settings with the extracted AKV options + if (azureKeyVaultOptions is not null) + { + replacementSettings = new DeserializationVariableReplacementSettings( + azureKeyVaultOptions: azureKeyVaultOptions, + doReplaceEnvVar: replacementSettings.DoReplaceEnvVar, + doReplaceAkvVar: replacementSettings.DoReplaceAkvVar, + envFailureMode: replacementSettings.EnvFailureMode); + } + } + + JsonSerializerOptions options = GetSerializationOptions(replacementSettings); try { @@ -180,11 +241,11 @@ public static bool TryParseConfig(string json, DataSource ds = config.GetDataSourceFromDataSourceName(dataSourceKey); // Add Application Name for telemetry for MsSQL or PgSql - if (ds.DatabaseType is DatabaseType.MSSQL && replaceEnvVar) + if (ds.DatabaseType is DatabaseType.MSSQL && replacementSettings?.DoReplaceEnvVar == true) { updatedConnection = GetConnectionStringWithApplicationName(connectionValue); } - else if (ds.DatabaseType is DatabaseType.PostgreSQL && replaceEnvVar) + else if (ds.DatabaseType is DatabaseType.PostgreSQL && replacementSettings?.DoReplaceEnvVar == true) { updatedConnection = GetPgSqlConnectionStringWithApplicationName(connectionValue); } @@ -225,11 +286,10 @@ ex is JsonException || /// /// Get Serializer options for the config file. /// - /// Whether to replace environment variable with value or not while deserializing. - /// By default, no replacement happens. + /// Settings for variable replacement during deserialization. + /// If null, no variable replacement will be performed. public static JsonSerializerOptions GetSerializationOptions( - bool replaceEnvVar = false, - EnvironmentVariableReplacementFailureMode replacementFailureMode = EnvironmentVariableReplacementFailureMode.Throw) + DeserializationVariableReplacementSettings? replacementSettings = null) { JsonSerializerOptions options = new() { @@ -241,33 +301,37 @@ public static JsonSerializerOptions GetSerializationOptions( Encoder = JavaScriptEncoder.UnsafeRelaxedJsonEscaping }; options.Converters.Add(new EnumMemberJsonEnumConverterFactory()); - options.Converters.Add(new RuntimeHealthOptionsConvertorFactory(replaceEnvVar)); - options.Converters.Add(new DataSourceHealthOptionsConvertorFactory(replaceEnvVar)); + options.Converters.Add(new RuntimeHealthOptionsConvertorFactory(replacementSettings)); + options.Converters.Add(new DataSourceHealthOptionsConvertorFactory(replacementSettings)); options.Converters.Add(new EntityHealthOptionsConvertorFactory()); options.Converters.Add(new RestRuntimeOptionsConverterFactory()); - options.Converters.Add(new GraphQLRuntimeOptionsConverterFactory(replaceEnvVar)); - options.Converters.Add(new McpRuntimeOptionsConverterFactory(replaceEnvVar)); + options.Converters.Add(new GraphQLRuntimeOptionsConverterFactory(replacementSettings)); + options.Converters.Add(new McpRuntimeOptionsConverterFactory(replacementSettings)); options.Converters.Add(new DmlToolsConfigConverter()); - options.Converters.Add(new EntitySourceConverterFactory(replaceEnvVar)); - options.Converters.Add(new EntityGraphQLOptionsConverterFactory(replaceEnvVar)); - options.Converters.Add(new EntityRestOptionsConverterFactory(replaceEnvVar)); + options.Converters.Add(new EntitySourceConverterFactory(replacementSettings)); + options.Converters.Add(new EntityGraphQLOptionsConverterFactory(replacementSettings)); + options.Converters.Add(new EntityRestOptionsConverterFactory(replacementSettings)); options.Converters.Add(new EntityActionConverterFactory()); options.Converters.Add(new DataSourceFilesConverter()); - options.Converters.Add(new EntityCacheOptionsConverterFactory(replaceEnvVar)); + options.Converters.Add(new EntityCacheOptionsConverterFactory(replacementSettings)); options.Converters.Add(new RuntimeCacheOptionsConverterFactory()); options.Converters.Add(new RuntimeCacheLevel2OptionsConverterFactory()); options.Converters.Add(new MultipleCreateOptionsConverter()); options.Converters.Add(new MultipleMutationOptionsConverter(options)); - options.Converters.Add(new DataSourceConverterFactory(replaceEnvVar)); + options.Converters.Add(new DataSourceConverterFactory(replacementSettings)); options.Converters.Add(new HostOptionsConvertorFactory()); - options.Converters.Add(new AKVRetryPolicyOptionsConverterFactory(replaceEnvVar)); - options.Converters.Add(new AzureLogAnalyticsOptionsConverterFactory(replaceEnvVar)); - options.Converters.Add(new AzureLogAnalyticsAuthOptionsConverter(replaceEnvVar)); - options.Converters.Add(new FileSinkConverter(replaceEnvVar)); + options.Converters.Add(new AKVRetryPolicyOptionsConverterFactory(replacementSettings)); + options.Converters.Add(new AzureLogAnalyticsOptionsConverterFactory(replacementSettings)); + options.Converters.Add(new AzureLogAnalyticsAuthOptionsConverter(replacementSettings)); + options.Converters.Add(new FileSinkConverter(replacementSettings)); + + // Add AzureKeyVaultOptionsConverterFactory to ensure AKV config is deserialized properly + options.Converters.Add(new AzureKeyVaultOptionsConverterFactory(replacementSettings)); - if (replaceEnvVar) + // Only add the extensible string converter if we have replacement settings + if (replacementSettings is not null) { - options.Converters.Add(new StringJsonConverterFactory(replacementFailureMode)); + options.Converters.Add(new StringJsonConverterFactory(replacementSettings)); } return options; diff --git a/src/Core/Configurations/RuntimeConfigProvider.cs b/src/Core/Configurations/RuntimeConfigProvider.cs index faeb2b94d0..b46a716f48 100644 --- a/src/Core/Configurations/RuntimeConfigProvider.cs +++ b/src/Core/Configurations/RuntimeConfigProvider.cs @@ -6,7 +6,6 @@ using System.IO.Abstractions; using System.Net; using Azure.DataApiBuilder.Config; -using Azure.DataApiBuilder.Config.Converters; using Azure.DataApiBuilder.Config.NamingPolicies; using Azure.DataApiBuilder.Config.ObjectModel; using Azure.DataApiBuilder.Service.Exceptions; @@ -189,8 +188,7 @@ public async Task Initialize( if (RuntimeConfigLoader.TryParseConfig( configuration, out RuntimeConfig? runtimeConfig, - replaceEnvVar: false, - replacementFailureMode: EnvironmentVariableReplacementFailureMode.Ignore)) + replacementSettings: null)) { _configLoader.RuntimeConfig = runtimeConfig; @@ -257,8 +255,7 @@ public async Task Initialize( string? graphQLSchema, string connectionString, string? accessToken, - bool replaceEnvVar = true, - EnvironmentVariableReplacementFailureMode replacementFailureMode = EnvironmentVariableReplacementFailureMode.Throw) + DeserializationVariableReplacementSettings? replacementSettings) { if (string.IsNullOrEmpty(connectionString)) { @@ -272,7 +269,7 @@ public async Task Initialize( IsLateConfigured = true; - if (RuntimeConfigLoader.TryParseConfig(jsonConfig, out RuntimeConfig? runtimeConfig, replaceEnvVar: replaceEnvVar, replacementFailureMode: replacementFailureMode)) + if (RuntimeConfigLoader.TryParseConfig(jsonConfig, out RuntimeConfig? runtimeConfig, replacementSettings)) { _configLoader.RuntimeConfig = runtimeConfig.DataSource.DatabaseType switch { diff --git a/src/Directory.Packages.props b/src/Directory.Packages.props index 14f097915c..542508f71f 100644 --- a/src/Directory.Packages.props +++ b/src/Directory.Packages.props @@ -5,6 +5,7 @@ + diff --git a/src/Service.Tests/Caching/CachingConfigProcessingTests.cs b/src/Service.Tests/Caching/CachingConfigProcessingTests.cs index 2780af63c5..a6daebf3e4 100644 --- a/src/Service.Tests/Caching/CachingConfigProcessingTests.cs +++ b/src/Service.Tests/Caching/CachingConfigProcessingTests.cs @@ -5,7 +5,6 @@ using System.Text; using System.Text.Json; using Azure.DataApiBuilder.Config; -using Azure.DataApiBuilder.Config.Converters; using Azure.DataApiBuilder.Config.ObjectModel; using Microsoft.VisualStudio.TestTools.UnitTesting; @@ -56,10 +55,7 @@ public void EntityCacheOptionsDeserialization_ValidJson( RuntimeConfigLoader.TryParseConfig( json: fullConfig, out RuntimeConfig? config, - logger: null, - connectionString: null, - replaceEnvVar: false, - replacementFailureMode: EnvironmentVariableReplacementFailureMode.Throw); + replacementSettings: null); // Assert Assert.IsNotNull(config, message: "Config must not be null, runtime config JSON deserialization failed."); @@ -103,10 +99,7 @@ public void EntityCacheOptionsDeserialization_InvalidValues(string entityCacheCo bool isParsingSuccessful = RuntimeConfigLoader.TryParseConfig( json: fullConfig, out _, - logger: null, - connectionString: null, - replaceEnvVar: false, - replacementFailureMode: EnvironmentVariableReplacementFailureMode.Throw); + replacementSettings: null); // Assert Assert.IsFalse(isParsingSuccessful, message: "Expected JSON parsing to fail."); @@ -141,10 +134,7 @@ public void GlobalCacheOptionsDeserialization_ValidValues( RuntimeConfigLoader.TryParseConfig( json: fullConfig, out RuntimeConfig? config, - logger: null, - connectionString: null, - replaceEnvVar: false, - replacementFailureMode: EnvironmentVariableReplacementFailureMode.Throw); + replacementSettings: null); // Assert Assert.IsNotNull(config, message: "Config must not be null, runtime config JSON deserialization failed."); @@ -187,10 +177,7 @@ public void GlobalCacheOptionsDeserialization_InvalidValues(string globalCacheCo bool parsingSuccessful = RuntimeConfigLoader.TryParseConfig( json: fullConfig, out _, - logger: null, - connectionString: null, - replaceEnvVar: false, - replacementFailureMode: EnvironmentVariableReplacementFailureMode.Throw); + replacementSettings: null); // Assert Assert.IsFalse(parsingSuccessful, message: "Expected JSON parsing to fail."); @@ -216,10 +203,7 @@ public void GlobalCacheOptionsOverridesEntityCacheOptions(string globalCacheConf RuntimeConfigLoader.TryParseConfig( json: fullConfig, out RuntimeConfig? config, - logger: null, - connectionString: null, - replaceEnvVar: false, - replacementFailureMode: EnvironmentVariableReplacementFailureMode.Throw); + replacementSettings: null); // Assert Assert.IsNotNull(config, message: "Config must not be null, runtime config JSON deserialization failed."); @@ -252,10 +236,7 @@ public void UserDefinedTtlWrittenToSerializedJsonConfigFile(bool expectIsUserDef RuntimeConfigLoader.TryParseConfig( json: fullConfig, out RuntimeConfig? config, - logger: null, - connectionString: null, - replaceEnvVar: false, - replacementFailureMode: EnvironmentVariableReplacementFailureMode.Throw); + replacementSettings: null); Assert.IsNotNull(config, message: "Test setup failure. Config must not be null, runtime config JSON deserialization failed."); // Act @@ -300,10 +281,7 @@ public void CachePropertyNotWrittenToSerializedJsonConfigFile(string cacheConfig RuntimeConfigLoader.TryParseConfig( json: fullConfig, out RuntimeConfig? config, - logger: null, - connectionString: null, - replaceEnvVar: false, - replacementFailureMode: EnvironmentVariableReplacementFailureMode.Throw); + replacementSettings: null); Assert.IsNotNull(config, message: "Test setup failure. Config must not be null, runtime config JSON deserialization failed."); // Act @@ -342,10 +320,7 @@ public void DefaultTtlNotWrittenToSerializedJsonConfigFile(string cacheConfig) RuntimeConfigLoader.TryParseConfig( json: fullConfig, out RuntimeConfig? config, - logger: null, - connectionString: null, - replaceEnvVar: false, - replacementFailureMode: EnvironmentVariableReplacementFailureMode.Throw); + replacementSettings: null); Assert.IsNotNull(config, message: "Test setup failure. Config must not be null, runtime config JSON deserialization failed."); // Act diff --git a/src/Service.Tests/Configuration/ConfigurationTests.cs b/src/Service.Tests/Configuration/ConfigurationTests.cs index 65f6e6643b..0614e7688f 100644 --- a/src/Service.Tests/Configuration/ConfigurationTests.cs +++ b/src/Service.Tests/Configuration/ConfigurationTests.cs @@ -838,9 +838,9 @@ public void MsSqlConnStringSupplementedWithAppNameProperty( // Act bool configParsed = RuntimeConfigLoader.TryParseConfig( - runtimeConfig.ToJson(), - out RuntimeConfig updatedRuntimeConfig, - replaceEnvVar: true); + json: runtimeConfig.ToJson(), + config: out RuntimeConfig updatedRuntimeConfig, + replacementSettings: new(doReplaceEnvVar: true)); // Assert Assert.AreEqual( @@ -891,9 +891,9 @@ public void PgSqlConnStringSupplementedWithAppNameProperty( // Act bool configParsed = RuntimeConfigLoader.TryParseConfig( - runtimeConfig.ToJson(), - out RuntimeConfig updatedRuntimeConfig, - replaceEnvVar: true); + json: runtimeConfig.ToJson(), + config: out RuntimeConfig updatedRuntimeConfig, + replacementSettings: new(doReplaceEnvVar: true)); // Assert Assert.AreEqual( @@ -956,9 +956,9 @@ public void TestConnectionStringIsCorrectlyUpdatedWithApplicationName( // Act bool configParsed = RuntimeConfigLoader.TryParseConfig( - runtimeConfig.ToJson(), - out RuntimeConfig updatedRuntimeConfig, - replaceEnvVar: true); + json: runtimeConfig.ToJson(), + config: out RuntimeConfig updatedRuntimeConfig, + replacementSettings: new(doReplaceEnvVar: true)); // Assert Assert.AreEqual( @@ -2346,7 +2346,12 @@ public async Task TestSPRestDefaultsForManuallyConstructedConfigs( HttpStatusCode expectedResponseStatusCode) { string configJson = TestHelper.AddPropertiesToJson(TestHelper.BASE_CONFIG, entityJson); - RuntimeConfigLoader.TryParseConfig(configJson, out RuntimeConfig deserializedConfig, logger: null, GetConnectionStringFromEnvironmentConfig(environment: TestCategory.MSSQL)); + RuntimeConfigLoader.TryParseConfig( + configJson, + out RuntimeConfig deserializedConfig, + replacementSettings: new(), + logger: null, + GetConnectionStringFromEnvironmentConfig(environment: TestCategory.MSSQL)); string configFileName = "custom-config.json"; File.WriteAllText(configFileName, deserializedConfig.ToJson()); string[] args = new[] @@ -2429,7 +2434,12 @@ public async Task SanityTestForRestAndGQLRequestsWithoutMultipleMutationFeatureF // The configuration file is constructed by merging hard-coded JSON strings to simulate the scenario where users manually edit the // configuration file (instead of using CLI). string configJson = TestHelper.AddPropertiesToJson(TestHelper.BASE_CONFIG, BOOK_ENTITY_JSON); - Assert.IsTrue(RuntimeConfigLoader.TryParseConfig(configJson, out RuntimeConfig deserializedConfig, logger: null, GetConnectionStringFromEnvironmentConfig(environment: TestCategory.MSSQL))); + Assert.IsTrue(RuntimeConfigLoader.TryParseConfig( + configJson, + out RuntimeConfig deserializedConfig, + replacementSettings: new(), + logger: null, + GetConnectionStringFromEnvironmentConfig(environment: TestCategory.MSSQL))); string configFileName = "custom-config.json"; File.WriteAllText(configFileName, deserializedConfig.ToJson()); string[] args = new[] @@ -3290,7 +3300,12 @@ public async Task ValidateStrictModeAsDefaultForRestRequestBody(bool includeExtr // The BASE_CONFIG omits the rest.request-body-strict option in the runtime section. string configJson = TestHelper.AddPropertiesToJson(TestHelper.BASE_CONFIG, entityJson); - RuntimeConfigLoader.TryParseConfig(configJson, out RuntimeConfig deserializedConfig, logger: null, GetConnectionStringFromEnvironmentConfig(environment: TestCategory.MSSQL)); + RuntimeConfigLoader.TryParseConfig( + configJson, + out RuntimeConfig deserializedConfig, + replacementSettings: new(), + logger: null, + GetConnectionStringFromEnvironmentConfig(environment: TestCategory.MSSQL)); const string CUSTOM_CONFIG = "custom-config.json"; File.WriteAllText(CUSTOM_CONFIG, deserializedConfig.ToJson()); string[] args = new[] @@ -5494,7 +5509,7 @@ public static string GetConnectionStringFromEnvironmentConfig(string environment string sqlFile = new FileSystemRuntimeConfigLoader(fileSystem).GetFileNameForEnvironment(environment, considerOverrides: true); string configPayload = File.ReadAllText(sqlFile); - RuntimeConfigLoader.TryParseConfig(configPayload, out RuntimeConfig runtimeConfig, replaceEnvVar: true); + RuntimeConfigLoader.TryParseConfig(configPayload, out RuntimeConfig runtimeConfig, replacementSettings: new()); return runtimeConfig.DataSource.ConnectionString; } diff --git a/src/Service.Tests/UnitTests/MySqlQueryExecutorUnitTests.cs b/src/Service.Tests/UnitTests/MySqlQueryExecutorUnitTests.cs index cbfef36664..63deed78d3 100644 --- a/src/Service.Tests/UnitTests/MySqlQueryExecutorUnitTests.cs +++ b/src/Service.Tests/UnitTests/MySqlQueryExecutorUnitTests.cs @@ -81,7 +81,8 @@ await provider.Initialize( provider.GetConfig().ToJson(), graphQLSchema: null, connectionString: connectionString, - accessToken: CONFIG_TOKEN); + accessToken: CONFIG_TOKEN, + replacementSettings: new()); mySqlQueryExecutor = new(provider, dbExceptionParser.Object, queryExecutorLogger.Object, httpContextAccessor.Object); } } diff --git a/src/Service.Tests/UnitTests/PostgreSqlQueryExecutorUnitTests.cs b/src/Service.Tests/UnitTests/PostgreSqlQueryExecutorUnitTests.cs index ccaa90b353..6039c46a72 100644 --- a/src/Service.Tests/UnitTests/PostgreSqlQueryExecutorUnitTests.cs +++ b/src/Service.Tests/UnitTests/PostgreSqlQueryExecutorUnitTests.cs @@ -89,7 +89,8 @@ await provider.Initialize( provider.GetConfig().ToJson(), graphQLSchema: null, connectionString: connectionString, - accessToken: CONFIG_TOKEN); + accessToken: CONFIG_TOKEN, + replacementSettings: new()); postgreSqlQueryExecutor = new(provider, dbExceptionParser.Object, queryExecutorLogger.Object, httpContextAccessor.Object); } } diff --git a/src/Service.Tests/UnitTests/RuntimeConfigLoaderJsonDeserializerTests.cs b/src/Service.Tests/UnitTests/RuntimeConfigLoaderJsonDeserializerTests.cs index b98de993e2..8329dc2134 100644 --- a/src/Service.Tests/UnitTests/RuntimeConfigLoaderJsonDeserializerTests.cs +++ b/src/Service.Tests/UnitTests/RuntimeConfigLoaderJsonDeserializerTests.cs @@ -13,6 +13,7 @@ using Azure.DataApiBuilder.Config.Converters; using Azure.DataApiBuilder.Config.ObjectModel; using Azure.DataApiBuilder.Service.Exceptions; +using Microsoft.Data.SqlClient; using Microsoft.VisualStudio.TestTools.UnitTesting; namespace Azure.DataApiBuilder.Service.Tests.UnitTests @@ -79,18 +80,18 @@ public void CheckConfigEnvParsingTest( if (replaceEnvVar) { Assert.IsTrue(RuntimeConfigLoader.TryParseConfig( - GetModifiedJsonString(repValues, @"""postgresql"""), out expectedConfig, replaceEnvVar: replaceEnvVar), + GetModifiedJsonString(repValues, @"""postgresql"""), out expectedConfig, replacementSettings: new DeserializationVariableReplacementSettings(azureKeyVaultOptions: null, doReplaceEnvVar: replaceEnvVar, doReplaceAkvVar: false)), "Should read the expected config"); } else { Assert.IsTrue(RuntimeConfigLoader.TryParseConfig( - GetModifiedJsonString(repKeys, @"""postgresql"""), out expectedConfig, replaceEnvVar: replaceEnvVar), + GetModifiedJsonString(repKeys, @"""postgresql"""), out expectedConfig, replacementSettings: new DeserializationVariableReplacementSettings(azureKeyVaultOptions: null, doReplaceEnvVar: replaceEnvVar, doReplaceAkvVar: false)), "Should read the expected config"); } Assert.IsTrue(RuntimeConfigLoader.TryParseConfig( - GetModifiedJsonString(repKeys, @"""@env('enumVarName')"""), out RuntimeConfig actualConfig, replaceEnvVar: replaceEnvVar), + GetModifiedJsonString(repKeys, @"""@env('enumVarName')"""), out RuntimeConfig actualConfig, replacementSettings: new DeserializationVariableReplacementSettings(azureKeyVaultOptions: null, doReplaceEnvVar: replaceEnvVar, doReplaceAkvVar: false)), "Should read actual config"); Assert.AreEqual(expectedConfig.ToJson(), actualConfig.ToJson()); } @@ -130,7 +131,7 @@ public void TestConfigParsingWithEnvVarReplacement(bool replaceEnvVar, string da string configWithEnvVar = _configWithVariableDataSource.Replace("{0}", GetDataSourceConfigForGivenDatabase(databaseType)); bool isParsingSuccessful = RuntimeConfigLoader.TryParseConfig( - configWithEnvVar, out RuntimeConfig runtimeConfig, replaceEnvVar: replaceEnvVar); + configWithEnvVar, out RuntimeConfig runtimeConfig, replacementSettings: new DeserializationVariableReplacementSettings(azureKeyVaultOptions: null, doReplaceEnvVar: replaceEnvVar, doReplaceAkvVar: true)); // Assert Assert.IsTrue(isParsingSuccessful); @@ -178,7 +179,7 @@ public void TestConfigParsingWhenDataSourceOptionsForCosmosDBContainsInvalidValu string configWithEnvVar = _configWithVariableDataSource.Replace("{0}", GetDataSourceOptionsForCosmosDBWithInvalidValues()); bool isParsingSuccessful = RuntimeConfigLoader.TryParseConfig( - configWithEnvVar, out RuntimeConfig runtimeConfig, replaceEnvVar: true); + configWithEnvVar, out RuntimeConfig runtimeConfig, replacementSettings: new DeserializationVariableReplacementSettings(azureKeyVaultOptions: null, doReplaceEnvVar: true, doReplaceAkvVar: true)); // Assert Assert.IsTrue(isParsingSuccessful); @@ -240,6 +241,7 @@ public void CheckCommentParsingInConfigFile() /// but have the effect of default values when deserialized. /// It starts with a minimal config and incrementally /// adds the optional subproperties. At each step, tests for valid deserialization. + /// [TestMethod] public void TestNullableOptionalProps() { @@ -302,7 +304,7 @@ public void CheckConfigEnvParsingThrowExceptions(string invalidEnvVarName) { string json = @"{ ""foo"" : ""@env('envVarName'), @env('" + invalidEnvVarName + @"')"" }"; SetEnvVariables(); - StringJsonConverterFactory stringConverterFactory = new(EnvironmentVariableReplacementFailureMode.Throw); + StringJsonConverterFactory stringConverterFactory = new(new(doReplaceEnvVar: true, envFailureMode: EnvironmentVariableReplacementFailureMode.Throw)); JsonSerializerOptions options = new() { PropertyNameCaseInsensitive = true }; options.Converters.Add(stringConverterFactory); Assert.ThrowsException(() => JsonSerializer.Deserialize(json, options)); @@ -324,7 +326,7 @@ public void TestDataSourceDeserializationFailures(string dbType, string connecti ""entities"":{ } }"; // replaceEnvVar: true is needed to make sure we do post-processing for the connection string case - Assert.IsFalse(RuntimeConfigLoader.TryParseConfig(configJson, out RuntimeConfig deserializedConfig, replaceEnvVar: true)); + Assert.IsFalse(RuntimeConfigLoader.TryParseConfig(configJson, out RuntimeConfig deserializedConfig, replacementSettings: new DeserializationVariableReplacementSettings(azureKeyVaultOptions: null, doReplaceEnvVar: true, doReplaceAkvVar: true))); Assert.IsNull(deserializedConfig); } @@ -343,7 +345,8 @@ public void TestLoadRuntimeConfigFailures( MockFileSystem fileSystem = new(); FileSystemRuntimeConfigLoader loader = new(fileSystem); - Assert.IsFalse(loader.TryLoadConfig(configFileName, out RuntimeConfig _)); + // Use null replacement settings for this test + Assert.IsFalse(loader.TryLoadConfig(configFileName, out RuntimeConfig _, replacementSettings: null)); } /// @@ -430,7 +433,7 @@ public static string GetModifiedJsonString(string[] reps, string enumString) ""host"": { ""mode"": ""development"", ""cors"": { - ""origins"": [ """ + reps[++index % reps.Length] + @""", """ + reps[++index % reps.Length] + @""" ], + ""origins"": [ """ + reps[++index % reps.Length] + @""", """ + reps[++index % reps.Length] + @"""], ""allow-credentials"": true }, ""authentication"": { @@ -670,5 +673,179 @@ private static bool TryParseAndAssertOnDefaults(string json, out RuntimeConfig p #endregion Helper Functions record StubJsonType(string Foo); + + /// + /// Test to verify Azure Key Vault variable replacement from local .akv file. + /// + [TestMethod] + public void TestAkvVariableReplacementFromLocalFile() + { + // Arrange: create a temporary .akv secrets file + string akvFilePath = Path.Combine(Path.GetTempPath(), Guid.NewGuid().ToString() + ".akv"); + string secretConnectionString = "Server=tcp:127.0.0.1,1433;Persist Security Info=False;Trusted_Connection=True;TrustServerCertificate=True;MultipleActiveResultSets=False;Connection Timeout=5;"; + File.WriteAllText(akvFilePath, $"DBCONN={secretConnectionString}\nAPI_KEY=abcd\n# Comment line should be ignored\n MALFORMEDLINE \n"); + + // Escape backslashes for JSON + string escapedPath = akvFilePath.Replace("\\", "\\\\"); + + string jsonConfig = $$""" + { + "$schema": "https://github.com/Azure/data-api-builder/releases/download/vmajor.minor.patch-alpha/dab.draft.schema.json", + "data-source": { + "database-type": "mssql", + "connection-string": "@akv('DBCONN')" + }, + "azure-key-vault": { + "endpoint": "{{escapedPath}}" + }, + "entities": { } + } + """; + + try + { + // Act + DeserializationVariableReplacementSettings replacementSettings = new( + azureKeyVaultOptions: null, + doReplaceEnvVar: false, + doReplaceAkvVar: true); + bool parsed = RuntimeConfigLoader.TryParseConfig(jsonConfig, out RuntimeConfig config, replacementSettings: replacementSettings); + + // Assert + Assert.IsTrue(parsed, "Config should parse successfully with local AKV file replacement."); + Assert.IsNotNull(config, "Config should not be null."); + Assert.AreEqual(secretConnectionString, config.DataSource.ConnectionString, "Connection string should be replaced from AKV local file secret."); + } + finally + { + // Cleanup + if (File.Exists(akvFilePath)) + { + File.Delete(akvFilePath); + } + } + } + + /// + /// Validates that when an AKV secret's value itself contains an @env('...') pattern, it is NOT further resolved + /// because replacement only runs once per original JSON token. Demonstrates that nested env patterns inside + /// AKV secret values are left intact. + /// + [TestMethod] + public void TestAkvSecretValueContainingEnvPatternIsNotEnvExpanded() + { + string akvFilePath = Path.Combine(Path.GetTempPath(), Guid.NewGuid().ToString() + ".akv"); + // Valid MSSQL connection string which embeds an @env('env') pattern in the Database value. + // This pattern should NOT be expanded because replacement only runs once on the original JSON token (@akv('DBCONN')). + string secretValueWithEnvPattern = "Server=localhost;Database=@env('env');User Id=sa;Password=XXXX;"; + File.WriteAllText(akvFilePath, $"DBCONN={secretValueWithEnvPattern}\n"); + string escapedPath = akvFilePath.Replace("\\", "\\\\"); + + // Set env variable to prove it would be different if expansion occurred. + Environment.SetEnvironmentVariable("env", "SHOULD_NOT_APPEAR"); + + string jsonConfig = $$""" + { + "$schema": "https://github.com/Azure/data-api-builder/releases/download/vmajor.minor.patch-alpha/dab.draft.schema.json", + "data-source": { + "database-type": "mssql", + "connection-string": "@akv('DBCONN')" + }, + "azure-key-vault": { + "endpoint": "{{escapedPath}}" + }, + "entities": { } + } + """; + + try + { + DeserializationVariableReplacementSettings replacementSettings = new( + azureKeyVaultOptions: null, + doReplaceEnvVar: true, + doReplaceAkvVar: true); + bool parsed = RuntimeConfigLoader.TryParseConfig(jsonConfig, out RuntimeConfig config, replacementSettings: replacementSettings); + Assert.IsTrue(parsed, "Config should parse successfully."); + Assert.IsNotNull(config); + + string actual = config.DataSource.ConnectionString; + Assert.IsTrue(actual.Contains("@env('env')"), "Nested @env pattern inside AKV secret should remain unexpanded."); + Assert.IsFalse(actual.Contains("SHOULD_NOT_APPEAR"), "Env var value should not be expanded inside AKV secret."); + Assert.IsTrue(actual.Contains("Application Name="), "Application Name should be appended for MSSQL when env replacement is enabled."); + + SqlConnectionStringBuilder builderOriginal = new(secretValueWithEnvPattern.Replace("Server=", "Data Source=").Replace("Database=", "Initial Catalog=")); + SqlConnectionStringBuilder builderActual = new(actual); + Assert.AreEqual(builderOriginal["Data Source"], builderActual["Data Source"], "Server/Data Source should match."); + Assert.AreEqual(builderOriginal["Initial Catalog"], builderActual["Initial Catalog"], "Database/Initial Catalog should match (with env pattern retained)."); + Assert.AreEqual(builderOriginal["User ID"], builderActual["User ID"], "User Id should match."); + Assert.AreEqual(builderOriginal["Password"], builderActual["Password"], "Password should match."); + } + finally + { + if (File.Exists(akvFilePath)) + { + File.Delete(akvFilePath); + } + + Environment.SetEnvironmentVariable("env", null); + } + } + + /// + /// Validates two-pass replacement where an env var resolves to an AKV pattern which then resolves to the secret value. + /// connection-string = @env('env_variable'), env_variable value = @akv('DBCONN'), AKV secret DBCONN holds the final connection string. + /// + [TestMethod] + public void TestEnvVariableResolvingToAkvPatternIsExpandedInSecondPass() + { + string akvFilePath = Path.Combine(Path.GetTempPath(), Guid.NewGuid().ToString() + ".akv"); + string finalSecretValue = "Server=localhost;Database=Test;User Id=sa;Password=XXXX;"; + File.WriteAllText(akvFilePath, $"DBCONN={finalSecretValue}\n"); + string escapedPath = akvFilePath.Replace("\\", "\\\\"); + Environment.SetEnvironmentVariable("env_variable", "@akv('DBCONN')"); + + string jsonConfig = $$""" + { + "$schema": "https://github.com/Azure/data-api-builder/releases/download/vmajor.minor.patch-alpha/dab.draft.schema.json", + "data-source": { + "database-type": "mssql", + "connection-string": "@env('env_variable')" + }, + "azure-key-vault": { + "endpoint": "{{escapedPath}}" + }, + "entities": { } + } + """; + + try + { + DeserializationVariableReplacementSettings replacementSettings = new( + azureKeyVaultOptions: null, + doReplaceEnvVar: true, + doReplaceAkvVar: true); + bool parsed = RuntimeConfigLoader.TryParseConfig(jsonConfig, out RuntimeConfig config, replacementSettings: replacementSettings); + Assert.IsTrue(parsed, "Config should parse successfully."); + Assert.IsNotNull(config); + + string expected = RuntimeConfigLoader.GetConnectionStringWithApplicationName(finalSecretValue); + SqlConnectionStringBuilder builderExpected = new(expected); + SqlConnectionStringBuilder builderActual = new(config.DataSource.ConnectionString); + Assert.AreEqual(builderExpected["Data Source"], builderActual["Data Source"], "Data Source should match."); + Assert.AreEqual(builderExpected["Initial Catalog"], builderActual["Initial Catalog"], "Initial Catalog should match."); + Assert.AreEqual(builderExpected["User ID"], builderActual["User ID"], "User ID should match."); + Assert.AreEqual(builderExpected["Password"], builderActual["Password"], "Password should match."); + Assert.IsTrue(builderActual.ApplicationName?.Contains("dab_"), "Application Name should be appended including product identifier."); + } + finally + { + if (File.Exists(akvFilePath)) + { + File.Delete(akvFilePath); + } + + Environment.SetEnvironmentVariable("env_variable", null); + } + } } } diff --git a/src/Service.Tests/UnitTests/SerializationDeserializationTests.cs b/src/Service.Tests/UnitTests/SerializationDeserializationTests.cs index 44978cd6aa..74d548fef4 100644 --- a/src/Service.Tests/UnitTests/SerializationDeserializationTests.cs +++ b/src/Service.Tests/UnitTests/SerializationDeserializationTests.cs @@ -8,6 +8,7 @@ using System.Reflection; using System.Text.Json; using System.Text.Json.Serialization; +using Azure.DataApiBuilder.Config; using Azure.DataApiBuilder.Config.DatabasePrimitives; using Azure.DataApiBuilder.Config.ObjectModel; using Azure.DataApiBuilder.Core.Services.MetadataProviders.Converters; @@ -187,7 +188,7 @@ public void TestSourceDefinitionCyclicObjectsSerializationDeserialization() _sourceDefinition.SourceEntityRelationshipMap.Add("persons", metadata); - // In serialization options we need ReferenceHandler = ReferenceHandler.Preserve, or else it doesnot seialize objects with cycle references + // In serialization options we need ReferenceHandler = ReferenceHandler.Preserve, or else it does not serialize objects with cycle references // SourceDefinition -> RelationShipMetadata -> ForeignKeyDefinition RelationshipPair ->DatabaseTable -> SourceDefinition Assert.ThrowsException(() => { @@ -489,5 +490,68 @@ private RelationShipPair GetRelationShipPair() }; return new(_databaseTable, table2); } + + /// + /// Verifies that when merging multiple runtime configs, if the child config omits + /// the azure-key-vault section, the merged result still contains the AzureKeyVaultOptions (including retry-policy) + /// inherited from the parent config. + /// + [TestMethod] + public void TestMergedConfigInheritsAzureKeyVaultOptions() + { + // Arrange + + // Parent config with AKV section. + string parentConfig = @"{ + ""data-source"": { ""database-type"": ""mssql"", ""connection-string"": ""Server=.;Database=Parent;Trusted_Connection=True;"" }, + ""runtime"": { ""rest"": { ""enabled"": true }, ""graphql"": { ""enabled"": true } }, + ""entities"": {}, + ""azure-key-vault"": { + ""endpoint"": ""https://myvault.vault.azure.net/"", + ""retry-policy"": { + ""mode"": ""fixed"", + ""max-count"": 7, + ""delay-seconds"": 3, + ""max-delay-seconds"": 15, + ""network-timeout-seconds"": 20 + } + } +}"; + + // Child config overrides some properties but omits azure-key-vault entirely. + string childConfig = @"{ + ""data-source"": { ""database-type"": ""mssql"", ""connection-string"": ""Server=.;Database=Child;Trusted_Connection=True;"" }, + ""runtime"": { ""rest"": { ""enabled"": true }, ""graphql"": { ""enabled"": true } }, + ""entities"": {} +}"; + // Act + + // Merge child over parent. + string mergedJson = MergeJsonProvider.Merge(parentConfig, childConfig); + + // Parse with AKV replacement enabled so extraction path executes. + DeserializationVariableReplacementSettings replacementSettings = new( + azureKeyVaultOptions: null, + doReplaceEnvVar: false, + doReplaceAkvVar: true); + + // Assert + + Assert.IsTrue(RuntimeConfigLoader.TryParseConfig(mergedJson, out RuntimeConfig mergedConfig, replacementSettings: replacementSettings), "Merged runtime config failed to parse."); + Assert.IsNotNull(mergedConfig, "Merged runtime config is null."); + + // Validate AKV inheritance. + Assert.IsNotNull(mergedConfig.AzureKeyVault, "AzureKeyVaultOptions should be inherited from base config."); + Assert.AreEqual("https://myvault.vault.azure.net/", mergedConfig.AzureKeyVault!.Endpoint, "Inherited AKV endpoint mismatch."); + Assert.IsNotNull(mergedConfig.AzureKeyVault.RetryPolicy, "RetryPolicy should be inherited."); + Assert.AreEqual(AKVRetryPolicyMode.Fixed, mergedConfig.AzureKeyVault.RetryPolicy!.Mode, "Inherited retry-policy mode mismatch."); + Assert.AreEqual(7, mergedConfig.AzureKeyVault.RetryPolicy.MaxCount, "Inherited retry-policy max-count mismatch."); + Assert.AreEqual(3, mergedConfig.AzureKeyVault.RetryPolicy.DelaySeconds, "Inherited retry-policy delay-seconds mismatch."); + Assert.AreEqual(15, mergedConfig.AzureKeyVault.RetryPolicy.MaxDelaySeconds, "Inherited retry-policy max-delay-seconds mismatch."); + Assert.AreEqual(20, mergedConfig.AzureKeyVault.RetryPolicy.NetworkTimeoutSeconds, "Inherited retry-policy network-timeout-seconds mismatch."); + + // Ensure child override for connection-string applied while AKV remained from base. + Assert.AreEqual("Server=.;Database=Child;Trusted_Connection=True;", mergedConfig.DataSource.ConnectionString, "Child connection-string override not applied."); + } } } diff --git a/src/Service.Tests/UnitTests/SqlQueryExecutorUnitTests.cs b/src/Service.Tests/UnitTests/SqlQueryExecutorUnitTests.cs index 908b7019c4..f549b1dd3e 100644 --- a/src/Service.Tests/UnitTests/SqlQueryExecutorUnitTests.cs +++ b/src/Service.Tests/UnitTests/SqlQueryExecutorUnitTests.cs @@ -115,7 +115,8 @@ await provider.Initialize( provider.GetConfig().ToJson(), graphQLSchema: null, connectionString: connectionString, - accessToken: CONFIG_TOKEN); + accessToken: CONFIG_TOKEN, + replacementSettings: new()); msSqlQueryExecutor = new(provider, dbExceptionParser.Object, queryExecutorLogger.Object, httpContextAccessor.Object); } } diff --git a/src/Service/Controllers/ConfigurationController.cs b/src/Service/Controllers/ConfigurationController.cs index be3f9bd727..4ad8fb40f4 100644 --- a/src/Service/Controllers/ConfigurationController.cs +++ b/src/Service/Controllers/ConfigurationController.cs @@ -91,8 +91,8 @@ public async Task Index([FromBody] ConfigurationPostParameters con configuration.Schema, configuration.ConnectionString, configuration.AccessToken, - replaceEnvVar: false, - replacementFailureMode: Config.Converters.EnvironmentVariableReplacementFailureMode.Ignore); + replacementSettings: new(azureKeyVaultOptions: null, doReplaceEnvVar: false, doReplaceAkvVar: false, envFailureMode: Config.Converters.EnvironmentVariableReplacementFailureMode.Ignore) + ); if (initResult && _configurationProvider.TryGetConfig(out _)) {