Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 37 additions & 92 deletions src/Azure.DataApiBuilder.Mcp/BuiltInTools/CreateRecordTool.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,13 @@
using Azure.DataApiBuilder.Auth;
using Azure.DataApiBuilder.Config.DatabasePrimitives;
using Azure.DataApiBuilder.Config.ObjectModel;
using Azure.DataApiBuilder.Core.Authorization;
using Azure.DataApiBuilder.Core.Configurations;
using Azure.DataApiBuilder.Core.Models;
using Azure.DataApiBuilder.Core.Resolvers;
using Azure.DataApiBuilder.Core.Resolvers.Factories;
using Azure.DataApiBuilder.Core.Services;
using Azure.DataApiBuilder.Core.Services.MetadataProviders;
using Azure.DataApiBuilder.Mcp.Model;
using Azure.DataApiBuilder.Mcp.Utils;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Mvc;
using Microsoft.Extensions.DependencyInjection;
Expand Down Expand Up @@ -57,79 +56,64 @@ public async Task<CallToolResult> ExecuteAsync(
CancellationToken cancellationToken = default)
{
ILogger<CreateRecordTool>? logger = serviceProvider.GetService<ILogger<CreateRecordTool>>();
string toolName = GetToolMetadata().Name;
if (arguments == null)
{
return Utils.McpResponseBuilder.BuildErrorResult("Invalid Arguments", "No arguments provided", logger);
return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", "No arguments provided.", logger);
}

RuntimeConfigProvider runtimeConfigProvider = serviceProvider.GetRequiredService<RuntimeConfigProvider>();
if (!runtimeConfigProvider.TryGetConfig(out RuntimeConfig? runtimeConfig))
{
return Utils.McpResponseBuilder.BuildErrorResult("Invalid Configuration", "Runtime configuration not available", logger);
return McpResponseBuilder.BuildErrorResult(toolName, "InvalidConfiguration", "Runtime configuration not available.", logger);
}

if (runtimeConfig.McpDmlTools?.CreateRecord != true)
{
return Utils.McpResponseBuilder.BuildErrorResult(
"ToolDisabled",
"The create_record tool is disabled in the configuration.",
logger);
return McpErrorHelpers.ToolDisabled(toolName, logger);
}

try
{
cancellationToken.ThrowIfCancellationRequested();
JsonElement root = arguments.RootElement;

if (!root.TryGetProperty("entity", out JsonElement entityElement) ||
!root.TryGetProperty("data", out JsonElement dataElement))
if (!McpArgumentParser.TryParseEntityAndData(root, out string entityName, out JsonElement dataElement, out string parseError))
{
return Utils.McpResponseBuilder.BuildErrorResult("InvalidArguments", "Missing required arguments 'entity' or 'data'", logger);
return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", parseError, logger);
}

string entityName = entityElement.GetString() ?? string.Empty;
if (string.IsNullOrWhiteSpace(entityName))
if (!McpMetadataHelper.TryResolveMetadata(
entityName,
runtimeConfig,
serviceProvider,
out ISqlMetadataProvider sqlMetadataProvider,
out DatabaseObject dbObject,
out string dataSourceName,
out string metadataError))
{
return Utils.McpResponseBuilder.BuildErrorResult("InvalidArguments", "Entity name cannot be empty", logger);
}

string dataSourceName;
try
{
dataSourceName = runtimeConfig.GetDataSourceNameFromEntityName(entityName);
}
catch (Exception)
{
return Utils.McpResponseBuilder.BuildErrorResult("InvalidConfiguration", $"Entity '{entityName}' not found in configuration", logger);
}

IMetadataProviderFactory metadataProviderFactory = serviceProvider.GetRequiredService<IMetadataProviderFactory>();
ISqlMetadataProvider sqlMetadataProvider = metadataProviderFactory.GetMetadataProvider(dataSourceName);

DatabaseObject dbObject;
try
{
dbObject = sqlMetadataProvider.GetDatabaseObjectByKey(entityName);
}
catch (Exception)
{
return Utils.McpResponseBuilder.BuildErrorResult("InvalidConfiguration", $"Database object for entity '{entityName}' not found", logger);
return McpResponseBuilder.BuildErrorResult(toolName, "EntityNotFound", metadataError, logger);
}

// Create an HTTP context for authorization
IHttpContextAccessor httpContextAccessor = serviceProvider.GetRequiredService<IHttpContextAccessor>();
HttpContext httpContext = httpContextAccessor.HttpContext ?? new DefaultHttpContext();
IAuthorizationResolver authorizationResolver = serviceProvider.GetRequiredService<IAuthorizationResolver>();

if (httpContext is null || !authorizationResolver.IsValidRoleContext(httpContext))
if (!McpAuthorizationHelper.ValidateRoleContext(httpContext, authorizationResolver, out string roleCtxError))
{
return Utils.McpResponseBuilder.BuildErrorResult("PermissionDenied", "Permission denied: Unable to resolve a valid role context for update operation.", logger);
return McpErrorHelpers.PermissionDenied(toolName, entityName, "create", roleCtxError, logger);
}

// Validate that we have at least one role authorized for create
if (!TryResolveAuthorizedRole(httpContext, authorizationResolver, entityName, out string authError))
if (!McpAuthorizationHelper.TryResolveAuthorizedRole(
httpContext,
authorizationResolver,
entityName,
EntityActionOperation.Create,
out string? effectiveRole,
out string authError))
{
return Utils.McpResponseBuilder.BuildErrorResult("PermissionDenied", authError, logger);
return McpErrorHelpers.PermissionDenied(toolName, entityName, "create", authError, logger);
}

JsonElement insertPayloadRoot = dataElement.Clone();
Expand All @@ -150,12 +134,13 @@ public async Task<CallToolResult> ExecuteAsync(
}
catch (Exception ex)
{
return Utils.McpResponseBuilder.BuildErrorResult("ValidationFailed", $"Request validation failed: {ex.Message}", logger);
return McpResponseBuilder.BuildErrorResult(toolName, "ValidationFailed", $"Request validation failed: {ex.Message}", logger);
}
}
else
{
return Utils.McpResponseBuilder.BuildErrorResult(
return McpResponseBuilder.BuildErrorResult(
toolName,
"InvalidCreateTarget",
"The create_record tool is only available for tables.",
logger);
Expand All @@ -169,7 +154,7 @@ public async Task<CallToolResult> ExecuteAsync(

if (result is CreatedResult createdResult)
{
return Utils.McpResponseBuilder.BuildSuccessResult(
return McpResponseBuilder.BuildSuccessResult(
new Dictionary<string, object?>
{
["entity"] = entityName,
Expand All @@ -184,14 +169,15 @@ public async Task<CallToolResult> ExecuteAsync(
bool isError = objectResult.StatusCode.HasValue && objectResult.StatusCode.Value >= 400 && objectResult.StatusCode.Value != 403;
if (isError)
{
return Utils.McpResponseBuilder.BuildErrorResult(
return McpResponseBuilder.BuildErrorResult(
toolName,
"CreateFailed",
$"Failed to create record in entity '{entityName}'. Error: {JsonSerializer.Serialize(objectResult.Value)}",
logger);
}
else
{
return Utils.McpResponseBuilder.BuildSuccessResult(
return McpResponseBuilder.BuildSuccessResult(
new Dictionary<string, object?>
{
["entity"] = entityName,
Expand All @@ -206,14 +192,15 @@ public async Task<CallToolResult> ExecuteAsync(
{
if (result is null)
{
return Utils.McpResponseBuilder.BuildErrorResult(
return McpResponseBuilder.BuildErrorResult(
toolName,
"UnexpectedError",
$"Mutation engine returned null result for entity '{entityName}'",
logger);
}
else
{
return Utils.McpResponseBuilder.BuildSuccessResult(
return McpResponseBuilder.BuildSuccessResult(
new Dictionary<string, object?>
{
["entity"] = entityName,
Expand All @@ -226,50 +213,8 @@ public async Task<CallToolResult> ExecuteAsync(
}
catch (Exception ex)
{
return Utils.McpResponseBuilder.BuildErrorResult("Error", $"Error: {ex.Message}", logger);
return McpResponseBuilder.BuildErrorResult(toolName, "Error", $"Error: {ex.Message}", logger);
}
}

private static bool TryResolveAuthorizedRole(
HttpContext httpContext,
IAuthorizationResolver authorizationResolver,
string entityName,
out string error)
{
error = string.Empty;

string roleHeader = httpContext.Request.Headers[AuthorizationResolver.CLIENT_ROLE_HEADER].ToString();

if (string.IsNullOrWhiteSpace(roleHeader))
{
error = "Client role header is missing or empty.";
return false;
}

string[] roles = roleHeader
.Split(',', StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries)
.Distinct(StringComparer.OrdinalIgnoreCase)
.ToArray();

if (roles.Length == 0)
{
error = "Client role header is missing or empty.";
return false;
}

foreach (string role in roles)
{
bool allowed = authorizationResolver.AreRoleAndOperationDefinedForEntity(
entityName, role, EntityActionOperation.Create);

if (allowed)
{
return true;
}
}

error = "You do not have permission to create records for this entity.";
return false;
}
}
}
Loading