Skip to content
Open
66 changes: 43 additions & 23 deletions InertiaCore/Extensions/Configure.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
using Microsoft.AspNetCore.Mvc;
using Microsoft.AspNetCore.Mvc.ViewFeatures;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;

namespace InertiaCore.Extensions;

Expand All @@ -24,20 +25,52 @@ public static IApplicationBuilder UseInertia(this IApplicationBuilder app)
Inertia.Version(Vite.GetManifestHash);
}

app.Use(async (context, next) =>
// Check if TempData services are available for error bag functionality
CheckTempDataAvailability(app);

app.UseMiddleware<Middleware>();

return app;
}

private static void CheckTempDataAvailability(IApplicationBuilder app)
{
// Skip warning in test environments
var environment = app.ApplicationServices.GetService<Microsoft.AspNetCore.Hosting.IWebHostEnvironment>();
if (environment?.EnvironmentName == "Test" ||
(environment?.EnvironmentName != "Development" && IsTestEnvironment()))
{
if (context.IsInertiaRequest()
&& context.Request.Method == "GET"
&& context.Request.Headers[InertiaHeader.Version] != Inertia.GetVersion())
return;
}

try
{
var tempDataFactory = app.ApplicationServices.GetService<ITempDataDictionaryFactory>();
if (tempDataFactory == null)
{
await OnVersionChange(context, app);
return;
var logger = app.ApplicationServices.GetService<ILogger<IApplicationBuilder>>();
logger?.LogWarning("TempData services are not configured. Error bag functionality will be limited. " +
"Consider adding services.AddSession() and app.UseSession() to enable full error bag support.");
}
}
catch (Exception)
{
// If we can't check for TempData services, that's also a sign they might not be configured
var logger = app.ApplicationServices.GetService<ILogger<IApplicationBuilder>>();
logger?.LogWarning("Unable to verify TempData configuration. Error bag functionality may be limited. " +
"Ensure services.AddSession() and app.UseSession() are configured for full error bag support.");
}
}

await next();
});

return app;
private static bool IsTestEnvironment()
{
// Check if we're running in a test context by looking for common test assemblies
var assemblies = AppDomain.CurrentDomain.GetAssemblies();
return assemblies.Any(a =>
a.FullName?.Contains("nunit", StringComparison.OrdinalIgnoreCase) == true ||
a.FullName?.Contains("xunit", StringComparison.OrdinalIgnoreCase) == true ||
a.FullName?.Contains("mstest", StringComparison.OrdinalIgnoreCase) == true ||
a.FullName?.Contains("testhost", StringComparison.OrdinalIgnoreCase) == true);
}

public static IServiceCollection AddInertia(this IServiceCollection services,
Expand All @@ -64,17 +97,4 @@ public static IServiceCollection AddViteHelper(this IServiceCollection services,

return services;
}

private static async Task OnVersionChange(HttpContext context, IApplicationBuilder app)
{
var tempData = app.ApplicationServices.GetRequiredService<ITempDataDictionaryFactory>()
.GetTempData(context);

if (tempData.Any()) tempData.Keep();

context.Response.Headers.Override(InertiaHeader.Location, context.RequestedUri());
context.Response.StatusCode = (int)HttpStatusCode.Conflict;

await context.Response.CompleteAsync();
}
}
132 changes: 132 additions & 0 deletions InertiaCore/Extensions/InertiaExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.Extensions;
using Microsoft.AspNetCore.Mvc;
using Microsoft.AspNetCore.Mvc.ModelBinding;
using Microsoft.AspNetCore.Mvc.ViewFeatures;
using Microsoft.Extensions.DependencyInjection;
using System.Text;

namespace InertiaCore.Extensions;
Expand Down Expand Up @@ -63,4 +66,133 @@ internal static string MD5(this string s)

return sb.ToString();
}

/// <summary>
/// Gets the TempData dictionary for the current HTTP context.
/// </summary>
internal static ITempDataDictionary? GetTempData(this HttpContext context)
{
try
{
var tempDataFactory = context.RequestServices?.GetRequiredService<ITempDataDictionaryFactory>();
return tempDataFactory?.GetTempData(context);
}
catch (InvalidOperationException)
{
// Service provider not available, return null
return null;
}
}

/// <summary>
/// Sets validation errors in TempData for the specified error bag.
/// </summary>
public static void SetValidationErrors(this ITempDataDictionary tempData, Dictionary<string, string> errors, string bagName = "default")
{
// Deserialize existing error bags from JSON
var errorBags = new Dictionary<string, Dictionary<string, string>>();
if (tempData["__ValidationErrors"] is string existingJson && !string.IsNullOrEmpty(existingJson))
{
try
{
errorBags = JsonSerializer.Deserialize<Dictionary<string, Dictionary<string, string>>>(existingJson)
?? new Dictionary<string, Dictionary<string, string>>();
}
catch (JsonException)
{
// If deserialization fails, start fresh
errorBags = new Dictionary<string, Dictionary<string, string>>();
}
}

errorBags[bagName] = errors;

// Serialize back to JSON for storage
tempData["__ValidationErrors"] = JsonSerializer.Serialize(errorBags);
}

/// <summary>
/// Sets validation errors in TempData from ModelState for the specified error bag.
/// </summary>
public static void SetValidationErrors(this ITempDataDictionary tempData, ModelStateDictionary modelState, string bagName = "default")
{
var errors = modelState.ToDictionary(
kvp => kvp.Key,
kvp => kvp.Value?.Errors.FirstOrDefault()?.ErrorMessage ?? ""
);
tempData.SetValidationErrors(errors, bagName);
}

