Skip to content

Commit a3db615

Browse files
authored
Improved class precedence resolution. (#73)
* Refactor. * Refactor. * Add tests for generic ProducesResponse attribute (#72) * Refactored. * Cleanup. * Bump version.
1 parent e918813 commit a3db615

File tree

37 files changed

+1270
-589
lines changed

37 files changed

+1270
-589
lines changed
Lines changed: 8 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
using System.Text;
1+
using System.Collections.Immutable;
2+
using System.Text;
23
using GeneratedEndpoints.Common;
34
using Microsoft.CodeAnalysis;
45
using Microsoft.CodeAnalysis.Text;
@@ -12,12 +13,15 @@ namespace GeneratedEndpoints;
1213

1314
internal static class AddEndpointHandlersGenerator
1415
{
15-
public static void GenerateSource(SourceProductionContext context, EquatableImmutableArray<RequestHandler> requestHandlers)
16+
public static void GenerateSource(SourceProductionContext context, ImmutableSortedDictionary<RequestHandlerClass, ImmutableArray<RequestHandler>> grouped)
1617
{
1718
context.CancellationToken.ThrowIfCancellationRequested();
1819

19-
var nonStaticClassNames = GetDistinctNonStaticClassNames(requestHandlers);
20-
var source = GetAddEndpointHandlersStringBuilder(nonStaticClassNames);
20+
var nonStaticClassNames = grouped.Keys
21+
.Where(x => !x.IsStatic)
22+
.Select(x => x.Name)
23+
.ToList();
24+
var source = new StringBuilder();
2125
source.AppendLine(FileHeader);
2226

2327
source.AppendLine();
@@ -61,45 +65,4 @@ public static void GenerateSource(SourceProductionContext context, EquatableImmu
6165
var sourceText = StringBuilderPool.ToStringAndReturn(source);
6266
context.AddSource(AddEndpointHandlersMethodHint, SourceText.From(sourceText, Encoding.UTF8));
6367
}
64-
65-
private static List<string> GetDistinctNonStaticClassNames(EquatableImmutableArray<RequestHandler> requestHandlers)
66-
{
67-
var classNames = new List<string>();
68-
if (requestHandlers.Count == 0)
69-
return classNames;
70-
71-
var seen = new HashSet<string>(StringComparer.Ordinal);
72-
for (var index = 0; index < requestHandlers.Count; index++)
73-
{
74-
var requestHandler = requestHandlers[index];
75-
if (requestHandler.Class.IsStatic)
76-
continue;
77-
78-
var className = requestHandler.Class.Name;
79-
if (seen.Add(className))
80-
classNames.Add(className);
81-
}
82-
83-
return classNames;
84-
}
85-
86-
private static StringBuilder GetAddEndpointHandlersStringBuilder(List<string> nonStaticClassNames)
87-
{
88-
var estimate = 512L;
89-
for (var index = 0; index < nonStaticClassNames.Count; index++)
90-
{
91-
var className = nonStaticClassNames[index];
92-
estimate += 36 + className.Length;
93-
}
94-
95-
estimate += Math.Max(256, nonStaticClassNames.Count * 12);
96-
estimate = (long)(estimate * 1.10);
97-
98-
if (estimate < 512)
99-
estimate = 512;
100-
else if (estimate > int.MaxValue)
101-
estimate = int.MaxValue;
102-
103-
return StringBuilderPool.Get((int)estimate);
104-
}
10568
}

src/GeneratedEndpoints/Common/AttributeDataExtensions.cs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,14 @@ internal static class AttributeDataExtensions
5656
return null;
5757
}
5858

59+
public static ITypeSymbol? GetConstructorTypeSymbol(this AttributeData attribute, int position = 0)
60+
{
61+
if (attribute.ConstructorArguments.Length > position && attribute.ConstructorArguments[position].Value is ITypeSymbol typeSymbol)
62+
return typeSymbol;
63+
64+
return null;
65+
}
66+
5967
public static ITypeSymbol? GetNamedTypeSymbol(this AttributeData attribute, string namedParameter)
6068
{
6169
foreach (var namedArg in attribute.NamedArguments)

src/GeneratedEndpoints/Common/Constants.GeneratedSources.cs

Lines changed: 34 additions & 30 deletions
Large diffs are not rendered by default.

src/GeneratedEndpoints/Common/Constants.cs

Lines changed: 1 addition & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,10 @@
1-
using System.ComponentModel;
2-
31
namespace GeneratedEndpoints.Common;
42

53
internal static partial class Constants
64
{
75
internal const string FallbackHttpMethod = "__FALLBACK__";
86

97
internal const string NameAttributeNamedParameter = "Name";
10-
internal const string ResponseTypeAttributeNamedParameter = "ResponseType";
11-
internal const string RequestTypeAttributeNamedParameter = "RequestType";
128
internal const string IsOptionalAttributeNamedParameter = "IsOptional";
139

1410
internal const string RequireAuthorizationAttributeName = "RequireAuthorizationAttribute";
@@ -47,12 +43,6 @@ internal static partial class Constants
4743
internal const string SummaryAttributeName = "SummaryAttribute";
4844
internal const string SummaryAttributeHint = $"{SummaryAttributeFullyQualifiedName}.gs.cs";
4945

50-
internal const string DisplayNameAttributeName = nameof(DisplayNameAttribute);
51-
internal const string DescriptionAttributeName = nameof(DescriptionAttribute);
52-
internal const string AllowAnonymousAttributeName = "AllowAnonymousAttribute";
53-
internal const string TagsAttributeName = "TagsAttribute";
54-
internal const string ExcludeFromDescriptionAttributeName = "ExcludeFromDescriptionAttribute";
55-
5646
internal const string EndpointFilterAttributeName = "EndpointFilterAttribute";
5747
internal const string EndpointFilterAttributeHint = $"{EndpointFilterAttributeFullyQualifiedName}.gs.cs";
5848

@@ -82,15 +72,7 @@ internal static partial class Constants
8272
internal const string AsyncSuffix = "Async";
8373
internal const string ApplicationJsonContentType = "application/json";
8474
internal const string GlobalPrefix = "global::";
85-
internal const string Dot = ".";
86-
87-
internal static readonly string[] AttributesNamespaceParts = AttributesNamespace.Split('.');
88-
internal static readonly string[] AspNetCoreHttpNamespaceParts = ["Microsoft", "AspNetCore", "Http"];
89-
internal static readonly string[] AspNetCoreMvcNamespaceParts = ["Microsoft", "AspNetCore", "Mvc"];
90-
internal static readonly string[] AspNetCoreAuthorizationNamespaceParts = ["Microsoft", "AspNetCore", "Authorization"];
91-
internal static readonly string[] AspNetCoreRoutingNamespaceParts = ["Microsoft", "AspNetCore", "Routing"];
92-
internal static readonly string[] ExtensionsDependencyInjectionNamespaceParts = ["Microsoft", "Extensions", "DependencyInjection"];
93-
internal static readonly string[] ComponentModelNamespaceParts = ["System", "ComponentModel"];
75+
9476
private const string BaseNamespace = "Microsoft.AspNetCore.Generated";
9577
private const string AttributesNamespace = $"{BaseNamespace}.Attributes";
9678
private const string RequireAuthorizationAttributeFullyQualifiedName = $"{AttributesNamespace}.{RequireAuthorizationAttributeName}";

src/GeneratedEndpoints/Common/EndpointConfiguration.cs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,5 @@ internal readonly record struct EndpointConfiguration
2727
public required bool WithRequestTimeout { get; init; }
2828
public required string? RequestTimeoutPolicyName { get; init; }
2929
public required int? Order { get; init; }
30-
public required string? GroupIdentifier { get; init; }
31-
public required string? GroupPattern { get; init; }
32-
public required string? GroupName { get; init; }
30+
public required EndpointGroup? Group { get; init; }
3331
}

src/GeneratedEndpoints/Common/EndpointConfigurationFactory.cs

Lines changed: 31 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,10 @@
1-
using System.Runtime.CompilerServices;
21
using Microsoft.CodeAnalysis;
32
using static GeneratedEndpoints.Common.Constants;
43

54
namespace GeneratedEndpoints.Common;
65

76
internal static class EndpointConfigurationFactory
87
{
9-
private static readonly ConditionalWeakTable<INamedTypeSymbol, GeneratedAttributeKindCacheEntry> GeneratedAttributeKindCache = new();
10-
118
public static EndpointConfiguration Create(ISymbol symbol)
129
{
1310
var attributes = symbol.GetAttributes();
@@ -48,7 +45,7 @@ public static EndpointConfiguration Create(ISymbol symbol)
4845
if (attributeClass is null)
4946
continue;
5047

51-
var attributeKind = GetGeneratedAttributeKind(attributeClass);
48+
var attributeKind = attributeClass.OriginalDefinition.GetRequestHandlerAttributeKind();
5249
switch (attributeKind)
5350
{
5451
case RequestHandlerAttributeKind.ShortCircuit:
@@ -172,9 +169,14 @@ public static EndpointConfiguration Create(ISymbol symbol)
172169
WithRequestTimeout = withRequestTimeout ?? false,
173170
RequestTimeoutPolicyName = requestTimeoutPolicyName,
174171
Order = order,
175-
GroupIdentifier = groupIdentifier,
176-
GroupPattern = groupPattern,
177-
GroupName = groupName,
172+
Group = groupIdentifier is not null && groupPattern is not null
173+
? new EndpointGroup
174+
{
175+
Identifier = groupIdentifier,
176+
Pattern = groupPattern,
177+
Name = groupName,
178+
}
179+
: null,
178180
};
179181
}
180182

@@ -198,16 +200,6 @@ public static EndpointConfiguration Create(ISymbol symbol)
198200
return StringBuilderPool.ToStringAndReturn(builder);
199201
}
200202

201-
private static RequestHandlerAttributeKind GetGeneratedAttributeKind(INamedTypeSymbol attributeClass)
202-
{
203-
var definition = attributeClass.OriginalDefinition;
204-
var cacheEntry = GeneratedAttributeKindCache.GetValue(
205-
definition, static def => new GeneratedAttributeKindCacheEntry(def.GetRequestHandlerAttributeKind())
206-
);
207-
208-
return cacheEntry.Kind;
209-
}
210-
211203
private static EquatableImmutableArray<T>? ToEquatableOrNull<T>(List<T>? values)
212204
{
213205
return values is { Count: > 0 } ? values.ToEquatableImmutableArray() : null;
@@ -219,9 +211,11 @@ private static void TryAddAcceptsMetadata(AttributeData attribute, INamedTypeSym
219211
if (attributeClass is { IsGenericType: true, TypeArguments.Length: 1 })
220212
requestType = attributeClass.TypeArguments[0]
221213
.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat);
222-
else if (attribute.GetNamedTypeSymbol(RequestTypeAttributeNamedParameter) is { } requestTypeSymbol)
223-
requestType = requestTypeSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat);
224214
else
215+
requestType = attribute.GetConstructorTypeSymbol()
216+
?.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat);
217+
218+
if (requestType is null)
225219
return;
226220

227221
var contentType = attribute.GetConstructorStringValue() ?? ApplicationJsonContentType;
@@ -236,23 +230,29 @@ private static void TryAddAcceptsMetadata(AttributeData attribute, INamedTypeSym
236230

237231
private static void TryAddProducesMetadata(AttributeData attribute, INamedTypeSymbol attributeClass, ref List<ProducesMetadata>? produces)
238232
{
239-
string? responseType;
233+
ProducesMetadata? producesMetadata;
240234
if (attributeClass is { IsGenericType: true, TypeArguments.Length: 1 })
241-
responseType = attributeClass.TypeArguments[0]
235+
{
236+
var responseType = attributeClass.TypeArguments[0]
242237
.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat);
243-
else if (attribute.GetNamedTypeSymbol(ResponseTypeAttributeNamedParameter) is { } responseTypeSymbol)
244-
responseType = responseTypeSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat);
238+
var statusCode = attribute.GetConstructorIntValue(0) ?? 200;
239+
var contentType = attribute.GetConstructorStringValue(1);
240+
var additionalContentTypes = attribute.GetConstructorStringArray(2);
241+
producesMetadata = new ProducesMetadata(responseType, statusCode, contentType, additionalContentTypes);
242+
}
245243
else
246-
return;
247-
248-
var statusCode = attribute.GetConstructorIntValue() ?? 200;
249-
var contentType = attribute.GetConstructorStringValue(1);
250-
var additionalContentTypes = attribute.GetConstructorStringArray(2);
251-
252-
var producesMetadata = new ProducesMetadata(responseType, statusCode, contentType, additionalContentTypes);
244+
{
245+
var responseType = attribute.GetConstructorTypeSymbol()
246+
?.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)
247+
?? "";
248+
var statusCode = attribute.GetConstructorIntValue(1) ?? 200;
249+
var contentType = attribute.GetConstructorStringValue(2);
250+
var additionalContentTypes = attribute.GetConstructorStringArray(3);
251+
producesMetadata = new ProducesMetadata(responseType, statusCode, contentType, additionalContentTypes);
252+
}
253253

254254
var producesList = produces ??= [];
255-
producesList.Add(producesMetadata);
255+
producesList.Add(producesMetadata.Value);
256256
}
257257

