diff --git a/src/Framework/App.Ref/src/Microsoft.AspNetCore.App.Ref.sfxproj b/src/Framework/App.Ref/src/Microsoft.AspNetCore.App.Ref.sfxproj index f0a83c2e417b..bd685a879b2d 100644 --- a/src/Framework/App.Ref/src/Microsoft.AspNetCore.App.Ref.sfxproj +++ b/src/Framework/App.Ref/src/Microsoft.AspNetCore.App.Ref.sfxproj @@ -76,6 +76,11 @@ Private="false" OutputItemType="AspNetCoreAnalyzer" ReferenceOutputAssembly="false" /> + + diff --git a/src/Http/Http.Abstractions/src/Metadata/IDisableValidationMetadata.cs b/src/Http/Http.Abstractions/src/Metadata/IDisableValidationMetadata.cs new file mode 100644 index 000000000000..f45eff5a09fe --- /dev/null +++ b/src/Http/Http.Abstractions/src/Metadata/IDisableValidationMetadata.cs @@ -0,0 +1,12 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.AspNetCore.Http.Metadata; + +/// +/// A marker interface which can be used to identify metadata that disables validation +/// on a given endpoint. +/// +public interface IDisableValidationMetadata +{ +} diff --git a/src/Http/Http.Abstractions/src/PublicAPI.Unshipped.txt b/src/Http/Http.Abstractions/src/PublicAPI.Unshipped.txt index d7c55b7606ff..2524b5a9cefe 100644 --- a/src/Http/Http.Abstractions/src/PublicAPI.Unshipped.txt +++ b/src/Http/Http.Abstractions/src/PublicAPI.Unshipped.txt @@ -1,4 +1,44 @@ #nullable enable +abstract Microsoft.AspNetCore.Http.Validation.ValidatableParameterInfo.GetValidationAttributes() -> System.ComponentModel.DataAnnotations.ValidationAttribute![]! +abstract Microsoft.AspNetCore.Http.Validation.ValidatablePropertyInfo.GetValidationAttributes() -> System.ComponentModel.DataAnnotations.ValidationAttribute![]! +Microsoft.AspNetCore.Http.Metadata.IDisableValidationMetadata Microsoft.AspNetCore.Http.ProducesResponseTypeMetadata.Description.get -> string? Microsoft.AspNetCore.Http.ProducesResponseTypeMetadata.Description.set -> void Microsoft.AspNetCore.Http.Metadata.IProducesResponseTypeMetadata.Description.get -> string? +Microsoft.AspNetCore.Http.Validation.IValidatableInfo +Microsoft.AspNetCore.Http.Validation.IValidatableInfo.ValidateAsync(object? value, Microsoft.AspNetCore.Http.Validation.ValidateContext! context, System.Threading.CancellationToken cancellationToken) -> System.Threading.Tasks.Task! +Microsoft.AspNetCore.Http.Validation.IValidatableInfoResolver +Microsoft.AspNetCore.Http.Validation.IValidatableInfoResolver.TryGetValidatableParameterInfo(System.Reflection.ParameterInfo! parameterInfo, out Microsoft.AspNetCore.Http.Validation.IValidatableInfo? validatableInfo) -> bool +Microsoft.AspNetCore.Http.Validation.IValidatableInfoResolver.TryGetValidatableTypeInfo(System.Type! type, out Microsoft.AspNetCore.Http.Validation.IValidatableInfo? validatableInfo) -> bool +Microsoft.AspNetCore.Http.Validation.ValidatableParameterInfo +Microsoft.AspNetCore.Http.Validation.ValidatableParameterInfo.ValidatableParameterInfo(System.Type! parameterType, string! name, string! displayName) -> void +Microsoft.AspNetCore.Http.Validation.ValidatablePropertyInfo +Microsoft.AspNetCore.Http.Validation.ValidatablePropertyInfo.ValidatablePropertyInfo(System.Type! declaringType, System.Type! propertyType, string! name, string! displayName) -> void +Microsoft.AspNetCore.Http.Validation.ValidatableTypeAttribute +Microsoft.AspNetCore.Http.Validation.ValidatableTypeAttribute.ValidatableTypeAttribute() -> void +Microsoft.AspNetCore.Http.Validation.ValidatableTypeInfo +Microsoft.AspNetCore.Http.Validation.ValidatableTypeInfo.ValidatableTypeInfo(System.Type! type, System.Collections.Generic.IReadOnlyList! members) -> void +Microsoft.AspNetCore.Http.Validation.ValidateContext +Microsoft.AspNetCore.Http.Validation.ValidateContext.CurrentDepth.get -> int +Microsoft.AspNetCore.Http.Validation.ValidateContext.CurrentDepth.set -> void +Microsoft.AspNetCore.Http.Validation.ValidateContext.CurrentValidationPath.get -> string! +Microsoft.AspNetCore.Http.Validation.ValidateContext.CurrentValidationPath.set -> void +Microsoft.AspNetCore.Http.Validation.ValidateContext.ValidateContext() -> void +Microsoft.AspNetCore.Http.Validation.ValidateContext.ValidationContext.get -> System.ComponentModel.DataAnnotations.ValidationContext? +Microsoft.AspNetCore.Http.Validation.ValidateContext.ValidationContext.set -> void +Microsoft.AspNetCore.Http.Validation.ValidateContext.ValidationErrors.get -> System.Collections.Generic.Dictionary? +Microsoft.AspNetCore.Http.Validation.ValidateContext.ValidationErrors.set -> void +Microsoft.AspNetCore.Http.Validation.ValidateContext.ValidationOptions.get -> Microsoft.AspNetCore.Http.Validation.ValidationOptions! +Microsoft.AspNetCore.Http.Validation.ValidateContext.ValidationOptions.set -> void +Microsoft.AspNetCore.Http.Validation.ValidationOptions +Microsoft.AspNetCore.Http.Validation.ValidationOptions.MaxDepth.get -> int +Microsoft.AspNetCore.Http.Validation.ValidationOptions.MaxDepth.set -> void +Microsoft.AspNetCore.Http.Validation.ValidationOptions.Resolvers.get -> System.Collections.Generic.IList! +Microsoft.AspNetCore.Http.Validation.ValidationOptions.TryGetValidatableParameterInfo(System.Reflection.ParameterInfo! parameterInfo, out Microsoft.AspNetCore.Http.Validation.IValidatableInfo? validatableInfo) -> bool +Microsoft.AspNetCore.Http.Validation.ValidationOptions.TryGetValidatableTypeInfo(System.Type! type, out Microsoft.AspNetCore.Http.Validation.IValidatableInfo? validatableTypeInfo) -> bool +Microsoft.AspNetCore.Http.Validation.ValidationOptions.ValidationOptions() -> void +Microsoft.Extensions.DependencyInjection.ValidationServiceCollectionExtensions +static Microsoft.Extensions.DependencyInjection.ValidationServiceCollectionExtensions.AddValidation(this Microsoft.Extensions.DependencyInjection.IServiceCollection! services, System.Action? configureOptions = null) -> Microsoft.Extensions.DependencyInjection.IServiceCollection! +virtual Microsoft.AspNetCore.Http.Validation.ValidatableParameterInfo.ValidateAsync(object? value, Microsoft.AspNetCore.Http.Validation.ValidateContext! context, System.Threading.CancellationToken cancellationToken) -> System.Threading.Tasks.Task! +virtual Microsoft.AspNetCore.Http.Validation.ValidatablePropertyInfo.ValidateAsync(object? value, Microsoft.AspNetCore.Http.Validation.ValidateContext! context, System.Threading.CancellationToken cancellationToken) -> System.Threading.Tasks.Task! +virtual Microsoft.AspNetCore.Http.Validation.ValidatableTypeInfo.ValidateAsync(object? value, Microsoft.AspNetCore.Http.Validation.ValidateContext! context, System.Threading.CancellationToken cancellationToken) -> System.Threading.Tasks.Task! diff --git a/src/Http/Http.Abstractions/src/Validation/IValidatableInfo.cs b/src/Http/Http.Abstractions/src/Validation/IValidatableInfo.cs new file mode 100644 index 000000000000..91766f69cfc1 --- /dev/null +++ b/src/Http/Http.Abstractions/src/Validation/IValidatableInfo.cs @@ -0,0 +1,19 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.AspNetCore.Http.Validation; + +/// +/// Represents an interface for validating a value. +/// +public interface IValidatableInfo +{ + /// + /// Validates the specified . + /// + /// The value to validate. + /// The validation context. + /// A cancellation token to support cancellation of the validation. + /// A representing the asynchronous operation. + Task ValidateAsync(object? value, ValidateContext context, CancellationToken cancellationToken); +} diff --git a/src/Http/Http.Abstractions/src/Validation/IValidatableInfoResolver.cs b/src/Http/Http.Abstractions/src/Validation/IValidatableInfoResolver.cs new file mode 100644 index 000000000000..b4d4abe31c2d --- /dev/null +++ b/src/Http/Http.Abstractions/src/Validation/IValidatableInfoResolver.cs @@ -0,0 +1,32 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Diagnostics.CodeAnalysis; +using System.Reflection; + +namespace Microsoft.AspNetCore.Http.Validation; + +/// +/// Provides an interface for resolving the validation information associated +/// with a given or . +/// +public interface IValidatableInfoResolver +{ + /// + /// Gets validation information for the specified type. + /// + /// The type to get validation information for. + /// + /// The output parameter that will contain the validatable information if found. + /// + /// if the validatable type information was found; otherwise, false. + bool TryGetValidatableTypeInfo(Type type, [NotNullWhen(true)] out IValidatableInfo? validatableInfo); + + /// + /// Gets validation information for the specified parameter. + /// + /// The parameter to get validation information for. + /// The output parameter that will contain the validatable information if found. + /// if the validatable parameter information was found; otherwise, false. + bool TryGetValidatableParameterInfo(ParameterInfo parameterInfo, [NotNullWhen(true)] out IValidatableInfo? validatableInfo); +} diff --git a/src/Http/Http.Abstractions/src/Validation/RuntimeValidatableParameterInfoResolver.cs b/src/Http/Http.Abstractions/src/Validation/RuntimeValidatableParameterInfoResolver.cs new file mode 100644 index 000000000000..59949eba767c --- /dev/null +++ b/src/Http/Http.Abstractions/src/Validation/RuntimeValidatableParameterInfoResolver.cs @@ -0,0 +1,61 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.ComponentModel.DataAnnotations; +using System.Diagnostics.CodeAnalysis; +using System.Linq; +using System.Reflection; + +namespace Microsoft.AspNetCore.Http.Validation; + +internal sealed class RuntimeValidatableParameterInfoResolver : IValidatableInfoResolver +{ + // TODO: the implementation currently relies on static discovery of types. + public bool TryGetValidatableTypeInfo(Type type, [NotNullWhen(true)] out IValidatableInfo? validatableInfo) + { + validatableInfo = null; + return false; + } + + public bool TryGetValidatableParameterInfo(ParameterInfo parameterInfo, [NotNullWhen(true)] out IValidatableInfo? validatableInfo) + { + if (parameterInfo.Name == null) + { + throw new InvalidOperationException($"Encountered a parameter of type '{parameterInfo.ParameterType}' without a name. Parameters must have a name."); + } + + var validationAttributes = parameterInfo + .GetCustomAttributes() + .ToArray(); + validatableInfo = new RuntimeValidatableParameterInfo( + parameterType: parameterInfo.ParameterType, + name: parameterInfo.Name, + displayName: GetDisplayName(parameterInfo), + validationAttributes: validationAttributes + ); + return true; + } + + private static string GetDisplayName(ParameterInfo parameterInfo) + { + var displayAttribute = parameterInfo.GetCustomAttribute(); + if (displayAttribute != null) + { + return displayAttribute.Name ?? parameterInfo.Name!; + } + + return parameterInfo.Name!; + } + + private sealed class RuntimeValidatableParameterInfo( + Type parameterType, + string name, + string displayName, + ValidationAttribute[] validationAttributes) : + ValidatableParameterInfo(parameterType, name, displayName) + { + protected override ValidationAttribute[] GetValidationAttributes() => _validationAttributes; + + private readonly ValidationAttribute[] _validationAttributes = validationAttributes; + } +} diff --git a/src/Http/Http.Abstractions/src/Validation/TypeExtensions.cs b/src/Http/Http.Abstractions/src/Validation/TypeExtensions.cs new file mode 100644 index 000000000000..244f0b3fe888 --- /dev/null +++ b/src/Http/Http.Abstractions/src/Validation/TypeExtensions.cs @@ -0,0 +1,134 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections; +using System.ComponentModel.DataAnnotations; +using System.Diagnostics.CodeAnalysis; + +namespace Microsoft.AspNetCore.Http.Validation; + +internal static class TypeExtensions +{ + /// + /// Determines whether the specified type is an enumerable type. + /// + /// The type to check. + /// if the type is enumerable; otherwise, . + public static bool IsEnumerable(this Type type) + { + // Check if type itself is an IEnumerable + if (type.IsGenericType && + (type.GetGenericTypeDefinition() == typeof(IEnumerable<>) || + type.GetGenericTypeDefinition() == typeof(ICollection<>) || + type.GetGenericTypeDefinition() == typeof(List<>) || + type.GetGenericTypeDefinition() == typeof(IList<>))) + { + return true; + } + + // Or an array + if (type.IsArray) + { + return true; + } + + // Then evaluate if it implements IEnumerable and is not a string + if (typeof(IEnumerable).IsAssignableFrom(type) && + type != typeof(string)) + { + return true; + } + + return false; + } + + /// + /// Determines whether the specified type is a nullable type. + /// + /// The type to check. + /// if the type is nullable; otherwise, . + public static bool IsNullable(this Type type) + { + if (type.IsValueType) + { + return false; + } + + if (type.IsGenericType && + type.GetGenericTypeDefinition() == typeof(Nullable<>)) + { + return true; + } + + return false; + } + + /// + /// Tries to get the from the specified array of validation attributes. + /// + /// The array of to search. + /// The found if available, otherwise null. + /// if a is found; otherwise, . + public static bool TryGetRequiredAttribute(this ValidationAttribute[] attributes, [NotNullWhen(true)] out RequiredAttribute? requiredAttribute) + { + foreach (var attribute in attributes) + { + if (attribute is RequiredAttribute requiredAttr) + { + requiredAttribute = requiredAttr; + return true; + } + } + + requiredAttribute = null; + return false; + } + + /// + /// Gets all types that the specified type implements or inherits from. + /// + /// The type to analyze. + /// A collection containing all implemented interfaces and all base types of the given type. + public static List GetAllImplementedTypes([DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.Interfaces)] this Type type) + { + ArgumentNullException.ThrowIfNull(type); + + var implementedTypes = new List(); + + // Yield all interfaces directly and indirectly implemented by this type + foreach (var interfaceType in type.GetInterfaces()) + { + implementedTypes.Add(interfaceType); + } + + // Finally, walk up the inheritance chain + var baseType = type.BaseType; + while (baseType != null && baseType != typeof(object)) + { + implementedTypes.Add(baseType); + baseType = baseType.BaseType; + } + + return implementedTypes; + } + + /// + /// Determines whether the specified type implements the given interface. + /// + /// The type to check. + /// The interface type to check for. + /// True if the type implements the specified interface; otherwise, false. + public static bool ImplementsInterface(this Type type, Type interfaceType) + { + ArgumentNullException.ThrowIfNull(type); + ArgumentNullException.ThrowIfNull(interfaceType); + + // Check if interfaceType is actually an interface + if (!interfaceType.IsInterface) + { + throw new ArgumentException($"Type {interfaceType.FullName} is not an interface.", nameof(interfaceType)); + } + + return interfaceType.IsAssignableFrom(type); + } +} diff --git a/src/Http/Http.Abstractions/src/Validation/ValidatableParameterInfo.cs b/src/Http/Http.Abstractions/src/Validation/ValidatableParameterInfo.cs new file mode 100644 index 000000000000..58895fc00014 --- /dev/null +++ b/src/Http/Http.Abstractions/src/Validation/ValidatableParameterInfo.cs @@ -0,0 +1,140 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections; +using System.ComponentModel.DataAnnotations; +using System.Diagnostics; + +namespace Microsoft.AspNetCore.Http.Validation; + +/// +/// Contains validation information for a parameter. +/// +public abstract class ValidatableParameterInfo : IValidatableInfo +{ + private RequiredAttribute? _requiredAttribute; + + /// + /// Creates a new instance of . + /// + /// The associated with the parameter. + /// The parameter name. + /// The display name for the parameter. + protected ValidatableParameterInfo( + Type parameterType, + string name, + string displayName) + { + ParameterType = parameterType; + Name = name; + DisplayName = displayName; + } + + /// + /// Gets the parameter type. + /// + internal Type ParameterType { get; } + + /// + /// Gets the parameter name. + /// + internal string Name { get; } + + /// + /// Gets the display name for the parameter. + /// + internal string DisplayName { get; } + + /// + /// Gets the validation attributes for this parameter. + /// + /// An array of validation attributes to apply to this parameter. + protected abstract ValidationAttribute[] GetValidationAttributes(); + + /// + /// + /// If the parameter is a collection, each item in the collection will be validated. + /// If the parameter is not a collection but has a validatable type, the single value will be validated. + /// + public virtual async Task ValidateAsync(object? value, ValidateContext context, CancellationToken cancellationToken) + { + Debug.Assert(context.ValidationContext is not null); + + // Skip validation if value is null and parameter is optional + if (value == null && ParameterType.IsNullable()) + { + return; + } + + context.ValidationContext.DisplayName = DisplayName; + context.ValidationContext.MemberName = Name; + + var validationAttributes = GetValidationAttributes(); + + if (_requiredAttribute is not null || validationAttributes.TryGetRequiredAttribute(out _requiredAttribute)) + { + var result = _requiredAttribute.GetValidationResult(value, context.ValidationContext); + + if (result is not null && result != ValidationResult.Success && result.ErrorMessage is not null) + { + var key = string.IsNullOrEmpty(context.CurrentValidationPath) ? Name : $"{context.CurrentValidationPath}.{Name}"; + context.AddValidationError(key, [result.ErrorMessage]); + return; + } + } + + // Validate against validation attributes + for (var i = 0; i < validationAttributes.Length; i++) + { + var attribute = validationAttributes[i]; + try + { + var result = attribute.GetValidationResult(value, context.ValidationContext); + if (result is not null && result != ValidationResult.Success && result.ErrorMessage is not null) + { + var key = string.IsNullOrEmpty(context.CurrentValidationPath) ? Name : $"{context.CurrentValidationPath}.{Name}"; + context.AddOrExtendValidationErrors(key, [result.ErrorMessage]); + } + } + catch (Exception ex) + { + var key = string.IsNullOrEmpty(context.CurrentValidationPath) ? Name : $"{context.CurrentValidationPath}.{Name}"; + context.AddValidationError(key, [ex.Message]); + } + } + + // If the parameter is a collection, validate each item + if (ParameterType.IsEnumerable() && value is IEnumerable enumerable) + { + var index = 0; + var currentPrefix = context.CurrentValidationPath; + + foreach (var item in enumerable) + { + if (item != null) + { + context.CurrentValidationPath = string.IsNullOrEmpty(currentPrefix) + ? $"{Name}[{index}]" + : $"{currentPrefix}.{Name}[{index}]"; + + if (context.ValidationOptions.TryGetValidatableTypeInfo(item.GetType(), out var validatableType)) + { + await validatableType.ValidateAsync(item, context, cancellationToken); + } + } + index++; + } + + context.CurrentValidationPath = currentPrefix; + } + // If not enumerable, validate the single value + else if (value != null) + { + var valueType = value.GetType(); + if (context.ValidationOptions.TryGetValidatableTypeInfo(valueType, out var validatableType)) + { + await validatableType.ValidateAsync(value, context, cancellationToken); + } + } + } +} diff --git a/src/Http/Http.Abstractions/src/Validation/ValidatablePropertyInfo.cs b/src/Http/Http.Abstractions/src/Validation/ValidatablePropertyInfo.cs new file mode 100644 index 000000000000..f6dfe94e688b --- /dev/null +++ b/src/Http/Http.Abstractions/src/Validation/ValidatablePropertyInfo.cs @@ -0,0 +1,175 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.ComponentModel.DataAnnotations; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; + +namespace Microsoft.AspNetCore.Http.Validation; + +/// +/// Contains validation information for a member of a type. +/// +public abstract class ValidatablePropertyInfo : IValidatableInfo +{ + private RequiredAttribute? _requiredAttribute; + + /// + /// Creates a new instance of . + /// + protected ValidatablePropertyInfo( + [param: DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicProperties)] + Type declaringType, + Type propertyType, + string name, + string displayName) + { + DeclaringType = declaringType; + PropertyType = propertyType; + Name = name; + DisplayName = displayName; + } + + /// + /// Gets the member type. + /// + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicProperties)] + internal Type DeclaringType { get; } + + /// + /// Gets the member type. + /// + internal Type PropertyType { get; } + + /// + /// Gets the member name. + /// + internal string Name { get; } + + /// + /// Gets the display name for the member as designated by the . + /// + internal string DisplayName { get; } + + /// + /// Gets the validation attributes for this member. + /// + /// An array of validation attributes to apply to this member. + protected abstract ValidationAttribute[] GetValidationAttributes(); + + /// + public virtual async Task ValidateAsync(object? value, ValidateContext context, CancellationToken cancellationToken) + { + Debug.Assert(context.ValidationContext is not null); + + var property = DeclaringType.GetProperty(Name) ?? throw new InvalidOperationException($"Property '{Name}' not found on type '{DeclaringType.Name}'."); + var propertyValue = property.GetValue(value); + var validationAttributes = GetValidationAttributes(); + + // Calculate and save the current path + var originalPrefix = context.CurrentValidationPath; + if (string.IsNullOrEmpty(originalPrefix)) + { + context.CurrentValidationPath = Name; + } + else + { + context.CurrentValidationPath = $"{originalPrefix}.{Name}"; + } + + context.ValidationContext.DisplayName = DisplayName; + context.ValidationContext.MemberName = Name; + + // Check required attribute first + if (_requiredAttribute is not null || validationAttributes.TryGetRequiredAttribute(out _requiredAttribute)) + { + var result = _requiredAttribute.GetValidationResult(propertyValue, context.ValidationContext); + + if (result is not null && result != ValidationResult.Success && result.ErrorMessage is not null) + { + context.AddValidationError(context.CurrentValidationPath, [result.ErrorMessage]); + context.CurrentValidationPath = originalPrefix; // Restore prefix + return; + } + } + + // Validate any other attributes + ValidateValue(propertyValue, context.CurrentValidationPath, validationAttributes); + + // Check if we've reached the maximum depth before validating complex properties + if (context.CurrentDepth >= context.ValidationOptions.MaxDepth) + { + throw new InvalidOperationException( + $"Maximum validation depth of {context.ValidationOptions.MaxDepth} exceeded at '{context.CurrentValidationPath}' in '{DeclaringType.Name}.{Name}'. " + + "This is likely caused by a circular reference in the object graph. " + + "Consider increasing the MaxDepth in ValidationOptions if deeper validation is required."); + } + + // Increment depth counter + context.CurrentDepth++; + + try + { + // Handle enumerable values + if (PropertyType.IsEnumerable() && propertyValue is System.Collections.IEnumerable enumerable) + { + var index = 0; + var currentPrefix = context.CurrentValidationPath; + + foreach (var item in enumerable) + { + context.CurrentValidationPath = $"{currentPrefix}[{index}]"; + + if (item != null) + { + var itemType = item.GetType(); + if (context.ValidationOptions.TryGetValidatableTypeInfo(itemType, out var validatableType)) + { + await validatableType.ValidateAsync(item, context, cancellationToken); + } + } + + index++; + } + + // Restore prefix to the property name before validating the next item + context.CurrentValidationPath = currentPrefix; + } + else if (propertyValue != null) + { + // Validate as a complex object + var valueType = propertyValue.GetType(); + if (context.ValidationOptions.TryGetValidatableTypeInfo(valueType, out var validatableType)) + { + await validatableType.ValidateAsync(propertyValue, context, cancellationToken); + } + } + } + finally + { + // Always decrement the depth counter and restore prefix + context.CurrentDepth--; + context.CurrentValidationPath = originalPrefix; + } + + void ValidateValue(object? val, string errorPrefix, ValidationAttribute[] validationAttributes) + { + for (var i = 0; i < validationAttributes.Length; i++) + { + var attribute = validationAttributes[i]; + try + { + var result = attribute.GetValidationResult(val, context.ValidationContext); + if (result is not null && result != ValidationResult.Success && result.ErrorMessage is not null) + { + context.AddOrExtendValidationErrors(errorPrefix.TrimStart('.'), [result.ErrorMessage]); + } + } + catch (Exception ex) + { + context.AddOrExtendValidationErrors(errorPrefix.TrimStart('.'), [ex.Message]); + } + } + } + } +} diff --git a/src/Http/Http.Abstractions/src/Validation/ValidatableTypeAttribute.cs b/src/Http/Http.Abstractions/src/Validation/ValidatableTypeAttribute.cs new file mode 100644 index 000000000000..0ea382c59a55 --- /dev/null +++ b/src/Http/Http.Abstractions/src/Validation/ValidatableTypeAttribute.cs @@ -0,0 +1,13 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.AspNetCore.Http.Validation; + +/// +/// Indicates that a type is validatable to support discovery by the +/// validations generator. +/// +[AttributeUsage(AttributeTargets.Class)] +public sealed class ValidatableTypeAttribute : Attribute +{ +} diff --git a/src/Http/Http.Abstractions/src/Validation/ValidatableTypeInfo.cs b/src/Http/Http.Abstractions/src/Validation/ValidatableTypeInfo.cs new file mode 100644 index 000000000000..6bf3ff182d19 --- /dev/null +++ b/src/Http/Http.Abstractions/src/Validation/ValidatableTypeInfo.cs @@ -0,0 +1,126 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.ComponentModel.DataAnnotations; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Linq; + +namespace Microsoft.AspNetCore.Http.Validation; + +/// +/// Contains validation information for a type. +/// +public abstract class ValidatableTypeInfo : IValidatableInfo +{ + private readonly int _membersCount; + private readonly List _subTypes; + + /// + /// Creates a new instance of . + /// + /// The type being validated. + /// The members that can be validated. + protected ValidatableTypeInfo( + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.Interfaces)] Type type, + IReadOnlyList members) + { + Type = type; + Members = members; + _membersCount = members.Count; + _subTypes = type.GetAllImplementedTypes(); + } + + /// + /// The type being validated. + /// + internal Type Type { get; } + + /// + /// The members that can be validated. + /// + internal IReadOnlyList Members { get; } + + /// + public virtual async Task ValidateAsync(object? value, ValidateContext context, CancellationToken cancellationToken) + { + Debug.Assert(context.ValidationContext is not null); + if (value == null) + { + return; + } + + // Check if we've exceeded the maximum depth + if (context.CurrentDepth >= context.ValidationOptions.MaxDepth) + { + throw new InvalidOperationException( + $"Maximum validation depth of {context.ValidationOptions.MaxDepth} exceeded at '{context.CurrentValidationPath}' in '{Type.Name}'. " + + "This is likely caused by a circular reference in the object graph. " + + "Consider increasing the MaxDepth in ValidationOptions if deeper validation is required."); + } + + var originalPrefix = context.CurrentValidationPath; + + try + { + var actualType = value.GetType(); + + // First validate members + for (var i = 0; i < _membersCount; i++) + { + await Members[i].ValidateAsync(value, context, cancellationToken); + context.CurrentValidationPath = originalPrefix; + } + + // Then validate sub-types if any + foreach (var subType in _subTypes) + { + // Check if the actual type is assignable to the sub-type + // and validate it if it is + if (subType.IsAssignableFrom(actualType)) + { + if (context.ValidationOptions.TryGetValidatableTypeInfo(subType, out var subTypeInfo)) + { + await subTypeInfo.ValidateAsync(value, context, cancellationToken); + context.CurrentValidationPath = originalPrefix; + } + } + } + + // Finally validate IValidatableObject if implemented + if (Type.ImplementsInterface(typeof(IValidatableObject)) && value is IValidatableObject validatable) + { + // Important: Set the DisplayName to the type name for top-level validations + // and restore the original validation context properties + var originalDisplayName = context.ValidationContext.DisplayName; + var originalMemberName = context.ValidationContext.MemberName; + + // Set the display name to the class name for IValidatableObject validation + context.ValidationContext.DisplayName = Type.Name; + context.ValidationContext.MemberName = null; + + var validationResults = validatable.Validate(context.ValidationContext); + foreach (var validationResult in validationResults) + { + if (validationResult != ValidationResult.Success && validationResult.ErrorMessage is not null) + { + var memberName = validationResult.MemberNames.First(); + var key = string.IsNullOrEmpty(originalPrefix) ? + memberName : + $"{originalPrefix}.{memberName}"; + + context.AddOrExtendValidationError(key, validationResult.ErrorMessage); + } + } + + // Restore the original validation context properties + context.ValidationContext.DisplayName = originalDisplayName; + context.ValidationContext.MemberName = originalMemberName; + } + } + finally + { + context.CurrentValidationPath = originalPrefix; + } + } +} diff --git a/src/Http/Http.Abstractions/src/Validation/ValidateContext.cs b/src/Http/Http.Abstractions/src/Validation/ValidateContext.cs new file mode 100644 index 000000000000..a78ca7cabfb5 --- /dev/null +++ b/src/Http/Http.Abstractions/src/Validation/ValidateContext.cs @@ -0,0 +1,81 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.ComponentModel.DataAnnotations; + +namespace Microsoft.AspNetCore.Http.Validation; + +/// +/// Represents the context for validating a validatable object. +/// +public sealed class ValidateContext +{ + /// + /// Gets or sets the validation context used for validating objects that implement or have . + /// This context provides access to service provider and other validation metadata. + /// + public ValidationContext? ValidationContext { get; set; } + + /// + /// Gets or sets the prefix used to identify the current object being validated in a complex object graph. + /// This is used to build property paths in validation error messages (e.g., "Customer.Address.Street"). + /// + public string CurrentValidationPath { get; set; } = string.Empty; + + /// + /// Gets or sets the validation options that control validation behavior, + /// including validation depth limits and resolver registration. + /// + public required ValidationOptions ValidationOptions { get; set; } + + /// + /// Gets or sets the dictionary of validation errors collected during validation. + /// Keys are property names or paths, and values are arrays of error messages. + /// In the default implementation, this dictionary is initialized when the first error is added. + /// + public Dictionary? ValidationErrors { get; set; } + + /// + /// Gets or sets the current depth in the validation hierarchy. + /// This is used to prevent stack overflows from circular references. + /// + public int CurrentDepth { get; set; } + + internal void AddValidationError(string key, string[] error) + { + ValidationErrors ??= []; + + ValidationErrors[key] = error; + } + + internal void AddOrExtendValidationErrors(string key, string[] errors) + { + ValidationErrors ??= []; + + if (ValidationErrors.TryGetValue(key, out var existingErrors)) + { + var newErrors = new string[existingErrors.Length + errors.Length]; + existingErrors.CopyTo(newErrors, 0); + errors.CopyTo(newErrors, existingErrors.Length); + ValidationErrors[key] = newErrors; + } + else + { + ValidationErrors[key] = errors; + } + } + + internal void AddOrExtendValidationError(string key, string error) + { + ValidationErrors ??= []; + + if (ValidationErrors.TryGetValue(key, out var existingErrors) && !existingErrors.Contains(error)) + { + ValidationErrors[key] = [.. existingErrors, error]; + } + else + { + ValidationErrors[key] = [error]; + } + } +} diff --git a/src/Http/Http.Abstractions/src/Validation/ValidationOptions.cs b/src/Http/Http.Abstractions/src/Validation/ValidationOptions.cs new file mode 100644 index 000000000000..d27ba37eaf13 --- /dev/null +++ b/src/Http/Http.Abstractions/src/Validation/ValidationOptions.cs @@ -0,0 +1,72 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Diagnostics.CodeAnalysis; +using System.Reflection; + +namespace Microsoft.AspNetCore.Http.Validation; + +/// +/// Provides configuration options for the validation system. +/// +public class ValidationOptions +{ + /// + /// Gets the list of resolvers that provide validation metadata for types and parameters. + /// Resolvers are processed in order, with the first resolver providing a non-null result being used. + /// + /// + /// Source-generated resolvers are typically inserted at the beginning of this list + /// to ensure they are checked before any runtime-based resolvers. + /// + public IList Resolvers { get; } = []; + + /// + /// Gets or sets the maximum depth for validation of nested objects. + /// This prevents stack overflows from circular references or extremely deep object graphs. + /// Default value is 32. + /// + public int MaxDepth { get; set; } = 32; + + /// + /// Attempts to get validation information for the specified type. + /// + /// The type to get validation information for. + /// When this method returns, contains the validation information for the specified type, + /// if the type was found; otherwise, null. + /// true if validation information was found for the specified type; otherwise, false. + public bool TryGetValidatableTypeInfo(Type type, [NotNullWhen(true)] out IValidatableInfo? validatableTypeInfo) + { + foreach (var resolver in Resolvers) + { + if (resolver.TryGetValidatableTypeInfo(type, out validatableTypeInfo)) + { + return true; + } + } + + validatableTypeInfo = null; + return false; + } + + /// + /// Attempts to get validation information for the specified parameter. + /// + /// The parameter to get validation information for. + /// When this method returns, contains the validation information for the specified parameter, + /// if validation information was found; otherwise, null. + /// true if validation information was found for the specified parameter; otherwise, false. + public bool TryGetValidatableParameterInfo(ParameterInfo parameterInfo, [NotNullWhen(true)] out IValidatableInfo? validatableInfo) + { + foreach (var resolver in Resolvers) + { + if (resolver.TryGetValidatableParameterInfo(parameterInfo, out validatableInfo)) + { + return true; + } + } + + validatableInfo = null; + return false; + } +} diff --git a/src/Http/Http.Abstractions/src/Validation/ValidationServiceCollectionExtensions.cs b/src/Http/Http.Abstractions/src/Validation/ValidationServiceCollectionExtensions.cs new file mode 100644 index 000000000000..77a128842ea4 --- /dev/null +++ b/src/Http/Http.Abstractions/src/Validation/ValidationServiceCollectionExtensions.cs @@ -0,0 +1,32 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Microsoft.AspNetCore.Http.Validation; + +namespace Microsoft.Extensions.DependencyInjection; + +/// +/// Extension methods for adding validation services. +/// +public static class ValidationServiceCollectionExtensions +{ + /// + /// Adds the validation services to the specified . + /// + /// The to add the services to. + /// An optional action to configure the . + /// The for chaining. + public static IServiceCollection AddValidation(this IServiceCollection services, Action? configureOptions = null) + { + services.Configure(options => + { + if (configureOptions is not null) + { + configureOptions(options); + } + // Support ParameterInfo resolution at runtime + options.Resolvers.Add(new RuntimeValidatableParameterInfoResolver()); + }); + return services; + } +} diff --git a/src/Http/Http.Abstractions/test/Validation/ValidatableInfoResolverTests.cs b/src/Http/Http.Abstractions/test/Validation/ValidatableInfoResolverTests.cs new file mode 100644 index 000000000000..6960162a13c5 --- /dev/null +++ b/src/Http/Http.Abstractions/test/Validation/ValidatableInfoResolverTests.cs @@ -0,0 +1,221 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.ComponentModel.DataAnnotations; +using System.Reflection; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Options; +using Moq; + +namespace Microsoft.AspNetCore.Http.Validation.Tests; + +public class ValidatableInfoResolverTests +{ + public delegate void TryGetValidatableTypeInfoCallback(Type type, out IValidatableInfo? validatableInfo); + public delegate void TryGetValidatableParameterInfoCallback(ParameterInfo parameter, out IValidatableInfo? validatableInfo); + + [Fact] + public void GetValidatableTypeInfo_ReturnsNull_ForNonValidatableType() + { + // Arrange + var resolver = new Mock(); + IValidatableInfo? validatableInfo = null; + resolver.Setup(r => r.TryGetValidatableTypeInfo(It.IsAny(), out validatableInfo)).Returns(false); + + // Act + var result = resolver.Object.TryGetValidatableTypeInfo(typeof(NonValidatableType), out validatableInfo); + + // Assert + Assert.False(result); + Assert.Null(validatableInfo); + } + + [Fact] + public void GetValidatableTypeInfo_ReturnsTypeInfo_ForValidatableType() + { + // Arrange + var mockTypeInfo = new Mock( + typeof(ValidatableType), + Array.Empty()).Object; + + var resolver = new Mock(); + IValidatableInfo? validatableInfo = null; + resolver + .Setup(r => r.TryGetValidatableTypeInfo(typeof(ValidatableType), out validatableInfo)) + .Callback(new TryGetValidatableTypeInfoCallback((t, out info) => + { + info = mockTypeInfo; // Set the out parameter to our mock + })) + .Returns(true); + + // Act + var result = resolver.Object.TryGetValidatableTypeInfo(typeof(ValidatableType), out validatableInfo); + + // Assert + Assert.True(result); + var validatableTypeInfo = Assert.IsAssignableFrom(validatableInfo); + Assert.Equal(typeof(ValidatableType), validatableTypeInfo.Type); + } + + [Fact] + public void GetValidatableParameterInfo_ReturnsNull_ForNonValidatableParameter() + { + // Arrange + var method = typeof(TestMethods).GetMethod(nameof(TestMethods.MethodWithNonValidatableParam))!; + var parameter = method.GetParameters()[0]; + + var resolver = new Mock(); + IValidatableInfo? validatableInfo = null; + resolver.Setup(r => r.TryGetValidatableParameterInfo(It.IsAny(), out validatableInfo)).Returns(false); + + // Act + var result = resolver.Object.TryGetValidatableParameterInfo(parameter, out validatableInfo); + + // Assert + Assert.False(result); + } + + [Fact] + public void GetValidatableParameterInfo_ReturnsParameterInfo_ForValidatableParameter() + { + // Arrange + var method = typeof(TestMethods).GetMethod(nameof(TestMethods.MethodWithValidatableParam))!; + var parameter = method.GetParameters()[0]; + + var mockParamInfo = new Mock( + typeof(string), + "model", + "model").Object; + + var resolver = new Mock(); + + // Setup using the same pattern as in the type info test + resolver.Setup(r => r.TryGetValidatableParameterInfo(parameter, out It.Ref.IsAny)) + .Callback(new TryGetValidatableParameterInfoCallback((ParameterInfo p, out IValidatableInfo? info) => + { + info = mockParamInfo; // Set the out parameter to our mock + })) + .Returns(true); + + // Act + var result = resolver.Object.TryGetValidatableParameterInfo(parameter, out var validatableInfo); + + // Assert + Assert.True(result); + var validatableParamInfo = Assert.IsAssignableFrom(validatableInfo); + Assert.Equal("model", validatableParamInfo.Name); + } + + [Fact] + public void ResolversChain_ProcessesInCorrectOrder() + { + // Arrange + var services = new ServiceCollection(); + + var resolver1 = new Mock(); + var resolver2 = new Mock(); + var resolver3 = new Mock(); + + // Create the object that will be returned by resolver2 + var mockTypeInfo = new Mock(typeof(ValidatableType), Array.Empty()).Object; + + // Setup resolver1 to return false (doesn't handle this type) + resolver1 + .Setup(r => r.TryGetValidatableTypeInfo(typeof(ValidatableType), out It.Ref.IsAny)) + .Callback(new TryGetValidatableTypeInfoCallback((Type t, out IValidatableInfo? info) => + { + info = null; + })) + .Returns(false); + + // Setup resolver2 to return true and set the mock type info + resolver2 + .Setup(r => r.TryGetValidatableTypeInfo(typeof(ValidatableType), out It.Ref.IsAny)) + .Callback(new TryGetValidatableTypeInfoCallback((Type t, out IValidatableInfo? info) => + { + info = mockTypeInfo; + })) + .Returns(true); + + services.AddValidation(Options => + { + Options.Resolvers.Add(resolver1.Object); + Options.Resolvers.Add(resolver2.Object); + Options.Resolvers.Add(resolver3.Object); + }); + + var serviceProvider = services.BuildServiceProvider(); + var validationOptions = serviceProvider.GetRequiredService>().Value; + + // Act + var result = validationOptions.TryGetValidatableTypeInfo(typeof(ValidatableType), out var validatableInfo); + + // Assert + Assert.True(result); + Assert.NotNull(validatableInfo); + Assert.Equal(typeof(ValidatableType), ((ValidatableTypeInfo)validatableInfo).Type); + + // Verify that resolvers were called in the expected order + resolver1.Verify(r => r.TryGetValidatableTypeInfo(typeof(ValidatableType), out It.Ref.IsAny), Times.Once); + resolver2.Verify(r => r.TryGetValidatableTypeInfo(typeof(ValidatableType), out It.Ref.IsAny), Times.Once); + resolver3.Verify(r => r.TryGetValidatableTypeInfo(typeof(ValidatableType), out It.Ref.IsAny), Times.Never); + } + + // Test types + private class NonValidatableType { } + + [ValidatableType] + private class ValidatableType + { + [Required] + public string Name { get; set; } = ""; + } + + private static class TestMethods + { + public static void MethodWithNonValidatableParam(NonValidatableType param) { } + public static void MethodWithValidatableParam(ValidatableType model) { } + } + + // Test implementations + private class TestValidatablePropertyInfo : ValidatablePropertyInfo + { + private readonly ValidationAttribute[] _validationAttributes; + + public TestValidatablePropertyInfo( + Type containingType, + Type propertyType, + string name, + string displayName, + ValidationAttribute[] validationAttributes) + : base(containingType, propertyType, name, displayName) + { + _validationAttributes = validationAttributes; + } + + protected override ValidationAttribute[] GetValidationAttributes() => _validationAttributes; + } + + private class TestValidatableParameterInfo : ValidatableParameterInfo + { + private readonly ValidationAttribute[] _validationAttributes; + + public TestValidatableParameterInfo( + Type parameterType, + string name, + string displayName, + ValidationAttribute[] validationAttributes) + : base(parameterType, name, displayName) + { + _validationAttributes = validationAttributes; + } + + protected override ValidationAttribute[] GetValidationAttributes() => _validationAttributes; + } + + private class TestValidatableTypeInfo( + Type type, + ValidatablePropertyInfo[] members) : ValidatableTypeInfo(type, members) + { + } +} diff --git a/src/Http/Http.Abstractions/test/Validation/ValidatableParameterInfoTests.cs b/src/Http/Http.Abstractions/test/Validation/ValidatableParameterInfoTests.cs new file mode 100644 index 000000000000..c89f182c92f9 --- /dev/null +++ b/src/Http/Http.Abstractions/test/Validation/ValidatableParameterInfoTests.cs @@ -0,0 +1,430 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.ComponentModel.DataAnnotations; +using System.Diagnostics.CodeAnalysis; +using System.Reflection; +using Microsoft.Extensions.DependencyInjection; + +namespace Microsoft.AspNetCore.Http.Validation.Tests; + +public class ValidatableParameterInfoTests +{ + [Fact] + public async Task Validate_RequiredParameter_AddsErrorWhenNull() + { + // Arrange + var paramInfo = CreateTestParameterInfo( + parameterType: typeof(string), + name: "testParam", + displayName: "Test Parameter", + validationAttributes: [new RequiredAttribute()]); + + var context = CreateValidatableContext(); + + // Act + await paramInfo.ValidateAsync(null, context, default); + + // Assert + var errors = context.ValidationErrors; + Assert.NotNull(errors); + var error = Assert.Single(errors); + Assert.Equal("testParam", error.Key); + Assert.Equal("The Test Parameter field is required.", error.Value.Single()); + } + + [Fact] + public async Task Validate_RequiredParameter_ShortCircuitsOtherValidations() + { + // Arrange + var paramInfo = CreateTestParameterInfo( + parameterType: typeof(string), + name: "testParam", + displayName: "Test Parameter", + // Most ValidationAttributes skip validation if the value is null + // so we use a custom one that always fails to assert on the behavior here + validationAttributes: [new RequiredAttribute(), new CustomTestValidationAttribute()]); + + var context = CreateValidatableContext(); + + // Act + await paramInfo.ValidateAsync(null, context, default); + + // Assert + var errors = context.ValidationErrors; + Assert.NotNull(errors); + var error = Assert.Single(errors); + Assert.Equal("testParam", error.Key); + Assert.Equal("The Test Parameter field is required.", error.Value.Single()); + } + + [Fact] + public async Task Validate_SkipsValidation_WhenNullAndNotRequired() + { + // Arrange + var paramInfo = CreateTestParameterInfo( + parameterType: typeof(string), + name: "testParam", + displayName: "Test Parameter", + validationAttributes: [new StringLengthAttribute(10)]); + + var context = CreateValidatableContext(); + + // Act + await paramInfo.ValidateAsync(null, context, default); + + // Assert + var errors = context.ValidationErrors; + Assert.Null(errors); // No errors added + } + + [Fact] + public async Task Validate_WithRangeAttribute_ValidatesCorrectly() + { + // Arrange + var paramInfo = CreateTestParameterInfo( + parameterType: typeof(int), + name: "testParam", + displayName: "Test Parameter", + validationAttributes: [new RangeAttribute(10, 100)]); + + var context = CreateValidatableContext(); + + // Act + await paramInfo.ValidateAsync(5, context, default); + + // Assert + var errors = context.ValidationErrors; + Assert.NotNull(errors); + var error = Assert.Single(errors); + Assert.Equal("testParam", error.Key); + Assert.Equal("The field Test Parameter must be between 10 and 100.", error.Value.First()); + } + + [Fact] + public async Task Validate_WithDisplayNameAttribute_UsesDisplayNameInErrorMessage() + { + // Arrange + var paramInfo = CreateTestParameterInfo( + parameterType: typeof(string), + name: "testParam", + displayName: "Custom Display Name", + validationAttributes: [new RequiredAttribute()]); + + var context = CreateValidatableContext(); + + // Act + await paramInfo.ValidateAsync(null, context, default); + + // Assert + var errors = context.ValidationErrors; + Assert.NotNull(errors); + var error = Assert.Single(errors); + Assert.Equal("testParam", error.Key); + // The error message should use the display name + Assert.Equal("The Custom Display Name field is required.", error.Value.First()); + } + + [Fact] + public async Task Validate_WhenValidatableTypeHasErrors_AddsNestedErrors() + { + // Arrange + var personTypeInfo = new TestValidatableTypeInfo( + typeof(Person), + [ + new TestValidatablePropertyInfo( + typeof(Person), + typeof(string), + "Name", + "Name", + [new RequiredAttribute()]) + ]); + + var paramInfo = CreateTestParameterInfo( + parameterType: typeof(Person), + name: "person", + displayName: "Person", + validationAttributes: []); + + var typeMapping = new Dictionary + { + { typeof(Person), personTypeInfo } + }; + + var context = CreateValidatableContext(typeMapping); + var person = new Person(); // Name is null, so should fail validation + + // Act + await paramInfo.ValidateAsync(person, context, default); + + // Assert + var errors = context.ValidationErrors; + Assert.NotNull(errors); + var error = Assert.Single(errors); + Assert.Equal("Name", error.Key); + Assert.Equal("The Name field is required.", error.Value[0]); + } + + [Fact] + public async Task Validate_WithEnumerableOfValidatableType_ValidatesEachItem() + { + // Arrange + var personTypeInfo = new TestValidatableTypeInfo( + typeof(Person), + [ + new TestValidatablePropertyInfo( + typeof(Person), + typeof(string), + "Name", + "Name", + [new RequiredAttribute()]) + ]); + + var paramInfo = CreateTestParameterInfo( + parameterType: typeof(IEnumerable), + name: "people", + displayName: "People", + validationAttributes: []); + + var typeMapping = new Dictionary + { + { typeof(Person), personTypeInfo } + }; + + var context = CreateValidatableContext(typeMapping); + var people = new List + { + new() { Name = "Valid" }, + new() // Name is null, should fail + }; + + // Act + await paramInfo.ValidateAsync(people, context, default); + + // Assert + var errors = context.ValidationErrors; + Assert.NotNull(errors); + var error = Assert.Single(errors); + Assert.Equal("people[1].Name", error.Key); + Assert.Equal("The Name field is required.", error.Value[0]); + } + + [Fact] + public async Task Validate_MultipleErrorsOnSameParameter_CollectsAllErrors() + { + // Arrange + var paramInfo = CreateTestParameterInfo( + parameterType: typeof(int), + name: "testParam", + displayName: "Test Parameter", + validationAttributes: + [ + new RangeAttribute(10, 100) { ErrorMessage = "Range error" }, + new CustomTestValidationAttribute { ErrorMessage = "Custom error" } + ]); + + var context = CreateValidatableContext(); + + // Act + await paramInfo.ValidateAsync(5, context, default); + + // Assert + var errors = context.ValidationErrors; + Assert.NotNull(errors); + var error = Assert.Single(errors); + Assert.Equal("testParam", error.Key); + Assert.Collection(error.Value, + e => Assert.Equal("Range error", e), + e => Assert.Equal("Custom error", e)); + } + + [Fact] + public async Task Validate_WithContextPrefix_AddsErrorsWithCorrectPrefix() + { + // Arrange + var paramInfo = CreateTestParameterInfo( + parameterType: typeof(int), + name: "testParam", + displayName: "Test Parameter", + validationAttributes: [new RangeAttribute(10, 100)]); + + var context = CreateValidatableContext(); + context.CurrentValidationPath = "parent"; + + // Act + await paramInfo.ValidateAsync(5, context, default); + + // Assert + var errors = context.ValidationErrors; + Assert.NotNull(errors); + var error = Assert.Single(errors); + Assert.Equal("parent.testParam", error.Key); + Assert.Equal("The field Test Parameter must be between 10 and 100.", error.Value.First()); + } + + [Fact] + public async Task Validate_ExceptionDuringValidation_CapturesExceptionAsError() + { + // Arrange + var paramInfo = CreateTestParameterInfo( + parameterType: typeof(string), + name: "testParam", + displayName: "Test Parameter", + validationAttributes: [new ThrowingValidationAttribute()]); + + var context = CreateValidatableContext(); + + // Act + await paramInfo.ValidateAsync("test", context, default); + + // Assert + var errors = context.ValidationErrors; + Assert.NotNull(errors); + var error = Assert.Single(errors); + Assert.Equal("testParam", error.Key); + Assert.Equal("Test exception", error.Value.First()); + } + + private TestValidatableParameterInfo CreateTestParameterInfo( + Type parameterType, + string name, + string displayName, + ValidationAttribute[] validationAttributes) + { + return new TestValidatableParameterInfo( + parameterType, + name, + displayName, + validationAttributes); + } + + private ValidateContext CreateValidatableContext( + Dictionary? typeMapping = null) + { + var serviceProvider = new ServiceCollection().BuildServiceProvider(); + var validationContext = new ValidationContext(new object(), serviceProvider, null); + + return new ValidateContext + { + ValidationContext = validationContext, + ValidationOptions = new TestValidationOptions(typeMapping ?? new Dictionary()) + }; + } + + private class TestValidatableParameterInfo : ValidatableParameterInfo + { + private readonly ValidationAttribute[] _validationAttributes; + + public TestValidatableParameterInfo( + Type parameterType, + string name, + string displayName, + ValidationAttribute[] validationAttributes) + : base(parameterType, name, displayName) + { + _validationAttributes = validationAttributes; + } + + protected override ValidationAttribute[] GetValidationAttributes() => _validationAttributes; + } + + private class TestValidatablePropertyInfo : ValidatablePropertyInfo + { + private readonly ValidationAttribute[] _validationAttributes; + + public TestValidatablePropertyInfo( + Type containingType, + Type propertyType, + string name, + string displayName, + ValidationAttribute[] validationAttributes) + : base(containingType, propertyType, name, displayName) + { + _validationAttributes = validationAttributes; + } + + protected override ValidationAttribute[] GetValidationAttributes() => _validationAttributes; + } + + private class TestValidatableTypeInfo( + Type type, + ValidatablePropertyInfo[] members) : ValidatableTypeInfo(type, members) + { + } + + private class TestValidationOptions : ValidationOptions + { + public TestValidationOptions(Dictionary typeInfoMappings) + { + // Create a custom resolver that uses the dictionary + var resolver = new DictionaryBasedResolver(typeInfoMappings); + + // Add it to the resolvers collection + Resolvers.Add(resolver); + } + + // Private resolver implementation that uses a dictionary lookup + private class DictionaryBasedResolver : IValidatableInfoResolver + { + private readonly Dictionary _typeInfoMappings; + + public DictionaryBasedResolver(Dictionary typeInfoMappings) + { + _typeInfoMappings = typeInfoMappings; + } + + public ValidatableTypeInfo? TryGetValidatableTypeInfo(Type type) + { + _typeInfoMappings.TryGetValue(type, out var info); + return info; + } + + public ValidatableParameterInfo? GetValidatableParameterInfo(ParameterInfo parameterInfo) + { + // Not implemented in the test + return null; + } + + public bool TryGetValidatableTypeInfo(Type type, [NotNullWhen(true)] out IValidatableInfo? validatableInfo) + { + if (_typeInfoMappings.TryGetValue(type, out var validatableTypeInfo)) + { + validatableInfo = validatableTypeInfo; + return true; + } + validatableInfo = null; + return false; + } + + public bool TryGetValidatableParameterInfo(ParameterInfo parameterInfo, [NotNullWhen(true)] out IValidatableInfo? validatableInfo) + { + validatableInfo = null; + return false; + } + } + } + + // Test data classes and validation attributes + + private class Person + { + public string? Name { get; set; } + } + + private class CustomTestValidationAttribute : ValidationAttribute + { + public override bool IsValid(object? value) + { + // Always fail for testing + return false; + } + } + + private class ThrowingValidationAttribute : ValidationAttribute + { + public override bool IsValid(object? value) + { + throw new InvalidOperationException("Test exception"); + } + } +} diff --git a/src/Http/Http.Abstractions/test/Validation/ValidatableTypeInfoTests.cs b/src/Http/Http.Abstractions/test/Validation/ValidatableTypeInfoTests.cs new file mode 100644 index 000000000000..fe75387e8e22 --- /dev/null +++ b/src/Http/Http.Abstractions/test/Validation/ValidatableTypeInfoTests.cs @@ -0,0 +1,710 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.ComponentModel.DataAnnotations; +using System.Diagnostics.CodeAnalysis; +using System.Reflection; + +namespace Microsoft.AspNetCore.Http.Validation.Tests; + +public class ValidatableTypeInfoTests +{ + [Fact] + public async Task Validate_ValidatesComplexType_WithNestedProperties() + { + // Arrange + var personType = new TestValidatableTypeInfo( + typeof(Person), + [ + CreatePropertyInfo(typeof(Person), typeof(string), "Name", "Name", + [new RequiredAttribute()]), + CreatePropertyInfo(typeof(Person), typeof(int), "Age", "Age", + [new RangeAttribute(0, 120)]), + CreatePropertyInfo(typeof(Person), typeof(Address), "Address", "Address", + []) + ]); + + var addressType = new TestValidatableTypeInfo( + typeof(Address), + [ + CreatePropertyInfo(typeof(Address), typeof(string), "Street", "Street", + [new RequiredAttribute()]), + CreatePropertyInfo(typeof(Address), typeof(string), "City", "City", + [new RequiredAttribute()]) + ]); + + var validationOptions = new TestValidationOptions(new Dictionary + { + { typeof(Person), personType }, + { typeof(Address), addressType } + }); + + var context = new ValidateContext + { + ValidationOptions = validationOptions, + }; + + var personWithMissingRequiredFields = new Person + { + Age = 150, // Invalid age + Address = new Address() // Missing required City and Street + }; + context.ValidationContext = new ValidationContext(personWithMissingRequiredFields); + + // Act + await personType.ValidateAsync(personWithMissingRequiredFields, context, default); + + // Assert + Assert.NotNull(context.ValidationErrors); + Assert.Collection(context.ValidationErrors, + kvp => + { + Assert.Equal("Name", kvp.Key); + Assert.Equal("The Name field is required.", kvp.Value.First()); + }, + kvp => + { + Assert.Equal("Age", kvp.Key); + Assert.Equal("The field Age must be between 0 and 120.", kvp.Value.First()); + }, + kvp => + { + Assert.Equal("Address.Street", kvp.Key); + Assert.Equal("The Street field is required.", kvp.Value.First()); + }, + kvp => + { + Assert.Equal("Address.City", kvp.Key); + Assert.Equal("The City field is required.", kvp.Value.First()); + }); + } + + [Fact] + public async Task Validate_HandlesIValidatableObject_Implementation() + { + // Arrange + var employeeType = new TestValidatableTypeInfo( + typeof(Employee), + [ + CreatePropertyInfo(typeof(Employee), typeof(string), "Name", "Name", + [new RequiredAttribute()]), + CreatePropertyInfo(typeof(Employee), typeof(string), "Department", "Department", + []), + CreatePropertyInfo(typeof(Employee), typeof(decimal), "Salary", "Salary", + []) + ]); + + var context = new ValidateContext + { + ValidationOptions = new TestValidationOptions(new Dictionary + { + { typeof(Employee), employeeType } + }) + }; + + var employee = new Employee + { + Name = "John Doe", + Department = "IT", + Salary = -5000 // Negative salary will trigger IValidatableObject validation + }; + context.ValidationContext = new ValidationContext(employee); + + // Act + await employeeType.ValidateAsync(employee, context, default); + + // Assert + Assert.NotNull(context.ValidationErrors); + var error = Assert.Single(context.ValidationErrors); + Assert.Equal("Salary", error.Key); + Assert.Equal("Salary must be a positive value.", error.Value.First()); + } + + [Fact] + public async Task Validate_HandlesPolymorphicTypes_WithSubtypes() + { + // Arrange + var baseType = new TestValidatableTypeInfo( + typeof(Vehicle), + [ + CreatePropertyInfo(typeof(Vehicle), typeof(string), "Make", "Make", + [new RequiredAttribute()]), + CreatePropertyInfo(typeof(Vehicle), typeof(string), "Model", "Model", + [new RequiredAttribute()]) + ]); + + var derivedType = new TestValidatableTypeInfo( + typeof(Car), + [ + CreatePropertyInfo(typeof(Car), typeof(int), "Doors", "Doors", + [new RangeAttribute(2, 5)]) + ]); + + var context = new ValidateContext + { + ValidationOptions = new TestValidationOptions(new Dictionary + { + { typeof(Vehicle), baseType }, + { typeof(Car), derivedType } + }) + }; + + var car = new Car + { + // Missing Make and Model (required in base type) + Doors = 7 // Invalid number of doors + }; + context.ValidationContext = new ValidationContext(car); + + // Act + await derivedType.ValidateAsync(car, context, default); + + // Assert + Assert.NotNull(context.ValidationErrors); + Assert.Collection(context.ValidationErrors, + kvp => + { + Assert.Equal("Doors", kvp.Key); + Assert.Equal("The field Doors must be between 2 and 5.", kvp.Value.First()); + }, + kvp => + { + Assert.Equal("Make", kvp.Key); + Assert.Equal("The Make field is required.", kvp.Value.First()); + }, + kvp => + { + Assert.Equal("Model", kvp.Key); + Assert.Equal("The Model field is required.", kvp.Value.First()); + }); + } + + [Fact] + public async Task Validate_HandlesCollections_OfValidatableTypes() + { + // Arrange + var itemType = new TestValidatableTypeInfo( + typeof(OrderItem), + [ + CreatePropertyInfo(typeof(OrderItem), typeof(string), "ProductName", "ProductName", + [new RequiredAttribute()]), + CreatePropertyInfo(typeof(OrderItem), typeof(int), "Quantity", "Quantity", + [new RangeAttribute(1, 100)]) + ]); + + var orderType = new TestValidatableTypeInfo( + typeof(Order), + [ + CreatePropertyInfo(typeof(Order), typeof(string), "OrderNumber", "OrderNumber", + [new RequiredAttribute()]), + CreatePropertyInfo(typeof(Order), typeof(List), "Items", "Items", + []) + ]); + + var context = new ValidateContext + { + ValidationOptions = new TestValidationOptions(new Dictionary + { + { typeof(OrderItem), itemType }, + { typeof(Order), orderType } + }) + }; + + var order = new Order + { + OrderNumber = "ORD-12345", + Items = + [ + new OrderItem { ProductName = "Valid Product", Quantity = 5 }, + new OrderItem { /* Missing ProductName (required) */ Quantity = 0 /* Invalid quantity */ }, + new OrderItem { ProductName = "Another Product", Quantity = 200 /* Invalid quantity */ } + ] + }; + context.ValidationContext = new ValidationContext(order); + + // Act + await orderType.ValidateAsync(order, context, default); + + // Assert + Assert.NotNull(context.ValidationErrors); + Assert.Collection(context.ValidationErrors, + kvp => + { + Assert.Equal("Items[1].ProductName", kvp.Key); + Assert.Equal("The ProductName field is required.", kvp.Value.First()); + }, + kvp => + { + Assert.Equal("Items[1].Quantity", kvp.Key); + Assert.Equal("The field Quantity must be between 1 and 100.", kvp.Value.First()); + }, + kvp => + { + Assert.Equal("Items[2].Quantity", kvp.Key); + Assert.Equal("The field Quantity must be between 1 and 100.", kvp.Value.First()); + }); + } + + [Fact] + public async Task Validate_HandlesNullValues_Appropriately() + { + // Arrange + var personType = new TestValidatableTypeInfo( + typeof(Person), + [ + CreatePropertyInfo(typeof(Person), typeof(string), "Name", "Name", + []), + CreatePropertyInfo(typeof(Person), typeof(Address), "Address", "Address", + []) + ]); + + var context = new ValidateContext + { + ValidationOptions = new TestValidationOptions(new Dictionary + { + { typeof(Person), personType } + }) + }; + + var person = new Person + { + Name = null, + Address = null + }; + context.ValidationContext = new ValidationContext(person); + + // Act + await personType.ValidateAsync(person, context, default); + + // Assert + Assert.Null(context.ValidationErrors); // No validation errors for nullable properties with null values + } + + [Fact] + public async Task Validate_RespectsMaxDepthOption_ForCircularReferences() + { + // Arrange + // Create a type that can contain itself (circular reference) + var nodeType = new TestValidatableTypeInfo( + typeof(TreeNode), + [ + CreatePropertyInfo(typeof(TreeNode), typeof(string), "Name", "Name", + [new RequiredAttribute()]), + CreatePropertyInfo(typeof(TreeNode), typeof(TreeNode), "Parent", "Parent", + []), + CreatePropertyInfo(typeof(TreeNode), typeof(List), "Children", "Children", + []) + ]); + + // Create a validation options with a small max depth + var validationOptions = new TestValidationOptions(new Dictionary + { + { typeof(TreeNode), nodeType } + }); + validationOptions.MaxDepth = 3; // Set a small max depth to trigger the limit + + var context = new ValidateContext + { + ValidationOptions = validationOptions, + ValidationErrors = [] + }; + + // Create a deep tree with circular references + var rootNode = new TreeNode { Name = "Root" }; + var level1 = new TreeNode { Name = "Level1", Parent = rootNode }; + var level2 = new TreeNode { Name = "Level2", Parent = level1 }; + var level3 = new TreeNode { Name = "Level3", Parent = level2 }; + var level4 = new TreeNode { Name = "" }; // Invalid: missing required name + var level5 = new TreeNode { Name = "" }; // Invalid but beyond max depth, should not be validated + + rootNode.Children.Add(level1); + level1.Children.Add(level2); + level2.Children.Add(level3); + level3.Children.Add(level4); + level4.Children.Add(level5); + + // Add a circular reference + level5.Children.Add(rootNode); + + context.ValidationContext = new ValidationContext(rootNode); + + // Act + Assert + var exception = await Assert.ThrowsAsync( + async () => await nodeType.ValidateAsync(rootNode, context, default)); + + Assert.NotNull(exception); + Assert.Equal("Maximum validation depth of 3 exceeded at 'Children[0].Parent.Children[0]' in 'TreeNode'. This is likely caused by a circular reference in the object graph. Consider increasing the MaxDepth in ValidationOptions if deeper validation is required.", exception.Message); + Assert.Equal(0, context.CurrentDepth); + } + + [Fact] + public async Task Validate_HandlesCustomValidationAttributes() + { + // Arrange + var productType = new TestValidatableTypeInfo( + typeof(Product), + [ + CreatePropertyInfo(typeof(Product), typeof(string), "SKU", "SKU", [new RequiredAttribute(), new CustomSkuValidationAttribute()]), + ]); + + var context = new ValidateContext + { + ValidationOptions = new TestValidationOptions(new Dictionary + { + { typeof(Product), productType } + }) + }; + + var product = new Product { SKU = "INVALID-SKU" }; + context.ValidationContext = new ValidationContext(product); + + // Act + await productType.ValidateAsync(product, context, default); + + // Assert + Assert.NotNull(context.ValidationErrors); + var error = Assert.Single(context.ValidationErrors); + Assert.Equal("SKU", error.Key); + Assert.Equal("SKU must start with 'PROD-'.", error.Value.First()); + } + + [Fact] + public async Task Validate_HandlesMultipleErrorsOnSameProperty() + { + // Arrange + var userType = new TestValidatableTypeInfo( + typeof(User), + [ + CreatePropertyInfo(typeof(User), typeof(string), "Password", "Password", + [ + new RequiredAttribute(), + new MinLengthAttribute(8) { ErrorMessage = "Password must be at least 8 characters." }, + new PasswordComplexityAttribute() + ]) + ]); + + var context = new ValidateContext + { + ValidationOptions = new TestValidationOptions(new Dictionary + { + { typeof(User), userType } + }) + }; + + var user = new User { Password = "abc" }; // Too short and not complex enough + context.ValidationContext = new ValidationContext(user); + + // Act + await userType.ValidateAsync(user, context, default); + + // Assert + Assert.NotNull(context.ValidationErrors); + Assert.Single(context.ValidationErrors.Keys); // Only the "Password" key + Assert.Equal(2, context.ValidationErrors["Password"].Length); // But with 2 errors + Assert.Contains("Password must be at least 8 characters.", context.ValidationErrors["Password"]); + Assert.Contains("Password must contain at least one number and one special character.", context.ValidationErrors["Password"]); + } + + [Fact] + public async Task Validate_HandlesMultiLevelInheritance() + { + // Arrange + var baseType = new TestValidatableTypeInfo( + typeof(BaseEntity), + [ + CreatePropertyInfo(typeof(BaseEntity), typeof(Guid), "Id", "Id", []) + ]); + + var intermediateType = new TestValidatableTypeInfo( + typeof(IntermediateEntity), + [ + CreatePropertyInfo(typeof(IntermediateEntity), typeof(DateTime), "CreatedAt", "CreatedAt", [new PastDateAttribute()]) + ]); + + var derivedType = new TestValidatableTypeInfo( + typeof(DerivedEntity), + [ + CreatePropertyInfo(typeof(DerivedEntity), typeof(string), "Name", "Name", [new RequiredAttribute()]) + ]); + + var context = new ValidateContext + { + ValidationOptions = new TestValidationOptions(new Dictionary + { + { typeof(BaseEntity), baseType }, + { typeof(IntermediateEntity), intermediateType }, + { typeof(DerivedEntity), derivedType } + }) + }; + + var entity = new DerivedEntity + { + Name = "", // Invalid: required + CreatedAt = DateTime.Now.AddDays(1) // Invalid: future date + }; + context.ValidationContext = new ValidationContext(entity); + + // Act + await derivedType.ValidateAsync(entity, context, default); + + // Assert + Assert.NotNull(context.ValidationErrors); + Assert.Collection(context.ValidationErrors, + kvp => + { + Assert.Equal("Name", kvp.Key); + Assert.Equal("The Name field is required.", kvp.Value.First()); + }, + kvp => + { + Assert.Equal("CreatedAt", kvp.Key); + Assert.Equal("Date must be in the past.", kvp.Value.First()); + }); + } + + [Fact] + public async Task Validate_RequiredOnPropertyShortCircuitsOtherValidations() + { + // Arrange + var userType = new TestValidatableTypeInfo( + typeof(User), + [ + CreatePropertyInfo(typeof(User), typeof(string), "Password", "Password", + [new RequiredAttribute(), new PasswordComplexityAttribute()]) + ]); + + var context = new ValidateContext + { + ValidationOptions = new TestValidationOptions(new Dictionary + { + { typeof(User), userType } + }) + }; + + var user = new User { Password = null }; // Invalid: required + context.ValidationContext = new ValidationContext(user); + + // Act + await userType.ValidateAsync(user, context, default); + + // Assert + Assert.NotNull(context.ValidationErrors); + Assert.Single(context.ValidationErrors.Keys); + var error = Assert.Single(context.ValidationErrors); + Assert.Equal("Password", error.Key); + Assert.Equal("The Password field is required.", error.Value.Single()); + } + + private ValidatablePropertyInfo CreatePropertyInfo( + Type containingType, + Type propertyType, + string name, + string displayName, + ValidationAttribute[] validationAttributes) + { + return new TestValidatablePropertyInfo( + containingType, + propertyType, + name, + displayName, + validationAttributes); + } + + // Test model classes + private class Person + { + public string? Name { get; set; } + public int Age { get; set; } + public Address? Address { get; set; } + } + + private class Address + { + public string? Street { get; set; } + public string? City { get; set; } + } + + private class Employee : IValidatableObject + { + public string? Name { get; set; } + public string? Department { get; set; } + public decimal Salary { get; set; } + + public IEnumerable Validate(ValidationContext validationContext) + { + if (Salary < 0) + { + yield return new ValidationResult("Salary must be a positive value.", new[] { nameof(Salary) }); + } + } + } + + private class Vehicle + { + public string? Make { get; set; } + public string? Model { get; set; } + } + + private class Car : Vehicle + { + public int Doors { get; set; } + } + + private class Order + { + public string? OrderNumber { get; set; } + public List Items { get; set; } = []; + } + + private class OrderItem + { + public string? ProductName { get; set; } + public int Quantity { get; set; } + } + + private class TreeNode + { + public string Name { get; set; } = string.Empty; + public TreeNode? Parent { get; set; } + public List Children { get; set; } = []; + } + + private class Product + { + public string SKU { get; set; } = string.Empty; + } + + private class User + { + public string? Password { get; set; } = string.Empty; + } + + private class BaseEntity + { + public Guid Id { get; set; } = Guid.NewGuid(); + } + + private class IntermediateEntity : BaseEntity + { + public DateTime CreatedAt { get; set; } + } + + private class DerivedEntity : IntermediateEntity + { + public string Name { get; set; } = string.Empty; + } + + private class PastDateAttribute : ValidationAttribute + { + protected override ValidationResult? IsValid(object? value, ValidationContext validationContext) + { + if (value is DateTime date && date > DateTime.Now) + { + return new ValidationResult("Date must be in the past."); + } + + return ValidationResult.Success; + } + } + + private class CustomSkuValidationAttribute : ValidationAttribute + { + protected override ValidationResult? IsValid(object? value, ValidationContext validationContext) + { + if (value is string sku && !sku.StartsWith("PROD-", StringComparison.Ordinal)) + { + return new ValidationResult("SKU must start with 'PROD-'."); + } + + return ValidationResult.Success; + } + } + + private class PasswordComplexityAttribute : ValidationAttribute + { + protected override ValidationResult? IsValid(object? value, ValidationContext validationContext) + { + if (value is string password) + { + var hasDigit = password.Any(c => char.IsDigit(c)); + var hasSpecial = password.Any(c => !char.IsLetterOrDigit(c)); + + if (!hasDigit || !hasSpecial) + { + return new ValidationResult("Password must contain at least one number and one special character."); + } + } + + return ValidationResult.Success; + } + } + + // Test implementations + private class TestValidatablePropertyInfo : ValidatablePropertyInfo + { + private readonly ValidationAttribute[] _validationAttributes; + + public TestValidatablePropertyInfo( + Type containingType, + Type propertyType, + string name, + string displayName, + ValidationAttribute[] validationAttributes) + : base(containingType, propertyType, name, displayName) + { + _validationAttributes = validationAttributes; + } + + protected override ValidationAttribute[] GetValidationAttributes() => _validationAttributes; + } + + private class TestValidatableTypeInfo : ValidatableTypeInfo + { + public TestValidatableTypeInfo( + Type type, + ValidatablePropertyInfo[] members) + : base(type, members) + { + } + } + + private class TestValidationOptions : ValidationOptions + { + public TestValidationOptions(Dictionary typeInfoMappings) + { + // Create a custom resolver that uses the dictionary + var resolver = new DictionaryBasedResolver(typeInfoMappings); + + // Add it to the resolvers collection + Resolvers.Add(resolver); + } + + // Private resolver implementation that uses a dictionary lookup + private class DictionaryBasedResolver : IValidatableInfoResolver + { + private readonly Dictionary _typeInfoMappings; + + public DictionaryBasedResolver(Dictionary typeInfoMappings) + { + _typeInfoMappings = typeInfoMappings; + } + + public bool TryGetValidatableTypeInfo(Type type, [NotNullWhen(true)] out IValidatableInfo? validatableInfo) + { + if (_typeInfoMappings.TryGetValue(type, out var info)) + { + validatableInfo = info; + return true; + } + validatableInfo = null; + return false; + } + + public bool TryGetValidatableParameterInfo(ParameterInfo parameterInfo, [NotNullWhen(true)] out IValidatableInfo? validatableInfo) + { + validatableInfo = null; + return false; + } + } + } +} diff --git a/src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.ValidationsGenerator/Emitters/ValidationsGenerator.Emitter.cs b/src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.ValidationsGenerator/Emitters/ValidationsGenerator.Emitter.cs new file mode 100644 index 000000000000..8a9127bdc89d --- /dev/null +++ b/src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.ValidationsGenerator/Emitters/ValidationsGenerator.Emitter.cs @@ -0,0 +1,218 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Immutable; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.Text; +using System.Text; +using Microsoft.CodeAnalysis.CSharp; +using System.IO; + +namespace Microsoft.AspNetCore.Http.ValidationsGenerator; + +public sealed partial class ValidationsGenerator : IIncrementalGenerator +{ + public static string GeneratedCodeConstructor => $@"global::System.CodeDom.Compiler.GeneratedCodeAttribute(""{typeof(ValidationsGenerator).Assembly.FullName}"", ""{typeof(ValidationsGenerator).Assembly.GetName().Version}"")"; + public static string GeneratedCodeAttribute => $"[{GeneratedCodeConstructor}]"; + + internal static void Emit(SourceProductionContext context, (InterceptableLocation? AddValidation, ImmutableArray ValidatableTypes) emitInputs) + { + if (emitInputs.AddValidation is null) + { + // Avoid generating code if no AddValidation call was found. + return; + } + var source = Emit(emitInputs.AddValidation, emitInputs.ValidatableTypes); + context.AddSource("ValidatableInfoResolver.g.cs", SourceText.From(source, Encoding.UTF8)); + } + + private static string Emit(InterceptableLocation addValidation, ImmutableArray validatableTypes) => $$""" +#nullable enable annotations +//------------------------------------------------------------------------------ +// +// This code was generated by a tool. +// +// Changes to this file may cause incorrect behavior and will be lost if +// the code is regenerated. +// +//------------------------------------------------------------------------------ +#nullable enable + +namespace System.Runtime.CompilerServices +{ + {{GeneratedCodeAttribute}} + [AttributeUsage(AttributeTargets.Method, AllowMultiple = true)] + file sealed class InterceptsLocationAttribute : System.Attribute + { + public InterceptsLocationAttribute(int version, string data) + { + } + } +} + +namespace Microsoft.AspNetCore.Http.Validation.Generated +{ + {{GeneratedCodeAttribute}} + file sealed class GeneratedValidatablePropertyInfo : global::Microsoft.AspNetCore.Http.Validation.ValidatablePropertyInfo + { + public GeneratedValidatablePropertyInfo( + global::System.Type containingType, + global::System.Type propertyType, + string name, + string displayName) : base(containingType, propertyType, name, displayName) + { + ContainingType = containingType; + Name = name; + } + + internal global::System.Type ContainingType { get; } + internal string Name { get; } + + protected override global::System.ComponentModel.DataAnnotations.ValidationAttribute[] GetValidationAttributes() + => ValidationAttributeCache.GetValidationAttributes(ContainingType, Name); + } + + {{GeneratedCodeAttribute}} + file sealed class GeneratedValidatableTypeInfo : global::Microsoft.AspNetCore.Http.Validation.ValidatableTypeInfo + { + public GeneratedValidatableTypeInfo( + global::System.Type type, + ValidatablePropertyInfo[] members) : base(type, members) { } + } + + {{GeneratedCodeAttribute}} + file class GeneratedValidatableInfoResolver : global::Microsoft.AspNetCore.Http.Validation.IValidatableInfoResolver + { + public bool TryGetValidatableTypeInfo(global::System.Type type, [global::System.Diagnostics.CodeAnalysis.NotNullWhen(true)] out global::Microsoft.AspNetCore.Http.Validation.IValidatableInfo? validatableInfo) + { + validatableInfo = null; +{{EmitTypeChecks(validatableTypes)}} + return false; + } + + // No-ops, rely on runtime code for ParameterInfo-based resolution + public bool TryGetValidatableParameterInfo(global::System.Reflection.ParameterInfo parameterInfo, [global::System.Diagnostics.CodeAnalysis.NotNullWhen(true)] out global::Microsoft.AspNetCore.Http.Validation.IValidatableInfo? validatableInfo) + { + validatableInfo = null; + return false; + } + +{{EmitCreateMethods(validatableTypes)}} + } + + {{GeneratedCodeAttribute}} + file static class GeneratedServiceCollectionExtensions + { + {{addValidation.GetInterceptsLocationAttributeSyntax()}} + public static global::Microsoft.Extensions.DependencyInjection.IServiceCollection AddValidation(this global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, global::System.Action? configureOptions = null) + { + // Use non-extension method to avoid infinite recursion. + return global::Microsoft.Extensions.DependencyInjection.ValidationServiceCollectionExtensions.AddValidation(services, options => + { + options.Resolvers.Insert(0, new GeneratedValidatableInfoResolver()); + if (configureOptions is not null) + { + configureOptions(options); + } + }); + } + } + + {{GeneratedCodeAttribute}} + file static class ValidationAttributeCache + { + private sealed record CacheKey(global::System.Type ContainingType, string PropertyName); + private static readonly global::System.Collections.Concurrent.ConcurrentDictionary _cache = new(); + + public static global::System.ComponentModel.DataAnnotations.ValidationAttribute[] GetValidationAttributes( + global::System.Type containingType, + string propertyName) + { + var key = new CacheKey(containingType, propertyName); + return _cache.GetOrAdd(key, static k => + { + var property = k.ContainingType.GetProperty(k.PropertyName); + if (property == null) + { + return []; + } + + return [.. global::System.Reflection.CustomAttributeExtensions.GetCustomAttributes(property, inherit: true)]; + }); + } + } +} +"""; + + private static string EmitTypeChecks(ImmutableArray validatableTypes) + { + var sw = new StringWriter(); + var cw = new CodeWriter(sw, baseIndent: 3); + foreach (var validatableType in validatableTypes) + { + var typeName = validatableType.Type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); + cw.WriteLine($"if (type == typeof({typeName}))"); + cw.StartBlock(); + cw.WriteLine($"validatableInfo = Create{SanitizeTypeName(validatableType.Type.MetadataName)}();"); + cw.WriteLine("return true;"); + cw.EndBlock(); + } + return sw.ToString(); + } + + private static string EmitCreateMethods(ImmutableArray validatableTypes) + { + var sw = new StringWriter(); + var cw = new CodeWriter(sw, baseIndent: 2); + foreach (var validatableType in validatableTypes) + { + cw.WriteLine($@"private ValidatableTypeInfo Create{SanitizeTypeName(validatableType.Type.MetadataName)}()"); + cw.StartBlock(); + cw.WriteLine("return new GeneratedValidatableTypeInfo("); + cw.Indent++; + cw.WriteLine($"type: typeof({validatableType.Type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)}),"); + if (validatableType.Members.IsDefaultOrEmpty) + { + cw.WriteLine("members: []"); + } + else + { + cw.WriteLine("members: ["); + cw.Indent++; + foreach (var member in validatableType.Members) + { + EmitValidatableMemberForCreate(member, cw); + } + cw.Indent--; + cw.WriteLine("]"); + } + cw.Indent--; + cw.WriteLine(");"); + cw.EndBlock(); + } + return sw.ToString(); + } + + private static void EmitValidatableMemberForCreate(ValidatableProperty member, CodeWriter cw) + { + cw.WriteLine("new GeneratedValidatablePropertyInfo("); + cw.Indent++; + cw.WriteLine($"containingType: typeof({member.ContainingType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)}),"); + cw.WriteLine($"propertyType: typeof({member.Type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)}),"); + cw.WriteLine($"name: \"{member.Name}\","); + cw.WriteLine($"displayName: \"{member.DisplayName}\""); + cw.Indent--; + cw.WriteLine("),"); + } + + private static string SanitizeTypeName(string typeName) + { + // Replace invalid characters with underscores and remove generic notation + return typeName + .Replace(".", "_") + .Replace("<", "_") + .Replace(">", "_") + .Replace(",", "_") + .Replace(" ", "_"); + } +} diff --git a/src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.ValidationsGenerator/Extensions/ISymbolExtensions.cs b/src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.ValidationsGenerator/Extensions/ISymbolExtensions.cs new file mode 100644 index 000000000000..54efe204c1ec --- /dev/null +++ b/src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.ValidationsGenerator/Extensions/ISymbolExtensions.cs @@ -0,0 +1,32 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Linq; +using Microsoft.CodeAnalysis; + +namespace Microsoft.AspNetCore.Http.ValidationsGenerator; + +internal static class ISymbolExtensions +{ + public static string GetDisplayName(this ISymbol property, INamedTypeSymbol displayAttribute) + { + var displayNameAttribute = property.GetAttributes() + .FirstOrDefault(attribute => + attribute.AttributeClass is { } attributeClass && + SymbolEqualityComparer.Default.Equals(attributeClass, displayAttribute)); + if (displayNameAttribute is not null) + { + if (displayNameAttribute.ConstructorArguments.Length > 0) + { + return displayNameAttribute.ConstructorArguments[0].Value?.ToString() ?? property.Name; + } + else if (displayNameAttribute.NamedArguments.Length > 0) + { + return displayNameAttribute.NamedArguments[0].Value.Value?.ToString() ?? property.Name; + } + return property.Name; + } + + return property.Name; + } +} diff --git a/src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.ValidationsGenerator/Extensions/ITypeSymbolExtensions.cs b/src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.ValidationsGenerator/Extensions/ITypeSymbolExtensions.cs new file mode 100644 index 000000000000..a09a575e2782 --- /dev/null +++ b/src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.ValidationsGenerator/Extensions/ITypeSymbolExtensions.cs @@ -0,0 +1,104 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Immutable; +using Microsoft.CodeAnalysis; + +namespace Microsoft.AspNetCore.Http.ValidationsGenerator; + +internal static class ITypeSymbolExtensions +{ + public static bool IsEnumerable(this ITypeSymbol type, INamedTypeSymbol enumerable) + { + if (type.SpecialType == SpecialType.System_String) + { + return false; + } + + return type.ImplementsInterface(enumerable) || SymbolEqualityComparer.Default.Equals(type, enumerable); + } + + public static bool ImplementsValidationAttribute(this ITypeSymbol typeSymbol, INamedTypeSymbol validationAttributeSymbol) + { + var baseType = typeSymbol.BaseType; + while (baseType != null) + { + if (SymbolEqualityComparer.Default.Equals(baseType, validationAttributeSymbol)) + { + return true; + } + baseType = baseType.BaseType; + } + + return false; + } + + public static ITypeSymbol UnwrapType(this ITypeSymbol type, INamedTypeSymbol enumerable) + { + if (type.OriginalDefinition.SpecialType == SpecialType.System_Nullable_T && + type is INamedTypeSymbol { TypeArguments.Length: 1 }) + { + // Extract the T from a Nullable + type = ((INamedTypeSymbol)type).TypeArguments[0]; + } + + if (type.NullableAnnotation == NullableAnnotation.Annotated) + { + // Extract the underlying type from a reference type + type = type.OriginalDefinition; + } + + if (type is INamedTypeSymbol namedType && namedType.IsEnumerable(enumerable) && namedType.TypeArguments.Length == 1) + { + // Extract the T from an IEnumerable or List + type = namedType.TypeArguments[0]; + } + + return type; + } + + internal static bool ImplementsInterface(this ITypeSymbol type, ITypeSymbol interfaceType) + { + foreach (var iface in type.AllInterfaces) + { + if (SymbolEqualityComparer.Default.Equals(interfaceType, iface)) + { + return true; + } + } + return false; + } + + internal static ImmutableArray? GetJsonDerivedTypes(this ITypeSymbol type, INamedTypeSymbol jsonDerivedTypeAttribute) + { + var derivedTypes = ImmutableArray.CreateBuilder(); + foreach (var attribute in type.GetAttributes()) + { + if (SymbolEqualityComparer.Default.Equals(attribute.AttributeClass, jsonDerivedTypeAttribute)) + { + var derivedType = (INamedTypeSymbol?)attribute.ConstructorArguments[0].Value; + if (derivedType is not null && !SymbolEqualityComparer.Default.Equals(derivedType, type)) + { + derivedTypes.Add(derivedType); + } + } + } + + return derivedTypes.Count == 0 ? null : derivedTypes.ToImmutable(); + } + + // Types exempted here have special binding rules in RDF and RDG and are not validatable + // types themselves so we short-circuit on them. + internal static bool IsExemptType(this ITypeSymbol type, RequiredSymbols requiredSymbols) + { + return SymbolEqualityComparer.Default.Equals(type, requiredSymbols.HttpContext) + || SymbolEqualityComparer.Default.Equals(type, requiredSymbols.HttpRequest) + || SymbolEqualityComparer.Default.Equals(type, requiredSymbols.HttpResponse) + || SymbolEqualityComparer.Default.Equals(type, requiredSymbols.CancellationToken) + || SymbolEqualityComparer.Default.Equals(type, requiredSymbols.IFormCollection) + || SymbolEqualityComparer.Default.Equals(type, requiredSymbols.IFormFileCollection) + || SymbolEqualityComparer.Default.Equals(type, requiredSymbols.IFormFile) + || SymbolEqualityComparer.Default.Equals(type, requiredSymbols.Stream) + || SymbolEqualityComparer.Default.Equals(type, requiredSymbols.PipeReader); + } +} diff --git a/src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.ValidationsGenerator/Extensions/IncrementalValuesProviderExtensions.cs b/src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.ValidationsGenerator/Extensions/IncrementalValuesProviderExtensions.cs new file mode 100644 index 000000000000..bfbbd9369b62 --- /dev/null +++ b/src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.ValidationsGenerator/Extensions/IncrementalValuesProviderExtensions.cs @@ -0,0 +1,112 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Collections.Immutable; +using System.Linq; +using Microsoft.CodeAnalysis; + +namespace Microsoft.AspNetCore.Http.ValidationsGenerator; + +internal static class IncrementalValuesProviderExtensions +{ + public static IncrementalValuesProvider Distinct(this IncrementalValuesProvider source, IEqualityComparer comparer) + { + return source + .Collect() + .WithComparer(ImmutableArrayEqualityComparer.Instance) + .SelectMany((values, cancellationToken) => + { + if (values.IsEmpty) + { + return values; + } + + var results = ImmutableArray.CreateBuilder(values.Length); + HashSet set = new(comparer); + + foreach (var value in values) + { + if (set.Add(value)) + { + results.Add(value); + } + } + + return results.DrainToImmutable(); + }); + } + + public static IncrementalValuesProvider Concat( + this IncrementalValuesProvider> first, + IncrementalValuesProvider> second) + { + return first.Collect() + .Combine(second.Collect()) + .SelectMany((tuple, _) => + { + if (tuple.Left.IsEmpty && tuple.Right.IsEmpty) + { + return []; + } + + var results = ImmutableArray.CreateBuilder(tuple.Left.Length + tuple.Right.Length); + for (var i = 0; i < tuple.Left.Length; i++) + { + results.AddRange(tuple.Left[i]); + } + for (var i = 0; i < tuple.Right.Length; i++) + { + results.AddRange(tuple.Right[i]); + } + return results.DrainToImmutable(); + }); + } + + private sealed class ImmutableArrayEqualityComparer : IEqualityComparer> + { + public static readonly ImmutableArrayEqualityComparer Instance = new(); + + public bool Equals(ImmutableArray x, ImmutableArray y) + { + if (x.IsDefault) + { + return y.IsDefault; + } + else if (y.IsDefault) + { + return false; + } + + if (x.Length != y.Length) + { + return false; + } + + for (var i = 0; i < x.Length; i++) + { + if (!EqualityComparer.Default.Equals(x[i], y[i])) + { + return false; + } + } + + return true; + } + + public int GetHashCode(ImmutableArray obj) + { + if (obj.IsDefault) + { + return 0; + } + var hashCode = -450793227; + foreach (var item in obj) + { + hashCode = (hashCode * -1521134295) + EqualityComparer.Default.GetHashCode(item); + } + + return hashCode; + } + } +} diff --git a/src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.ValidationsGenerator/Models/RequiredSymbols.cs b/src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.ValidationsGenerator/Models/RequiredSymbols.cs new file mode 100644 index 000000000000..ea16b1de1490 --- /dev/null +++ b/src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.ValidationsGenerator/Models/RequiredSymbols.cs @@ -0,0 +1,25 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Microsoft.CodeAnalysis; + +namespace Microsoft.AspNetCore.Http.ValidationsGenerator; + +internal sealed record class RequiredSymbols( + INamedTypeSymbol DisplayAttribute, + INamedTypeSymbol ValidationAttribute, + INamedTypeSymbol IEnumerable, + INamedTypeSymbol IValidatableObject, + INamedTypeSymbol JsonDerivedTypeAttribute, + INamedTypeSymbol RequiredAttribute, + INamedTypeSymbol CustomValidationAttribute, + INamedTypeSymbol HttpContext, + INamedTypeSymbol HttpRequest, + INamedTypeSymbol HttpResponse, + INamedTypeSymbol CancellationToken, + INamedTypeSymbol IFormCollection, + INamedTypeSymbol IFormFileCollection, + INamedTypeSymbol IFormFile, + INamedTypeSymbol Stream, + INamedTypeSymbol PipeReader +); diff --git a/src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.ValidationsGenerator/Models/ValidatableProperty.cs b/src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.ValidationsGenerator/Models/ValidatableProperty.cs new file mode 100644 index 000000000000..658f27a82e6b --- /dev/null +++ b/src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.ValidationsGenerator/Models/ValidatableProperty.cs @@ -0,0 +1,15 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Immutable; +using Microsoft.CodeAnalysis; + +namespace Microsoft.AspNetCore.Http.ValidationsGenerator; + +internal sealed record class ValidatableProperty( + ITypeSymbol ContainingType, + ITypeSymbol Type, + string Name, + string DisplayName, + ImmutableArray Attributes +); diff --git a/src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.ValidationsGenerator/Models/ValidatableType.cs b/src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.ValidationsGenerator/Models/ValidatableType.cs new file mode 100644 index 000000000000..c6d7e36f36a9 --- /dev/null +++ b/src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.ValidationsGenerator/Models/ValidatableType.cs @@ -0,0 +1,12 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Immutable; +using Microsoft.CodeAnalysis; + +namespace Microsoft.AspNetCore.Http.ValidationsGenerator; + +internal sealed record class ValidatableType( + ITypeSymbol Type, + ImmutableArray Members +); diff --git a/src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.ValidationsGenerator/Models/ValidatableTypeComparer.cs b/src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.ValidationsGenerator/Models/ValidatableTypeComparer.cs new file mode 100644 index 000000000000..fcd99f51dc0b --- /dev/null +++ b/src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.ValidationsGenerator/Models/ValidatableTypeComparer.cs @@ -0,0 +1,30 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using Microsoft.CodeAnalysis; + +namespace Microsoft.AspNetCore.Http.ValidationsGenerator; + +internal sealed class ValidatableTypeComparer : IEqualityComparer +{ + public static ValidatableTypeComparer Instance { get; } = new(); + + public bool Equals(ValidatableType? x, ValidatableType? y) + { + if (x is null && y is null) + { + return true; + } + if (x is null || y is null) + { + return false; + } + return SymbolEqualityComparer.Default.Equals(x.Type, y.Type); + } + + public int GetHashCode(ValidatableType? obj) + { + return SymbolEqualityComparer.Default.GetHashCode(obj?.Type); + } +} diff --git a/src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.ValidationsGenerator/Models/ValidationAttribute.cs b/src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.ValidationsGenerator/Models/ValidationAttribute.cs new file mode 100644 index 000000000000..c29e12a99c0d --- /dev/null +++ b/src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.ValidationsGenerator/Models/ValidationAttribute.cs @@ -0,0 +1,14 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; + +namespace Microsoft.AspNetCore.Http.ValidationsGenerator; + +internal sealed record class ValidationAttribute( + string Name, + string ClassName, + List Arguments, + Dictionary NamedArguments, + bool IsCustomValidationAttribute +); diff --git a/src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.ValidationsGenerator/Parsers/ValidationsGenerator.AddValidation.cs b/src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.ValidationsGenerator/Parsers/ValidationsGenerator.AddValidation.cs new file mode 100644 index 000000000000..fbf7673b19fd --- /dev/null +++ b/src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.ValidationsGenerator/Parsers/ValidationsGenerator.AddValidation.cs @@ -0,0 +1,31 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Threading; +using Microsoft.AspNetCore.Http.RequestDelegateGenerator.StaticRouteHandlerModel; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; + +namespace Microsoft.AspNetCore.Http.ValidationsGenerator; + +public sealed partial class ValidationsGenerator : IIncrementalGenerator +{ + internal bool FindAddValidation(SyntaxNode syntaxNode, CancellationToken cancellationToken) + { + if (syntaxNode is InvocationExpressionSyntax + && syntaxNode.TryGetMapMethodName(out var method) + && method == "AddValidation") + { + return true; + } + return false; + } + + internal InterceptableLocation? TransformAddValidation(GeneratorSyntaxContext context, CancellationToken cancellationToken) + { + var node = (InvocationExpressionSyntax)context.Node; + var semanticModel = context.SemanticModel; + return semanticModel.GetInterceptableLocation(node, cancellationToken); + } +} diff --git a/src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.ValidationsGenerator/Parsers/ValidationsGenerator.AttributeParser.cs b/src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.ValidationsGenerator/Parsers/ValidationsGenerator.AttributeParser.cs new file mode 100644 index 000000000000..5bea9a8ad218 --- /dev/null +++ b/src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.ValidationsGenerator/Parsers/ValidationsGenerator.AttributeParser.cs @@ -0,0 +1,30 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Collections.Immutable; +using System.Threading; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp.Syntax; + +namespace Microsoft.AspNetCore.Http.ValidationsGenerator; + +public sealed partial class ValidationsGenerator : IIncrementalGenerator +{ + internal static bool ShouldTransformSymbolWithAttribute(SyntaxNode syntaxNode, CancellationToken cancellationToken) + { + return syntaxNode is ClassDeclarationSyntax; + } + + internal ImmutableArray TransformValidatableTypeWithAttribute(GeneratorAttributeSyntaxContext context, CancellationToken cancellationToken) + { + var validatableTypes = new HashSet(ValidatableTypeComparer.Instance); + List visitedTypes = []; + var requiredSymbols = ExtractRequiredSymbols(context.SemanticModel.Compilation, cancellationToken); + if (TryExtractValidatableType((ITypeSymbol)context.TargetSymbol, requiredSymbols, ref validatableTypes, ref visitedTypes)) + { + return [..validatableTypes]; + } + return []; + } +} diff --git a/src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.ValidationsGenerator/Parsers/ValidationsGenerator.EndpointsParser.cs b/src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.ValidationsGenerator/Parsers/ValidationsGenerator.EndpointsParser.cs new file mode 100644 index 000000000000..da325a562d1f --- /dev/null +++ b/src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.ValidationsGenerator/Parsers/ValidationsGenerator.EndpointsParser.cs @@ -0,0 +1,47 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Immutable; +using System.Linq; +using System.Threading; +using Microsoft.AspNetCore.Analyzers.Infrastructure; +using Microsoft.AspNetCore.Http.RequestDelegateGenerator.StaticRouteHandlerModel; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using Microsoft.CodeAnalysis.Operations; + +namespace Microsoft.AspNetCore.Http.ValidationsGenerator; + +public sealed partial class ValidationsGenerator : IIncrementalGenerator +{ + internal bool FindEndpoints(SyntaxNode syntaxNode, CancellationToken cancellationToken) + { + if (syntaxNode is InvocationExpressionSyntax + && syntaxNode.TryGetMapMethodName(out var method)) + { + return method == "MapMethods" || InvocationOperationExtensions.KnownMethods.Contains(method); + } + return false; + } + + internal IInvocationOperation? TransformEndpoints(GeneratorSyntaxContext context, CancellationToken cancellationToken) + { + if (context.Node is not InvocationExpressionSyntax node) + { + return null; + } + var operation = context.SemanticModel.GetOperation(node, cancellationToken); + AnalyzerDebug.Assert(operation != null, "Operation should not be null."); + return operation is IInvocationOperation invocationOperation + ? invocationOperation + : null; + } + + internal ImmutableArray ExtractValidatableEndpoint((IInvocationOperation? Operation, RequiredSymbols RequiredSymbols) input, CancellationToken cancellationToken) + { + AnalyzerDebug.Assert(input.Operation != null, "Operation should not be null."); + var validatableTypes = ExtractValidatableTypes(input.Operation, input.RequiredSymbols); + return validatableTypes; + } +} diff --git a/src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.ValidationsGenerator/Parsers/ValidationsGenerator.RequiredSymbolsParser.cs b/src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.ValidationsGenerator/Parsers/ValidationsGenerator.RequiredSymbolsParser.cs new file mode 100644 index 000000000000..ca1486a518aa --- /dev/null +++ b/src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.ValidationsGenerator/Parsers/ValidationsGenerator.RequiredSymbolsParser.cs @@ -0,0 +1,32 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Threading; +using Microsoft.CodeAnalysis; + +namespace Microsoft.AspNetCore.Http.ValidationsGenerator; + +public sealed partial class ValidationsGenerator : IIncrementalGenerator +{ + internal RequiredSymbols ExtractRequiredSymbols(Compilation compilation, CancellationToken cancellationToken) + { + return new RequiredSymbols( + compilation.GetTypeByMetadataName("System.ComponentModel.DataAnnotations.DisplayAttribute")!, + compilation.GetTypeByMetadataName("System.ComponentModel.DataAnnotations.ValidationAttribute")!, + compilation.GetTypeByMetadataName("System.Collections.IEnumerable")!, + compilation.GetTypeByMetadataName("System.ComponentModel.DataAnnotations.IValidatableObject")!, + compilation.GetTypeByMetadataName("System.Text.Json.Serialization.JsonDerivedTypeAttribute")!, + compilation.GetTypeByMetadataName("System.ComponentModel.DataAnnotations.RequiredAttribute")!, + compilation.GetTypeByMetadataName("System.ComponentModel.DataAnnotations.CustomValidationAttribute")!, + compilation.GetTypeByMetadataName("Microsoft.AspNetCore.Http.HttpContext")!, + compilation.GetTypeByMetadataName("Microsoft.AspNetCore.Http.HttpRequest")!, + compilation.GetTypeByMetadataName("Microsoft.AspNetCore.Http.HttpResponse")!, + compilation.GetTypeByMetadataName("System.Threading.CancellationToken")!, + compilation.GetTypeByMetadataName("Microsoft.AspNetCore.Http.IFormCollection")!, + compilation.GetTypeByMetadataName("Microsoft.AspNetCore.Http.IFormFileCollection")!, + compilation.GetTypeByMetadataName("Microsoft.AspNetCore.Http.IFormFile")!, + compilation.GetTypeByMetadataName("System.IO.Stream")!, + compilation.GetTypeByMetadataName("System.IO.Pipelines.PipeReader")! + ); + } +} diff --git a/src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.ValidationsGenerator/Parsers/ValidationsGenerator.TypesParser.cs b/src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.ValidationsGenerator/Parsers/ValidationsGenerator.TypesParser.cs new file mode 100644 index 000000000000..ecba1bd124ef --- /dev/null +++ b/src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.ValidationsGenerator/Parsers/ValidationsGenerator.TypesParser.cs @@ -0,0 +1,134 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Collections.Immutable; +using System.Linq; +using Microsoft.AspNetCore.Analyzers.Infrastructure; +using Microsoft.AspNetCore.Http.RequestDelegateGenerator.StaticRouteHandlerModel; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.Operations; + +namespace Microsoft.AspNetCore.Http.ValidationsGenerator; + +public sealed partial class ValidationsGenerator : IIncrementalGenerator +{ + private static readonly SymbolDisplayFormat _symbolDisplayFormat = new( + globalNamespaceStyle: SymbolDisplayGlobalNamespaceStyle.Included, + typeQualificationStyle: SymbolDisplayTypeQualificationStyle.NameAndContainingTypesAndNamespaces); + + internal ImmutableArray ExtractValidatableTypes(IInvocationOperation operation, RequiredSymbols requiredSymbols) + { + AnalyzerDebug.Assert(operation.SemanticModel != null, "SemanticModel should not be null."); + var parameters = operation.TryGetRouteHandlerMethod(operation.SemanticModel, out var method) + ? method.Parameters + : []; + var validatableTypes = new HashSet(ValidatableTypeComparer.Instance); + List visitedTypes = []; + foreach (var parameter in parameters) + { + _ = TryExtractValidatableType(parameter.Type.UnwrapType(requiredSymbols.IEnumerable), requiredSymbols, ref validatableTypes, ref visitedTypes); + } + return [.. validatableTypes]; + } + + internal bool TryExtractValidatableType(ITypeSymbol typeSymbol, RequiredSymbols requiredSymbols, ref HashSet validatableTypes, ref List visitedTypes) + { + if (typeSymbol.SpecialType != SpecialType.None) + { + return false; + } + + if (visitedTypes.Contains(typeSymbol)) + { + return true; + } + + if (typeSymbol.IsExemptType(requiredSymbols)) + { + return false; + } + + visitedTypes.Add(typeSymbol); + + // Extract validatable types discovered in base types of this type and add them to the top-level list. + var current = typeSymbol.BaseType; + var hasValidatableBaseType = false; + while (current != null && current.SpecialType != SpecialType.System_Object) + { + hasValidatableBaseType |= TryExtractValidatableType(current, requiredSymbols, ref validatableTypes, ref visitedTypes); + current = current.BaseType; + } + + // Extract validatable types discovered in members of this type and add them to the top-level list. + var members = ExtractValidatableMembers(typeSymbol, requiredSymbols, ref validatableTypes, ref visitedTypes); + + // Extract the validatable types discovered in the JsonDerivedTypeAttributes of this type and add them to the top-level list. + var derivedTypes = typeSymbol.GetJsonDerivedTypes(requiredSymbols.JsonDerivedTypeAttribute); + var hasValidatableDerivedTypes = false; + foreach (var derivedType in derivedTypes ?? []) + { + hasValidatableDerivedTypes |= TryExtractValidatableType(derivedType, requiredSymbols, ref validatableTypes, ref visitedTypes); + } + + // No validatable members or derived types found, so we don't need to add this type. + if (members.IsDefaultOrEmpty && !hasValidatableBaseType && !hasValidatableDerivedTypes) + { + return false; + } + + // Add the type itself as a validatable type itself. + validatableTypes.Add(new ValidatableType( + Type: typeSymbol, + Members: members)); + + return true; + } + + internal ImmutableArray ExtractValidatableMembers(ITypeSymbol typeSymbol, RequiredSymbols requiredSymbols, ref HashSet validatableTypes, ref List visitedTypes) + { + var members = new List(); + foreach (var member in typeSymbol.GetMembers().OfType()) + { + var hasValidatableType = TryExtractValidatableType(member.Type.UnwrapType(requiredSymbols.IEnumerable), requiredSymbols, ref validatableTypes, ref visitedTypes); + var attributes = ExtractValidationAttributes(member, requiredSymbols, out var isRequired); + // If the member has no validation attributes or validatable types and is not required, skip it. + if (attributes.IsDefaultOrEmpty && !hasValidatableType && !isRequired) + { + continue; + } + members.Add(new ValidatableProperty( + ContainingType: member.ContainingType, + Type: member.Type, + Name: member.Name, + DisplayName: member.GetDisplayName(requiredSymbols.DisplayAttribute), + Attributes: attributes)); + } + + return [.. members]; + } + + internal static ImmutableArray ExtractValidationAttributes(ISymbol symbol, RequiredSymbols requiredSymbols, out bool isRequired) + { + var attributes = symbol.GetAttributes(); + if (attributes.Length == 0) + { + isRequired = false; + return []; + } + + var validationAttributes = attributes + .Where(attribute => attribute.AttributeClass != null) + .Where(attribute => attribute.AttributeClass!.ImplementsValidationAttribute(requiredSymbols.ValidationAttribute)); + isRequired = validationAttributes.Any(attr => SymbolEqualityComparer.Default.Equals(attr.AttributeClass, requiredSymbols.RequiredAttribute)); + return [.. validationAttributes + .Where(attr => !SymbolEqualityComparer.Default.Equals(attr.AttributeClass, requiredSymbols.ValidationAttribute)) + .Select(attribute => new ValidationAttribute( + Name: symbol.Name + attribute.AttributeClass!.Name, + ClassName: attribute.AttributeClass!.ToDisplayString(_symbolDisplayFormat), + Arguments: [.. attribute.ConstructorArguments.Select(a => a.ToCSharpString())], + NamedArguments: attribute.NamedArguments.ToDictionary(namedArgument => namedArgument.Key, namedArgument => namedArgument.Value.ToCSharpString()), + IsCustomValidationAttribute: SymbolEqualityComparer.Default.Equals(attribute.AttributeClass, requiredSymbols.CustomValidationAttribute)))]; + } +} diff --git a/src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.ValidationsGenerator/ValidationsGenerator.cs b/src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.ValidationsGenerator/ValidationsGenerator.cs index ac7be3762c0d..4949ed71825f 100644 --- a/src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.ValidationsGenerator/ValidationsGenerator.cs +++ b/src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.ValidationsGenerator/ValidationsGenerator.cs @@ -1,14 +1,52 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Linq; using Microsoft.CodeAnalysis; namespace Microsoft.AspNetCore.Http.ValidationsGenerator; +[Generator] public sealed partial class ValidationsGenerator : IIncrementalGenerator { public void Initialize(IncrementalGeneratorInitializationContext context) { - return; + // Resolve the symbols that will be required when making comparisons + // in future steps. + var requiredSymbols = context.CompilationProvider.Select(ExtractRequiredSymbols); + + // Find the builder.Services.AddValidation() call in the application. + var addValidation = context.SyntaxProvider.CreateSyntaxProvider( + predicate: FindAddValidation, + transform: TransformAddValidation + ); + // Extract types that have been marked with [ValidatableType]. + var validatableTypesWithAttribute = context.SyntaxProvider.ForAttributeWithMetadataName( + "Microsoft.AspNetCore.Http.Validation.ValidatableTypeAttribute", + predicate: ShouldTransformSymbolWithAttribute, + transform: TransformValidatableTypeWithAttribute + ); + // Extract all minimal API endpoints in the application. + var endpoints = context.SyntaxProvider + .CreateSyntaxProvider( + predicate: FindEndpoints, + transform: TransformEndpoints) + .Where(endpoint => endpoint is not null); + // Extract validatable types from all endpoints. + var validatableTypesFromEndpoints = endpoints + .Combine(requiredSymbols) + .Select(ExtractValidatableEndpoint); + // Join all validatable types encountered in the type graph. + var validatableTypes = validatableTypesWithAttribute + .Concat(validatableTypesFromEndpoints) + .Distinct(ValidatableTypeComparer.Instance) + .Collect(); + + var emitInputs = addValidation + .Combine(validatableTypes); + + // Emit the IValidatableInfo resolver injection and + // ValidatableTypeInfo for all validatable types. + context.RegisterSourceOutput(emitInputs, Emit); } } diff --git a/src/Http/Http.Extensions/test/ValidationEndpointConventionBuilderExtensionsTests.cs b/src/Http/Http.Extensions/test/ValidationEndpointConventionBuilderExtensionsTests.cs new file mode 100644 index 000000000000..c871c407bd31 --- /dev/null +++ b/src/Http/Http.Extensions/test/ValidationEndpointConventionBuilderExtensionsTests.cs @@ -0,0 +1,78 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.ComponentModel.DataAnnotations; +using System.Text; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Http.Metadata; +using Microsoft.AspNetCore.InternalTesting; +using Microsoft.AspNetCore.Routing; +using Microsoft.Extensions.DependencyInjection; + +namespace Microsoft.AspNetCore.Http.Extensions.Tests; + +public class ValidationEndpointConventionBuilderExtensionsTests : LoggedTest +{ + [Fact] + public async Task DisableValidation_PreventsValidationFilterRegistration() + { + // Arrange + var services = new ServiceCollection(); + services.AddValidation(); + services.AddSingleton(LoggerFactory); + var serviceProvider = services.BuildServiceProvider(); + + var builder = new DefaultEndpointRouteBuilder(new ApplicationBuilder(serviceProvider)); + + // Act - Create two endpoints - one with validation disabled, one without + var regularBuilder = builder.MapGet("test-enabled", ([Range(5, 10)] int param) => "Validation enabled here."); + var disabledBuilder = builder.MapGet("test-disabled", ([Range(5, 10)] int param) => "Validation disabled here."); + + disabledBuilder.DisableValidation(); + + // Build the endpoints + var dataSource = Assert.Single(builder.DataSources); + var endpoints = dataSource.Endpoints; + + // Assert + Assert.Equal(2, endpoints.Count); + + // Get filter factories from both endpoints + var regularEndpoint = endpoints[0]; + var disabledEndpoint = endpoints[1]; + + // Verify the disabled endpoint has the IDisableValidationMetadata + Assert.Contains(disabledEndpoint.Metadata, m => m is IDisableValidationMetadata); + + // Verify that invalid arguments on the disabled endpoint do not trigger validation + var context = new DefaultHttpContext + { + RequestServices = serviceProvider + }; + context.Request.Method = "GET"; + context.Request.QueryString = new QueryString("?param=15"); + var ms = new MemoryStream(); + context.Response.Body = ms; + + await disabledEndpoint.RequestDelegate(context); + Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode); + Assert.Equal("Validation disabled here.", Encoding.UTF8.GetString(ms.ToArray())); + + context = new DefaultHttpContext + { + RequestServices = serviceProvider + }; + context.Request.Method = "GET"; + context.Request.QueryString = new QueryString("?param=15"); + await regularEndpoint.RequestDelegate(context); + Assert.Equal(StatusCodes.Status400BadRequest, context.Response.StatusCode); + } + + private class DefaultEndpointRouteBuilder(IApplicationBuilder applicationBuilder) : IEndpointRouteBuilder + { + private IApplicationBuilder ApplicationBuilder { get; } = applicationBuilder ?? throw new ArgumentNullException(nameof(applicationBuilder)); + public IApplicationBuilder CreateApplicationBuilder() => ApplicationBuilder.New(); + public ICollection DataSources { get; } = []; + public IServiceProvider ServiceProvider => ApplicationBuilder.ApplicationServices; + } +} diff --git a/src/Http/Http.Extensions/test/ValidationsGenerator/ValidationsGenerator.ComplexType.cs b/src/Http/Http.Extensions/test/ValidationsGenerator/ValidationsGenerator.ComplexType.cs new file mode 100644 index 000000000000..5dd548cf184c --- /dev/null +++ b/src/Http/Http.Extensions/test/ValidationsGenerator/ValidationsGenerator.ComplexType.cs @@ -0,0 +1,373 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.AspNetCore.Http.ValidationsGenerator.Tests; + +public partial class ValidationsGeneratorTests : ValidationsGeneratorTestBase +{ + [Fact] + public async Task CanValidateComplexTypes() + { + // Arrange + var source = """ +using System; +using System.ComponentModel.DataAnnotations; +using System.Collections.Generic; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Http.Validation; +using Microsoft.AspNetCore.Routing; +using Microsoft.Extensions.DependencyInjection; + +var builder = WebApplication.CreateBuilder(); + +builder.Services.AddValidation(); + +var app = builder.Build(); + +app.MapPost("/complex-type", (ComplexType complexType) => Results.Ok("Passed"!)); + +app.Run(); + +public class ComplexType +{ + [Range(10, 100)] + public int IntegerWithRange { get; set; } = 10; + + [Range(10, 100), Display(Name = "Valid identifier")] + public int IntegerWithRangeAndDisplayName { get; set; } = 50; + + [Required] + public SubType PropertyWithMemberAttributes { get; set; } = new SubType(); + + public SubType PropertyWithoutMemberAttributes { get; set; } = new SubType(); + + public SubTypeWithInheritance PropertyWithInheritance { get; set; } = new SubTypeWithInheritance(); + + public List ListOfSubTypes { get; set; } = []; + + [DerivedValidation(ErrorMessage = "Value must be an even number")] + public int IntegerWithDerivedValidationAttribute { get; set; } + + [CustomValidation(typeof(CustomValidators), nameof(CustomValidators.Validate))] + public int IntegerWithCustomValidation { get; set; } = 0; + + [DerivedValidation, Range(10, 100)] + public int PropertyWithMultipleAttributes { get; set; } = 10; +} + +public class DerivedValidationAttribute : ValidationAttribute +{ + public override bool IsValid(object? value) => value is int number && number % 2 == 0; +} + +public class SubType +{ + [Required] + public string RequiredProperty { get; set; } = "some-value"; + + [StringLength(10)] + public string? StringWithLength { get; set; } +} + +public class SubTypeWithInheritance : SubType +{ + [EmailAddress] + public string? EmailString { get; set; } +} + +public static class CustomValidators +{ + public static ValidationResult Validate(int number, ValidationContext validationContext) + { + var parent = (ComplexType)validationContext.ObjectInstance; + + if (parent.IntegerWithRange == number) + { + return new ValidationResult( + "Can't use the same number value in two properties on the same class.", + new[] { validationContext.MemberName }); + } + + return ValidationResult.Success; + } +} +"""; + await Verify(source, out var compilation); + await VerifyEndpoint(compilation, "/complex-type", async (endpoint, serviceProvider) => + { + await InvalidIntegerWithRangeProducesError(endpoint); + await InvalidIntegerWithRangeAndDisplayNameProducesError(endpoint); + await MissingRequiredSubtypePropertyProducesError(endpoint); + await InvalidRequiredSubtypePropertyProducesError(endpoint); + await InvalidSubTypeWithInheritancePropertyProducesError(endpoint); + await InvalidListOfSubTypesProducesError(endpoint); + await InvalidPropertyWithDerivedValidationAttributeProducesError(endpoint); + await InvalidPropertyWithMultipleAttributesProducesError(endpoint); + await InvalidPropertyWithCustomValidationProducesError(endpoint); + await ValidInputProducesNoWarnings(endpoint); + + async Task InvalidIntegerWithRangeProducesError(Endpoint endpoint) + { + + var payload = """ + { + "IntegerWithRange": 5 + } + """; + var context = CreateHttpContextWithPayload(payload, serviceProvider); + + await endpoint.RequestDelegate(context); + + var problemDetails = await AssertBadRequest(context); + Assert.Collection(problemDetails.Errors, kvp => + { + Assert.Equal("IntegerWithRange", kvp.Key); + Assert.Equal("The field IntegerWithRange must be between 10 and 100.", kvp.Value.Single()); + }); + } + + async Task InvalidIntegerWithRangeAndDisplayNameProducesError(Endpoint endpoint) + { + var payload = """ + { + "IntegerWithRangeAndDisplayName": 5 + } + """; + var context = CreateHttpContextWithPayload(payload, serviceProvider); + + await endpoint.RequestDelegate(context); + + var problemDetails = await AssertBadRequest(context); + Assert.Collection(problemDetails.Errors, kvp => + { + Assert.Equal("IntegerWithRangeAndDisplayName", kvp.Key); + Assert.Equal("The field Valid identifier must be between 10 and 100.", kvp.Value.Single()); + }); + } + + async Task MissingRequiredSubtypePropertyProducesError(Endpoint endpoint) + { + var payload = """ + { + "PropertyWithMemberAttributes": null + } + """; + var context = CreateHttpContextWithPayload(payload, serviceProvider); + + await endpoint.RequestDelegate(context); + + var problemDetails = await AssertBadRequest(context); + Assert.Collection(problemDetails.Errors, kvp => + { + Assert.Equal("PropertyWithMemberAttributes", kvp.Key); + Assert.Equal("The PropertyWithMemberAttributes field is required.", kvp.Value.Single()); + }); + } + + async Task InvalidRequiredSubtypePropertyProducesError(Endpoint endpoint) + { + var payload = """ + { + "PropertyWithMemberAttributes": { + "RequiredProperty": "", + "StringWithLength": "way-too-long" + } + } + """; + var context = CreateHttpContextWithPayload(payload, serviceProvider); + + await endpoint.RequestDelegate(context); + + var problemDetails = await AssertBadRequest(context); + Assert.Collection(problemDetails.Errors, + kvp => + { + Assert.Equal("PropertyWithMemberAttributes.RequiredProperty", kvp.Key); + Assert.Equal("The RequiredProperty field is required.", kvp.Value.Single()); + }, + kvp => + { + Assert.Equal("PropertyWithMemberAttributes.StringWithLength", kvp.Key); + Assert.Equal("The field StringWithLength must be a string with a maximum length of 10.", kvp.Value.Single()); + }); + } + + async Task InvalidSubTypeWithInheritancePropertyProducesError(Endpoint endpoint) + { + var payload = """ + { + "PropertyWithInheritance": { + "RequiredProperty": "", + "StringWithLength": "way-too-long", + "EmailString": "not-an-email" + } + } + """; + var context = CreateHttpContextWithPayload(payload, serviceProvider); + + await endpoint.RequestDelegate(context); + + var problemDetails = await AssertBadRequest(context); + Assert.Collection(problemDetails.Errors, + kvp => + { + Assert.Equal("PropertyWithInheritance.EmailString", kvp.Key); + Assert.Equal("The EmailString field is not a valid e-mail address.", kvp.Value.Single()); + }, + kvp => + { + Assert.Equal("PropertyWithInheritance.RequiredProperty", kvp.Key); + Assert.Equal("The RequiredProperty field is required.", kvp.Value.Single()); + }, + kvp => + { + Assert.Equal("PropertyWithInheritance.StringWithLength", kvp.Key); + Assert.Equal("The field StringWithLength must be a string with a maximum length of 10.", kvp.Value.Single()); + }); + } + + async Task InvalidListOfSubTypesProducesError(Endpoint endpoint) + { + var payload = """ + { + "ListOfSubTypes": [ + { + "RequiredProperty": "", + "StringWithLength": "way-too-long" + }, + { + "RequiredProperty": "valid", + "StringWithLength": "way-too-long" + }, + { + "RequiredProperty": "valid", + "StringWithLength": "valid" + } + ] + } + """; + var context = CreateHttpContextWithPayload(payload, serviceProvider); + + await endpoint.RequestDelegate(context); + + var problemDetails = await AssertBadRequest(context); + Assert.Collection(problemDetails.Errors, + kvp => + { + Assert.Equal("ListOfSubTypes[0].RequiredProperty", kvp.Key); + Assert.Equal("The RequiredProperty field is required.", kvp.Value.Single()); + }, + kvp => + { + Assert.Equal("ListOfSubTypes[0].StringWithLength", kvp.Key); + Assert.Equal("The field StringWithLength must be a string with a maximum length of 10.", kvp.Value.Single()); + }, + kvp => + { + Assert.Equal("ListOfSubTypes[1].StringWithLength", kvp.Key); + Assert.Equal("The field StringWithLength must be a string with a maximum length of 10.", kvp.Value.Single()); + }); + } + + async Task InvalidPropertyWithDerivedValidationAttributeProducesError(Endpoint endpoint) + { + var payload = """ + { + "IntegerWithDerivedValidationAttribute": 5 + } + """; + var context = CreateHttpContextWithPayload(payload, serviceProvider); + + await endpoint.RequestDelegate(context); + + var problemDetails = await AssertBadRequest(context); + Assert.Collection(problemDetails.Errors, kvp => + { + Assert.Equal("IntegerWithDerivedValidationAttribute", kvp.Key); + Assert.Equal("Value must be an even number", kvp.Value.Single()); + }); + } + + async Task InvalidPropertyWithMultipleAttributesProducesError(Endpoint endpoint) + { + var payload = """ + { + "PropertyWithMultipleAttributes": 5 + } + """; + var context = CreateHttpContextWithPayload(payload, serviceProvider); + + await endpoint.RequestDelegate(context); + + var problemDetails = await AssertBadRequest(context); + Assert.Collection(problemDetails.Errors, kvp => + { + Assert.Equal("PropertyWithMultipleAttributes", kvp.Key); + Assert.Collection(kvp.Value, + error => + { + Assert.Equal("The field PropertyWithMultipleAttributes is invalid.", error); + }, + error => + { + Assert.Equal("The field PropertyWithMultipleAttributes must be between 10 and 100.", error); + }); + }); + } + + async Task InvalidPropertyWithCustomValidationProducesError(Endpoint endpoint) + { + var payload = """ + { + "IntegerWithRange": 42, + "IntegerWithCustomValidation": 42 + } + """; + var context = CreateHttpContextWithPayload(payload, serviceProvider); + + await endpoint.RequestDelegate(context); + + var problemDetails = await AssertBadRequest(context); + Assert.Collection(problemDetails.Errors, kvp => + { + Assert.Equal("IntegerWithCustomValidation", kvp.Key); + var error = Assert.Single(kvp.Value); + Assert.Equal("Can't use the same number value in two properties on the same class.", error); + }); + } + + async Task ValidInputProducesNoWarnings(Endpoint endpoint) + { + var payload = """ + { + "IntegerWithRange": 50, + "IntegerWithRangeAndDisplayName": 50, + "PropertyWithMemberAttributes": { + "RequiredProperty": "valid", + "StringWithLength": "valid" + }, + "PropertyWithoutMemberAttributes": { + "RequiredProperty": "valid", + "StringWithLength": "valid" + }, + "PropertyWithInheritance": { + "RequiredProperty": "valid", + "StringWithLength": "valid", + "EmailString": "test@example.com" + }, + "ListOfSubTypes": [], + "IntegerWithDerivedValidationAttribute": 2, + "IntegerWithCustomValidation": 0, + "PropertyWithMultipleAttributes": 12 + } + """; + var context = CreateHttpContextWithPayload(payload, serviceProvider); + await endpoint.RequestDelegate(context); + + Assert.Equal(200, context.Response.StatusCode); + } + }); + } +} diff --git a/src/Http/Http.Extensions/test/ValidationsGenerator/ValidationsGenerator.IValidatableObject.cs b/src/Http/Http.Extensions/test/ValidationsGenerator/ValidationsGenerator.IValidatableObject.cs new file mode 100644 index 000000000000..30de31e208b0 --- /dev/null +++ b/src/Http/Http.Extensions/test/ValidationsGenerator/ValidationsGenerator.IValidatableObject.cs @@ -0,0 +1,193 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.AspNetCore.Http.ValidationsGenerator.Tests; + +public partial class ValidationsGeneratorTests : ValidationsGeneratorTestBase +{ + [Fact] + public async Task CanValidateIValidatableObject() + { + var source = """ +using System.Collections.Generic; +using System.ComponentModel.DataAnnotations; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Http.Validation; +using Microsoft.AspNetCore.Routing; +using Microsoft.Extensions.DependencyInjection; + +var builder = WebApplication.CreateBuilder(); +builder.Services.AddSingleton(); +builder.Services.AddValidation(); + +var app = builder.Build(); + +app.MapPost("/validatable-object", (ComplexValidatableType model) => Results.Ok()); + +app.Run(); + +public class ComplexValidatableType: IValidatableObject +{ + [Display(Name = "Value 1")] + public int Value1 { get; set; } + + [EmailAddress] + [Required] + public required string Value2 { get; set; } = "test@example.com"; + + public ValidatableSubType SubType { get; set; } = new ValidatableSubType(); + + public IEnumerable Validate(ValidationContext validationContext) + { + var rangeService = (IRangeService?)validationContext.GetService(typeof(IRangeService)); + var minimum = rangeService?.GetMinimum(); + var maximum = rangeService?.GetMaximum(); + if (Value1 < minimum || Value1 > maximum) + { + yield return new ValidationResult($"The field {nameof(Value1)} must be between {minimum} and {maximum}.", [nameof(Value1)]); + } + } +} + +public class SubType +{ + [Required] + public string RequiredProperty { get; set; } = "some-value"; + + [StringLength(10)] + public string? StringWithLength { get; set; } +} + +public class ValidatableSubType : SubType, IValidatableObject +{ + public string Value3 { get; set; } = "some-value"; + + public IEnumerable Validate(ValidationContext validationContext) + { + if (Value3 != "some-value") + { + yield return new ValidationResult($"The field {validationContext.DisplayName} must be 'some-value'.", [nameof(Value3)]); + } + } +} + +public interface IRangeService +{ + int GetMinimum(); + int GetMaximum(); +} + +public class RangeService : IRangeService +{ + public int GetMinimum() => 10; + public int GetMaximum() => 100; +} +"""; + + await Verify(source, out var compilation); + await VerifyEndpoint(compilation, "/validatable-object", async (endpoint, serviceProvider) => + { + await ValidateMethodCalledIfPropertyValidationsFail(); + await ValidateForSubtypeInvokedFirst(); + await ValidateForTopLevelInvoked(); + + async Task ValidateMethodCalledIfPropertyValidationsFail() + { + var httpContext = CreateHttpContextWithPayload(""" + { + "Value1": 5, + "Value2": "", + "SubType": { + "Value3": "foo", + "RequiredProperty": "", + "StringWithLength": "" + } + } + """, serviceProvider); + + await endpoint.RequestDelegate(httpContext); + + var problemDetails = await AssertBadRequest(httpContext); + Assert.Collection(problemDetails.Errors, + error => + { + Assert.Equal("Value2", error.Key); + Assert.Collection(error.Value, + msg => Assert.Equal("The Value2 field is required.", msg)); + }, + error => + { + Assert.Equal("SubType.RequiredProperty", error.Key); + Assert.Equal("The RequiredProperty field is required.", error.Value.Single()); + }, + error => + { + Assert.Equal("SubType.Value3", error.Key); + Assert.Equal("The field ValidatableSubType must be 'some-value'.", error.Value.Single()); + }, + error => + { + Assert.Equal("Value1", error.Key); + Assert.Equal("The field Value1 must be between 10 and 100.", error.Value.Single()); + }); + } + + async Task ValidateForSubtypeInvokedFirst() + { + var httpContext = CreateHttpContextWithPayload(""" + { + "Value1": 5, + "Value2": "test@test.com", + "SubType": { + "Value3": "foo", + "RequiredProperty": "some-value-2", + "StringWithLength": "element" + } + } + """, serviceProvider); + + await endpoint.RequestDelegate(httpContext); + + var problemDetails = await AssertBadRequest(httpContext); + Assert.Collection(problemDetails.Errors, + error => + { + Assert.Equal("SubType.Value3", error.Key); + Assert.Equal("The field ValidatableSubType must be 'some-value'.", error.Value.Single()); + }, + error => + { + Assert.Equal("Value1", error.Key); + Assert.Equal("The field Value1 must be between 10 and 100.", error.Value.Single()); + }); + } + + async Task ValidateForTopLevelInvoked() + { + var httpContext = CreateHttpContextWithPayload(""" + { + "Value1": 5, + "Value2": "test@test.com", + "SubType": { + "Value3": "some-value", + "RequiredProperty": "some-value-2", + "StringWithLength": "element" + } + } + """, serviceProvider); + + await endpoint.RequestDelegate(httpContext); + + var problemDetails = await AssertBadRequest(httpContext); + Assert.Collection(problemDetails.Errors, + error => + { + Assert.Equal("Value1", error.Key); + Assert.Equal("The field Value1 must be between 10 and 100.", error.Value.Single()); + }); + } + }); + } +} diff --git a/src/Http/Http.Extensions/test/ValidationsGenerator/ValidationsGenerator.NoOp.cs b/src/Http/Http.Extensions/test/ValidationsGenerator/ValidationsGenerator.NoOp.cs new file mode 100644 index 000000000000..4249a7455bf0 --- /dev/null +++ b/src/Http/Http.Extensions/test/ValidationsGenerator/ValidationsGenerator.NoOp.cs @@ -0,0 +1,114 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.AspNetCore.Http.ValidationsGenerator.Tests; + +public partial class ValidationsGeneratorTests : ValidationsGeneratorTestBase +{ + [Fact] + public async Task DoesNotEmitIfNoAddValidationCallExists() + { + // Arrange + var source = """ +using System; +using System.ComponentModel.DataAnnotations; +using System.Collections.Generic; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Http.Validation; +using Microsoft.AspNetCore.Routing; +using Microsoft.Extensions.DependencyInjection; + +var builder = WebApplication.CreateBuilder(); + +var app = builder.Build(); + +app.MapPost("/complex-type", (ComplexType complexType) => Results.Ok("Passed")); + +app.Run(); + +public class ComplexType +{ + [Range(10, 100)] + public int IntegerWithRange { get; set; } = 10; +} +"""; + await Verify(source, out var compilation); + // Verify that we don't validate types if no AddValidation call exists + await VerifyEndpoint(compilation, "/complex-type", async (endpoint, serviceProvider) => + { + var payload = """ + { + "IntegerWithRange": 5 + } + """; + var context = CreateHttpContextWithPayload(payload, serviceProvider); + + await endpoint.RequestDelegate(context); + + Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode); + }); + } + + [Fact] + public async Task DoesNotEmitForExemptTypes() + { + var source = """ +using System; +using System.ComponentModel.DataAnnotations; +using System.IO; +using System.IO.Pipelines; +using System.Threading; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Http.Validation; +using Microsoft.Extensions.DependencyInjection; + +var builder = WebApplication.CreateBuilder(); + +builder.Services.AddValidation(); + +var app = builder.Build(); + +app.MapGet("/exempt-1", (HttpContext context) => Results.Ok("Exempt Passed!")); +app.MapGet("/exempt-2", (HttpRequest request) => Results.Ok("Exempt Passed!")); +app.MapGet("/exempt-3", (HttpResponse response) => Results.Ok("Exempt Passed!")); +app.MapGet("/exempt-4", (IFormCollection formCollection) => Results.Ok("Exempt Passed!")); +app.MapGet("/exempt-5", (IFormFileCollection formFileCollection) => Results.Ok("Exempt Passed!")); +app.MapGet("/exempt-6", (IFormFile formFile) => Results.Ok("Exempt Passed!")); +app.MapGet("/exempt-7", (Stream stream) => Results.Ok("Exempt Passed!")); +app.MapGet("/exempt-8", (PipeReader pipeReader) => Results.Ok("Exempt Passed!")); +app.MapGet("/exempt-9", (CancellationToken cancellationToken) => Results.Ok("Exempt Passed!")); +app.MapPost("/complex-type", (ComplexType complexType) => Results.Ok("Passed")); + +app.Run(); + +public class ComplexType +{ + [Range(10, 100)] + public int IntegerWithRange { get; set; } = 10; +} +"""; + await Verify(source, out var compilation); + // Verify that we can validate non-exempt types + await VerifyEndpoint(compilation, "/complex-type", async (endpoint, serviceProvider) => + { + var payload = """ + { + "IntegerWithRange": 5 + } + """; + var context = CreateHttpContextWithPayload(payload, serviceProvider); + + await endpoint.RequestDelegate(context); + + var problemDetails = await AssertBadRequest(context); + Assert.Collection(problemDetails.Errors, kvp => + { + Assert.Equal("IntegerWithRange", kvp.Key); + Assert.Equal("The field IntegerWithRange must be between 10 and 100.", kvp.Value.Single()); + }); + }); + } +} diff --git a/src/Http/Http.Extensions/test/ValidationsGenerator/ValidationsGenerator.Parameters.cs b/src/Http/Http.Extensions/test/ValidationsGenerator/ValidationsGenerator.Parameters.cs index 6a6def95f50a..3be6e96fe7a9 100644 --- a/src/Http/Http.Extensions/test/ValidationsGenerator/ValidationsGenerator.Parameters.cs +++ b/src/Http/Http.Extensions/test/ValidationsGenerator/ValidationsGenerator.Parameters.cs @@ -11,11 +11,18 @@ public async Task CanValidateParameters() var source = """ using System; using System.ComponentModel.DataAnnotations; +using System.Collections.Generic; +using System.Linq; using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Http.Validation; using Microsoft.AspNetCore.Routing; +using Microsoft.Extensions.DependencyInjection; var builder = WebApplication.CreateBuilder(); +builder.Services.AddValidation(); + var app = builder.Build(); app.MapGet("/params", ( @@ -33,12 +40,45 @@ public class CustomValidationAttribute : ValidationAttribute } """; await Verify(source, out var compilation); - VerifyEndpoint(compilation, "/params", async endpoint => + await VerifyEndpoint(compilation, "/params", async (endpoint, serviceProvider) => { - var context = CreateHttpContext(); + var context = CreateHttpContext(serviceProvider); context.Request.QueryString = new QueryString("?value1=5&value2=5&value3=&value4=3&value5=5"); await endpoint.RequestDelegate(context); - Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode); + var problemDetails = await AssertBadRequest(context); + Assert.Collection(problemDetails.Errors, + error => + { + Assert.Equal("value1", error.Key); + Assert.Equal("The field value1 must be between 10 and 100.", error.Value.Single()); + }, + error => + { + Assert.Equal("value2", error.Key); + Assert.Equal("The field Valid identifier must be between 10 and 100.", error.Value.Single()); + }, + error => + { + Assert.Equal("value3", error.Key); + Assert.Equal("The value3 field is required.", error.Value.Single()); + }, + error => + { + Assert.Equal("value4", error.Key); + Assert.Equal("Value must be an even number", error.Value.Single()); + }, + error => + { + Assert.Equal("value5", error.Key); + Assert.Collection(error.Value, error => + { + Assert.Equal("The field value5 is invalid.", error); + }, + error => + { + Assert.Equal("The field value5 must be between 10 and 100.", error); + }); + }); }); } } diff --git a/src/Http/Http.Extensions/test/ValidationsGenerator/ValidationsGenerator.Polymorphism.cs b/src/Http/Http.Extensions/test/ValidationsGenerator/ValidationsGenerator.Polymorphism.cs new file mode 100644 index 000000000000..54148e784a0a --- /dev/null +++ b/src/Http/Http.Extensions/test/ValidationsGenerator/ValidationsGenerator.Polymorphism.cs @@ -0,0 +1,202 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.AspNetCore.Http.ValidationsGenerator.Tests; + +public partial class ValidationsGeneratorTests : ValidationsGeneratorTestBase +{ + [Fact] + public async Task CanValidatePolymorphicTypes() + { + var source = """ +using System; +using System.Collections.Generic; +using System.ComponentModel.DataAnnotations; +using System.Text.Json.Serialization; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Http.Validation; +using Microsoft.AspNetCore.Routing; +using Microsoft.Extensions.DependencyInjection; + +var builder = WebApplication.CreateBuilder(); +builder.Services.AddValidation(); + +var app = builder.Build(); + +app.MapPost("/basic-polymorphism", (BaseType model) => Results.Ok()); +app.MapPost("/validatable-polymorphism", (BaseValidatableType model) => Results.Ok()); +app.MapPost("/polymorphism-container", (ContainerType model) => Results.Ok()); + +app.Run(); + +public class ContainerType +{ + public BaseType BaseType { get; set; } = new BaseType(); + public BaseValidatableType BaseValidatableType { get; set; } = new BaseValidatableType(); +} + +[JsonDerivedType(typeof(BaseType), typeDiscriminator: "base")] +[JsonDerivedType(typeof(DerivedType), typeDiscriminator: "derived")] +public class BaseType +{ + [Display(Name = "Value 1")] + [Range(10, 100)] + public int Value1 { get; set; } + + [EmailAddress] + [Required] + public string Value2 { get; set; } = "test@example.com"; +} + +public class DerivedType : BaseType +{ + [Base64String] + public string? Value3 { get; set; } +} + +[JsonDerivedType(typeof(BaseValidatableType), typeDiscriminator: "base")] +[JsonDerivedType(typeof(DerivedValidatableType), typeDiscriminator: "derived")] +public class BaseValidatableType : IValidatableObject +{ + [Display(Name = "Value 1")] + public int Value1 { get; set; } + + public IEnumerable Validate(ValidationContext validationContext) + { + if (Value1 < 10 || Value1 > 100) + { + yield return new ValidationResult("The field Value 1 must be between 10 and 100.", new[] { nameof(Value1) }); + } + } +} + +public class DerivedValidatableType : BaseValidatableType +{ + [EmailAddress] + public required string Value3 { get; set; } +} +"""; + await Verify(source, out var compilation); + + await VerifyEndpoint(compilation, "/basic-polymorphism", async (endpoint, serviceProvider) => + { + var httpContext = CreateHttpContextWithPayload(""" + { + "$type": "derived", + "Value1": 5, + "Value2": "invalid-email", + "Value3": "invalid-base64" + } + """, serviceProvider); + + await endpoint.RequestDelegate(httpContext); + + var problemDetails = await AssertBadRequest(httpContext); + Assert.Collection(problemDetails.Errors, + error => + { + Assert.Equal("Value3", error.Key); + Assert.Equal("The Value3 field is not a valid Base64 encoding.", error.Value.Single()); + }, + error => + { + Assert.Equal("Value1", error.Key); + Assert.Equal("The field Value 1 must be between 10 and 100.", error.Value.Single()); + }, + error => + { + Assert.Equal("Value2", error.Key); + Assert.Equal("The Value2 field is not a valid e-mail address.", error.Value.Single()); + }); + }); + + await VerifyEndpoint(compilation, "/validatable-polymorphism", async (endpoint, serviceProvider) => + { + var httpContext = CreateHttpContextWithPayload(""" + { + "$type": "derived", + "Value1": 5, + "Value3": "invalid-email" + } + """, serviceProvider); + + await endpoint.RequestDelegate(httpContext); + + var problemDetails = await AssertBadRequest(httpContext); + Assert.Collection(problemDetails.Errors, + error => + { + Assert.Equal("Value3", error.Key); + Assert.Equal("The Value3 field is not a valid e-mail address.", error.Value.Single()); + }, + error => + { + Assert.Equal("Value1", error.Key); + Assert.Equal("The field Value 1 must be between 10 and 100.", error.Value.Single()); + }); + + httpContext = CreateHttpContextWithPayload(""" + { + "$type": "derived", + "Value1": 5, + "Value3": "test@example.com" + } + """, serviceProvider); + + await endpoint.RequestDelegate(httpContext); + + var problemDetails1 = await AssertBadRequest(httpContext); + Assert.Collection(problemDetails1.Errors, + error => + { + Assert.Equal("Value1", error.Key); + Assert.Equal("The field Value 1 must be between 10 and 100.", error.Value.Single()); + }); + }); + + await VerifyEndpoint(compilation, "/polymorphism-container", async (endpoint, serviceProvider) => + { + var httpContext = CreateHttpContextWithPayload(""" + { + "BaseType": { + "$type": "derived", + "Value1": 5, + "Value2": "invalid-email", + "Value3": "invalid-base64" + }, + "BaseValidatableType": { + "$type": "derived", + "Value1": 5, + "Value3": "test@example.com" + } + } + """, serviceProvider); + + await endpoint.RequestDelegate(httpContext); + + var problemDetails = await AssertBadRequest(httpContext); + Assert.Collection(problemDetails.Errors, + error => + { + Assert.Equal("BaseType.Value3", error.Key); + Assert.Equal("The Value3 field is not a valid Base64 encoding.", error.Value.Single()); + }, + error => + { + Assert.Equal("BaseType.Value1", error.Key); + Assert.Equal("The field Value 1 must be between 10 and 100.", error.Value.Single()); + }, + error => + { + Assert.Equal("BaseType.Value2", error.Key); + Assert.Equal("The Value2 field is not a valid e-mail address.", error.Value.Single()); + }, + error => + { + Assert.Equal("BaseValidatableType.Value1", error.Key); + Assert.Equal("The field Value 1 must be between 10 and 100.", error.Value.Single()); + }); + }); + } +} diff --git a/src/Http/Http.Extensions/test/ValidationsGenerator/ValidationsGenerator.Recursion.cs b/src/Http/Http.Extensions/test/ValidationsGenerator/ValidationsGenerator.Recursion.cs new file mode 100644 index 000000000000..4affa35f8997 --- /dev/null +++ b/src/Http/Http.Extensions/test/ValidationsGenerator/ValidationsGenerator.Recursion.cs @@ -0,0 +1,158 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.AspNetCore.Http.ValidationsGenerator.Tests; + +public partial class ValidationsGeneratorTests : ValidationsGeneratorTestBase +{ + [Fact] + public async Task CanValidateRecursiveTypes() + { + var source = """ +using System.ComponentModel.DataAnnotations; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Routing; +using Microsoft.Extensions.DependencyInjection; + +var builder = WebApplication.CreateBuilder(); +builder.Services.AddValidation(options => +{ + options.MaxDepth = 8; +}); + +var app = builder.Build(); + +app.MapPost("/recursive-type", (RecursiveType model) => Results.Ok()); + +app.Run(); + +public class RecursiveType +{ + [Range(10, 100)] + public int Value { get; set; } + public RecursiveType? Next { get; set; } +} +"""; + await Verify(source, out var compilation); + + await VerifyEndpoint(compilation, "/recursive-type", async (endpoint, serviceProvider) => + { + await ThrowsExceptionForDeeplyNestedType(endpoint); + await ValidatesTypeWithLimitedNesting(endpoint); + + async Task ThrowsExceptionForDeeplyNestedType(Endpoint endpoint) + { + var httpContext = CreateHttpContextWithPayload(""" + { + "value": 1, + "next": { + "value": 2, + "next": { + "value": 3, + "next": { + "value": 4, + "next": { + "value": 5, + "next": { + "value": 6, + "next": { + "value": 7, + "next": { + "value": 8, + "next": { + "value": 9, + "next": { + "value": 10 + } + } + } + } + } + } + } + } + } + } + """, serviceProvider); + + var exception = await Assert.ThrowsAsync(async () => await endpoint.RequestDelegate(httpContext)); + } + + async Task ValidatesTypeWithLimitedNesting(Endpoint endpoint) + { + var httpContext = CreateHttpContextWithPayload(""" + { + "value": 1, + "next": { + "value": 2, + "next": { + "value": 3, + "next": { + "value": 4, + "next": { + "value": 5, + "next": { + "value": 6, + "next": { + "value": 7, + "next": { + "value": 8 + } + } + } + } + } + } + } + } + """, serviceProvider); + + await endpoint.RequestDelegate(httpContext); + + var problemDetails = await AssertBadRequest(httpContext); + Assert.Collection(problemDetails.Errors, + error => + { + Assert.Equal("Value", error.Key); + Assert.Equal("The field Value must be between 10 and 100.", error.Value.Single()); + }, + error => + { + Assert.Equal("Next.Value", error.Key); + Assert.Equal("The field Value must be between 10 and 100.", error.Value.Single()); + }, + error => + { + Assert.Equal("Next.Next.Value", error.Key); + Assert.Equal("The field Value must be between 10 and 100.", error.Value.Single()); + }, + error => + { + Assert.Equal("Next.Next.Next.Value", error.Key); + Assert.Equal("The field Value must be between 10 and 100.", error.Value.Single()); + }, + error => + { + Assert.Equal("Next.Next.Next.Next.Value", error.Key); + Assert.Equal("The field Value must be between 10 and 100.", error.Value.Single()); + }, + error => + { + Assert.Equal("Next.Next.Next.Next.Next.Value", error.Key); + Assert.Equal("The field Value must be between 10 and 100.", error.Value.Single()); + }, + error => + { + Assert.Equal("Next.Next.Next.Next.Next.Next.Value", error.Key); + Assert.Equal("The field Value must be between 10 and 100.", error.Value.Single()); + }, + error => + { + Assert.Equal("Next.Next.Next.Next.Next.Next.Next.Value", error.Key); + Assert.Equal("The field Value must be between 10 and 100.", error.Value.Single()); + }); + } + }); + } +} diff --git a/src/Http/Http.Extensions/test/ValidationsGenerator/ValidationsGenerator.ValidatableType.cs b/src/Http/Http.Extensions/test/ValidationsGenerator/ValidationsGenerator.ValidatableType.cs new file mode 100644 index 000000000000..4628c9574004 --- /dev/null +++ b/src/Http/Http.Extensions/test/ValidationsGenerator/ValidationsGenerator.ValidatableType.cs @@ -0,0 +1,377 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.ComponentModel.DataAnnotations; +using Microsoft.AspNetCore.Http.Validation; + +namespace Microsoft.AspNetCore.Http.ValidationsGenerator.Tests; + +public partial class ValidationsGeneratorTests : ValidationsGeneratorTestBase +{ + [Fact] + public async Task CanValidateTypesWithAttribute() + { + var source = """ +using System; +using System.ComponentModel.DataAnnotations; +using System.Collections.Generic; +using System.Linq; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Http.Validation; +using Microsoft.AspNetCore.Routing; +using Microsoft.Extensions.DependencyInjection; + +var builder = WebApplication.CreateBuilder(); + +builder.Services.AddValidation(); + +var app = builder.Build(); + +app.Run(); + +[ValidatableType] +public class ComplexType +{ + [Range(10, 100)] + public int IntegerWithRange { get; set; } = 10; + + [Range(10, 100), Display(Name = "Valid identifier")] + public int IntegerWithRangeAndDisplayName { get; set; } = 50; + + [Required] + public SubType PropertyWithMemberAttributes { get; set; } = new SubType(); + + public SubType PropertyWithoutMemberAttributes { get; set; } = new SubType(); + + public SubTypeWithInheritance PropertyWithInheritance { get; set; } = new SubTypeWithInheritance(); + + public List ListOfSubTypes { get; set; } = []; + + [CustomValidation(ErrorMessage = "Value must be an even number")] + public int IntegerWithCustomValidationAttribute { get; set; } + + [CustomValidation, Range(10, 100)] + public int PropertyWithMultipleAttributes { get; set; } = 10; +} + +public class CustomValidationAttribute : ValidationAttribute +{ + public override bool IsValid(object? value) => value is int number && number % 2 == 0; +} + +public class SubType +{ + [Required] + public string RequiredProperty { get; set; } = "some-value"; + + [StringLength(10)] + public string? StringWithLength { get; set; } +} + +public class SubTypeWithInheritance : SubType +{ + [EmailAddress] + public string? EmailString { get; set; } +} +"""; + await Verify(source, out var compilation); + VerifyValidatableType(compilation, "ComplexType", async (validationOptions, type) => + { + Assert.True(validationOptions.TryGetValidatableTypeInfo(type, out var validatableTypeInfo)); + + await InvalidIntegerWithRangeProducesError(validatableTypeInfo); + await InvalidIntegerWithRangeAndDisplayNameProducesError(validatableTypeInfo); + await MissingRequiredSubtypePropertyProducesError(validatableTypeInfo); + await InvalidRequiredSubtypePropertyProducesError(validatableTypeInfo); + await InvalidSubTypeWithInheritancePropertyProducesError(validatableTypeInfo); + await InvalidListOfSubTypesProducesError(validatableTypeInfo); + await InvalidPropertyWithDerivedValidationAttributeProducesError(validatableTypeInfo); + await InvalidPropertyWithMultipleAttributesProducesError(validatableTypeInfo); + await InvalidPropertyWithCustomValidationProducesError(validatableTypeInfo); + await ValidInputProducesNoWarnings(validatableTypeInfo); + + async Task InvalidIntegerWithRangeProducesError(IValidatableInfo validatableInfo) + { + var instance = Activator.CreateInstance(type); + type.GetProperty("IntegerWithRange")?.SetValue(instance, 5); + var context = new ValidateContext + { + ValidationOptions = validationOptions, + ValidationContext = new ValidationContext(instance) + }; + + await validatableTypeInfo.ValidateAsync(instance, context, CancellationToken.None); + + Assert.Collection(context.ValidationErrors, kvp => + { + Assert.Equal("IntegerWithRange", kvp.Key); + Assert.Equal("The field IntegerWithRange must be between 10 and 100.", kvp.Value.Single()); + }); + } + + async Task InvalidIntegerWithRangeAndDisplayNameProducesError(IValidatableInfo validatableInfo) + { + var instance = Activator.CreateInstance(type); + type.GetProperty("IntegerWithRangeAndDisplayName")?.SetValue(instance, 5); + var context = new ValidateContext + { + ValidationOptions = validationOptions, + ValidationContext = new ValidationContext(instance) + }; + + await validatableInfo.ValidateAsync(instance, context, CancellationToken.None); + + Assert.Collection(context.ValidationErrors, kvp => + { + Assert.Equal("IntegerWithRangeAndDisplayName", kvp.Key); + Assert.Equal("The field Valid identifier must be between 10 and 100.", kvp.Value.Single()); + }); + } + + async Task MissingRequiredSubtypePropertyProducesError(IValidatableInfo validatableInfo) + { + var instance = Activator.CreateInstance(type); + type.GetProperty("PropertyWithMemberAttributes")?.SetValue(instance, null); + var context = new ValidateContext + { + ValidationOptions = validationOptions, + ValidationContext = new ValidationContext(instance) + }; + + await validatableInfo.ValidateAsync(instance, context, CancellationToken.None); + + Assert.Collection(context.ValidationErrors, kvp => + { + Assert.Equal("PropertyWithMemberAttributes", kvp.Key); + Assert.Equal("The PropertyWithMemberAttributes field is required.", kvp.Value.Single()); + }); + } + + async Task InvalidRequiredSubtypePropertyProducesError(IValidatableInfo validatableInfo) + { + var instance = Activator.CreateInstance(type); + var subType = Activator.CreateInstance(type.Assembly.GetType("SubType")!); + subType.GetType().GetProperty("RequiredProperty")?.SetValue(subType, ""); + subType.GetType().GetProperty("StringWithLength")?.SetValue(subType, "way-too-long"); + type.GetProperty("PropertyWithMemberAttributes")?.SetValue(instance, subType); + var context = new ValidateContext + { + ValidationOptions = validationOptions, + ValidationContext = new ValidationContext(instance) + }; + + await validatableInfo.ValidateAsync(instance, context, CancellationToken.None); + + Assert.Collection(context.ValidationErrors, + kvp => + { + Assert.Equal("PropertyWithMemberAttributes.RequiredProperty", kvp.Key); + Assert.Equal("The RequiredProperty field is required.", kvp.Value.Single()); + }, + kvp => + { + Assert.Equal("PropertyWithMemberAttributes.StringWithLength", kvp.Key); + Assert.Equal("The field StringWithLength must be a string with a maximum length of 10.", kvp.Value.Single()); + }); + } + + async Task InvalidSubTypeWithInheritancePropertyProducesError(IValidatableInfo validatableInfo) + { + var instance = Activator.CreateInstance(type); + var inheritanceType = Activator.CreateInstance(type.Assembly.GetType("SubTypeWithInheritance")!); + inheritanceType.GetType().GetProperty("RequiredProperty")?.SetValue(inheritanceType, ""); + inheritanceType.GetType().GetProperty("StringWithLength")?.SetValue(inheritanceType, "way-too-long"); + inheritanceType.GetType().GetProperty("EmailString")?.SetValue(inheritanceType, "not-an-email"); + type.GetProperty("PropertyWithInheritance")?.SetValue(instance, inheritanceType); + var context = new ValidateContext + { + ValidationOptions = validationOptions, + ValidationContext = new ValidationContext(instance) + }; + + await validatableInfo.ValidateAsync(instance, context, CancellationToken.None); + + Assert.Collection(context.ValidationErrors, + kvp => + { + Assert.Equal("PropertyWithInheritance.EmailString", kvp.Key); + Assert.Equal("The EmailString field is not a valid e-mail address.", kvp.Value.Single()); + }, + kvp => + { + Assert.Equal("PropertyWithInheritance.RequiredProperty", kvp.Key); + Assert.Equal("The RequiredProperty field is required.", kvp.Value.Single()); + }, + kvp => + { + Assert.Equal("PropertyWithInheritance.StringWithLength", kvp.Key); + Assert.Equal("The field StringWithLength must be a string with a maximum length of 10.", kvp.Value.Single()); + }); + } + + async Task InvalidListOfSubTypesProducesError(IValidatableInfo validatableInfo) + { + var instance = Activator.CreateInstance(type); + var subTypeList = Activator.CreateInstance(typeof(List<>).MakeGenericType(type.Assembly.GetType("SubType")!)); + + // Create first invalid item + var subType1 = Activator.CreateInstance(type.Assembly.GetType("SubType")!); + subType1.GetType().GetProperty("RequiredProperty")?.SetValue(subType1, ""); + subType1.GetType().GetProperty("StringWithLength")?.SetValue(subType1, "way-too-long"); + + // Create second invalid item + var subType2 = Activator.CreateInstance(type.Assembly.GetType("SubType")!); + subType2.GetType().GetProperty("RequiredProperty")?.SetValue(subType2, "valid"); + subType2.GetType().GetProperty("StringWithLength")?.SetValue(subType2, "way-too-long"); + + // Create valid item + var subType3 = Activator.CreateInstance(type.Assembly.GetType("SubType")!); + subType3.GetType().GetProperty("RequiredProperty")?.SetValue(subType3, "valid"); + subType3.GetType().GetProperty("StringWithLength")?.SetValue(subType3, "valid"); + + // Add to list + subTypeList.GetType().GetMethod("Add")?.Invoke(subTypeList, [subType1]); + subTypeList.GetType().GetMethod("Add")?.Invoke(subTypeList, [subType2]); + subTypeList.GetType().GetMethod("Add")?.Invoke(subTypeList, [subType3]); + + type.GetProperty("ListOfSubTypes")?.SetValue(instance, subTypeList); + var context = new ValidateContext + { + ValidationOptions = validationOptions, + ValidationContext = new ValidationContext(instance) + }; + + await validatableInfo.ValidateAsync(instance, context, CancellationToken.None); + + Assert.Collection(context.ValidationErrors, + kvp => + { + Assert.Equal("ListOfSubTypes[0].RequiredProperty", kvp.Key); + Assert.Equal("The RequiredProperty field is required.", kvp.Value.Single()); + }, + kvp => + { + Assert.Equal("ListOfSubTypes[0].StringWithLength", kvp.Key); + Assert.Equal("The field StringWithLength must be a string with a maximum length of 10.", kvp.Value.Single()); + }, + kvp => + { + Assert.Equal("ListOfSubTypes[1].StringWithLength", kvp.Key); + Assert.Equal("The field StringWithLength must be a string with a maximum length of 10.", kvp.Value.Single()); + }); + } + + async Task InvalidPropertyWithDerivedValidationAttributeProducesError(IValidatableInfo validatableInfo) + { + var instance = Activator.CreateInstance(type); + type.GetProperty("IntegerWithCustomValidationAttribute")?.SetValue(instance, 5); // Odd number, should fail + var context = new ValidateContext + { + ValidationOptions = validationOptions, + ValidationContext = new ValidationContext(instance) + }; + + await validatableInfo.ValidateAsync(instance, context, CancellationToken.None); + + Assert.Collection(context.ValidationErrors, kvp => + { + Assert.Equal("IntegerWithCustomValidationAttribute", kvp.Key); + Assert.Equal("Value must be an even number", kvp.Value.Single()); + }); + } + + async Task InvalidPropertyWithMultipleAttributesProducesError(IValidatableInfo validatableInfo) + { + var instance = Activator.CreateInstance(type); + type.GetProperty("PropertyWithMultipleAttributes")?.SetValue(instance, 5); + var context = new ValidateContext + { + ValidationOptions = validationOptions, + ValidationContext = new ValidationContext(instance) + }; + + await validatableInfo.ValidateAsync(instance, context, CancellationToken.None); + + Assert.Collection(context.ValidationErrors, kvp => + { + Assert.Equal("PropertyWithMultipleAttributes", kvp.Key); + Assert.Collection(kvp.Value, + error => + { + Assert.Equal("The field PropertyWithMultipleAttributes is invalid.", error); + }, + error => + { + Assert.Equal("The field PropertyWithMultipleAttributes must be between 10 and 100.", error); + }); + }); + } + + async Task InvalidPropertyWithCustomValidationProducesError(IValidatableInfo validatableInfo) + { + var instance = Activator.CreateInstance(type); + type.GetProperty("IntegerWithCustomValidationAttribute")?.SetValue(instance, 3); // Odd number should fail + var context = new ValidateContext + { + ValidationOptions = validationOptions, + ValidationContext = new ValidationContext(instance) + }; + + await validatableInfo.ValidateAsync(instance, context, CancellationToken.None); + + Assert.Collection(context.ValidationErrors, kvp => + { + Assert.Equal("IntegerWithCustomValidationAttribute", kvp.Key); + Assert.Equal("Value must be an even number", kvp.Value.Single()); + }); + } + + async Task ValidInputProducesNoWarnings(IValidatableInfo validatableInfo) + { + var instance = Activator.CreateInstance(type); + + // Set all properties with valid values + type.GetProperty("IntegerWithRange")?.SetValue(instance, 50); + type.GetProperty("IntegerWithRangeAndDisplayName")?.SetValue(instance, 50); + + // Create and set PropertyWithMemberAttributes + var subType1 = Activator.CreateInstance(type.Assembly.GetType("SubType")!); + subType1.GetType().GetProperty("RequiredProperty")?.SetValue(subType1, "valid"); + subType1.GetType().GetProperty("StringWithLength")?.SetValue(subType1, "valid"); + type.GetProperty("PropertyWithMemberAttributes")?.SetValue(instance, subType1); + + // Create and set PropertyWithoutMemberAttributes + var subType2 = Activator.CreateInstance(type.Assembly.GetType("SubType")!); + subType2.GetType().GetProperty("RequiredProperty")?.SetValue(subType2, "valid"); + subType2.GetType().GetProperty("StringWithLength")?.SetValue(subType2, "valid"); + type.GetProperty("PropertyWithoutMemberAttributes")?.SetValue(instance, subType2); + + // Create and set PropertyWithInheritance + var inheritanceType = Activator.CreateInstance(type.Assembly.GetType("SubTypeWithInheritance")!); + inheritanceType.GetType().GetProperty("RequiredProperty")?.SetValue(inheritanceType, "valid"); + inheritanceType.GetType().GetProperty("StringWithLength")?.SetValue(inheritanceType, "valid"); + inheritanceType.GetType().GetProperty("EmailString")?.SetValue(inheritanceType, "test@example.com"); + type.GetProperty("PropertyWithInheritance")?.SetValue(instance, inheritanceType); + + // Create empty list for ListOfSubTypes + var emptyList = Activator.CreateInstance(typeof(List<>).MakeGenericType(type.Assembly.GetType("SubType")!)); + type.GetProperty("ListOfSubTypes")?.SetValue(instance, emptyList); + + // Set custom validation attributes + type.GetProperty("IntegerWithCustomValidationAttribute")?.SetValue(instance, 2); // Even number should pass + type.GetProperty("PropertyWithMultipleAttributes")?.SetValue(instance, 12); + + var context = new ValidateContext + { + ValidationOptions = validationOptions, + ValidationContext = new ValidationContext(instance) + }; + + await validatableInfo.ValidateAsync(instance, context, CancellationToken.None); + + Assert.Null(context.ValidationErrors); + } + }); + } +} \ No newline at end of file diff --git a/src/Http/Http.Extensions/test/ValidationsGenerator/ValidationsGeneratorTestBase.cs b/src/Http/Http.Extensions/test/ValidationsGenerator/ValidationsGeneratorTestBase.cs index b8740de3e928..2310c2c3aaf9 100644 --- a/src/Http/Http.Extensions/test/ValidationsGenerator/ValidationsGeneratorTestBase.cs +++ b/src/Http/Http.Extensions/test/ValidationsGenerator/ValidationsGeneratorTestBase.cs @@ -2,12 +2,16 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Diagnostics; +using System.Globalization; using System.Reflection; using System.Runtime.Loader; using System.Text; +using System.Text.Json; +using System.Text.RegularExpressions; using Microsoft.AspNetCore.Builder; using Microsoft.AspNetCore.Hosting.Server; using Microsoft.AspNetCore.Http.Features; +using Microsoft.AspNetCore.Http.Validation; using Microsoft.AspNetCore.InternalTesting; using Microsoft.AspNetCore.Routing; using Microsoft.CodeAnalysis; @@ -16,16 +20,19 @@ using Microsoft.CodeAnalysis.Text; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Hosting; -using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; +using static Microsoft.AspNetCore.Http.Generators.Tests.RequestDelegateCreationTestBase; namespace Microsoft.AspNetCore.Http.ValidationsGenerator.Tests; [UsesVerify] -public class ValidationsGeneratorTestBase : LoggedTestBase +public partial class ValidationsGeneratorTestBase : LoggedTestBase { + [GeneratedRegex(@"\[global::System\.Runtime\.CompilerServices\.InterceptsLocationAttribute\([^)]*\)\]")] + private static partial Regex InterceptsLocationRegex(); + private static readonly CSharpParseOptions ParseOptions = new CSharpParseOptions(LanguageVersion.Preview) - .WithFeatures([new KeyValuePair("InterceptorsNamespaces", "Microsoft.AspNetCore.Http.Validations.Generated")]); + .WithFeatures([new KeyValuePair("InterceptorsNamespaces", "Microsoft.AspNetCore.Http.Validation.Generated")]); internal static Task Verify(string source, out Compilation compilation) { @@ -50,9 +57,11 @@ internal static Task Verify(string source, out Compilation compilation) MetadataReference.CreateFromFile(typeof(ValidateOptionsResult).Assembly.Location), MetadataReference.CreateFromFile(typeof(IHttpMethodMetadata).Assembly.Location), MetadataReference.CreateFromFile(typeof(IResult).Assembly.Location), - MetadataReference.CreateFromFile(typeof(HttpJsonServiceExtensions).Assembly.Location) + MetadataReference.CreateFromFile(typeof(HttpJsonServiceExtensions).Assembly.Location), + MetadataReference.CreateFromFile(typeof(IValidatableInfoResolver).Assembly.Location), + MetadataReference.CreateFromFile(typeof(EndpointFilterFactoryContext).Assembly.Location), ]); - var inputCompilation = CSharpCompilation.Create("OpenApiXmlCommentGeneratorSample", + var inputCompilation = CSharpCompilation.Create("ValidationsGeneratorSample", [CSharpSyntaxTree.ParseText(source, options: ParseOptions, path: "Program.cs")], references, new CSharpCompilationOptions(OutputKind.ConsoleApplication)); @@ -60,13 +69,48 @@ internal static Task Verify(string source, out Compilation compilation) var driver = CSharpGeneratorDriver.Create(generators: [generator.AsSourceGenerator()], parseOptions: ParseOptions); return Verifier .Verify(driver.RunGeneratorsAndUpdateCompilation(inputCompilation, out compilation, out var diagnostics)) + .ScrubLinesWithReplace(line => InterceptsLocationRegex().Replace(line, "[InterceptsLocation]")) .UseDirectory(SkipOnHelixAttribute.OnHelix() ? Path.Combine(Environment.GetEnvironmentVariable("HELIX_WORKITEM_ROOT"), "ValidationsGenerator", "snapshots") : "snapshots"); } - internal static void VerifyEndpoint(Compilation compilation, string routePattern, Action verifyFunc) + internal static void VerifyValidatableType(Compilation compilation, string typeName, Action verifyFunc) + { + if (TryResolveServicesFromCompilation(compilation, targetAssemblyName: "Microsoft.AspNetCore.Http.Abstractions", typeName: "Microsoft.AspNetCore.Http.Validation.ValidationOptions", out var services, out var serviceType, out var outputAssemblyName) is false) + { + throw new InvalidOperationException("Could not resolve services from compilation."); + } + var targetAssembly = AppDomain.CurrentDomain.GetAssemblies().FirstOrDefault(assembly => assembly.GetName().Name == outputAssemblyName); + var type = targetAssembly.GetType(typeName, throwOnError: false); + + // Get IOptions first + var optionsType = typeof(IOptions<>).MakeGenericType(serviceType); + var optionsInstance = services.GetService(optionsType) ?? throw new InvalidOperationException("Could not resolve IOptions."); + + // Then access the Value property + var valueProperty = optionsType.GetProperty("Value"); + var service = (ValidationOptions)valueProperty.GetValue(optionsInstance) ?? throw new InvalidOperationException("Could not resolve ValidationOptions."); + verifyFunc(service, type); + } + + internal static async Task VerifyEndpoint(Compilation compilation, string routePattern, Func verifyFunc) + { + if (TryResolveServicesFromCompilation(compilation, targetAssemblyName: "Microsoft.AspNetCore.Routing", typeName: "Microsoft.AspNetCore.Routing.EndpointDataSource", out var services, out var serviceType, out var outputAssemblyName) is false) + { + throw new InvalidOperationException("Could not resolve services from compilation."); + } + var service = services.GetService(serviceType) ?? throw new InvalidOperationException("Could not resolve EndpointDataSource."); + var endpoints = (IReadOnlyList)service.GetType().GetProperty("Endpoints", BindingFlags.Instance | BindingFlags.Public).GetValue(service); + var endpoint = endpoints.FirstOrDefault(endpoint => endpoint is RouteEndpoint routeEndpoint && routeEndpoint.RoutePattern.RawText == routePattern); + await verifyFunc(endpoint, services); + } + + private static bool TryResolveServicesFromCompilation(Compilation compilation, string targetAssemblyName, string typeName, out IServiceProvider serviceProvider, out Type serviceType, out string outputAssemblyName) { + serviceProvider = null; + serviceType = null; + outputAssemblyName = $"TestProject-{Guid.NewGuid()}"; var assemblyName = compilation.AssemblyName; var symbolsName = Path.ChangeExtension(assemblyName, "pdb"); @@ -76,7 +120,7 @@ internal static void VerifyEndpoint(Compilation compilation, string routePattern var emitOptions = new EmitOptions( debugInformationFormat: DebugInformationFormat.PortablePdb, pdbFilePath: symbolsName, - outputNameOverride: $"TestProject-{Guid.NewGuid()}"); + outputNameOverride: outputAssemblyName); var embeddedTexts = new List(); @@ -135,28 +179,24 @@ void OnEntryPointExit(Exception exception) if (factory == null) { - return; + return false; } var services = ((IHost)factory([$"--{HostDefaults.ApplicationKey}={assemblyName}"])).Services; var applicationLifetime = services.GetRequiredService(); - using (var registration = applicationLifetime.ApplicationStarted.Register(() => waitForStartTcs.TrySetResult(0))) - { - waitForStartTcs.Task.Wait(); - var targetAssembly = AppDomain.CurrentDomain.GetAssemblies().FirstOrDefault(assembly => assembly.GetName().Name == "Microsoft.AspNetCore.Routing"); - var serviceType = targetAssembly.GetType("Microsoft.AspNetCore.Routing.EndpointDataSource", throwOnError: false); - - if (serviceType == null) - { - return; - } + using var registration = applicationLifetime.ApplicationStarted.Register(() => waitForStartTcs.TrySetResult(0)); + waitForStartTcs.Task.Wait(); + var targetAssembly = AppDomain.CurrentDomain.GetAssemblies().FirstOrDefault(assembly => assembly.GetName().Name == targetAssemblyName); + serviceType = targetAssembly.GetType(typeName, throwOnError: false); - var service = services.GetService(serviceType) ?? throw new InvalidOperationException("Could not resolve EndpointDataSource."); - var endpoints = (IReadOnlyList)serviceType.GetProperty("Endpoints", BindingFlags.Instance | BindingFlags.Public).GetValue(service); - var endpoint = endpoints.FirstOrDefault(endpoint => endpoint is RouteEndpoint routeEndpoint && routeEndpoint.RoutePattern.RawText == routePattern); - verifyFunc(endpoint); + if (serviceType == null) + { + return false; } + + serviceProvider = services; + return true; } private sealed class NoopHostLifetime : IHostLifetime @@ -510,10 +550,10 @@ private sealed class HostAbortedException : Exception } } - internal HttpContext CreateHttpContext(IServiceProvider serviceProvider = null) + internal HttpContext CreateHttpContext(IServiceProvider serviceProvider) { var httpContext = new DefaultHttpContext(); - httpContext.RequestServices = serviceProvider ?? CreateServiceProvider(); + httpContext.RequestServices = serviceProvider; var outStream = new MemoryStream(); httpContext.Response.Body = outStream; @@ -521,14 +561,24 @@ internal HttpContext CreateHttpContext(IServiceProvider serviceProvider = null) return httpContext; } - internal ServiceProvider CreateServiceProvider(Action configureServices = null) + internal HttpContext CreateHttpContextWithPayload(string requestData, IServiceProvider serviceProvider = null) { - var serviceCollection = new ServiceCollection(); - serviceCollection.AddSingleton(LoggerFactory); - if (configureServices is not null) - { - configureServices(serviceCollection); - } - return serviceCollection.BuildServiceProvider(); + var httpContext = CreateHttpContext(serviceProvider); + httpContext.Features.Set(new RequestBodyDetectionFeature(true)); + httpContext.Request.Headers["Content-Type"] = "application/json"; + + var stream = new MemoryStream(System.Text.Encoding.UTF8.GetBytes(requestData)); + httpContext.Request.Body = stream; + httpContext.Request.Headers["Content-Length"] = stream.Length.ToString(CultureInfo.InvariantCulture); + return httpContext; + } + + internal static async Task AssertBadRequest(HttpContext context) + { + Assert.Equal(StatusCodes.Status400BadRequest, context.Response.StatusCode); + context.Response.Body.Position = 0; + using var reader = new StreamReader(context.Response.Body); + var responseBody = await reader.ReadToEndAsync(); + return JsonSerializer.Deserialize(responseBody); } } diff --git a/src/Http/Http.Extensions/test/ValidationsGenerator/snapshots/ValidationsGeneratorTests.CanValidateComplexTypes#ValidatableInfoResolver.g.verified.cs b/src/Http/Http.Extensions/test/ValidationsGenerator/snapshots/ValidationsGeneratorTests.CanValidateComplexTypes#ValidatableInfoResolver.g.verified.cs new file mode 100644 index 000000000000..3030800620aa --- /dev/null +++ b/src/Http/Http.Extensions/test/ValidationsGenerator/snapshots/ValidationsGeneratorTests.CanValidateComplexTypes#ValidatableInfoResolver.g.verified.cs @@ -0,0 +1,227 @@ +//HintName: ValidatableInfoResolver.g.cs +#nullable enable annotations +//------------------------------------------------------------------------------ +// +// This code was generated by a tool. +// +// Changes to this file may cause incorrect behavior and will be lost if +// the code is regenerated. +// +//------------------------------------------------------------------------------ +#nullable enable + +namespace System.Runtime.CompilerServices +{ + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.AspNetCore.Http.ValidationsGenerator, Version=42.42.42.42, Culture=neutral, PublicKeyToken=adb9793829ddae60", "42.42.42.42")] + [AttributeUsage(AttributeTargets.Method, AllowMultiple = true)] + file sealed class InterceptsLocationAttribute : System.Attribute + { + public InterceptsLocationAttribute(int version, string data) + { + } + } +} + +namespace Microsoft.AspNetCore.Http.Validation.Generated +{ + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.AspNetCore.Http.ValidationsGenerator, Version=42.42.42.42, Culture=neutral, PublicKeyToken=adb9793829ddae60", "42.42.42.42")] + file sealed class GeneratedValidatablePropertyInfo : global::Microsoft.AspNetCore.Http.Validation.ValidatablePropertyInfo + { + public GeneratedValidatablePropertyInfo( + global::System.Type containingType, + global::System.Type propertyType, + string name, + string displayName) : base(containingType, propertyType, name, displayName) + { + ContainingType = containingType; + Name = name; + } + + internal global::System.Type ContainingType { get; } + internal string Name { get; } + + protected override global::System.ComponentModel.DataAnnotations.ValidationAttribute[] GetValidationAttributes() + => ValidationAttributeCache.GetValidationAttributes(ContainingType, Name); + } + + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.AspNetCore.Http.ValidationsGenerator, Version=42.42.42.42, Culture=neutral, PublicKeyToken=adb9793829ddae60", "42.42.42.42")] + file sealed class GeneratedValidatableTypeInfo : global::Microsoft.AspNetCore.Http.Validation.ValidatableTypeInfo + { + public GeneratedValidatableTypeInfo( + global::System.Type type, + ValidatablePropertyInfo[] members) : base(type, members) { } + } + + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.AspNetCore.Http.ValidationsGenerator, Version=42.42.42.42, Culture=neutral, PublicKeyToken=adb9793829ddae60", "42.42.42.42")] + file class GeneratedValidatableInfoResolver : global::Microsoft.AspNetCore.Http.Validation.IValidatableInfoResolver + { + public bool TryGetValidatableTypeInfo(global::System.Type type, [global::System.Diagnostics.CodeAnalysis.NotNullWhen(true)] out global::Microsoft.AspNetCore.Http.Validation.IValidatableInfo? validatableInfo) + { + validatableInfo = null; + if (type == typeof(global::SubType)) + { + validatableInfo = CreateSubType(); + return true; + } + if (type == typeof(global::SubTypeWithInheritance)) + { + validatableInfo = CreateSubTypeWithInheritance(); + return true; + } + if (type == typeof(global::ComplexType)) + { + validatableInfo = CreateComplexType(); + return true; + } + + return false; + } + + // No-ops, rely on runtime code for ParameterInfo-based resolution + public bool TryGetValidatableParameterInfo(global::System.Reflection.ParameterInfo parameterInfo, [global::System.Diagnostics.CodeAnalysis.NotNullWhen(true)] out global::Microsoft.AspNetCore.Http.Validation.IValidatableInfo? validatableInfo) + { + validatableInfo = null; + return false; + } + + private ValidatableTypeInfo CreateSubType() + { + return new GeneratedValidatableTypeInfo( + type: typeof(global::SubType), + members: [ + new GeneratedValidatablePropertyInfo( + containingType: typeof(global::SubType), + propertyType: typeof(string), + name: "RequiredProperty", + displayName: "RequiredProperty" + ), + new GeneratedValidatablePropertyInfo( + containingType: typeof(global::SubType), + propertyType: typeof(string), + name: "StringWithLength", + displayName: "StringWithLength" + ), + ] + ); + } + private ValidatableTypeInfo CreateSubTypeWithInheritance() + { + return new GeneratedValidatableTypeInfo( + type: typeof(global::SubTypeWithInheritance), + members: [ + new GeneratedValidatablePropertyInfo( + containingType: typeof(global::SubTypeWithInheritance), + propertyType: typeof(string), + name: "EmailString", + displayName: "EmailString" + ), + ] + ); + } + private ValidatableTypeInfo CreateComplexType() + { + return new GeneratedValidatableTypeInfo( + type: typeof(global::ComplexType), + members: [ + new GeneratedValidatablePropertyInfo( + containingType: typeof(global::ComplexType), + propertyType: typeof(int), + name: "IntegerWithRange", + displayName: "IntegerWithRange" + ), + new GeneratedValidatablePropertyInfo( + containingType: typeof(global::ComplexType), + propertyType: typeof(int), + name: "IntegerWithRangeAndDisplayName", + displayName: "Valid identifier" + ), + new GeneratedValidatablePropertyInfo( + containingType: typeof(global::ComplexType), + propertyType: typeof(global::SubType), + name: "PropertyWithMemberAttributes", + displayName: "PropertyWithMemberAttributes" + ), + new GeneratedValidatablePropertyInfo( + containingType: typeof(global::ComplexType), + propertyType: typeof(global::SubType), + name: "PropertyWithoutMemberAttributes", + displayName: "PropertyWithoutMemberAttributes" + ), + new GeneratedValidatablePropertyInfo( + containingType: typeof(global::ComplexType), + propertyType: typeof(global::SubTypeWithInheritance), + name: "PropertyWithInheritance", + displayName: "PropertyWithInheritance" + ), + new GeneratedValidatablePropertyInfo( + containingType: typeof(global::ComplexType), + propertyType: typeof(global::System.Collections.Generic.List), + name: "ListOfSubTypes", + displayName: "ListOfSubTypes" + ), + new GeneratedValidatablePropertyInfo( + containingType: typeof(global::ComplexType), + propertyType: typeof(int), + name: "IntegerWithDerivedValidationAttribute", + displayName: "IntegerWithDerivedValidationAttribute" + ), + new GeneratedValidatablePropertyInfo( + containingType: typeof(global::ComplexType), + propertyType: typeof(int), + name: "IntegerWithCustomValidation", + displayName: "IntegerWithCustomValidation" + ), + new GeneratedValidatablePropertyInfo( + containingType: typeof(global::ComplexType), + propertyType: typeof(int), + name: "PropertyWithMultipleAttributes", + displayName: "PropertyWithMultipleAttributes" + ), + ] + ); + } + + } + + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.AspNetCore.Http.ValidationsGenerator, Version=42.42.42.42, Culture=neutral, PublicKeyToken=adb9793829ddae60", "42.42.42.42")] + file static class GeneratedServiceCollectionExtensions + { + [InterceptsLocation] + public static global::Microsoft.Extensions.DependencyInjection.IServiceCollection AddValidation(this global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, global::System.Action? configureOptions = null) + { + // Use non-extension method to avoid infinite recursion. + return global::Microsoft.Extensions.DependencyInjection.ValidationServiceCollectionExtensions.AddValidation(services, options => + { + options.Resolvers.Insert(0, new GeneratedValidatableInfoResolver()); + if (configureOptions is not null) + { + configureOptions(options); + } + }); + } + } + + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.AspNetCore.Http.ValidationsGenerator, Version=42.42.42.42, Culture=neutral, PublicKeyToken=adb9793829ddae60", "42.42.42.42")] + file static class ValidationAttributeCache + { + private sealed record CacheKey(global::System.Type ContainingType, string PropertyName); + private static readonly global::System.Collections.Concurrent.ConcurrentDictionary _cache = new(); + + public static global::System.ComponentModel.DataAnnotations.ValidationAttribute[] GetValidationAttributes( + global::System.Type containingType, + string propertyName) + { + var key = new CacheKey(containingType, propertyName); + return _cache.GetOrAdd(key, static k => + { + var property = k.ContainingType.GetProperty(k.PropertyName); + if (property == null) + { + return []; + } + + return [.. global::System.Reflection.CustomAttributeExtensions.GetCustomAttributes(property, inherit: true)]; + }); + } + } +} \ No newline at end of file diff --git a/src/Http/Http.Extensions/test/ValidationsGenerator/snapshots/ValidationsGeneratorTests.CanValidateIValidatableObject#ValidatableInfoResolver.g.verified.cs b/src/Http/Http.Extensions/test/ValidationsGenerator/snapshots/ValidationsGeneratorTests.CanValidateIValidatableObject#ValidatableInfoResolver.g.verified.cs new file mode 100644 index 000000000000..3750fbc9a1f3 --- /dev/null +++ b/src/Http/Http.Extensions/test/ValidationsGenerator/snapshots/ValidationsGeneratorTests.CanValidateIValidatableObject#ValidatableInfoResolver.g.verified.cs @@ -0,0 +1,178 @@ +//HintName: ValidatableInfoResolver.g.cs +#nullable enable annotations +//------------------------------------------------------------------------------ +// +// This code was generated by a tool. +// +// Changes to this file may cause incorrect behavior and will be lost if +// the code is regenerated. +// +//------------------------------------------------------------------------------ +#nullable enable + +namespace System.Runtime.CompilerServices +{ + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.AspNetCore.Http.ValidationsGenerator, Version=42.42.42.42, Culture=neutral, PublicKeyToken=adb9793829ddae60", "42.42.42.42")] + [AttributeUsage(AttributeTargets.Method, AllowMultiple = true)] + file sealed class InterceptsLocationAttribute : System.Attribute + { + public InterceptsLocationAttribute(int version, string data) + { + } + } +} + +namespace Microsoft.AspNetCore.Http.Validation.Generated +{ + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.AspNetCore.Http.ValidationsGenerator, Version=42.42.42.42, Culture=neutral, PublicKeyToken=adb9793829ddae60", "42.42.42.42")] + file sealed class GeneratedValidatablePropertyInfo : global::Microsoft.AspNetCore.Http.Validation.ValidatablePropertyInfo + { + public GeneratedValidatablePropertyInfo( + global::System.Type containingType, + global::System.Type propertyType, + string name, + string displayName) : base(containingType, propertyType, name, displayName) + { + ContainingType = containingType; + Name = name; + } + + internal global::System.Type ContainingType { get; } + internal string Name { get; } + + protected override global::System.ComponentModel.DataAnnotations.ValidationAttribute[] GetValidationAttributes() + => ValidationAttributeCache.GetValidationAttributes(ContainingType, Name); + } + + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.AspNetCore.Http.ValidationsGenerator, Version=42.42.42.42, Culture=neutral, PublicKeyToken=adb9793829ddae60", "42.42.42.42")] + file sealed class GeneratedValidatableTypeInfo : global::Microsoft.AspNetCore.Http.Validation.ValidatableTypeInfo + { + public GeneratedValidatableTypeInfo( + global::System.Type type, + ValidatablePropertyInfo[] members) : base(type, members) { } + } + + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.AspNetCore.Http.ValidationsGenerator, Version=42.42.42.42, Culture=neutral, PublicKeyToken=adb9793829ddae60", "42.42.42.42")] + file class GeneratedValidatableInfoResolver : global::Microsoft.AspNetCore.Http.Validation.IValidatableInfoResolver + { + public bool TryGetValidatableTypeInfo(global::System.Type type, [global::System.Diagnostics.CodeAnalysis.NotNullWhen(true)] out global::Microsoft.AspNetCore.Http.Validation.IValidatableInfo? validatableInfo) + { + validatableInfo = null; + if (type == typeof(global::SubType)) + { + validatableInfo = CreateSubType(); + return true; + } + if (type == typeof(global::ValidatableSubType)) + { + validatableInfo = CreateValidatableSubType(); + return true; + } + if (type == typeof(global::ComplexValidatableType)) + { + validatableInfo = CreateComplexValidatableType(); + return true; + } + + return false; + } + + // No-ops, rely on runtime code for ParameterInfo-based resolution + public bool TryGetValidatableParameterInfo(global::System.Reflection.ParameterInfo parameterInfo, [global::System.Diagnostics.CodeAnalysis.NotNullWhen(true)] out global::Microsoft.AspNetCore.Http.Validation.IValidatableInfo? validatableInfo) + { + validatableInfo = null; + return false; + } + + private ValidatableTypeInfo CreateSubType() + { + return new GeneratedValidatableTypeInfo( + type: typeof(global::SubType), + members: [ + new GeneratedValidatablePropertyInfo( + containingType: typeof(global::SubType), + propertyType: typeof(string), + name: "RequiredProperty", + displayName: "RequiredProperty" + ), + new GeneratedValidatablePropertyInfo( + containingType: typeof(global::SubType), + propertyType: typeof(string), + name: "StringWithLength", + displayName: "StringWithLength" + ), + ] + ); + } + private ValidatableTypeInfo CreateValidatableSubType() + { + return new GeneratedValidatableTypeInfo( + type: typeof(global::ValidatableSubType), + members: [] + ); + } + private ValidatableTypeInfo CreateComplexValidatableType() + { + return new GeneratedValidatableTypeInfo( + type: typeof(global::ComplexValidatableType), + members: [ + new GeneratedValidatablePropertyInfo( + containingType: typeof(global::ComplexValidatableType), + propertyType: typeof(string), + name: "Value2", + displayName: "Value2" + ), + new GeneratedValidatablePropertyInfo( + containingType: typeof(global::ComplexValidatableType), + propertyType: typeof(global::ValidatableSubType), + name: "SubType", + displayName: "SubType" + ), + ] + ); + } + + } + + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.AspNetCore.Http.ValidationsGenerator, Version=42.42.42.42, Culture=neutral, PublicKeyToken=adb9793829ddae60", "42.42.42.42")] + file static class GeneratedServiceCollectionExtensions + { + [InterceptsLocation] + public static global::Microsoft.Extensions.DependencyInjection.IServiceCollection AddValidation(this global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, global::System.Action? configureOptions = null) + { + // Use non-extension method to avoid infinite recursion. + return global::Microsoft.Extensions.DependencyInjection.ValidationServiceCollectionExtensions.AddValidation(services, options => + { + options.Resolvers.Insert(0, new GeneratedValidatableInfoResolver()); + if (configureOptions is not null) + { + configureOptions(options); + } + }); + } + } + + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.AspNetCore.Http.ValidationsGenerator, Version=42.42.42.42, Culture=neutral, PublicKeyToken=adb9793829ddae60", "42.42.42.42")] + file static class ValidationAttributeCache + { + private sealed record CacheKey(global::System.Type ContainingType, string PropertyName); + private static readonly global::System.Collections.Concurrent.ConcurrentDictionary _cache = new(); + + public static global::System.ComponentModel.DataAnnotations.ValidationAttribute[] GetValidationAttributes( + global::System.Type containingType, + string propertyName) + { + var key = new CacheKey(containingType, propertyName); + return _cache.GetOrAdd(key, static k => + { + var property = k.ContainingType.GetProperty(k.PropertyName); + if (property == null) + { + return []; + } + + return [.. global::System.Reflection.CustomAttributeExtensions.GetCustomAttributes(property, inherit: true)]; + }); + } + } +} \ No newline at end of file diff --git a/src/Http/Http.Extensions/test/ValidationsGenerator/snapshots/ValidationsGeneratorTests.CanValidateParameters#ValidatableInfoResolver.g.verified.cs b/src/Http/Http.Extensions/test/ValidationsGenerator/snapshots/ValidationsGeneratorTests.CanValidateParameters#ValidatableInfoResolver.g.verified.cs new file mode 100644 index 000000000000..b1fafa43e639 --- /dev/null +++ b/src/Http/Http.Extensions/test/ValidationsGenerator/snapshots/ValidationsGeneratorTests.CanValidateParameters#ValidatableInfoResolver.g.verified.cs @@ -0,0 +1,116 @@ +//HintName: ValidatableInfoResolver.g.cs +#nullable enable annotations +//------------------------------------------------------------------------------ +// +// This code was generated by a tool. +// +// Changes to this file may cause incorrect behavior and will be lost if +// the code is regenerated. +// +//------------------------------------------------------------------------------ +#nullable enable + +namespace System.Runtime.CompilerServices +{ + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.AspNetCore.Http.ValidationsGenerator, Version=42.42.42.42, Culture=neutral, PublicKeyToken=adb9793829ddae60", "42.42.42.42")] + [AttributeUsage(AttributeTargets.Method, AllowMultiple = true)] + file sealed class InterceptsLocationAttribute : System.Attribute + { + public InterceptsLocationAttribute(int version, string data) + { + } + } +} + +namespace Microsoft.AspNetCore.Http.Validation.Generated +{ + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.AspNetCore.Http.ValidationsGenerator, Version=42.42.42.42, Culture=neutral, PublicKeyToken=adb9793829ddae60", "42.42.42.42")] + file sealed class GeneratedValidatablePropertyInfo : global::Microsoft.AspNetCore.Http.Validation.ValidatablePropertyInfo + { + public GeneratedValidatablePropertyInfo( + global::System.Type containingType, + global::System.Type propertyType, + string name, + string displayName) : base(containingType, propertyType, name, displayName) + { + ContainingType = containingType; + Name = name; + } + + internal global::System.Type ContainingType { get; } + internal string Name { get; } + + protected override global::System.ComponentModel.DataAnnotations.ValidationAttribute[] GetValidationAttributes() + => ValidationAttributeCache.GetValidationAttributes(ContainingType, Name); + } + + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.AspNetCore.Http.ValidationsGenerator, Version=42.42.42.42, Culture=neutral, PublicKeyToken=adb9793829ddae60", "42.42.42.42")] + file sealed class GeneratedValidatableTypeInfo : global::Microsoft.AspNetCore.Http.Validation.ValidatableTypeInfo + { + public GeneratedValidatableTypeInfo( + global::System.Type type, + ValidatablePropertyInfo[] members) : base(type, members) { } + } + + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.AspNetCore.Http.ValidationsGenerator, Version=42.42.42.42, Culture=neutral, PublicKeyToken=adb9793829ddae60", "42.42.42.42")] + file class GeneratedValidatableInfoResolver : global::Microsoft.AspNetCore.Http.Validation.IValidatableInfoResolver + { + public bool TryGetValidatableTypeInfo(global::System.Type type, [global::System.Diagnostics.CodeAnalysis.NotNullWhen(true)] out global::Microsoft.AspNetCore.Http.Validation.IValidatableInfo? validatableInfo) + { + validatableInfo = null; + + return false; + } + + // No-ops, rely on runtime code for ParameterInfo-based resolution + public bool TryGetValidatableParameterInfo(global::System.Reflection.ParameterInfo parameterInfo, [global::System.Diagnostics.CodeAnalysis.NotNullWhen(true)] out global::Microsoft.AspNetCore.Http.Validation.IValidatableInfo? validatableInfo) + { + validatableInfo = null; + return false; + } + + + } + + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.AspNetCore.Http.ValidationsGenerator, Version=42.42.42.42, Culture=neutral, PublicKeyToken=adb9793829ddae60", "42.42.42.42")] + file static class GeneratedServiceCollectionExtensions + { + [InterceptsLocation] + public static global::Microsoft.Extensions.DependencyInjection.IServiceCollection AddValidation(this global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, global::System.Action? configureOptions = null) + { + // Use non-extension method to avoid infinite recursion. + return global::Microsoft.Extensions.DependencyInjection.ValidationServiceCollectionExtensions.AddValidation(services, options => + { + options.Resolvers.Insert(0, new GeneratedValidatableInfoResolver()); + if (configureOptions is not null) + { + configureOptions(options); + } + }); + } + } + + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.AspNetCore.Http.ValidationsGenerator, Version=42.42.42.42, Culture=neutral, PublicKeyToken=adb9793829ddae60", "42.42.42.42")] + file static class ValidationAttributeCache + { + private sealed record CacheKey(global::System.Type ContainingType, string PropertyName); + private static readonly global::System.Collections.Concurrent.ConcurrentDictionary _cache = new(); + + public static global::System.ComponentModel.DataAnnotations.ValidationAttribute[] GetValidationAttributes( + global::System.Type containingType, + string propertyName) + { + var key = new CacheKey(containingType, propertyName); + return _cache.GetOrAdd(key, static k => + { + var property = k.ContainingType.GetProperty(k.PropertyName); + if (property == null) + { + return []; + } + + return [.. global::System.Reflection.CustomAttributeExtensions.GetCustomAttributes(property, inherit: true)]; + }); + } + } +} \ No newline at end of file diff --git a/src/Http/Http.Extensions/test/ValidationsGenerator/snapshots/ValidationsGeneratorTests.CanValidatePolymorphicTypes#ValidatableInfoResolver.g.verified.cs b/src/Http/Http.Extensions/test/ValidationsGenerator/snapshots/ValidationsGeneratorTests.CanValidatePolymorphicTypes#ValidatableInfoResolver.g.verified.cs new file mode 100644 index 000000000000..78f5b721df90 --- /dev/null +++ b/src/Http/Http.Extensions/test/ValidationsGenerator/snapshots/ValidationsGeneratorTests.CanValidatePolymorphicTypes#ValidatableInfoResolver.g.verified.cs @@ -0,0 +1,216 @@ +//HintName: ValidatableInfoResolver.g.cs +#nullable enable annotations +//------------------------------------------------------------------------------ +// +// This code was generated by a tool. +// +// Changes to this file may cause incorrect behavior and will be lost if +// the code is regenerated. +// +//------------------------------------------------------------------------------ +#nullable enable + +namespace System.Runtime.CompilerServices +{ + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.AspNetCore.Http.ValidationsGenerator, Version=42.42.42.42, Culture=neutral, PublicKeyToken=adb9793829ddae60", "42.42.42.42")] + [AttributeUsage(AttributeTargets.Method, AllowMultiple = true)] + file sealed class InterceptsLocationAttribute : System.Attribute + { + public InterceptsLocationAttribute(int version, string data) + { + } + } +} + +namespace Microsoft.AspNetCore.Http.Validation.Generated +{ + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.AspNetCore.Http.ValidationsGenerator, Version=42.42.42.42, Culture=neutral, PublicKeyToken=adb9793829ddae60", "42.42.42.42")] + file sealed class GeneratedValidatablePropertyInfo : global::Microsoft.AspNetCore.Http.Validation.ValidatablePropertyInfo + { + public GeneratedValidatablePropertyInfo( + global::System.Type containingType, + global::System.Type propertyType, + string name, + string displayName) : base(containingType, propertyType, name, displayName) + { + ContainingType = containingType; + Name = name; + } + + internal global::System.Type ContainingType { get; } + internal string Name { get; } + + protected override global::System.ComponentModel.DataAnnotations.ValidationAttribute[] GetValidationAttributes() + => ValidationAttributeCache.GetValidationAttributes(ContainingType, Name); + } + + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.AspNetCore.Http.ValidationsGenerator, Version=42.42.42.42, Culture=neutral, PublicKeyToken=adb9793829ddae60", "42.42.42.42")] + file sealed class GeneratedValidatableTypeInfo : global::Microsoft.AspNetCore.Http.Validation.ValidatableTypeInfo + { + public GeneratedValidatableTypeInfo( + global::System.Type type, + ValidatablePropertyInfo[] members) : base(type, members) { } + } + + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.AspNetCore.Http.ValidationsGenerator, Version=42.42.42.42, Culture=neutral, PublicKeyToken=adb9793829ddae60", "42.42.42.42")] + file class GeneratedValidatableInfoResolver : global::Microsoft.AspNetCore.Http.Validation.IValidatableInfoResolver + { + public bool TryGetValidatableTypeInfo(global::System.Type type, [global::System.Diagnostics.CodeAnalysis.NotNullWhen(true)] out global::Microsoft.AspNetCore.Http.Validation.IValidatableInfo? validatableInfo) + { + validatableInfo = null; + if (type == typeof(global::DerivedType)) + { + validatableInfo = CreateDerivedType(); + return true; + } + if (type == typeof(global::BaseType)) + { + validatableInfo = CreateBaseType(); + return true; + } + if (type == typeof(global::DerivedValidatableType)) + { + validatableInfo = CreateDerivedValidatableType(); + return true; + } + if (type == typeof(global::BaseValidatableType)) + { + validatableInfo = CreateBaseValidatableType(); + return true; + } + if (type == typeof(global::ContainerType)) + { + validatableInfo = CreateContainerType(); + return true; + } + + return false; + } + + // No-ops, rely on runtime code for ParameterInfo-based resolution + public bool TryGetValidatableParameterInfo(global::System.Reflection.ParameterInfo parameterInfo, [global::System.Diagnostics.CodeAnalysis.NotNullWhen(true)] out global::Microsoft.AspNetCore.Http.Validation.IValidatableInfo? validatableInfo) + { + validatableInfo = null; + return false; + } + + private ValidatableTypeInfo CreateDerivedType() + { + return new GeneratedValidatableTypeInfo( + type: typeof(global::DerivedType), + members: [ + new GeneratedValidatablePropertyInfo( + containingType: typeof(global::DerivedType), + propertyType: typeof(string), + name: "Value3", + displayName: "Value3" + ), + ] + ); + } + private ValidatableTypeInfo CreateBaseType() + { + return new GeneratedValidatableTypeInfo( + type: typeof(global::BaseType), + members: [ + new GeneratedValidatablePropertyInfo( + containingType: typeof(global::BaseType), + propertyType: typeof(int), + name: "Value1", + displayName: "Value 1" + ), + new GeneratedValidatablePropertyInfo( + containingType: typeof(global::BaseType), + propertyType: typeof(string), + name: "Value2", + displayName: "Value2" + ), + ] + ); + } + private ValidatableTypeInfo CreateDerivedValidatableType() + { + return new GeneratedValidatableTypeInfo( + type: typeof(global::DerivedValidatableType), + members: [ + new GeneratedValidatablePropertyInfo( + containingType: typeof(global::DerivedValidatableType), + propertyType: typeof(string), + name: "Value3", + displayName: "Value3" + ), + ] + ); + } + private ValidatableTypeInfo CreateBaseValidatableType() + { + return new GeneratedValidatableTypeInfo( + type: typeof(global::BaseValidatableType), + members: [] + ); + } + private ValidatableTypeInfo CreateContainerType() + { + return new GeneratedValidatableTypeInfo( + type: typeof(global::ContainerType), + members: [ + new GeneratedValidatablePropertyInfo( + containingType: typeof(global::ContainerType), + propertyType: typeof(global::BaseType), + name: "BaseType", + displayName: "BaseType" + ), + new GeneratedValidatablePropertyInfo( + containingType: typeof(global::ContainerType), + propertyType: typeof(global::BaseValidatableType), + name: "BaseValidatableType", + displayName: "BaseValidatableType" + ), + ] + ); + } + + } + + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.AspNetCore.Http.ValidationsGenerator, Version=42.42.42.42, Culture=neutral, PublicKeyToken=adb9793829ddae60", "42.42.42.42")] + file static class GeneratedServiceCollectionExtensions + { + [InterceptsLocation] + public static global::Microsoft.Extensions.DependencyInjection.IServiceCollection AddValidation(this global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, global::System.Action? configureOptions = null) + { + // Use non-extension method to avoid infinite recursion. + return global::Microsoft.Extensions.DependencyInjection.ValidationServiceCollectionExtensions.AddValidation(services, options => + { + options.Resolvers.Insert(0, new GeneratedValidatableInfoResolver()); + if (configureOptions is not null) + { + configureOptions(options); + } + }); + } + } + + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.AspNetCore.Http.ValidationsGenerator, Version=42.42.42.42, Culture=neutral, PublicKeyToken=adb9793829ddae60", "42.42.42.42")] + file static class ValidationAttributeCache + { + private sealed record CacheKey(global::System.Type ContainingType, string PropertyName); + private static readonly global::System.Collections.Concurrent.ConcurrentDictionary _cache = new(); + + public static global::System.ComponentModel.DataAnnotations.ValidationAttribute[] GetValidationAttributes( + global::System.Type containingType, + string propertyName) + { + var key = new CacheKey(containingType, propertyName); + return _cache.GetOrAdd(key, static k => + { + var property = k.ContainingType.GetProperty(k.PropertyName); + if (property == null) + { + return []; + } + + return [.. global::System.Reflection.CustomAttributeExtensions.GetCustomAttributes(property, inherit: true)]; + }); + } + } +} \ No newline at end of file diff --git a/src/Http/Http.Extensions/test/ValidationsGenerator/snapshots/ValidationsGeneratorTests.CanValidateRecursiveTypes#ValidatableInfoResolver.g.verified.cs b/src/Http/Http.Extensions/test/ValidationsGenerator/snapshots/ValidationsGeneratorTests.CanValidateRecursiveTypes#ValidatableInfoResolver.g.verified.cs new file mode 100644 index 000000000000..85f8662d5a3f --- /dev/null +++ b/src/Http/Http.Extensions/test/ValidationsGenerator/snapshots/ValidationsGeneratorTests.CanValidateRecursiveTypes#ValidatableInfoResolver.g.verified.cs @@ -0,0 +1,141 @@ +//HintName: ValidatableInfoResolver.g.cs +#nullable enable annotations +//------------------------------------------------------------------------------ +// +// This code was generated by a tool. +// +// Changes to this file may cause incorrect behavior and will be lost if +// the code is regenerated. +// +//------------------------------------------------------------------------------ +#nullable enable + +namespace System.Runtime.CompilerServices +{ + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.AspNetCore.Http.ValidationsGenerator, Version=42.42.42.42, Culture=neutral, PublicKeyToken=adb9793829ddae60", "42.42.42.42")] + [AttributeUsage(AttributeTargets.Method, AllowMultiple = true)] + file sealed class InterceptsLocationAttribute : System.Attribute + { + public InterceptsLocationAttribute(int version, string data) + { + } + } +} + +namespace Microsoft.AspNetCore.Http.Validation.Generated +{ + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.AspNetCore.Http.ValidationsGenerator, Version=42.42.42.42, Culture=neutral, PublicKeyToken=adb9793829ddae60", "42.42.42.42")] + file sealed class GeneratedValidatablePropertyInfo : global::Microsoft.AspNetCore.Http.Validation.ValidatablePropertyInfo + { + public GeneratedValidatablePropertyInfo( + global::System.Type containingType, + global::System.Type propertyType, + string name, + string displayName) : base(containingType, propertyType, name, displayName) + { + ContainingType = containingType; + Name = name; + } + + internal global::System.Type ContainingType { get; } + internal string Name { get; } + + protected override global::System.ComponentModel.DataAnnotations.ValidationAttribute[] GetValidationAttributes() + => ValidationAttributeCache.GetValidationAttributes(ContainingType, Name); + } + + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.AspNetCore.Http.ValidationsGenerator, Version=42.42.42.42, Culture=neutral, PublicKeyToken=adb9793829ddae60", "42.42.42.42")] + file sealed class GeneratedValidatableTypeInfo : global::Microsoft.AspNetCore.Http.Validation.ValidatableTypeInfo + { + public GeneratedValidatableTypeInfo( + global::System.Type type, + ValidatablePropertyInfo[] members) : base(type, members) { } + } + + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.AspNetCore.Http.ValidationsGenerator, Version=42.42.42.42, Culture=neutral, PublicKeyToken=adb9793829ddae60", "42.42.42.42")] + file class GeneratedValidatableInfoResolver : global::Microsoft.AspNetCore.Http.Validation.IValidatableInfoResolver + { + public bool TryGetValidatableTypeInfo(global::System.Type type, [global::System.Diagnostics.CodeAnalysis.NotNullWhen(true)] out global::Microsoft.AspNetCore.Http.Validation.IValidatableInfo? validatableInfo) + { + validatableInfo = null; + if (type == typeof(global::RecursiveType)) + { + validatableInfo = CreateRecursiveType(); + return true; + } + + return false; + } + + // No-ops, rely on runtime code for ParameterInfo-based resolution + public bool TryGetValidatableParameterInfo(global::System.Reflection.ParameterInfo parameterInfo, [global::System.Diagnostics.CodeAnalysis.NotNullWhen(true)] out global::Microsoft.AspNetCore.Http.Validation.IValidatableInfo? validatableInfo) + { + validatableInfo = null; + return false; + } + + private ValidatableTypeInfo CreateRecursiveType() + { + return new GeneratedValidatableTypeInfo( + type: typeof(global::RecursiveType), + members: [ + new GeneratedValidatablePropertyInfo( + containingType: typeof(global::RecursiveType), + propertyType: typeof(int), + name: "Value", + displayName: "Value" + ), + new GeneratedValidatablePropertyInfo( + containingType: typeof(global::RecursiveType), + propertyType: typeof(global::RecursiveType), + name: "Next", + displayName: "Next" + ), + ] + ); + } + + } + + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.AspNetCore.Http.ValidationsGenerator, Version=42.42.42.42, Culture=neutral, PublicKeyToken=adb9793829ddae60", "42.42.42.42")] + file static class GeneratedServiceCollectionExtensions + { + [InterceptsLocation] + public static global::Microsoft.Extensions.DependencyInjection.IServiceCollection AddValidation(this global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, global::System.Action? configureOptions = null) + { + // Use non-extension method to avoid infinite recursion. + return global::Microsoft.Extensions.DependencyInjection.ValidationServiceCollectionExtensions.AddValidation(services, options => + { + options.Resolvers.Insert(0, new GeneratedValidatableInfoResolver()); + if (configureOptions is not null) + { + configureOptions(options); + } + }); + } + } + + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.AspNetCore.Http.ValidationsGenerator, Version=42.42.42.42, Culture=neutral, PublicKeyToken=adb9793829ddae60", "42.42.42.42")] + file static class ValidationAttributeCache + { + private sealed record CacheKey(global::System.Type ContainingType, string PropertyName); + private static readonly global::System.Collections.Concurrent.ConcurrentDictionary _cache = new(); + + public static global::System.ComponentModel.DataAnnotations.ValidationAttribute[] GetValidationAttributes( + global::System.Type containingType, + string propertyName) + { + var key = new CacheKey(containingType, propertyName); + return _cache.GetOrAdd(key, static k => + { + var property = k.ContainingType.GetProperty(k.PropertyName); + if (property == null) + { + return []; + } + + return [.. global::System.Reflection.CustomAttributeExtensions.GetCustomAttributes(property, inherit: true)]; + }); + } + } +} \ No newline at end of file diff --git a/src/Http/Http.Extensions/test/ValidationsGenerator/snapshots/ValidationsGeneratorTests.CanValidateTypesWithAttribute#ValidatableInfoResolver.g.verified.cs b/src/Http/Http.Extensions/test/ValidationsGenerator/snapshots/ValidationsGeneratorTests.CanValidateTypesWithAttribute#ValidatableInfoResolver.g.verified.cs new file mode 100644 index 000000000000..3262b7948771 --- /dev/null +++ b/src/Http/Http.Extensions/test/ValidationsGenerator/snapshots/ValidationsGeneratorTests.CanValidateTypesWithAttribute#ValidatableInfoResolver.g.verified.cs @@ -0,0 +1,221 @@ +//HintName: ValidatableInfoResolver.g.cs +#nullable enable annotations +//------------------------------------------------------------------------------ +// +// This code was generated by a tool. +// +// Changes to this file may cause incorrect behavior and will be lost if +// the code is regenerated. +// +//------------------------------------------------------------------------------ +#nullable enable + +namespace System.Runtime.CompilerServices +{ + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.AspNetCore.Http.ValidationsGenerator, Version=42.42.42.42, Culture=neutral, PublicKeyToken=adb9793829ddae60", "42.42.42.42")] + [AttributeUsage(AttributeTargets.Method, AllowMultiple = true)] + file sealed class InterceptsLocationAttribute : System.Attribute + { + public InterceptsLocationAttribute(int version, string data) + { + } + } +} + +namespace Microsoft.AspNetCore.Http.Validation.Generated +{ + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.AspNetCore.Http.ValidationsGenerator, Version=42.42.42.42, Culture=neutral, PublicKeyToken=adb9793829ddae60", "42.42.42.42")] + file sealed class GeneratedValidatablePropertyInfo : global::Microsoft.AspNetCore.Http.Validation.ValidatablePropertyInfo + { + public GeneratedValidatablePropertyInfo( + global::System.Type containingType, + global::System.Type propertyType, + string name, + string displayName) : base(containingType, propertyType, name, displayName) + { + ContainingType = containingType; + Name = name; + } + + internal global::System.Type ContainingType { get; } + internal string Name { get; } + + protected override global::System.ComponentModel.DataAnnotations.ValidationAttribute[] GetValidationAttributes() + => ValidationAttributeCache.GetValidationAttributes(ContainingType, Name); + } + + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.AspNetCore.Http.ValidationsGenerator, Version=42.42.42.42, Culture=neutral, PublicKeyToken=adb9793829ddae60", "42.42.42.42")] + file sealed class GeneratedValidatableTypeInfo : global::Microsoft.AspNetCore.Http.Validation.ValidatableTypeInfo + { + public GeneratedValidatableTypeInfo( + global::System.Type type, + ValidatablePropertyInfo[] members) : base(type, members) { } + } + + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.AspNetCore.Http.ValidationsGenerator, Version=42.42.42.42, Culture=neutral, PublicKeyToken=adb9793829ddae60", "42.42.42.42")] + file class GeneratedValidatableInfoResolver : global::Microsoft.AspNetCore.Http.Validation.IValidatableInfoResolver + { + public bool TryGetValidatableTypeInfo(global::System.Type type, [global::System.Diagnostics.CodeAnalysis.NotNullWhen(true)] out global::Microsoft.AspNetCore.Http.Validation.IValidatableInfo? validatableInfo) + { + validatableInfo = null; + if (type == typeof(global::SubType)) + { + validatableInfo = CreateSubType(); + return true; + } + if (type == typeof(global::SubTypeWithInheritance)) + { + validatableInfo = CreateSubTypeWithInheritance(); + return true; + } + if (type == typeof(global::ComplexType)) + { + validatableInfo = CreateComplexType(); + return true; + } + + return false; + } + + // No-ops, rely on runtime code for ParameterInfo-based resolution + public bool TryGetValidatableParameterInfo(global::System.Reflection.ParameterInfo parameterInfo, [global::System.Diagnostics.CodeAnalysis.NotNullWhen(true)] out global::Microsoft.AspNetCore.Http.Validation.IValidatableInfo? validatableInfo) + { + validatableInfo = null; + return false; + } + + private ValidatableTypeInfo CreateSubType() + { + return new GeneratedValidatableTypeInfo( + type: typeof(global::SubType), + members: [ + new GeneratedValidatablePropertyInfo( + containingType: typeof(global::SubType), + propertyType: typeof(string), + name: "RequiredProperty", + displayName: "RequiredProperty" + ), + new GeneratedValidatablePropertyInfo( + containingType: typeof(global::SubType), + propertyType: typeof(string), + name: "StringWithLength", + displayName: "StringWithLength" + ), + ] + ); + } + private ValidatableTypeInfo CreateSubTypeWithInheritance() + { + return new GeneratedValidatableTypeInfo( + type: typeof(global::SubTypeWithInheritance), + members: [ + new GeneratedValidatablePropertyInfo( + containingType: typeof(global::SubTypeWithInheritance), + propertyType: typeof(string), + name: "EmailString", + displayName: "EmailString" + ), + ] + ); + } + private ValidatableTypeInfo CreateComplexType() + { + return new GeneratedValidatableTypeInfo( + type: typeof(global::ComplexType), + members: [ + new GeneratedValidatablePropertyInfo( + containingType: typeof(global::ComplexType), + propertyType: typeof(int), + name: "IntegerWithRange", + displayName: "IntegerWithRange" + ), + new GeneratedValidatablePropertyInfo( + containingType: typeof(global::ComplexType), + propertyType: typeof(int), + name: "IntegerWithRangeAndDisplayName", + displayName: "Valid identifier" + ), + new GeneratedValidatablePropertyInfo( + containingType: typeof(global::ComplexType), + propertyType: typeof(global::SubType), + name: "PropertyWithMemberAttributes", + displayName: "PropertyWithMemberAttributes" + ), + new GeneratedValidatablePropertyInfo( + containingType: typeof(global::ComplexType), + propertyType: typeof(global::SubType), + name: "PropertyWithoutMemberAttributes", + displayName: "PropertyWithoutMemberAttributes" + ), + new GeneratedValidatablePropertyInfo( + containingType: typeof(global::ComplexType), + propertyType: typeof(global::SubTypeWithInheritance), + name: "PropertyWithInheritance", + displayName: "PropertyWithInheritance" + ), + new GeneratedValidatablePropertyInfo( + containingType: typeof(global::ComplexType), + propertyType: typeof(global::System.Collections.Generic.List), + name: "ListOfSubTypes", + displayName: "ListOfSubTypes" + ), + new GeneratedValidatablePropertyInfo( + containingType: typeof(global::ComplexType), + propertyType: typeof(int), + name: "IntegerWithCustomValidationAttribute", + displayName: "IntegerWithCustomValidationAttribute" + ), + new GeneratedValidatablePropertyInfo( + containingType: typeof(global::ComplexType), + propertyType: typeof(int), + name: "PropertyWithMultipleAttributes", + displayName: "PropertyWithMultipleAttributes" + ), + ] + ); + } + + } + + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.AspNetCore.Http.ValidationsGenerator, Version=42.42.42.42, Culture=neutral, PublicKeyToken=adb9793829ddae60", "42.42.42.42")] + file static class GeneratedServiceCollectionExtensions + { + [InterceptsLocation] + public static global::Microsoft.Extensions.DependencyInjection.IServiceCollection AddValidation(this global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, global::System.Action? configureOptions = null) + { + // Use non-extension method to avoid infinite recursion. + return global::Microsoft.Extensions.DependencyInjection.ValidationServiceCollectionExtensions.AddValidation(services, options => + { + options.Resolvers.Insert(0, new GeneratedValidatableInfoResolver()); + if (configureOptions is not null) + { + configureOptions(options); + } + }); + } + } + + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.AspNetCore.Http.ValidationsGenerator, Version=42.42.42.42, Culture=neutral, PublicKeyToken=adb9793829ddae60", "42.42.42.42")] + file static class ValidationAttributeCache + { + private sealed record CacheKey(global::System.Type ContainingType, string PropertyName); + private static readonly global::System.Collections.Concurrent.ConcurrentDictionary _cache = new(); + + public static global::System.ComponentModel.DataAnnotations.ValidationAttribute[] GetValidationAttributes( + global::System.Type containingType, + string propertyName) + { + var key = new CacheKey(containingType, propertyName); + return _cache.GetOrAdd(key, static k => + { + var property = k.ContainingType.GetProperty(k.PropertyName); + if (property == null) + { + return []; + } + + return [.. global::System.Reflection.CustomAttributeExtensions.GetCustomAttributes(property, inherit: true)]; + }); + } + } +} \ No newline at end of file diff --git a/src/Http/Http.Extensions/test/ValidationsGenerator/snapshots/ValidationsGeneratorTests.DoesNotEmitForExemptTypes#ValidatableInfoResolver.g.verified.cs b/src/Http/Http.Extensions/test/ValidationsGenerator/snapshots/ValidationsGeneratorTests.DoesNotEmitForExemptTypes#ValidatableInfoResolver.g.verified.cs new file mode 100644 index 000000000000..61e47e74a630 --- /dev/null +++ b/src/Http/Http.Extensions/test/ValidationsGenerator/snapshots/ValidationsGeneratorTests.DoesNotEmitForExemptTypes#ValidatableInfoResolver.g.verified.cs @@ -0,0 +1,135 @@ +//HintName: ValidatableInfoResolver.g.cs +#nullable enable annotations +//------------------------------------------------------------------------------ +// +// This code was generated by a tool. +// +// Changes to this file may cause incorrect behavior and will be lost if +// the code is regenerated. +// +//------------------------------------------------------------------------------ +#nullable enable + +namespace System.Runtime.CompilerServices +{ + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.AspNetCore.Http.ValidationsGenerator, Version=42.42.42.42, Culture=neutral, PublicKeyToken=adb9793829ddae60", "42.42.42.42")] + [AttributeUsage(AttributeTargets.Method, AllowMultiple = true)] + file sealed class InterceptsLocationAttribute : System.Attribute + { + public InterceptsLocationAttribute(int version, string data) + { + } + } +} + +namespace Microsoft.AspNetCore.Http.Validation.Generated +{ + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.AspNetCore.Http.ValidationsGenerator, Version=42.42.42.42, Culture=neutral, PublicKeyToken=adb9793829ddae60", "42.42.42.42")] + file sealed class GeneratedValidatablePropertyInfo : global::Microsoft.AspNetCore.Http.Validation.ValidatablePropertyInfo + { + public GeneratedValidatablePropertyInfo( + global::System.Type containingType, + global::System.Type propertyType, + string name, + string displayName) : base(containingType, propertyType, name, displayName) + { + ContainingType = containingType; + Name = name; + } + + internal global::System.Type ContainingType { get; } + internal string Name { get; } + + protected override global::System.ComponentModel.DataAnnotations.ValidationAttribute[] GetValidationAttributes() + => ValidationAttributeCache.GetValidationAttributes(ContainingType, Name); + } + + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.AspNetCore.Http.ValidationsGenerator, Version=42.42.42.42, Culture=neutral, PublicKeyToken=adb9793829ddae60", "42.42.42.42")] + file sealed class GeneratedValidatableTypeInfo : global::Microsoft.AspNetCore.Http.Validation.ValidatableTypeInfo + { + public GeneratedValidatableTypeInfo( + global::System.Type type, + ValidatablePropertyInfo[] members) : base(type, members) { } + } + + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.AspNetCore.Http.ValidationsGenerator, Version=42.42.42.42, Culture=neutral, PublicKeyToken=adb9793829ddae60", "42.42.42.42")] + file class GeneratedValidatableInfoResolver : global::Microsoft.AspNetCore.Http.Validation.IValidatableInfoResolver + { + public bool TryGetValidatableTypeInfo(global::System.Type type, [global::System.Diagnostics.CodeAnalysis.NotNullWhen(true)] out global::Microsoft.AspNetCore.Http.Validation.IValidatableInfo? validatableInfo) + { + validatableInfo = null; + if (type == typeof(global::ComplexType)) + { + validatableInfo = CreateComplexType(); + return true; + } + + return false; + } + + // No-ops, rely on runtime code for ParameterInfo-based resolution + public bool TryGetValidatableParameterInfo(global::System.Reflection.ParameterInfo parameterInfo, [global::System.Diagnostics.CodeAnalysis.NotNullWhen(true)] out global::Microsoft.AspNetCore.Http.Validation.IValidatableInfo? validatableInfo) + { + validatableInfo = null; + return false; + } + + private ValidatableTypeInfo CreateComplexType() + { + return new GeneratedValidatableTypeInfo( + type: typeof(global::ComplexType), + members: [ + new GeneratedValidatablePropertyInfo( + containingType: typeof(global::ComplexType), + propertyType: typeof(int), + name: "IntegerWithRange", + displayName: "IntegerWithRange" + ), + ] + ); + } + + } + + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.AspNetCore.Http.ValidationsGenerator, Version=42.42.42.42, Culture=neutral, PublicKeyToken=adb9793829ddae60", "42.42.42.42")] + file static class GeneratedServiceCollectionExtensions + { + [InterceptsLocation] + public static global::Microsoft.Extensions.DependencyInjection.IServiceCollection AddValidation(this global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, global::System.Action? configureOptions = null) + { + // Use non-extension method to avoid infinite recursion. + return global::Microsoft.Extensions.DependencyInjection.ValidationServiceCollectionExtensions.AddValidation(services, options => + { + options.Resolvers.Insert(0, new GeneratedValidatableInfoResolver()); + if (configureOptions is not null) + { + configureOptions(options); + } + }); + } + } + + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.AspNetCore.Http.ValidationsGenerator, Version=42.42.42.42, Culture=neutral, PublicKeyToken=adb9793829ddae60", "42.42.42.42")] + file static class ValidationAttributeCache + { + private sealed record CacheKey(global::System.Type ContainingType, string PropertyName); + private static readonly global::System.Collections.Concurrent.ConcurrentDictionary _cache = new(); + + public static global::System.ComponentModel.DataAnnotations.ValidationAttribute[] GetValidationAttributes( + global::System.Type containingType, + string propertyName) + { + var key = new CacheKey(containingType, propertyName); + return _cache.GetOrAdd(key, static k => + { + var property = k.ContainingType.GetProperty(k.PropertyName); + if (property == null) + { + return []; + } + + return [.. global::System.Reflection.CustomAttributeExtensions.GetCustomAttributes(property, inherit: true)]; + }); + } + } +} \ No newline at end of file diff --git a/src/Http/Http/perf/Microbenchmarks/ValidatableTypesBenchmark.cs b/src/Http/Http/perf/Microbenchmarks/ValidatableTypesBenchmark.cs new file mode 100644 index 000000000000..75abf445ec30 --- /dev/null +++ b/src/Http/Http/perf/Microbenchmarks/ValidatableTypesBenchmark.cs @@ -0,0 +1,365 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.ComponentModel.DataAnnotations; +using System.Diagnostics.CodeAnalysis; +using System.Reflection; +using BenchmarkDotNet.Attributes; +using Microsoft.AspNetCore.Http.Validation; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Options; + +namespace Microsoft.AspNetCore.Http.Microbenchmarks; + +public class ValidatableTypeInfoBenchmark +{ + private IValidatableInfo _simpleTypeInfo = null!; + private IValidatableInfo _complexTypeInfo = null!; + private IValidatableInfo _hierarchicalTypeInfo = null!; + private IValidatableInfo _ivalidatableObjectTypeInfo = null!; + + private ValidateContext _context = null!; + private SimpleModel _simpleModel = null!; + private ComplexModel _complexModel = null!; + private HierarchicalModel _hierarchicalModel = null!; + private ValidatableObjectModel _validatableObjectModel = null!; + + [GlobalSetup] + public void Setup() + { + var services = new ServiceCollection(); + var mockResolver = new MockValidatableTypeInfoResolver(); + + services.AddValidation(options => + { + // Register our mock resolver + options.Resolvers.Insert(0, mockResolver); + }); + + var serviceProvider = services.BuildServiceProvider(); + var validationOptions = serviceProvider.GetRequiredService>().Value; + + _context = new ValidateContext + { + ValidationOptions = validationOptions, + ValidationContext = new ValidationContext(new object(), serviceProvider, null), + ValidationErrors = new Dictionary(StringComparer.Ordinal) + }; + + // Create the model instances + _simpleModel = new SimpleModel + { + Id = 1, + Name = "Test Name", + Email = "test@example.com" + }; + + _complexModel = new ComplexModel + { + Id = 1, + Name = "Complex Model", + Properties = new Dictionary + { + ["Prop1"] = "Value1", + ["Prop2"] = "Value2" + }, + Items = ["Item1", "Item2", "Item3"], + CreatedOn = DateTime.UtcNow + }; + + _hierarchicalModel = new HierarchicalModel + { + Id = 1, + Name = "Parent Model", + Child = new ChildModel + { + Id = 2, + Name = "Child Model", + ParentId = 1 + }, + Siblings = + [ + new SimpleModel { Id = 3, Name = "Sibling 1", Email = "sibling1@example.com" }, + new SimpleModel { Id = 4, Name = "Sibling 2", Email = "sibling2@example.com" } + ] + }; + + _validatableObjectModel = new ValidatableObjectModel + { + Id = 1, + Name = "Validatable Model", + CustomField = "Valid Value" + }; + + // Get the type info instances from validation options using the mock resolver + validationOptions.TryGetValidatableTypeInfo(typeof(SimpleModel), out _simpleTypeInfo); + validationOptions.TryGetValidatableTypeInfo(typeof(ComplexModel), out _complexTypeInfo); + validationOptions.TryGetValidatableTypeInfo(typeof(HierarchicalModel), out _hierarchicalTypeInfo); + validationOptions.TryGetValidatableTypeInfo(typeof(ValidatableObjectModel), out _ivalidatableObjectTypeInfo); + + // Ensure we have all type infos (this should not be needed with our mock resolver) + if (_simpleTypeInfo == null || _complexTypeInfo == null || + _hierarchicalTypeInfo == null || _ivalidatableObjectTypeInfo == null) + { + throw new InvalidOperationException("Failed to register one or more type infos with mock resolver"); + } + } + + [Benchmark(Description = "Validate Simple Model")] + [BenchmarkCategory("Simple")] + public async Task ValidateSimpleModel() + { + _context.ValidationErrors.Clear(); + await _simpleTypeInfo.ValidateAsync(_simpleModel, _context, default); + } + + [Benchmark(Description = "Validate Complex Model")] + [BenchmarkCategory("Complex")] + public async Task ValidateComplexModel() + { + _context.ValidationErrors.Clear(); + await _complexTypeInfo.ValidateAsync(_complexModel, _context, default); + } + + [Benchmark(Description = "Validate Hierarchical Model")] + [BenchmarkCategory("Hierarchical")] + public async Task ValidateHierarchicalModel() + { + _context.ValidationErrors.Clear(); + await _hierarchicalTypeInfo.ValidateAsync(_hierarchicalModel, _context, default); + } + + [Benchmark(Description = "Validate IValidatableObject Model")] + [BenchmarkCategory("IValidatableObject")] + public async Task ValidateIValidatableObjectModel() + { + _context.ValidationErrors.Clear(); + await _ivalidatableObjectTypeInfo.ValidateAsync(_validatableObjectModel, _context, default); + } + + [Benchmark(Description = "Validate invalid Simple Model")] + [BenchmarkCategory("Invalid")] + public async Task ValidateInvalidSimpleModel() + { + _context.ValidationErrors.Clear(); + _simpleModel.Email = "invalid-email"; + await _simpleTypeInfo.ValidateAsync(_simpleModel, _context, default); + } + + [Benchmark(Description = "Validate invalid IValidatableObject Model")] + [BenchmarkCategory("Invalid")] + public async Task ValidateInvalidIValidatableObjectModel() + { + _context.ValidationErrors.Clear(); + _validatableObjectModel.CustomField = "Invalid"; + await _ivalidatableObjectTypeInfo.ValidateAsync(_validatableObjectModel, _context, default); + } + + #region Helper methods to create type info instances manually if needed + + private ValidatablePropertyInfo CreatePropertyInfo(string name, Type type, params ValidationAttribute[] attributes) + { + return new MockValidatablePropertyInfo( + typeof(SimpleModel), + type, + name, + name, + attributes); + } + + #endregion + + #region Test Models + + public class SimpleModel + { + public int Id { get; set; } + + [Required] + public string Name { get; set; } + + [EmailAddress] + public string Email { get; set; } + } + + public class ComplexModel + { + public int Id { get; set; } + + [Required] + public string Name { get; set; } + + public Dictionary Properties { get; set; } + + public List Items { get; set; } + + public DateTime CreatedOn { get; set; } + } + + public class ChildModel + { + public int Id { get; set; } + + [Required] + public string Name { get; set; } + + public int ParentId { get; set; } + } + + public class HierarchicalModel + { + public int Id { get; set; } + + [Required] + public string Name { get; set; } + + public ChildModel Child { get; set; } + + public List Siblings { get; set; } + } + + public class ValidatableObjectModel : IValidatableObject + { + public int Id { get; set; } + + [Required] + public string Name { get; set; } + + public string CustomField { get; set; } + + public IEnumerable Validate(ValidationContext validationContext) + { + if (CustomField == "Invalid") + { + yield return new ValidationResult("CustomField has an invalid value", new[] { nameof(CustomField) }); + } + } + } + + #endregion + + #region Mock Implementations for Testing + + private class MockValidatableTypeInfo(Type type, ValidatablePropertyInfo[] members) : ValidatableTypeInfo(type, members) + { + } + + private class MockValidatablePropertyInfo( + Type containingType, + Type propertyType, + string name, + string displayName, + ValidationAttribute[] validationAttributes) : ValidatablePropertyInfo(containingType, propertyType, name, displayName) + { + private readonly ValidationAttribute[] _validationAttributes = validationAttributes; + + protected override ValidationAttribute[] GetValidationAttributes() => _validationAttributes; + } + + #endregion + + #region Mock Resolver Implementation + + private class MockValidatableTypeInfoResolver : IValidatableInfoResolver + { + private readonly Dictionary _typeInfoCache = []; + + public MockValidatableTypeInfoResolver() + { + // Initialize the cache with our test models + _typeInfoCache[typeof(SimpleModel)] = CreateSimpleModelTypeInfo(); + _typeInfoCache[typeof(ComplexModel)] = CreateComplexModelTypeInfo(); + _typeInfoCache[typeof(HierarchicalModel)] = CreateHierarchicalModelTypeInfo(); + _typeInfoCache[typeof(ValidatableObjectModel)] = CreateValidatableObjectModelTypeInfo(); + + // Add child models that might be validated separately + _typeInfoCache[typeof(ChildModel)] = CreateChildModelTypeInfo(); + } + + private ValidatableTypeInfo CreateSimpleModelTypeInfo() + { + return new MockValidatableTypeInfo( + typeof(SimpleModel), + [ + CreatePropertyInfo(typeof(SimpleModel), "Id", typeof(int)), + CreatePropertyInfo(typeof(SimpleModel), "Name", typeof(string)), + CreatePropertyInfo(typeof(SimpleModel), "Email", typeof(string), new EmailAddressAttribute()) + ]); + } + + private ValidatableTypeInfo CreateComplexModelTypeInfo() + { + return new MockValidatableTypeInfo( + typeof(ComplexModel), + [ + CreatePropertyInfo(typeof(ComplexModel), "Id", typeof(int)), + CreatePropertyInfo(typeof(ComplexModel), "Name", typeof(string)), + CreatePropertyInfo(typeof(ComplexModel), "Properties", typeof(Dictionary)), + CreatePropertyInfo(typeof(ComplexModel), "Items", typeof(List)), + CreatePropertyInfo(typeof(ComplexModel), "CreatedOn", typeof(DateTime)) + ]); + } + + private ValidatableTypeInfo CreateChildModelTypeInfo() + { + return new MockValidatableTypeInfo( + typeof(ChildModel), + [ + CreatePropertyInfo(typeof(ChildModel), "Id", typeof(int)), + CreatePropertyInfo(typeof(ChildModel), "Name", typeof(string)), + CreatePropertyInfo(typeof(ChildModel), "ParentId", typeof(int)) + ]); + } + + private ValidatableTypeInfo CreateHierarchicalModelTypeInfo() + { + return new MockValidatableTypeInfo( + typeof(HierarchicalModel), + [ + CreatePropertyInfo(typeof(HierarchicalModel), "Id", typeof(int)), + CreatePropertyInfo(typeof(HierarchicalModel), "Name", typeof(string)), + CreatePropertyInfo(typeof(HierarchicalModel), "Child", typeof(ChildModel)), + CreatePropertyInfo(typeof(HierarchicalModel), "Siblings", typeof(List)) + ]); + } + + private ValidatableTypeInfo CreateValidatableObjectModelTypeInfo() + { + return new MockValidatableTypeInfo( + typeof(ValidatableObjectModel), + [ + CreatePropertyInfo(typeof(ValidatableObjectModel), "Id", typeof(int)), + CreatePropertyInfo(typeof(ValidatableObjectModel), "Name", typeof(string)), + CreatePropertyInfo(typeof(ValidatableObjectModel), "CustomField", typeof(string)) + ]); + } + + private ValidatablePropertyInfo CreatePropertyInfo(Type containingType, string name, Type type, params ValidationAttribute[] attributes) + { + return new MockValidatablePropertyInfo( + containingType, + type, + name, + name, // Use name as display name + attributes); + } + + public bool TryGetValidatableTypeInfo(Type type, out IValidatableInfo validatableInfo) + { + if (_typeInfoCache.TryGetValue(type, out var typeInfo)) + { + validatableInfo = typeInfo; + return true; + } + validatableInfo = null; + return false; + } + + public bool TryGetValidatableParameterInfo(ParameterInfo parameterInfo, out IValidatableInfo validatableInfo) + { + validatableInfo = null; + return false; + } + } + #endregion +} diff --git a/src/Http/Routing/src/Builder/ValidationRouteHandlerBuilderExtensions.cs b/src/Http/Routing/src/Builder/ValidationRouteHandlerBuilderExtensions.cs new file mode 100644 index 000000000000..6ab1792eec99 --- /dev/null +++ b/src/Http/Routing/src/Builder/ValidationRouteHandlerBuilderExtensions.cs @@ -0,0 +1,32 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Microsoft.AspNetCore.Http.Metadata; + +namespace Microsoft.AspNetCore.Builder; + +/// +/// Extension methods for to interact with +/// parameter validation features. +/// +public static class ValidationEndpointConventionBuilderExtensions +{ + /// + /// Disables validation for the specified endpoint. + /// + /// The type of the builder. + /// The endpoint convention builder. + /// + /// The for chaining. + /// + public static TBuilder DisableValidation(this TBuilder builder) + where TBuilder : IEndpointConventionBuilder + { + builder.WithMetadata(new DisableValidationMetadata()); + return builder; + } + + private sealed class DisableValidationMetadata : IDisableValidationMetadata + { + } +} diff --git a/src/Http/Routing/src/PublicAPI.Unshipped.txt b/src/Http/Routing/src/PublicAPI.Unshipped.txt index 7dc5c58110bf..0612dc9ff2b0 100644 --- a/src/Http/Routing/src/PublicAPI.Unshipped.txt +++ b/src/Http/Routing/src/PublicAPI.Unshipped.txt @@ -1 +1,3 @@ #nullable enable +Microsoft.AspNetCore.Builder.ValidationEndpointConventionBuilderExtensions +static Microsoft.AspNetCore.Builder.ValidationEndpointConventionBuilderExtensions.DisableValidation(this TBuilder builder) -> TBuilder diff --git a/src/Http/Routing/src/RouteEndpointDataSource.cs b/src/Http/Routing/src/RouteEndpointDataSource.cs index 2ed6ff242276..59bcd699a59d 100644 --- a/src/Http/Routing/src/RouteEndpointDataSource.cs +++ b/src/Http/Routing/src/RouteEndpointDataSource.cs @@ -2,12 +2,17 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Diagnostics; +using System.Linq; using System.Reflection; using System.Runtime.CompilerServices; using Microsoft.AspNetCore.Builder; using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Http.Metadata; +using Microsoft.AspNetCore.Http.Validation; using Microsoft.AspNetCore.Routing.Patterns; +using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.FileProviders; +using Microsoft.Extensions.Options; using Microsoft.Extensions.Primitives; namespace Microsoft.AspNetCore.Routing; @@ -100,7 +105,7 @@ public override IReadOnlyList Endpoints public override IReadOnlyList GetGroupedEndpoints(RouteGroupContext context) { var endpoints = new RouteEndpoint[_routeEntries.Count]; - for (int i = 0; i < _routeEntries.Count; i++) + for (var i = 0; i < _routeEntries.Count; i++) { endpoints[i] = (RouteEndpoint)CreateRouteEndpointBuilder(_routeEntries[i], context.Prefix, context.Conventions, context.FinallyConventions).Build(); } @@ -155,7 +160,7 @@ private RouteEndpointBuilder CreateRouteEndpointBuilder( // If we're not a route handler, we started with a fully realized (although unfiltered) RequestDelegate, so we can just redirect to that // while running any conventions. We'll put the original back if it remains unfiltered right before building the endpoint. - RequestDelegate? factoryCreatedRequestDelegate = isRouteHandler ? null : (RequestDelegate)entry.RouteHandler; + var factoryCreatedRequestDelegate = isRouteHandler ? null : (RequestDelegate)entry.RouteHandler; // Let existing conventions capture and call into builder.RequestDelegate as long as they do so after it has been created. RequestDelegate redirectRequestDelegate = context => @@ -232,6 +237,15 @@ private RouteEndpointBuilder CreateRouteEndpointBuilder( entrySpecificConvention(builder); } + // Initialize this route endpoint builder with validation convention if validation options + // are registered and validation is not disabled on the endpoint. + var hasValidationResolvers = builder.ApplicationServices.GetService>() is { Value: { } options } && options.Resolvers.Count > 0; + var hasDisableValidationMetadata = builder.Metadata.OfType().FirstOrDefault() is not null; + if (hasValidationResolvers && !hasDisableValidationMetadata) + { + builder.FilterFactories.Insert(0, ValidationEndpointFilterFactory.Create); + } + // If no convention has modified builder.RequestDelegate, we can use the RequestDelegate returned by the RequestDelegateFactory directly. var conventionOverriddenRequestDelegate = ReferenceEquals(builder.RequestDelegate, redirectRequestDelegate) ? null : builder.RequestDelegate; diff --git a/src/Http/Routing/src/ValidationEndpointFilterFactory.cs b/src/Http/Routing/src/ValidationEndpointFilterFactory.cs new file mode 100644 index 000000000000..189acc52cf30 --- /dev/null +++ b/src/Http/Routing/src/ValidationEndpointFilterFactory.cs @@ -0,0 +1,99 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.ComponentModel.DataAnnotations; +using System.Diagnostics.CodeAnalysis; +using System.Reflection; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Options; + +namespace Microsoft.AspNetCore.Http.Validation; + +internal static class ValidationEndpointFilterFactory +{ + private const string ValidationContextJustification = "The DisplayName property is always statically initialized in the ValidationContext through this codepath."; + + public static EndpointFilterDelegate Create(EndpointFilterFactoryContext context, EndpointFilterDelegate next) + { + var parameters = context.MethodInfo.GetParameters(); + var options = context.ApplicationServices.GetService>()?.Value; + if (options is null || options.Resolvers.Count == 0) + { + return next; + } + + var parameterCount = parameters.Length; + var validatableParameters = new IValidatableInfo[parameterCount]; + var parameterDisplayNames = new string[parameterCount]; + var hasValidatableParameters = false; + + for (var i = 0; i < parameterCount; i++) + { + if (options.TryGetValidatableParameterInfo(parameters[i], out var validatableParameter)) + { + validatableParameters[i] = validatableParameter; + parameterDisplayNames[i] = GetDisplayName(parameters[i]); + hasValidatableParameters = true; + } + } + + if (!hasValidatableParameters) + { + return next; + } + + return async (context) => + { + var validatableContext = new ValidateContext { ValidationOptions = options }; + + for (var i = 0; i < context.Arguments.Count; i++) + { + var validatableParameter = validatableParameters[i]; + var displayName = parameterDisplayNames[i]; + + var argument = context.Arguments[i]; + if (argument is null || validatableParameter is null) + { + continue; + } + // ValidationContext is not trim-friendly in codepaths that don't + // initialize an explicit DisplayName. We can suppress the warning here. + // Eventually, this can be removed when the code is updated to + // use https://github.com/dotnet/runtime/issues/113134. + var validationContext = CreateValidationContext(argument, displayName, context.HttpContext.RequestServices); + validatableContext.ValidationContext = validationContext; + await validatableParameter.ValidateAsync(argument, validatableContext, context.HttpContext.RequestAborted); + } + + if (validatableContext.ValidationErrors is { Count: > 0 }) + { + context.HttpContext.Response.StatusCode = StatusCodes.Status400BadRequest; + context.HttpContext.Response.ContentType = "application/problem+json"; + return await ValueTask.FromResult(new HttpValidationProblemDetails(validatableContext.ValidationErrors)); + } + + return await next(context); + }; + } + + /// + /// ValidationContext is not trim-friendly in codepaths that don't + /// initialize an explicit DisplayName. We can suppress the warning here. + /// Eventually, this can be removed when the code is updated to + /// use https://github.com/dotnet/runtime/issues/113134. + /// + [UnconditionalSuppressMessage("Trimming", "IL2026:Members annotated with 'RequiresUnreferencedCodeAttribute' require dynamic access otherwise can break functionality when trimming application code", Justification = ValidationContextJustification)] + private static ValidationContext CreateValidationContext(object argument, string displayName, IServiceProvider serviceProvider) + => new(argument, serviceProvider, items: null) { DisplayName = displayName }; + + private static string GetDisplayName(ParameterInfo parameterInfo) + { + var displayAttribute = parameterInfo.GetCustomAttribute(); + if (displayAttribute != null) + { + return displayAttribute.Name ?? parameterInfo.Name!; + } + + return parameterInfo.Name!; + } +} diff --git a/src/Http/samples/MinimalValidationSample/MinimalValidationSample.csproj b/src/Http/samples/MinimalValidationSample/MinimalValidationSample.csproj new file mode 100644 index 000000000000..24690f4dbd35 --- /dev/null +++ b/src/Http/samples/MinimalValidationSample/MinimalValidationSample.csproj @@ -0,0 +1,27 @@ + + + + $(DefaultNetCoreTargetFramework) + enable + true + $(InterceptorsNamespaces);Microsoft.AspNetCore.Http.Validation.Generated + + + + + + + + + + + + + + + + + + diff --git a/src/Http/samples/MinimalValidationSample/MinimalValidationSample.http b/src/Http/samples/MinimalValidationSample/MinimalValidationSample.http new file mode 100644 index 000000000000..1966d2733ff8 --- /dev/null +++ b/src/Http/samples/MinimalValidationSample/MinimalValidationSample.http @@ -0,0 +1,114 @@ +### Valid customer ID request +GET http://localhost:5021/customers/42 +Accept: application/json + +### Invalid customer ID request (ID must be >= 1) +GET http://localhost:5021/customers/0 +Accept: application/json + +### Valid customer POST request +POST http://localhost:5021/customers +Content-Type: application/json + +{ + "name": "John Doe", + "email": "john.doe@example.com", + "age": 30, + "homeAddress": { + "street": "123 Main St", + "city": "Anytown", + "zipCode": "12345" + } +} + +### Invalid customer POST request (missing required fields) +POST http://localhost:5021/customers +Content-Type: application/json + +{ + "age": 15 +} + +### Invalid customer POST request (invalid email format) +POST http://localhost:5021/customers +Content-Type: application/json + +{ + "name": "John Doe", + "email": "not-an-email", + "age": 30 +} + +### Invalid customer POST request (age out of range) +POST http://localhost:5021/customers +Content-Type: application/json + +{ + "name": "John Doe", + "email": "john.doe@example.com", + "age": 15 +} + +### Invalid customer POST request (invalid zipCode length) +POST http://localhost:5021/customers +Content-Type: application/json + +{ + "name": "John Doe", + "email": "john.doe@example.com", + "age": 30, + "homeAddress": { + "street": "123 Main St", + "city": "Anytown", + "zipCode": "1234567" + } +} + +### Valid order POST request +POST http://localhost:5021/orders +Content-Type: application/json + +{ + "orderId": 12345, + "productName": "Sample Product", + "quantity": 5 +} + +### Invalid order POST request (missing required field) +POST http://localhost:5021/orders +Content-Type: application/json + +{ + "orderId": 12345, + "quantity": 5 +} + +### Invalid order POST request (IValidatableObject validation failure) +POST http://localhost:5021/orders +Content-Type: application/json + +{ + "orderId": 12345, + "productName": "Sample Product", + "quantity": 0 +} + +### Invalid order POST request (negative orderId) +POST http://localhost:5021/orders +Content-Type: application/json + +{ + "orderId": -1, + "productName": "Sample Product", + "quantity": 5 +} + +### Valid product POST request (validation disabled) +# This endpoint has DisableValidation() applied, so even invalid data should be accepted +POST http://localhost:5021/products?productId=2&name=TestProduct +Content-Type: application/json + +### Invalid product POST request (validation disabled) +# This has an odd productId and is missing name, but should still work because validation is disabled +POST http://localhost:5021/products?productId=3 +Content-Type: application/json diff --git a/src/Http/samples/MinimalValidationSample/Program.cs b/src/Http/samples/MinimalValidationSample/Program.cs new file mode 100644 index 000000000000..91be84ed0978 --- /dev/null +++ b/src/Http/samples/MinimalValidationSample/Program.cs @@ -0,0 +1,97 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.ComponentModel.DataAnnotations; +using Microsoft.AspNetCore.Http.Validation; + +var builder = WebApplication.CreateBuilder(args); + +builder.Services.AddValidation(); + +var app = builder.Build(); + +// ValidationEndpointFilterFactory is implicitly enabled on all endpoints +app.MapGet("/customers/{id}", ([Range(1, int.MaxValue)] int id) => + $"Getting customer with ID: {id}"); + +app.MapPost("/customers", (Customer customer) => TypedResults.Created($"/customers/{customer.Name}", customer)); + +app.MapPost("/orders", (Order order) => TypedResults.Created($"/orders/{order.OrderId}", order)); + +app.MapPost("/products", + ([EvenNumber(ErrorMessage = "Product ID must be even")] int productId, [Required] string name) + => TypedResults.Ok(new { productId, name })) + .DisableValidation(); + +app.Run(); + +// Define validatable types with the ValidatableType attribute +[ValidatableType] +public class Customer +{ + [Required] + public required string Name { get; set; } + + [EmailAddress] + public required string Email { get; set; } + + [Range(18, 120)] + [Display(Name = "Customer Age")] + public int Age { get; set; } + + // Complex property with nested validation + public Address HomeAddress { get; set; } = new Address + { + Street = "123 Main St", + City = "Anytown", + ZipCode = "12345" + }; +} + +public class Address +{ + [Required] + public required string Street { get; set; } + + [Required] + public required string City { get; set; } + + [StringLength(5)] + public required string ZipCode { get; set; } +} + +// Define a type implementing IValidatableObject for custom validation +public class Order : IValidatableObject +{ + [Range(1, int.MaxValue)] + public int OrderId { get; set; } + + [Required] + public required string ProductName { get; set; } + + public int Quantity { get; set; } + + // Custom validation logic using IValidatableObject + public IEnumerable Validate(ValidationContext validationContext) + { + if (Quantity <= 0) + { + yield return new ValidationResult( + "Quantity must be greater than zero", + [nameof(Quantity)]); + } + } +} + +// Use a custom validation attribute +public class EvenNumberAttribute : ValidationAttribute +{ + public override bool IsValid(object? value) + { + if (value is int number) + { + return number % 2 == 0; + } + return false; + } +} diff --git a/src/Http/samples/MinimalValidationSample/Properties/launchSettings.json b/src/Http/samples/MinimalValidationSample/Properties/launchSettings.json new file mode 100644 index 000000000000..6e42095c6bd3 --- /dev/null +++ b/src/Http/samples/MinimalValidationSample/Properties/launchSettings.json @@ -0,0 +1,13 @@ +{ + "profiles": { + "HttpApiSampleApp": { + "commandName": "Project", + "dotnetRunMessages": true, + "launchBrowser": true, + "applicationUrl": "https://localhost:5022;http://localhost:5021", + "environmentVariables": { + "ASPNETCORE_ENVIRONMENT": "Development" + } + } + } +} diff --git a/src/Http/samples/MinimalValidationSample/appsettings.Development.json b/src/Http/samples/MinimalValidationSample/appsettings.Development.json new file mode 100644 index 000000000000..8983e0fc1c5e --- /dev/null +++ b/src/Http/samples/MinimalValidationSample/appsettings.Development.json @@ -0,0 +1,9 @@ +{ + "Logging": { + "LogLevel": { + "Default": "Information", + "Microsoft": "Warning", + "Microsoft.Hosting.Lifetime": "Information" + } + } +} diff --git a/src/Http/samples/MinimalValidationSample/appsettings.json b/src/Http/samples/MinimalValidationSample/appsettings.json new file mode 100644 index 000000000000..d9d9a9bff6fd --- /dev/null +++ b/src/Http/samples/MinimalValidationSample/appsettings.json @@ -0,0 +1,10 @@ +{ + "Logging": { + "LogLevel": { + "Default": "Information", + "Microsoft": "Warning", + "Microsoft.Hosting.Lifetime": "Information" + } + }, + "AllowedHosts": "*" +}