Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

using System.Collections.Immutable;
using System.Linq;
using Microsoft.AspNetCore.App.Analyzers.Infrastructure;
using Microsoft.CodeAnalysis;

namespace Microsoft.AspNetCore.Http.ValidationsGenerator;
Expand Down Expand Up @@ -90,17 +91,17 @@ internal static bool ImplementsInterface(this ITypeSymbol type, ITypeSymbol inte

// Types exempted here have special binding rules in RDF and RDG and are not validatable
// types themselves so we short-circuit on them.
internal static bool IsExemptType(this ITypeSymbol type, RequiredSymbols requiredSymbols)
internal static bool IsExemptType(this ITypeSymbol type, WellKnownTypes wellKnownTypes)
{
return SymbolEqualityComparer.Default.Equals(type, requiredSymbols.HttpContext)
|| SymbolEqualityComparer.Default.Equals(type, requiredSymbols.HttpRequest)
|| SymbolEqualityComparer.Default.Equals(type, requiredSymbols.HttpResponse)
|| SymbolEqualityComparer.Default.Equals(type, requiredSymbols.CancellationToken)
|| SymbolEqualityComparer.Default.Equals(type, requiredSymbols.IFormCollection)
|| SymbolEqualityComparer.Default.Equals(type, requiredSymbols.IFormFileCollection)
|| SymbolEqualityComparer.Default.Equals(type, requiredSymbols.IFormFile)
|| SymbolEqualityComparer.Default.Equals(type, requiredSymbols.Stream)
|| SymbolEqualityComparer.Default.Equals(type, requiredSymbols.PipeReader);
return SymbolEqualityComparer.Default.Equals(type, wellKnownTypes.Get(WellKnownTypeData.WellKnownType.Microsoft_AspNetCore_Http_HttpContext))
|| SymbolEqualityComparer.Default.Equals(type, wellKnownTypes.Get(WellKnownTypeData.WellKnownType.Microsoft_AspNetCore_Http_HttpRequest))
|| SymbolEqualityComparer.Default.Equals(type, wellKnownTypes.Get(WellKnownTypeData.WellKnownType.Microsoft_AspNetCore_Http_HttpResponse))
|| SymbolEqualityComparer.Default.Equals(type, wellKnownTypes.Get(WellKnownTypeData.WellKnownType.System_Threading_CancellationToken))
|| SymbolEqualityComparer.Default.Equals(type, wellKnownTypes.Get(WellKnownTypeData.WellKnownType.Microsoft_AspNetCore_Http_IFormCollection))
|| SymbolEqualityComparer.Default.Equals(type, wellKnownTypes.Get(WellKnownTypeData.WellKnownType.Microsoft_AspNetCore_Http_IFormFileCollection))
|| SymbolEqualityComparer.Default.Equals(type, wellKnownTypes.Get(WellKnownTypeData.WellKnownType.Microsoft_AspNetCore_Http_IFormFile))
|| SymbolEqualityComparer.Default.Equals(type, wellKnownTypes.Get(WellKnownTypeData.WellKnownType.System_IO_Stream))
|| SymbolEqualityComparer.Default.Equals(type, wellKnownTypes.Get(WellKnownTypeData.WellKnownType.System_IO_Pipelines_PipeReader));
}

internal static IPropertySymbol? FindPropertyIncludingBaseTypes(this INamedTypeSymbol typeSymbol, string propertyName)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
<Compile Include="$(SharedSourceRoot)RoslynUtils\CodeWriter.cs" LinkBase="Shared" />
<Compile Include="$(RepoRoot)\src\Http\Http.Extensions\gen\Microsoft.AspNetCore.Http.RequestDelegateGenerator\StaticRouteHandlerModel\InvocationOperationExtensions.cs" LinkBase="Shared" />
<Compile Include="$(SharedSourceRoot)Diagnostics\AnalyzerDebug.cs" LinkBase="Shared" />
<Compile Include="$(SharedSourceRoot)RoslynUtils\ParsabilityHelper.cs" LinkBase="Shared" />
<Compile Include="$(SharedSourceRoot)RoslynUtils\SymbolExtensions.cs" LinkBase="Shared" />
</ItemGroup>

</Project>
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Threading;
using Microsoft.AspNetCore.App.Analyzers.Infrastructure;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Syntax;

Expand All @@ -20,8 +21,8 @@ internal ImmutableArray<ValidatableType> TransformValidatableTypeWithAttribute(G
{
var validatableTypes = new HashSet<ValidatableType>(ValidatableTypeComparer.Instance);
List<ITypeSymbol> visitedTypes = [];
var requiredSymbols = ExtractRequiredSymbols(context.SemanticModel.Compilation, cancellationToken);
if (TryExtractValidatableType((ITypeSymbol)context.TargetSymbol, requiredSymbols, ref validatableTypes, ref visitedTypes))
var wellKnownTypes = WellKnownTypes.GetOrCreate(context.SemanticModel.Compilation);
if (TryExtractValidatableType((ITypeSymbol)context.TargetSymbol, wellKnownTypes, ref validatableTypes, ref visitedTypes))
{
return [..validatableTypes];
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using System.Linq;
using System.Threading;
using Microsoft.AspNetCore.Analyzers.Infrastructure;
using Microsoft.AspNetCore.App.Analyzers.Infrastructure;
using Microsoft.AspNetCore.Http.RequestDelegateGenerator.StaticRouteHandlerModel;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Syntax;
Expand Down Expand Up @@ -38,10 +39,12 @@ internal bool FindEndpoints(SyntaxNode syntaxNode, CancellationToken cancellatio
: null;
}

internal ImmutableArray<ValidatableType> ExtractValidatableEndpoint((IInvocationOperation? Operation, RequiredSymbols RequiredSymbols) input, CancellationToken cancellationToken)
internal ImmutableArray<ValidatableType> ExtractValidatableEndpoint(IInvocationOperation? operation, CancellationToken cancellationToken)
{
AnalyzerDebug.Assert(input.Operation != null, "Operation should not be null.");
var validatableTypes = ExtractValidatableTypes(input.Operation, input.RequiredSymbols);
AnalyzerDebug.Assert(operation != null, "Operation should not be null.");
AnalyzerDebug.Assert(operation.SemanticModel != null, "Operation should have a semantic model.");
var wellKnownTypes = WellKnownTypes.GetOrCreate(operation.SemanticModel.Compilation);
var validatableTypes = ExtractValidatableTypes(operation, wellKnownTypes);
return validatableTypes;
}
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System.Collections.Immutable;
using System.Linq;
using Microsoft.AspNetCore.Analyzers.Infrastructure;
using Microsoft.AspNetCore.App.Analyzers.Infrastructure;
using Microsoft.AspNetCore.Http.RequestDelegateGenerator.StaticRouteHandlerModel;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
Expand All @@ -18,7 +19,7 @@ public sealed partial class ValidationsGenerator : IIncrementalGenerator
globalNamespaceStyle: SymbolDisplayGlobalNamespaceStyle.Included,
typeQualificationStyle: SymbolDisplayTypeQualificationStyle.NameAndContainingTypesAndNamespaces);

internal ImmutableArray<ValidatableType> ExtractValidatableTypes(IInvocationOperation operation, RequiredSymbols requiredSymbols)
internal ImmutableArray<ValidatableType> ExtractValidatableTypes(IInvocationOperation operation, WellKnownTypes wellKnownTypes)
{
AnalyzerDebug.Assert(operation.SemanticModel != null, "SemanticModel should not be null.");
var parameters = operation.TryGetRouteHandlerMethod(operation.SemanticModel, out var method)
Expand All @@ -28,12 +29,12 @@ internal ImmutableArray<ValidatableType> ExtractValidatableTypes(IInvocationOper
List<ITypeSymbol> visitedTypes = [];
foreach (var parameter in parameters)
{
_ = TryExtractValidatableType(parameter.Type.UnwrapType(requiredSymbols.IEnumerable), requiredSymbols, ref validatableTypes, ref visitedTypes);
_ = TryExtractValidatableType(parameter.Type.UnwrapType(wellKnownTypes.Get(WellKnownTypeData.WellKnownType.System_Collections_IEnumerable)), wellKnownTypes, ref validatableTypes, ref visitedTypes);
}
return [.. validatableTypes];
}

internal bool TryExtractValidatableType(ITypeSymbol typeSymbol, RequiredSymbols requiredSymbols, ref HashSet<ValidatableType> validatableTypes, ref List<ITypeSymbol> visitedTypes)
internal bool TryExtractValidatableType(ITypeSymbol typeSymbol, WellKnownTypes wellKnownTypes, ref HashSet<ValidatableType> validatableTypes, ref List<ITypeSymbol> visitedTypes)
{
if (typeSymbol.SpecialType != SpecialType.None)
{
Expand All @@ -45,7 +46,7 @@ internal bool TryExtractValidatableType(ITypeSymbol typeSymbol, RequiredSymbols
return true;
}

if (typeSymbol.IsExemptType(requiredSymbols))
if (typeSymbol.IsExemptType(wellKnownTypes))
{
return false;
}
Expand All @@ -57,19 +58,23 @@ internal bool TryExtractValidatableType(ITypeSymbol typeSymbol, RequiredSymbols
var hasValidatableBaseType = false;
while (current != null && current.SpecialType != SpecialType.System_Object)
{
hasValidatableBaseType |= TryExtractValidatableType(current, requiredSymbols, ref validatableTypes, ref visitedTypes);
hasValidatableBaseType |= TryExtractValidatableType(current, wellKnownTypes, ref validatableTypes, ref visitedTypes);
current = current.BaseType;
}

// Extract validatable types discovered in members of this type and add them to the top-level list.
var members = ExtractValidatableMembers(typeSymbol, requiredSymbols, ref validatableTypes, ref visitedTypes);
ImmutableArray<ValidatableProperty> members = [];
if (ParsabilityHelper.GetParsability(typeSymbol, wellKnownTypes) is Parsability.NotParsable)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This and the change in ValidationsGenerator are the only real changes in this PR?

And this is removing TryParseable types from the validatable types? But somehow they are still validatable?

Copy link
Member Author

@captainsafia captainsafia Apr 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This and the change in ValidationsGenerator are the only real changes in this PR?

Yep, this is the main change. All the WellKnownType stuff is a reaction to the fact that this API needs it.

And this is removing TryParseable types from the validatable types? But somehow they are still validatable?

We'll still generate ValidatablePropertyInfo instances for them so that we'll validate attributes on the properties. This change does mean that we won't validate properties with validation attributes inside the parsable types. None of the properties in the case below get validated. I'd have to check how MVC handles this...if it does...

The right pattern hear might be to require that custom parsable types implement IValidatableObject to handle validation.

public class Person
{
    [Required]
    [StringLength(100, MinimumLength = 2)]
    public string FirstName { get; private set; } = string.Empty;

    [Required]
    [StringLength(100, MinimumLength = 2)]
    public string LastName { get; private set; } = string.Empty;

    [Range(0, 120)]
    public int Age { get; private set; }

    // Static TryParse method
    public static bool TryParse(string input, out Person? person) { }    
}

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update: it looks like MVC doesn't validate properties on a TryParse by default:

#:sdk Microsoft.NET.Sdk.Web
#:property TargetFramework net10.0

using Microsoft.AspNetCore.Mvc;
using System.ComponentModel.DataAnnotations;

var builder = WebApplication.CreateBuilder(args);

builder.Services.AddControllers();

var app = builder.Build();

app.MapControllers();

app.Run();

[ApiController]
[Route("[controller]")]
public class PersonController : ControllerBase
{
    // This endpoint binds PersonInput from query string
    [HttpGet]
    public IActionResult CreatePerson([FromQuery] PersonInput personInput)
    {
        if (!ModelState.IsValid)
        {
            return BadRequest(ModelState);
        }

        return Ok(new { Message = $"Created person: {personInput.FirstName} {personInput.LastName}, Age {personInput.Age}" });
    }
}

public class PersonInput
{
    [Required]
    [StringLength(100, MinimumLength = 2)]
    public string FirstName { get; init; } = string.Empty;

    [Required]
    [StringLength(100, MinimumLength = 2)]
    public string LastName { get; init; } = string.Empty;

    [Range(0, 120)]
    public int Age { get; init; }

    // TryParse for parsing raw CSV strings (no validation inside)
    public static bool TryParse(string input, out PersonInput? person)
    {
        person = null;

        if (string.IsNullOrWhiteSpace(input))
            return false;

        var parts = input.Split(',', StringSplitOptions.TrimEntries);
        if (parts.Length != 3)
            return false;

        if (!int.TryParse(parts[2], out int age))
            return false;

        person = new PersonInput
        {
            FirstName = parts[0],
            LastName = parts[1],
            Age = age
        };

        return true;
    }
}
$ curl http://localhost:5000/person?personInput=S,A,128
{
  "message": "Created person: S A, Age 128"
}

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The right pattern hear might be to require that custom parsable types implement IValidatableObject to handle validation.

Does this work now, or are you saying we might want to make this work in the future?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't work now. We would need to add a check to see if the type implements IValidatableObject and if so generate an IValidatableInfo instance for that member type.

{
members = ExtractValidatableMembers(typeSymbol, wellKnownTypes, ref validatableTypes, ref visitedTypes);
}

// Extract the validatable types discovered in the JsonDerivedTypeAttributes of this type and add them to the top-level list.
var derivedTypes = typeSymbol.GetJsonDerivedTypes(requiredSymbols.JsonDerivedTypeAttribute);
var derivedTypes = typeSymbol.GetJsonDerivedTypes(wellKnownTypes.Get(WellKnownTypeData.WellKnownType.System_Text_Json_Serialization_JsonDerivedTypeAttribute));
var hasValidatableDerivedTypes = false;
foreach (var derivedType in derivedTypes ?? [])
{
hasValidatableDerivedTypes |= TryExtractValidatableType(derivedType, requiredSymbols, ref validatableTypes, ref visitedTypes);
hasValidatableDerivedTypes |= TryExtractValidatableType(derivedType, wellKnownTypes, ref validatableTypes, ref visitedTypes);
}

// No validatable members or derived types found, so we don't need to add this type.
Expand All @@ -86,7 +91,7 @@ internal bool TryExtractValidatableType(ITypeSymbol typeSymbol, RequiredSymbols
return true;
}

internal ImmutableArray<ValidatableProperty> ExtractValidatableMembers(ITypeSymbol typeSymbol, RequiredSymbols requiredSymbols, ref HashSet<ValidatableType> validatableTypes, ref List<ITypeSymbol> visitedTypes)
internal ImmutableArray<ValidatableProperty> ExtractValidatableMembers(ITypeSymbol typeSymbol, WellKnownTypes wellKnownTypes, ref HashSet<ValidatableType> validatableTypes, ref List<ITypeSymbol> visitedTypes)
{
var members = new List<ValidatableProperty>();
var resolvedRecordProperty = new List<IPropertySymbol>();
Expand Down Expand Up @@ -121,17 +126,17 @@ internal ImmutableArray<ValidatableProperty> ExtractValidatableMembers(ITypeSymb
// Check if the property's type is validatable, this resolves
// validatable types in the inheritance hierarchy
var hasValidatableType = TryExtractValidatableType(
correspondingProperty.Type.UnwrapType(requiredSymbols.IEnumerable),
requiredSymbols,
correspondingProperty.Type.UnwrapType(wellKnownTypes.Get(WellKnownTypeData.WellKnownType.System_Collections_IEnumerable)),
wellKnownTypes,
ref validatableTypes,
ref visitedTypes);

members.Add(new ValidatableProperty(
ContainingType: correspondingProperty.ContainingType,
Type: correspondingProperty.Type,
Name: correspondingProperty.Name,
DisplayName: parameter.GetDisplayName(requiredSymbols.DisplayAttribute) ??
correspondingProperty.GetDisplayName(requiredSymbols.DisplayAttribute),
DisplayName: parameter.GetDisplayName(wellKnownTypes.Get(WellKnownTypeData.WellKnownType.System_ComponentModel_DataAnnotations_DisplayAttribute)) ??
correspondingProperty.GetDisplayName(wellKnownTypes.Get(WellKnownTypeData.WellKnownType.System_ComponentModel_DataAnnotations_DisplayAttribute)),
Attributes: []));
}
}
Expand All @@ -148,8 +153,8 @@ internal ImmutableArray<ValidatableProperty> ExtractValidatableMembers(ITypeSymb
continue;
}

var hasValidatableType = TryExtractValidatableType(member.Type.UnwrapType(requiredSymbols.IEnumerable), requiredSymbols, ref validatableTypes, ref visitedTypes);
var attributes = ExtractValidationAttributes(member, requiredSymbols, out var isRequired);
var hasValidatableType = TryExtractValidatableType(member.Type.UnwrapType(wellKnownTypes.Get(WellKnownTypeData.WellKnownType.System_Collections_IEnumerable)), wellKnownTypes, ref validatableTypes, ref visitedTypes);
var attributes = ExtractValidationAttributes(member, wellKnownTypes, out var isRequired);

// If the member has no validation attributes or validatable types and is not required, skip it.
if (attributes.IsDefaultOrEmpty && !hasValidatableType && !isRequired)
Expand All @@ -161,14 +166,14 @@ internal ImmutableArray<ValidatableProperty> ExtractValidatableMembers(ITypeSymb
ContainingType: member.ContainingType,
Type: member.Type,
Name: member.Name,
DisplayName: member.GetDisplayName(requiredSymbols.DisplayAttribute),
DisplayName: member.GetDisplayName(wellKnownTypes.Get(WellKnownTypeData.WellKnownType.System_ComponentModel_DataAnnotations_DisplayAttribute)),
Attributes: attributes));
}

return [.. members];
}

internal static ImmutableArray<ValidationAttribute> ExtractValidationAttributes(ISymbol symbol, RequiredSymbols requiredSymbols, out bool isRequired)
internal static ImmutableArray<ValidationAttribute> ExtractValidationAttributes(ISymbol symbol, WellKnownTypes wellKnownTypes, out bool isRequired)
{
var attributes = symbol.GetAttributes();
if (attributes.Length == 0)
Expand All @@ -179,15 +184,15 @@ internal static ImmutableArray<ValidationAttribute> ExtractValidationAttributes(

var validationAttributes = attributes
.Where(attribute => attribute.AttributeClass != null)
.Where(attribute => attribute.AttributeClass!.ImplementsValidationAttribute(requiredSymbols.ValidationAttribute));
isRequired = validationAttributes.Any(attr => SymbolEqualityComparer.Default.Equals(attr.AttributeClass, requiredSymbols.RequiredAttribute));
.Where(attribute => attribute.AttributeClass!.ImplementsValidationAttribute(wellKnownTypes.Get(WellKnownTypeData.WellKnownType.System_ComponentModel_DataAnnotations_ValidationAttribute)));
isRequired = validationAttributes.Any(attr => SymbolEqualityComparer.Default.Equals(attr.AttributeClass, wellKnownTypes.Get(WellKnownTypeData.WellKnownType.System_ComponentModel_DataAnnotations_RequiredAttribute)));
return [.. validationAttributes
.Where(attr => !SymbolEqualityComparer.Default.Equals(attr.AttributeClass, requiredSymbols.ValidationAttribute))
.Where(attr => !SymbolEqualityComparer.Default.Equals(attr.AttributeClass, wellKnownTypes.Get(WellKnownTypeData.WellKnownType.System_ComponentModel_DataAnnotations_ValidationAttribute)))
.Select(attribute => new ValidationAttribute(
Name: symbol.Name + attribute.AttributeClass!.Name,
ClassName: attribute.AttributeClass!.ToDisplayString(_symbolDisplayFormat),
Arguments: [.. attribute.ConstructorArguments.Select(a => a.ToCSharpString())],
NamedArguments: attribute.NamedArguments.ToDictionary(namedArgument => namedArgument.Key, namedArgument => namedArgument.Value.ToCSharpString()),
IsCustomValidationAttribute: SymbolEqualityComparer.Default.Equals(attribute.AttributeClass, requiredSymbols.CustomValidationAttribute)))];
IsCustomValidationAttribute: SymbolEqualityComparer.Default.Equals(attribute.AttributeClass, wellKnownTypes.Get(WellKnownTypeData.WellKnownType.System_ComponentModel_DataAnnotations_CustomValidationAttribute))))];
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,6 @@ public sealed partial class ValidationsGenerator : IIncrementalGenerator
{
public void Initialize(IncrementalGeneratorInitializationContext context)
{
// Resolve the symbols that will be required when making comparisons
// in future steps.
var requiredSymbols = context.CompilationProvider.Select(ExtractRequiredSymbols);

// Find the builder.Services.AddValidation() call in the application.
var addValidation = context.SyntaxProvider.CreateSyntaxProvider(
predicate: FindAddValidation,
Expand All @@ -34,7 +30,6 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
.Where(endpoint => endpoint is not null);
// Extract validatable types from all endpoints.
var validatableTypesFromEndpoints = endpoints
.Combine(requiredSymbols)
.Select(ExtractValidatableEndpoint);
// Join all validatable types encountered in the type graph.
var validatableTypes = validatableTypesWithAttribute
Expand Down
Loading
Loading