258258
private static void TryAddEndpointFilter(
@@ -291,9 +291,4 @@ private static void TryAddEndpointFilterType(ITypeSymbol? typeSymbol, ref List<s
291291
endpointFilters ??= [];
292292
endpointFilters.Add(displayString);
293293
}
294-
295-
private sealed class GeneratedAttributeKindCacheEntry(RequestHandlerAttributeKind kind)
296-
{
297-
public RequestHandlerAttributeKind Kind { get; } = kind;
298-
}
299294
}
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
namespace GeneratedEndpoints.Common;
2+
3+
internal readonly record struct EndpointGroup
4+
{
5+
public required string Identifier { get; init; }
6+
public required string Pattern { get; init; }
7+
public required string? Name { get; init; }
8+
}

src/GeneratedEndpoints/Common/MethodSymbolExtensions.cs

Lines changed: 1 addition & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
using System.Collections.Immutable;
22
using Microsoft.CodeAnalysis;
3-
using static GeneratedEndpoints.Common.AttributeSymbolMatcher;
43
using static GeneratedEndpoints.Common.Constants;
54

65
namespace GeneratedEndpoints.Common;
@@ -45,7 +44,7 @@ private static string GetBindingPrefix(IParameterSymbol parameter)
4544
if (attributeClass is null)
4645
continue;
4746

