Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,32 @@ public static IncrementalValuesProvider<T> Concat<T>(
});
}

public static IncrementalValuesProvider<T> Concat<T>(
this IncrementalValuesProvider<ImmutableArray<T>> first,
IncrementalValuesProvider<T> second)
{
return first.Collect()
.Combine(second.Collect())
.SelectMany((tuple, _) =>
{
if (tuple.Left.IsEmpty)
{
return tuple.Right;
}

var results = ImmutableArray.CreateBuilder<T>(tuple.Left.Length + tuple.Right.Length);
for (var i = 0; i < tuple.Left.Length; i++)
{
results.AddRange(tuple.Left[i]);
}
for (var i = 0; i < tuple.Right.Length; i++)
{
results.AddRange(tuple.Right[i]);
}
return results.DrainToImmutable();
});
}

private sealed class ImmutableArrayEqualityComparer<T> : IEqualityComparer<ImmutableArray<T>>
{
public static readonly ImmutableArrayEqualityComparer<T> Instance = new();
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Collections.Generic;
using System.Collections.Immutable;
using System.Threading;
Expand All @@ -24,7 +23,7 @@ internal ImmutableArray<ValidatableType> TransformValidatableTypeWithAttribute(G
var wellKnownTypes = WellKnownTypes.GetOrCreate(context.SemanticModel.Compilation);
if (TryExtractValidatableType((ITypeSymbol)context.TargetSymbol, wellKnownTypes, ref validatableTypes, ref visitedTypes))
{
return [..validatableTypes];
return [.. validatableTypes];
}
return [];
}
Expand Down
31 changes: 26 additions & 5 deletions src/Validation/gen/ValidationsGenerator.cs
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections.Immutable;
using System.Linq;
using System.Runtime.InteropServices;
using Microsoft.AspNetCore.App.Analyzers.Infrastructure;
using Microsoft.CodeAnalysis;

namespace Microsoft.Extensions.Validation;
Expand All @@ -16,24 +20,40 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
predicate: FindAddValidation,
transform: TransformAddValidation
);
// Extract types that have been marked with [ValidatableType].
var validatableTypesWithAttribute = context.SyntaxProvider.ForAttributeWithMetadataName(

// Extract types that have been marked with framework [ValidatableType].
var frameworkValidatableTypes = context.SyntaxProvider.ForAttributeWithMetadataName(
"Microsoft.Extensions.Validation.ValidatableTypeAttribute",
predicate: ShouldTransformSymbolWithAttribute,
transform: TransformValidatableTypeWithAttribute
);

// Extract types that have been marked with generated [ValidatableType].
var generatedValidatableTypes = context.SyntaxProvider.ForAttributeWithMetadataName(
"Microsoft.Extensions.Validation.Embedded.ValidatableTypeAttribute",
predicate: ShouldTransformSymbolWithAttribute,
transform: TransformValidatableTypeWithAttribute
);

// Combine both sources of validatable types
var validatableTypesWithAttribute = frameworkValidatableTypes.Concat(generatedValidatableTypes);

// Extract all minimal API endpoints in the application.
var endpoints = context.SyntaxProvider
.CreateSyntaxProvider(
predicate: FindEndpoints,
transform: TransformEndpoints)
.Where(endpoint => endpoint is not null);

// Extract validatable types from all endpoints.
var validatableTypesFromEndpoints = endpoints
.Select(ExtractValidatableEndpoint);

// Join all validatable types encountered in the type graph.
var validatableTypes = validatableTypesWithAttribute
.Concat(validatableTypesFromEndpoints)
var allValidatableTypesProviders = validatableTypesFromEndpoints
.Concat(validatableTypesWithAttribute);

var validatableTypes = allValidatableTypesProviders
.Distinct(ValidatableTypeComparer.Instance)
.Collect();

Expand All @@ -42,6 +62,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context)

// Emit the IValidatableInfo resolver injection and
// ValidatableTypeInfo for all validatable types.
context.RegisterSourceOutput(emitInputs, Emit);
context.RegisterSourceOutput(emitInputs, (context, emitInputs) =>
Emit(context, (emitInputs.Left, emitInputs.Right)));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using Microsoft.Extensions.Validation.GeneratorTests;
using VerifyXunit;
using Xunit;

namespace Microsoft.Extensions.Validation.GeneratorTests;

[UsesVerify]
public partial class ValidationsGeneratorTests : ValidationsGeneratorTestBase
{
private const string GeneratedAttributeSource = """
// <auto-generated/>
namespace Microsoft.CodeAnalysis
{
[global::System.AttributeUsage(global::System.AttributeTargets.All, AllowMultiple = true, Inherited = false)]
internal sealed class EmbeddedAttribute : global::System.Attribute
{
}
}

namespace Microsoft.Extensions.Validation.Embedded
{
[global::Microsoft.CodeAnalysis.EmbeddedAttribute]
[global::System.AttributeUsage(global::System.AttributeTargets.Class)]
internal sealed class ValidatableTypeAttribute : global::System.Attribute
{
}
}
""";

[Fact]
public async Task CanDiscoverGeneratedValidatableTypeAttribute()
{
var source = """

namespace MyApp
{
using Microsoft.AspNetCore.Builder;
using Microsoft.Extensions.DependencyInjection;
using System.ComponentModel.DataAnnotations;
using Microsoft.Extensions.Validation.Embedded;

public class Program
{
public static void Main(string[] args)
{
var builder = WebApplication.CreateBuilder(args);
builder.Services.AddValidation();
var app = builder.Build();
app.MapPost("/customers", (Customer customer) => "OK");
app.Run();
}
}

[ValidatableType]
public class Customer
{
[Required]
public string Name { get; set; } = "";

[EmailAddress]
public string Email { get; set; } = "";
}
}
""";

// Combine the generated attribute with the user's source
var combinedSource = GeneratedAttributeSource + "\n" + source;

await Verify(combinedSource, out var compilation);
}

[Fact]
public async Task CanUseBothFrameworkAndGeneratedValidatableTypeAttributes()
{
var source = """
namespace MyApp
{
using Microsoft.AspNetCore.Builder;
using Microsoft.Extensions.DependencyInjection;
using System.ComponentModel.DataAnnotations;
using Microsoft.Extensions.Validation.Embedded;

public class Program
{
public static void Main(string[] args)
{
var builder = WebApplication.CreateBuilder(args);
builder.Services.AddValidation();
var app = builder.Build();
app.MapPost("/customers", (Customer customer) => "OK");
app.Run();
}
}

// Using framework attribute
[Microsoft.Extensions.Validation.ValidatableType]
public class Customer
{
[Required]
public string Name { get; set; } = "";

[EmailAddress]
public string Email { get; set; } = "";
}

// Using generated attribute
[ValidatableType]
public class Product
{
[Required]
public string ProductName { get; set; } = "";

[Range(0, double.MaxValue)]
public decimal Price { get; set; }
}
}
""";

// Combine the generated attribute with the user's source
var combinedSource = GeneratedAttributeSource + "\n" + source;

await Verify(combinedSource, out var compilation);
}
}
Loading
Loading