Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ public bool TryGetValidatableParameterInfo(global::System.Reflection.ParameterIn
file static class GeneratedServiceCollectionExtensions
{
{{addValidation.GetInterceptsLocationAttributeSyntax()}}
public static global::Microsoft.Extensions.DependencyInjection.IServiceCollection AddValidation(this global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, global::System.Action<ValidationOptions>? configureOptions = null)
public static global::Microsoft.Extensions.DependencyInjection.IServiceCollection AddValidation(this global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, global::System.Action<global::Microsoft.AspNetCore.Http.Validation.ValidationOptions>? configureOptions = null)
{
// Use non-extension method to avoid infinite recursion.
return global::Microsoft.Extensions.DependencyInjection.ValidationServiceCollectionExtensions.AddValidation(services, options =>
Expand All @@ -133,13 +133,39 @@ private sealed record CacheKey(global::System.Type ContainingType, string Proper
var key = new CacheKey(containingType, propertyName);
return _cache.GetOrAdd(key, static k =>
{
var results = new global::System.Collections.Generic.List<global::System.ComponentModel.DataAnnotations.ValidationAttribute>();

// Get attributes from the property
var property = k.ContainingType.GetProperty(k.PropertyName);
if (property == null)
if (property != null)
{
var propertyAttributes = global::System.Reflection.CustomAttributeExtensions
.GetCustomAttributes<global::System.ComponentModel.DataAnnotations.ValidationAttribute>(property, inherit: true);

results.AddRange(propertyAttributes);
}

// Check constructors for parameters that match the property name to handle
// record scenarios
foreach (var constructor in k.ContainingType.GetConstructors())
{
return [];
// Look for parameter with matching name (case insensitive)
var parameter = global::System.Linq.Enumerable.FirstOrDefault(
constructor.GetParameters(),
p => string.Equals(p.Name, k.PropertyName, global::System.StringComparison.OrdinalIgnoreCase));

if (parameter != null)
{
var paramAttributes = global::System.Reflection.CustomAttributeExtensions
.GetCustomAttributes<global::System.ComponentModel.DataAnnotations.ValidationAttribute>(parameter, inherit: true);

results.AddRange(paramAttributes);

break;
}
}

return [.. global::System.Reflection.CustomAttributeExtensions.GetCustomAttributes<global::System.ComponentModel.DataAnnotations.ValidationAttribute>(property, inherit: true)];
return results.ToArray();
});
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
using Microsoft.CodeAnalysis;

namespace Microsoft.AspNetCore.Http.ValidationsGenerator;
Expand Down Expand Up @@ -101,4 +103,39 @@ internal static bool IsExemptType(this ITypeSymbol type, RequiredSymbols require
|| SymbolEqualityComparer.Default.Equals(type, requiredSymbols.Stream)
|| SymbolEqualityComparer.Default.Equals(type, requiredSymbols.PipeReader);
}

internal static IPropertySymbol? FindPropertyIncludingBaseTypes(this INamedTypeSymbol typeSymbol, string propertyName)
{
var property = typeSymbol.GetMembers()
.OfType<IPropertySymbol>()
.FirstOrDefault(p => string.Equals(p.Name, propertyName, System.StringComparison.OrdinalIgnoreCase));

if (property != null)
{
return property;
}

// If not found, recursively search base types
if (typeSymbol.BaseType is INamedTypeSymbol baseType)
{
return FindPropertyIncludingBaseTypes(baseType, propertyName);
}

return null;
}

// Helper method to get all properties including inherited ones
internal static IEnumerable<IPropertySymbol> GetAllProperties(this ITypeSymbol typeSymbol)
{
var current = typeSymbol;
var properties = new List<IPropertySymbol>();

while (current != null && current.SpecialType != SpecialType.System_Object)
{
properties.AddRange(current.GetMembers().OfType<IPropertySymbol>());
current = current.BaseType;
}

return properties;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -89,15 +89,74 @@ internal bool TryExtractValidatableType(ITypeSymbol typeSymbol, RequiredSymbols
internal ImmutableArray<ValidatableProperty> ExtractValidatableMembers(ITypeSymbol typeSymbol, RequiredSymbols requiredSymbols, ref HashSet<ValidatableType> validatableTypes, ref List<ITypeSymbol> visitedTypes)
{
var members = new List<ValidatableProperty>();
var resolvedRecordProperty = new List<IPropertySymbol>();

// Special handling for record types to extract properties from
// the primary constructor.
if (typeSymbol is INamedTypeSymbol { IsRecord: true } namedType)
{
// Find the primary constructor for the record, account
// for members that are in base types to account for
// record inheritance scenarios
var primaryConstructor = namedType.Constructors
.FirstOrDefault(c => c.Parameters.Length > 0 && c.Parameters.All(p =>
namedType.FindPropertyIncludingBaseTypes(p.Name) != null));

if (primaryConstructor != null)
{
// Process all parameters in constructor order to maintain parameter ordering
foreach (var parameter in primaryConstructor.Parameters)
{
// Find the corresponding property in this type, we ignore
// base types here since that will be handled by the inheritance
// checks in the default ValidatableTypeInfo implementation.
var correspondingProperty = typeSymbol.GetMembers()
.OfType<IPropertySymbol>()
.FirstOrDefault(p => string.Equals(p.Name, parameter.Name, System.StringComparison.OrdinalIgnoreCase));

if (correspondingProperty != null)
{
resolvedRecordProperty.Add(correspondingProperty);

// Check if the property's type is validatable, this resolves
// validatable types in the inheritance hierarchy
var hasValidatableType = TryExtractValidatableType(
correspondingProperty.Type.UnwrapType(requiredSymbols.IEnumerable),
requiredSymbols,
ref validatableTypes,
ref visitedTypes);

members.Add(new ValidatableProperty(
ContainingType: correspondingProperty.ContainingType,
Type: correspondingProperty.Type,
Name: correspondingProperty.Name,
DisplayName: parameter.GetDisplayName(requiredSymbols.DisplayAttribute) ??
correspondingProperty.GetDisplayName(requiredSymbols.DisplayAttribute),
Attributes: []));
}
}
}
}

// Handle properties for classes and any properties not handled by the constructor
foreach (var member in typeSymbol.GetMembers().OfType<IPropertySymbol>())
{
// Skip compiler generated properties and properties already processed via
// the record processing logic above.
if (member.IsImplicitlyDeclared || resolvedRecordProperty.Contains(member, SymbolEqualityComparer.Default))
{
continue;
}

var hasValidatableType = TryExtractValidatableType(member.Type.UnwrapType(requiredSymbols.IEnumerable), requiredSymbols, ref validatableTypes, ref visitedTypes);
var attributes = ExtractValidationAttributes(member, requiredSymbols, out var isRequired);

// If the member has no validation attributes or validatable types and is not required, skip it.
if (attributes.IsDefaultOrEmpty && !hasValidatableType && !isRequired)
{
continue;
}

members.Add(new ValidatableProperty(
ContainingType: member.ContainingType,
Type: member.Type,
Expand Down
Loading
Loading