48-
var attributeSource = GetBindingSourceFromAttributeClass(attributeClass);
47+
var attributeSource = attributeClass.GetBindingSource();
4948
if (attributeSource == BindingSource.None)
5049
continue;
5150

@@ -69,25 +68,6 @@ private static string GetBindingPrefix(IParameterSymbol parameter)
6968
return bindingPrefix;
7069
}
7170

72-
private static BindingSource GetBindingSourceFromAttributeClass(INamedTypeSymbol attributeClass)
73-
{
74-
var definition = attributeClass.OriginalDefinition;
75-
var namespaceSymbol = definition.ContainingNamespace;
76-
77-
return definition.Name switch
78-
{
79-
"FromRouteAttribute" when IsInNamespace(namespaceSymbol, AspNetCoreMvcNamespaceParts) => BindingSource.FromRoute,
80-
"FromQueryAttribute" when IsInNamespace(namespaceSymbol, AspNetCoreMvcNamespaceParts) => BindingSource.FromQuery,
81-
"FromHeaderAttribute" when IsInNamespace(namespaceSymbol, AspNetCoreMvcNamespaceParts) => BindingSource.FromHeader,
82-
"FromBodyAttribute" when IsInNamespace(namespaceSymbol, AspNetCoreMvcNamespaceParts) => BindingSource.FromBody,
83-
"FromFormAttribute" when IsInNamespace(namespaceSymbol, AspNetCoreMvcNamespaceParts) => BindingSource.FromForm,
84-
"FromServicesAttribute" when IsInNamespace(namespaceSymbol, AspNetCoreMvcNamespaceParts) => BindingSource.FromServices,
85-
"FromKeyedServicesAttribute" when IsInNamespace(namespaceSymbol, ExtensionsDependencyInjectionNamespaceParts) => BindingSource.FromKeyedServices,
86-
"AsParametersAttribute" when IsInNamespace(namespaceSymbol, AspNetCoreHttpNamespaceParts) => BindingSource.AsParameters,
87-
_ => BindingSource.None,
88-
};
89-
}
90-
9171
private static string GetBindingSourceAttribute(BindingSource source, TypedConstant? typedKey, string? bindingName)
9272
{
9373
switch (source)

src/GeneratedEndpoints/Common/NamedTypeSymbolExtensions.cs

Lines changed: 0 additions & 51 deletions
This file was deleted.

0 commit comments

Comments
 (0)