/// <summary>
/// Retrieve and clear validation errors from TempData, supporting error bags.
/// </summary>
public static Dictionary<string, string> GetAndClearValidationErrors(this ITempDataDictionary tempData, HttpRequest request)
{
var errors = new Dictionary<string, string>();

if (!tempData.ContainsKey("__ValidationErrors"))
return errors;

// Deserialize from JSON
Dictionary<string, Dictionary<string, string>> storedErrors;
if (tempData["__ValidationErrors"] is string jsonString && !string.IsNullOrEmpty(jsonString))
{
try
{
storedErrors = JsonSerializer.Deserialize<Dictionary<string, Dictionary<string, string>>>(jsonString) ?? new Dictionary<string, Dictionary<string, string>>();
}
catch (JsonException)
{
// If deserialization fails, return empty
return errors;
}
}
else
{
return errors;
}

// Check if there's a specific error bag in the request header
var errorBag = "default";
if (request.Headers.ContainsKey(InertiaHeader.ErrorBag))
{
errorBag = request.Headers[InertiaHeader.ErrorBag].ToString();
}

// If there's only the default bag and no specific bag requested, return the default bag directly
if (storedErrors.Count == 1 && storedErrors.ContainsKey("default") && errorBag == "default")
{
foreach (var kvp in storedErrors["default"])
{
errors[kvp.Key] = kvp.Value;
}
}
// If there are multiple bags or a specific bag is requested, return the named bag
else if (storedErrors.ContainsKey(errorBag))
{
foreach (var kvp in storedErrors[errorBag])
{
errors[kvp.Key] = kvp.Value;
}
}
// If no specific bag and multiple bags exist, return all bags
else if (errorBag == "default" && storedErrors.Count > 1)
{
// Return all error bags as nested structure
// This will be handled differently but for now just return default or first available
var firstBag = storedErrors.Values.FirstOrDefault();
if (firstBag != null)
{
foreach (var kvp in firstBag)
{
errors[kvp.Key] = kvp.Value;
}
}
}

// Clear the temp data after reading (one-time use)
tempData.Remove("__ValidationErrors");

return errors;
}
}
4 changes: 4 additions & 0 deletions InertiaCore/Inertia.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ public static class Inertia

internal static void UseFactory(IResponseFactory factory) => _factory = factory;

internal static void ResetFactory() => _factory = default!;

public static Response Render(string component, object? props = null) => _factory.Render(component, props);

public static Task<IHtmlContent> Head(dynamic model) => _factory.Head(model);
Expand All @@ -27,6 +29,8 @@ public static class Inertia

public static LocationResult Location(string url) => _factory.Location(url);

public static BackResult Back(string? fallbackUrl = null) => _factory.Back(fallbackUrl);

public static void Share(string key, object? value) => _factory.Share(key, value);

public static void Share(IDictionary<string, object?> data) => _factory.Share(data);
Expand Down
112 changes: 112 additions & 0 deletions InertiaCore/Middleware.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
using InertiaCore;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Mvc.ViewFeatures;
using System.Net;
using InertiaCore.Utils;
using Microsoft.AspNetCore.Builder;
using Microsoft.Extensions.DependencyInjection;
using InertiaCore.Extensions;

namespace InertiaCore;

public class Middleware
{
private readonly RequestDelegate _next;

public Middleware(RequestDelegate next)
{
_next = next;
}

public async Task InvokeAsync(HttpContext context)
{
if (context.IsInertiaRequest()
&& context.Request.Method == "GET"
&& context.Request.Headers[InertiaHeader.Version] != Inertia.GetVersion())
{
await OnVersionChange(context);
return;
}

await _next(context);

// Handle empty responses for Inertia requests
if (context.IsInertiaRequest()
&& context.Response.StatusCode == 200
&& await IsEmptyResponse(context))
{
await OnEmptyResponse(context);
}
}

private static async Task OnVersionChange(HttpContext context)
{
var tempData = context.RequestServices.GetRequiredService<ITempDataDictionaryFactory>()
.GetTempData(context);

if (tempData.Any()) tempData.Keep();

context.Response.Headers.Override(InertiaHeader.Location, context.RequestedUri());
context.Response.StatusCode = (int)HttpStatusCode.Conflict;

await context.Response.CompleteAsync();
}

private static async Task<bool> IsEmptyResponse(HttpContext context)
{
// Check if Content-Length is 0 or not set
if (context.Response.Headers.ContentLength.HasValue)
{
return context.Response.Headers.ContentLength.Value == 0;
}

// Check if response body is empty or only whitespace
if (context.Response.Body.CanSeek && context.Response.Body.Length >= 0)
{
var position = context.Response.Body.Position;

// Check if the stream has any content
if (context.Response.Body.Length == 0)
{
return true;
}

context.Response.Body.Seek(0, SeekOrigin.Begin);

using var reader = new StreamReader(context.Response.Body, leaveOpen: true);
var content = await reader.ReadToEndAsync();

context.Response.Body.Seek(position, SeekOrigin.Begin);

return string.IsNullOrWhiteSpace(content);
}

// For non-seekable streams, check if the response body position is still 0
// This indicates nothing has been written to the response
try
{
return context.Response.Body.Position == 0;
}
catch
{
// If we can't determine, assume it's not empty to be safe
return false;
}
}

private static async Task OnEmptyResponse(HttpContext context)
{
// Use Inertia.Back() to redirect back
var backResult = Inertia.Back();

// Determine the redirect URL using the same logic as BackResult
var referrer = context.Request.Headers.Referer.ToString();
var redirectUrl = !string.IsNullOrEmpty(referrer) ? referrer : "/";

// Set the appropriate headers and status code for a back redirect
context.Response.StatusCode = (int)HttpStatusCode.SeeOther;
context.Response.Headers.Override("Location", redirectUrl);

await context.Response.CompleteAsync();
}
}
Loading