diff --git a/src/Framework/AspNetCoreAnalyzers/src/Analyzers/RouteEmbeddedLanguage/FrameworkParametersCompletionProvider.cs b/src/Framework/AspNetCoreAnalyzers/src/Analyzers/RouteEmbeddedLanguage/FrameworkParametersCompletionProvider.cs index cbf3846cc212..3983dd27d1a1 100644 --- a/src/Framework/AspNetCoreAnalyzers/src/Analyzers/RouteEmbeddedLanguage/FrameworkParametersCompletionProvider.cs +++ b/src/Framework/AspNetCoreAnalyzers/src/Analyzers/RouteEmbeddedLanguage/FrameworkParametersCompletionProvider.cs @@ -131,7 +131,7 @@ public override async Task ProvideCompletionsAsync(CompletionContext context) // Don't offer route parameter names when the parameter type can't be bound to route parameters. // e.g. special types like HttpContext, non-primitive types that don't have a static TryParse method. - if (!IsCurrentParameterBindable(token, semanticModel, wellKnownTypes, context.CancellationToken)) + if (!IsCurrentParameterBindable(token, semanticModel, context.CancellationToken)) { return; } @@ -383,7 +383,7 @@ private static bool HasNonRouteAttribute(SyntaxToken token, SemanticModel semant return false; } - private static bool IsCurrentParameterBindable(SyntaxToken token, SemanticModel semanticModel, WellKnownTypes wellKnownTypes, CancellationToken cancellationToken) + private static bool IsCurrentParameterBindable(SyntaxToken token, SemanticModel semanticModel, CancellationToken cancellationToken) { if (token.Parent.IsKind(SyntaxKind.PredefinedType)) { @@ -393,7 +393,7 @@ private static bool IsCurrentParameterBindable(SyntaxToken token, SemanticModel var parameterTypeSymbol = semanticModel.GetSymbolInfo(token.Parent!, cancellationToken).GetAnySymbol(); if (parameterTypeSymbol is INamedTypeSymbol typeSymbol) { - return ParsabilityHelper.GetParsability(typeSymbol, wellKnownTypes) == Parsability.Parsable; + return ParsabilityHelper.GetParsability(typeSymbol) == Parsability.Parsable; } else if (parameterTypeSymbol is IMethodSymbol) diff --git a/src/Framework/AspNetCoreAnalyzers/src/Analyzers/RouteHandlers/DisallowNonParsableComplexTypesOnParameters.cs b/src/Framework/AspNetCoreAnalyzers/src/Analyzers/RouteHandlers/DisallowNonParsableComplexTypesOnParameters.cs index 1d5f8550a1a5..b1694a4ad4ec 100644 --- a/src/Framework/AspNetCoreAnalyzers/src/Analyzers/RouteHandlers/DisallowNonParsableComplexTypesOnParameters.cs +++ b/src/Framework/AspNetCoreAnalyzers/src/Analyzers/RouteHandlers/DisallowNonParsableComplexTypesOnParameters.cs @@ -67,7 +67,7 @@ private static void DisallowNonParsableComplexTypesOnParameters( if (IsRouteParameter(routeUsage, handlerDelegateParameter)) { - var parsability = ParsabilityHelper.GetParsability(parameterTypeSymbol, wellKnownTypes); + var parsability = ParsabilityHelper.GetParsability(parameterTypeSymbol); if (parsability != Parsability.Parsable) { @@ -97,7 +97,7 @@ static bool IsRouteParameter(RouteUsageModel routeUsage, IParameterSymbol handle static bool ReportFromAttributeDiagnostic(OperationAnalysisContext context, WellKnownType fromMetadataInterfaceType, WellKnownTypes wellKnownTypes, IParameterSymbol parameter, INamedTypeSymbol parameterTypeSymbol, Location location) { var fromMetadataInterfaceTypeSymbol = wellKnownTypes.Get(fromMetadataInterfaceType); - var parsability = ParsabilityHelper.GetParsability(parameterTypeSymbol, wellKnownTypes); + var parsability = ParsabilityHelper.GetParsability(parameterTypeSymbol); if (parameter.HasAttributeImplementingInterface(fromMetadataInterfaceTypeSymbol) && parsability != Parsability.Parsable) { context.ReportDiagnostic(Diagnostic.Create( diff --git a/src/Http/Http.Extensions/gen/RequestDelegateGenerator.cs b/src/Http/Http.Extensions/gen/RequestDelegateGenerator.cs index 9b904e9d12dc..3e1c65fa1a55 100644 --- a/src/Http/Http.Extensions/gen/RequestDelegateGenerator.cs +++ b/src/Http/Http.Extensions/gen/RequestDelegateGenerator.cs @@ -5,14 +5,11 @@ using System.Globalization; using System.IO; using System.Linq; -using System.Text; using Microsoft.AspNetCore.Analyzers.Infrastructure; -using Microsoft.AspNetCore.App.Analyzers.Infrastructure; using Microsoft.AspNetCore.Http.RequestDelegateGenerator.StaticRouteHandlerModel; using Microsoft.AspNetCore.Http.RequestDelegateGenerator.StaticRouteHandlerModel.Emitters; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp; -using Microsoft.CodeAnalysis.Operations; namespace Microsoft.AspNetCore.Http.RequestDelegateGenerator; @@ -26,10 +23,9 @@ public void Initialize(IncrementalGeneratorInitializationContext context) transform: static (context, token) => { var operation = context.SemanticModel.GetOperation(context.Node, token); - var wellKnownTypes = WellKnownTypes.GetOrCreate(context.SemanticModel.Compilation); - if (operation.IsValidOperation(wellKnownTypes, out var invocationOperation)) + if (operation.IsValidOperation(out var invocationOperation)) { - return new Endpoint(invocationOperation, wellKnownTypes, context.SemanticModel); + return new Endpoint(invocationOperation, context.SemanticModel); } return null; }) diff --git a/src/Http/Http.Extensions/gen/RequestDelegateGeneratorSuppressor.cs b/src/Http/Http.Extensions/gen/RequestDelegateGeneratorSuppressor.cs index cbde9e7e4373..623cca67e2d0 100644 --- a/src/Http/Http.Extensions/gen/RequestDelegateGeneratorSuppressor.cs +++ b/src/Http/Http.Extensions/gen/RequestDelegateGeneratorSuppressor.cs @@ -56,10 +56,9 @@ public override void ReportSuppressions(SuppressionAnalysisContext context) var semanticModel = context.GetSemanticModel(sourceTree); var operation = semanticModel.GetOperation(node, context.CancellationToken); - var wellKnownTypes = WellKnownTypes.GetOrCreate(semanticModel.Compilation); - if (operation.IsValidOperation(wellKnownTypes, out var invocationOperation)) + if (operation.IsValidOperation(out var invocationOperation)) { - var endpoint = new Endpoint(invocationOperation, wellKnownTypes, semanticModel); + var endpoint = new Endpoint(invocationOperation, semanticModel); if (endpoint.Diagnostics.Count == 0) { var targetSuppression = diagnostic.Id == SuppressRUCDiagnostic.SuppressedDiagnosticId diff --git a/src/Http/Http.Extensions/gen/StaticRouteHandlerModel/Endpoint.cs b/src/Http/Http.Extensions/gen/StaticRouteHandlerModel/Endpoint.cs index b16be0047a2d..8b3d9da5b4f3 100644 --- a/src/Http/Http.Extensions/gen/StaticRouteHandlerModel/Endpoint.cs +++ b/src/Http/Http.Extensions/gen/StaticRouteHandlerModel/Endpoint.cs @@ -15,7 +15,7 @@ namespace Microsoft.AspNetCore.Http.RequestDelegateGenerator.StaticRouteHandlerM internal class Endpoint { - public Endpoint(IInvocationOperation operation, WellKnownTypes wellKnownTypes, SemanticModel semanticModel) + public Endpoint(IInvocationOperation operation, SemanticModel semanticModel) { Operation = operation; Location = GetLocation(operation); @@ -28,7 +28,7 @@ public Endpoint(IInvocationOperation operation, WellKnownTypes wellKnownTypes, S return; } - Response = new EndpointResponse(method, wellKnownTypes); + Response = new EndpointResponse(method); Response.EmitRequiredDiagnostics(Diagnostics, Operation.Syntax.GetLocation()); IsAwaitable = Response?.IsAwaitable == true; @@ -56,7 +56,7 @@ public Endpoint(IInvocationOperation operation, WellKnownTypes wellKnownTypes, S { continue; } - var parameter = new EndpointParameter(this, parameterSymbol, wellKnownTypes); + var parameter = new EndpointParameter(this, parameterSymbol); switch (parameter.Source) { diff --git a/src/Http/Http.Extensions/gen/StaticRouteHandlerModel/EndpointParameter.cs b/src/Http/Http.Extensions/gen/StaticRouteHandlerModel/EndpointParameter.cs index 5bc8b8471113..8668a1252e74 100644 --- a/src/Http/Http.Extensions/gen/StaticRouteHandlerModel/EndpointParameter.cs +++ b/src/Http/Http.Extensions/gen/StaticRouteHandlerModel/EndpointParameter.cs @@ -9,26 +9,24 @@ using System.Text; using Microsoft.AspNetCore.Analyzers.Infrastructure; using Microsoft.AspNetCore.Analyzers.RouteEmbeddedLanguage.Infrastructure; -using Microsoft.AspNetCore.App.Analyzers.Infrastructure; using Microsoft.AspNetCore.Http.RequestDelegateGenerator.StaticRouteHandlerModel.Emitters; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp; -using WellKnownType = Microsoft.AspNetCore.App.Analyzers.Infrastructure.WellKnownTypeData.WellKnownType; namespace Microsoft.AspNetCore.Http.RequestDelegateGenerator.StaticRouteHandlerModel; internal class EndpointParameter { - public EndpointParameter(Endpoint endpoint, IParameterSymbol parameter, WellKnownTypes wellKnownTypes): this(endpoint, parameter.Type, parameter.Name, wellKnownTypes) + public EndpointParameter(Endpoint endpoint, IParameterSymbol parameter): this(endpoint, parameter.Type, parameter.Name) { Ordinal = parameter.Ordinal; IsOptional = parameter.IsOptional(); HasDefaultValue = parameter.HasExplicitDefaultValue; DefaultValue = parameter.GetDefaultValueString(); - ProcessEndpointParameterSource(endpoint, parameter, parameter.GetAttributes(), wellKnownTypes); + ProcessEndpointParameterSource(endpoint, parameter, parameter.GetAttributes()); } - private EndpointParameter(Endpoint endpoint, IPropertySymbol property, IParameterSymbol? parameter, WellKnownTypes wellKnownTypes) : this(endpoint, property.Type, property.Name, wellKnownTypes) + private EndpointParameter(Endpoint endpoint, IPropertySymbol property, IParameterSymbol? parameter) : this(endpoint, property.Type, property.Name) { Ordinal = parameter?.Ordinal ?? 0; IsProperty = true; @@ -48,10 +46,10 @@ private EndpointParameter(Endpoint endpoint, IPropertySymbol property, IParamete ? $"new PropertyAsParameterInfo({(IsOptional ? "true" : "false")}, {propertyInfo}, {parameter.GetParameterInfoFromConstructorCode()})" : $"new PropertyAsParameterInfo({(IsOptional ? "true" : "false")}, {propertyInfo})"; endpoint.EmitterContext.RequiresPropertyAsParameterInfo = IsProperty && IsEndpointParameterMetadataProvider; - ProcessEndpointParameterSource(endpoint, property, attributeBuilder.ToImmutable(), wellKnownTypes); + ProcessEndpointParameterSource(endpoint, property, attributeBuilder.ToImmutable()); } - private EndpointParameter(Endpoint endpoint, ITypeSymbol typeSymbol, string parameterName, WellKnownTypes wellKnownTypes) + private EndpointParameter(Endpoint endpoint, ITypeSymbol typeSymbol, string parameterName) { Type = typeSymbol; SymbolName = parameterName; @@ -59,51 +57,51 @@ private EndpointParameter(Endpoint endpoint, ITypeSymbol typeSymbol, string para Source = EndpointParameterSource.Unknown; IsArray = TryGetArrayElementType(typeSymbol, out var elementType); ElementType = elementType; - IsEndpointMetadataProvider = ImplementsIEndpointMetadataProvider(typeSymbol, wellKnownTypes); - IsEndpointParameterMetadataProvider = ImplementsIEndpointParameterMetadataProvider(typeSymbol, wellKnownTypes); + IsEndpointMetadataProvider = ImplementsIEndpointMetadataProvider(typeSymbol); + IsEndpointParameterMetadataProvider = ImplementsIEndpointParameterMetadataProvider(typeSymbol); endpoint.EmitterContext.HasEndpointParameterMetadataProvider |= IsEndpointParameterMetadataProvider; endpoint.EmitterContext.HasEndpointMetadataProvider |= IsEndpointMetadataProvider; } - private void ProcessEndpointParameterSource(Endpoint endpoint, ISymbol symbol, ImmutableArray attributes, WellKnownTypes wellKnownTypes) + private void ProcessEndpointParameterSource(Endpoint endpoint, ISymbol symbol, ImmutableArray attributes) { - if (attributes.TryGetAttributeImplementingInterface(wellKnownTypes.Get(WellKnownType.Microsoft_AspNetCore_Http_Metadata_IFromRouteMetadata), out var fromRouteAttribute)) + if (attributes.TryGetAttributeImplementingInterface(["Microsoft", "AspNetCore", "Http", "Metadata", "IFromRouteMetadata"], out var fromRouteAttribute)) { Source = EndpointParameterSource.Route; LookupName = GetEscapedParameterName(fromRouteAttribute, symbol.Name); - IsParsable = TryGetParsability(Type, wellKnownTypes, out var parsingBlockEmitter); + IsParsable = TryGetParsability(Type, out var parsingBlockEmitter); ParsingBlockEmitter = parsingBlockEmitter; } - else if (attributes.TryGetAttributeImplementingInterface(wellKnownTypes.Get(WellKnownType.Microsoft_AspNetCore_Http_Metadata_IFromQueryMetadata), out var fromQueryAttribute)) + else if (attributes.TryGetAttributeImplementingInterface(["Microsoft", "AspNetCore", "Http", "Metadata", "IFromQueryMetadata"], out var fromQueryAttribute)) { Source = EndpointParameterSource.Query; LookupName = GetEscapedParameterName(fromQueryAttribute, symbol.Name); - IsParsable = TryGetParsability(Type, wellKnownTypes, out var parsingBlockEmitter); + IsParsable = TryGetParsability(Type, out var parsingBlockEmitter); ParsingBlockEmitter = parsingBlockEmitter; } - else if (attributes.TryGetAttributeImplementingInterface(wellKnownTypes.Get(WellKnownType.Microsoft_AspNetCore_Http_Metadata_IFromHeaderMetadata), out var fromHeaderAttribute)) + else if (attributes.TryGetAttributeImplementingInterface(["Microsoft", "AspNetCore", "Http", "Metadata", "IFromHeaderMetadata"], out var fromHeaderAttribute)) { Source = EndpointParameterSource.Header; LookupName = GetEscapedParameterName(fromHeaderAttribute, symbol.Name); - IsParsable = TryGetParsability(Type, wellKnownTypes, out var parsingBlockEmitter); + IsParsable = TryGetParsability(Type, out var parsingBlockEmitter); ParsingBlockEmitter = parsingBlockEmitter; } - else if (attributes.TryGetAttributeImplementingInterface(wellKnownTypes.Get(WellKnownType.Microsoft_AspNetCore_Http_Metadata_IFromFormMetadata), out var fromFormAttribute)) + else if (attributes.TryGetAttributeImplementingInterface(["Microsoft", "AspNetCore", "Http", "Metadata", "IFromFormMetadata"], out var fromFormAttribute)) { endpoint.IsAwaitable = true; Source = EndpointParameterSource.FormBody; LookupName = GetEscapedParameterName(fromFormAttribute, symbol.Name); - if (SymbolEqualityComparer.Default.Equals(Type, wellKnownTypes.Get(WellKnownType.Microsoft_AspNetCore_Http_IFormFileCollection))) + if (Type.EqualsByName(["Microsoft", "AspNetCore", "Http", "IFormFileCollection"])) { IsFormFile = true; AssigningCode = "httpContext.Request.Form.Files"; } - else if (SymbolEqualityComparer.Default.Equals(Type, wellKnownTypes.Get(WellKnownType.Microsoft_AspNetCore_Http_IFormFile))) + else if (Type.EqualsByName(["Microsoft", "AspNetCore", "Http", "IFormFile"])) { IsFormFile = true; AssigningCode = $"httpContext.Request.Form.Files[{SymbolDisplay.FormatLiteral(LookupName, true)}]"; } - else if (SymbolEqualityComparer.Default.Equals(Type, wellKnownTypes.Get(WellKnownType.Microsoft_AspNetCore_Http_IFormCollection))) + else if (Type.EqualsByName(["Microsoft", "AspNetCore", "Http", "IFormCollection"])) { AssigningCode = "httpContext.Request.Form"; } @@ -112,18 +110,18 @@ private void ProcessEndpointParameterSource(Endpoint endpoint, ISymbol symbol, I AssigningCode = !IsArray ? $"(string?)httpContext.Request.Form[{SymbolDisplay.FormatLiteral(LookupName, true)}]" : $"httpContext.Request.Form[{SymbolDisplay.FormatLiteral(LookupName, true)}].ToArray()"; - IsParsable = TryGetParsability(Type, wellKnownTypes, out var parsingBlockEmitter); + IsParsable = TryGetParsability(Type, out var parsingBlockEmitter); ParsingBlockEmitter = parsingBlockEmitter; } } - else if (TryGetExplicitFromJsonBody(symbol, attributes, wellKnownTypes, out var isOptional)) + else if (TryGetExplicitFromJsonBody(symbol, attributes, out var isOptional)) { - if (SymbolEqualityComparer.Default.Equals(Type, wellKnownTypes.Get(WellKnownType.System_IO_Stream))) + if (Type.EqualsByName(["System", "IO", "Stream"])) { Source = EndpointParameterSource.SpecialType; AssigningCode = "httpContext.Request.Body"; } - else if (SymbolEqualityComparer.Default.Equals(Type, wellKnownTypes.Get(WellKnownType.System_IO_Pipelines_PipeReader))) + else if (Type.EqualsByName(["System", "IO", "Pipelines", "PipeReader"])) { Source = EndpointParameterSource.SpecialType; AssigningCode = "httpContext.Request.BodyReader"; @@ -135,22 +133,22 @@ private void ProcessEndpointParameterSource(Endpoint endpoint, ISymbol symbol, I } IsOptional = isOptional; } - else if (attributes.HasAttributeImplementingInterface(wellKnownTypes.Get(WellKnownType.Microsoft_AspNetCore_Http_Metadata_IFromServiceMetadata))) + else if (attributes.HasAttributeImplementingInterface(["Microsoft", "AspNetCore", "Http", "Metadata", "IFromServiceMetadata"])) { Source = EndpointParameterSource.Service; - if (attributes.TryGetAttribute(wellKnownTypes.Get(WellKnownType.Microsoft_Extensions_DependencyInjection_FromKeyedServicesAttribute), out var keyedServicesAttribute)) + if (attributes.TryGetAttribute(["Microsoft", "Extensions", "DependencyInjection", "FromKeyedServicesAttribute"], out var keyedServicesAttribute)) { var location = endpoint.Operation.Syntax.GetLocation(); endpoint.Diagnostics.Add(Diagnostic.Create(DiagnosticDescriptors.KeyedAndNotKeyedServiceAttributesNotSupported, location)); } } - else if (attributes.TryGetAttribute(wellKnownTypes.Get(WellKnownType.Microsoft_Extensions_DependencyInjection_FromKeyedServicesAttribute), out var keyedServicesAttribute)) + else if (attributes.TryGetAttribute(["Microsoft", "Extensions", "DependencyInjection", "FromKeyedServicesAttribute"], out var keyedServicesAttribute)) { Source = EndpointParameterSource.KeyedService; var constructorArgument = keyedServicesAttribute.ConstructorArguments.FirstOrDefault(); KeyedServiceKey = SymbolDisplay.FormatPrimitive(constructorArgument.Value!, true, true); } - else if (attributes.HasAttribute(wellKnownTypes.Get(WellKnownType.Microsoft_AspNetCore_Http_AsParametersAttribute))) + else if (attributes.HasAttribute(["Microsoft", "AspNetCore", "Http", "AsParametersAttribute"])) { Source = EndpointParameterSource.AsParameters; var location = endpoint.Operation.Syntax.GetLocation(); @@ -168,7 +166,7 @@ Type is not INamedTypeSymbol namedTypeSymbol || } return; } - EndpointParameters = matchedProperties.Select(matchedParameter => new EndpointParameter(endpoint, matchedParameter.Property, matchedParameter.Parameter, wellKnownTypes)); + EndpointParameters = matchedProperties.Select(matchedParameter => new EndpointParameter(endpoint, matchedParameter.Property, matchedParameter.Parameter)); if (isDefaultConstructor == true) { var parameterList = string.Join(", ", EndpointParameters.Select(p => $"{p.LookupName} = {p.EmitHandlerArgument()}")); @@ -180,33 +178,33 @@ Type is not INamedTypeSymbol namedTypeSymbol || AssigningCode = $"new {namedTypeSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)}({parameterList})"; } } - else if (TryGetSpecialTypeAssigningCode(Type, wellKnownTypes, out var specialTypeAssigningCode)) + else if (TryGetSpecialTypeAssigningCode(Type, out var specialTypeAssigningCode)) { Source = EndpointParameterSource.SpecialType; AssigningCode = specialTypeAssigningCode; } - else if (SymbolEqualityComparer.Default.Equals(Type, wellKnownTypes.Get(WellKnownType.Microsoft_AspNetCore_Http_IFormFileCollection))) + else if (Type.EqualsByName(["Microsoft", "AspNetCore", "Http", "IFormFileCollection"])) { endpoint.IsAwaitable = true; Source = EndpointParameterSource.FormBody; IsFormFile = true; AssigningCode = "httpContext.Request.Form.Files"; } - else if (SymbolEqualityComparer.Default.Equals(Type, wellKnownTypes.Get(WellKnownType.Microsoft_AspNetCore_Http_IFormFile))) + else if (Type.EqualsByName(["Microsoft", "AspNetCore", "Http", "IFormFile"])) { endpoint.IsAwaitable = true; Source = EndpointParameterSource.FormBody; IsFormFile = true; AssigningCode = $"httpContext.Request.Form.Files[{SymbolDisplay.FormatLiteral(LookupName, true)}]"; } - else if (SymbolEqualityComparer.Default.Equals(Type, wellKnownTypes.Get(WellKnownType.Microsoft_AspNetCore_Http_IFormCollection))) + else if (Type.EqualsByName(["Microsoft", "AspNetCore", "Http", "IFormCollection"])) { endpoint.IsAwaitable = true; Source = EndpointParameterSource.FormBody; LookupName = symbol.Name; AssigningCode = "httpContext.Request.Form"; } - else if (HasBindAsync(Type, wellKnownTypes, out var bindMethod, out var bindMethodSymbol)) + else if (HasBindAsync(Type, out var bindMethod, out var bindMethodSymbol)) { endpoint.IsAwaitable = true; endpoint.EmitterContext.RequiresPropertyAsParameterInfo = IsProperty && bindMethod is BindabilityMethod.BindAsyncWithParameter or BindabilityMethod.IBindableFromHttpContext; @@ -223,12 +221,12 @@ Type is not INamedTypeSymbol namedTypeSymbol || endpoint.IsAwaitable = true; Source = EndpointParameterSource.JsonBodyOrQuery; } - else if (SymbolEqualityComparer.Default.Equals(Type, wellKnownTypes.Get(WellKnownType.Microsoft_Extensions_Primitives_StringValues))) + else if (Type.EqualsByName(["Microsoft", "Extensions", "Primitives", "StringValues"])) { Source = EndpointParameterSource.Query; IsStringValues = true; } - else if (TryGetParsability(Type, wellKnownTypes, out var parsingBlockEmitter)) + else if (TryGetParsability(Type, out var parsingBlockEmitter)) { Source = EndpointParameterSource.RouteOrQuery; IsParsable = true; @@ -248,11 +246,11 @@ Type is not INamedTypeSymbol namedTypeSymbol || endpoint.EmitterContext.HasJsonBodyOrQuery |= Source == EndpointParameterSource.JsonBodyOrQuery; } - private static bool ImplementsIEndpointMetadataProvider(ITypeSymbol type, WellKnownTypes wellKnownTypes) - => type.Implements(wellKnownTypes.Get(WellKnownType.Microsoft_AspNetCore_Http_Metadata_IEndpointMetadataProvider)); + private static bool ImplementsIEndpointMetadataProvider(ITypeSymbol type) + => type.Implements(["Microsoft", "AspNetCore", "Http", "Metadata", "IEndpointMetadataProvider"]); - private static bool ImplementsIEndpointParameterMetadataProvider(ITypeSymbol type, WellKnownTypes wellKnownTypes) - => type.Implements(wellKnownTypes.Get(WellKnownType.Microsoft_AspNetCore_Http_Metadata_IEndpointParameterMetadataProvider)); + private static bool ImplementsIEndpointParameterMetadataProvider(ITypeSymbol type) + => type.Implements(["Microsoft", "AspNetCore", "Http", "Metadata", "IEndpointParameterMetadataProvider"]); public ITypeSymbol Type { get; } public ITypeSymbol ElementType { get; } @@ -285,13 +283,13 @@ private static bool ImplementsIEndpointParameterMetadataProvider(ITypeSymbol typ public BindabilityMethod? BindMethod { get; set; } public IMethodSymbol? BindableMethodSymbol { get; set; } - private static bool HasBindAsync(ITypeSymbol typeSymbol, WellKnownTypes wellKnownTypes, [NotNullWhen(true)] out BindabilityMethod? bindMethod, [NotNullWhen(true)] out IMethodSymbol? bindMethodSymbol) + private static bool HasBindAsync(ITypeSymbol typeSymbol, [NotNullWhen(true)] out BindabilityMethod? bindMethod, [NotNullWhen(true)] out IMethodSymbol? bindMethodSymbol) { var parameterType = typeSymbol.UnwrapTypeSymbol(unwrapArray: true, unwrapNullable: true); - return ParsabilityHelper.GetBindability(parameterType, wellKnownTypes, out bindMethod, out bindMethodSymbol) == Bindability.Bindable; + return ParsabilityHelper.GetBindability(parameterType, out bindMethod, out bindMethodSymbol) == Bindability.Bindable; } - private static bool TryGetArrayElementType(ITypeSymbol type, [NotNullWhen(true)]out ITypeSymbol elementType) + private static bool TryGetArrayElementType(ITypeSymbol type, [NotNullWhen(true)] out ITypeSymbol elementType) { if (type.TypeKind == TypeKind.Array) { @@ -305,7 +303,7 @@ private static bool TryGetArrayElementType(ITypeSymbol type, [NotNullWhen(true)] } } - private bool TryGetParsability(ITypeSymbol typeSymbol, WellKnownTypes wellKnownTypes, [NotNullWhen(true)] out Action? parsingBlockEmitter) + private bool TryGetParsability(ITypeSymbol typeSymbol, [NotNullWhen(true)] out Action? parsingBlockEmitter) { var parameterType = typeSymbol.UnwrapTypeSymbol(unwrapArray: true, unwrapNullable: true); @@ -314,7 +312,7 @@ private bool TryGetParsability(ITypeSymbol typeSymbol, WellKnownTypes wellKnownT // support usage in the code generator an optional out parameter has been added to hint at what variant of the various // TryParse methods should be used (this implies that the preferences are baked into ParsabilityHelper). If we aren't // parsable at all we bail. - if (ParsabilityHelper.GetParsability(parameterType, wellKnownTypes, out var parsabilityMethod) != Parsability.Parsable) + if (ParsabilityHelper.GetParsability(parameterType, out var parsabilityMethod) != Parsability.Parsable) { parsingBlockEmitter = null; return false; @@ -343,11 +341,11 @@ private bool TryGetParsability(ITypeSymbol typeSymbol, WellKnownTypes wellKnownT { preferredTryParseInvocation = (string inputArgument, string outputArgument) => $$"""{{parameterType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)}}.TryParse({{inputArgument}}!, CultureInfo.InvariantCulture, DateTimeStyles.AdjustToUniversal | DateTimeStyles.AllowWhiteSpaces, out var {{outputArgument}})"""; } - else if (SymbolEqualityComparer.Default.Equals(parameterType, wellKnownTypes.Get(WellKnownType.System_DateTimeOffset))) + else if (parameterType.EqualsByName(["System", "DateTimeOffset"])) { preferredTryParseInvocation = (string inputArgument, string outputArgument) => $$"""{{parameterType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)}}.TryParse({{inputArgument}}!, CultureInfo.InvariantCulture, DateTimeStyles.AssumeUniversal | DateTimeStyles.AllowWhiteSpaces, out var {{outputArgument}})"""; } - else if (SymbolEqualityComparer.Default.Equals(parameterType, wellKnownTypes.Get(WellKnownType.System_DateOnly))) + else if (parameterType.EqualsByName(["System", "DateOnly"])) { preferredTryParseInvocation = (string inputArgument, string outputArgument) => $$"""{{parameterType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)}}.TryParse({{inputArgument}}!, CultureInfo.InvariantCulture, DateTimeStyles.AllowWhiteSpaces, out var {{outputArgument}})"""; } @@ -411,40 +409,40 @@ private bool TryGetParsability(ITypeSymbol typeSymbol, WellKnownTypes wellKnownT return true; } - private static bool TryGetSpecialTypeAssigningCode(ITypeSymbol type, WellKnownTypes wellKnownTypes, [NotNullWhen(true)] out string? callingCode) + private static bool TryGetSpecialTypeAssigningCode(ITypeSymbol type, [NotNullWhen(true)] out string? callingCode) { callingCode = null; - if (SymbolEqualityComparer.Default.Equals(type, wellKnownTypes.Get(WellKnownType.Microsoft_AspNetCore_Http_HttpContext))) + if (type.EqualsByName(["Microsoft", "AspNetCore", "Http", "HttpContext"])) { callingCode = "httpContext"; return true; } - if (SymbolEqualityComparer.Default.Equals(type, wellKnownTypes.Get(WellKnownType.Microsoft_AspNetCore_Http_HttpRequest))) + if (type.EqualsByName(["Microsoft", "AspNetCore", "Http", "HttpRequest"])) { callingCode = "httpContext.Request"; return true; } - if (SymbolEqualityComparer.Default.Equals(type, wellKnownTypes.Get(WellKnownType.Microsoft_AspNetCore_Http_HttpResponse))) + if (type.EqualsByName(["Microsoft", "AspNetCore", "Http", "HttpResponse"])) { callingCode = "httpContext.Response"; return true; } - if (SymbolEqualityComparer.Default.Equals(type, wellKnownTypes.Get(WellKnownType.System_IO_Pipelines_PipeReader))) + if (type.EqualsByName(["System", "IO", "Pipelines", "PipeReader"])) { callingCode = "httpContext.Request.BodyReader"; return true; } - if (SymbolEqualityComparer.Default.Equals(type, wellKnownTypes.Get(WellKnownType.System_IO_Stream))) + if (type.EqualsByName(["System", "IO", "Stream"])) { callingCode = "httpContext.Request.Body"; return true; } - if (SymbolEqualityComparer.Default.Equals(type, wellKnownTypes.Get(WellKnownType.System_Security_Claims_ClaimsPrincipal))) + if (type.EqualsByName(["System", "Security", "Claims", "ClaimsPrincipal"])) { callingCode = "httpContext.User"; return true; } - if (SymbolEqualityComparer.Default.Equals(type, wellKnownTypes.Get(WellKnownType.System_Threading_CancellationToken))) + if (type.EqualsByName(["System", "Threading", "CancellationToken"])) { callingCode = "httpContext.RequestAborted"; return true; @@ -455,11 +453,10 @@ private static bool TryGetSpecialTypeAssigningCode(ITypeSymbol type, WellKnownTy private static bool TryGetExplicitFromJsonBody(ISymbol typeSymbol, ImmutableArray attributes, - WellKnownTypes wellKnownTypes, out bool isOptional) { isOptional = false; - if (!attributes.TryGetAttributeImplementingInterface(wellKnownTypes.Get(WellKnownType.Microsoft_AspNetCore_Http_Metadata_IFromBodyMetadata), out var fromBodyAttribute)) + if (!attributes.TryGetAttributeImplementingInterface(["Microsoft", "AspNetCore", "Http", "Metadata", "IFromBodyMetadata"], out var fromBodyAttribute)) { return false; } diff --git a/src/Http/Http.Extensions/gen/StaticRouteHandlerModel/EndpointResponse.cs b/src/Http/Http.Extensions/gen/StaticRouteHandlerModel/EndpointResponse.cs index 881903f1ea0f..11028f47d76e 100644 --- a/src/Http/Http.Extensions/gen/StaticRouteHandlerModel/EndpointResponse.cs +++ b/src/Http/Http.Extensions/gen/StaticRouteHandlerModel/EndpointResponse.cs @@ -3,14 +3,11 @@ using System; using Microsoft.AspNetCore.Analyzers.RouteEmbeddedLanguage.Infrastructure; -using Microsoft.AspNetCore.App.Analyzers.Infrastructure; using Microsoft.AspNetCore.Http.RequestDelegateGenerator.StaticRouteHandlerModel.Emitters; using Microsoft.CodeAnalysis; namespace Microsoft.AspNetCore.Http.RequestDelegateGenerator.StaticRouteHandlerModel; -using WellKnownType = WellKnownTypeData.WellKnownType; - internal class EndpointResponse { public ITypeSymbol? ResponseType { get; set; } @@ -21,11 +18,9 @@ internal class EndpointResponse public bool IsIResult { get; set; } public bool IsSerializable { get; set; } public bool IsEndpointMetadataProvider { get; set; } - private WellKnownTypes WellKnownTypes { get; init; } - internal EndpointResponse(IMethodSymbol method, WellKnownTypes wellKnownTypes) + internal EndpointResponse(IMethodSymbol method) { - WellKnownTypes = wellKnownTypes; ResponseType = UnwrapResponseType(method, out bool isAwaitable, out bool awaitableIsVoid); WrappedResponseType = method.ReturnType.ToDisplayString(EmitterConstants.DisplayFormat); IsAwaitable = isAwaitable; @@ -33,35 +28,25 @@ internal EndpointResponse(IMethodSymbol method, WellKnownTypes wellKnownTypes) IsIResult = GetIsIResult(); IsSerializable = GetIsSerializable(); ContentType = GetContentType(); - IsEndpointMetadataProvider = ImplementsIEndpointMetadataProvider(ResponseType, wellKnownTypes); + IsEndpointMetadataProvider = ImplementsIEndpointMetadataProvider(ResponseType); } - private static bool ImplementsIEndpointMetadataProvider(ITypeSymbol? responseType, WellKnownTypes wellKnownTypes) - => responseType == null ? false : responseType.Implements(wellKnownTypes.Get(WellKnownType.Microsoft_AspNetCore_Http_Metadata_IEndpointMetadataProvider)); + private static bool ImplementsIEndpointMetadataProvider(ITypeSymbol? responseType) + => responseType == null ? false : responseType.Implements(["Microsoft", "AspNetCore", "Http", "Metadata", "IEndpointMetadataProvider"]); private ITypeSymbol? UnwrapResponseType(IMethodSymbol method, out bool isAwaitable, out bool awaitableIsVoid) { isAwaitable = false; awaitableIsVoid = false; var returnType = method.ReturnType; - var task = WellKnownTypes.Get(WellKnownType.System_Threading_Tasks_Task); - var taskOfT = WellKnownTypes.Get(WellKnownType.System_Threading_Tasks_Task_T); - var valueTask = WellKnownTypes.Get(WellKnownType.System_Threading_Tasks_ValueTask); - var valueTaskOfT = WellKnownTypes.Get(WellKnownType.System_Threading_Tasks_ValueTask_T); - if (returnType.OriginalDefinition.Equals(taskOfT, SymbolEqualityComparer.Default) || - returnType.OriginalDefinition.Equals(valueTaskOfT, SymbolEqualityComparer.Default)) - { - isAwaitable = true; - awaitableIsVoid = false; - return ((INamedTypeSymbol)returnType).TypeArguments[0]; - } - - if (returnType.OriginalDefinition.Equals(task, SymbolEqualityComparer.Default) || - returnType.OriginalDefinition.Equals(valueTask, SymbolEqualityComparer.Default)) + if (returnType.OriginalDefinition.EqualsByName(["System", "Threading", "Tasks", "Task"]) || + returnType.OriginalDefinition.EqualsByName(["System", "Threading", "Tasks", "ValueTask"])) { isAwaitable = true; - awaitableIsVoid = true; - return null; + awaitableIsVoid = returnType is INamedTypeSymbol { IsGenericType: false }; + return returnType is INamedTypeSymbol { IsGenericType: true } namedReturnType + ? namedReturnType.TypeArguments[0] + : null; } return returnType; @@ -76,9 +61,8 @@ private bool GetIsSerializable() => private bool GetIsIResult() { - var resultType = WellKnownTypes.Get(WellKnownType.Microsoft_AspNetCore_Http_IResult); - return WellKnownTypes.Implements(ResponseType, resultType) || - SymbolEqualityComparer.Default.Equals(ResponseType, resultType); + return ResponseType is not null && + (ResponseType.Implements(["Microsoft", "AspNetCore", "Http", "IResult"]) || ResponseType.EqualsByName(["Microsoft", "AspNetCore", "Http", "IResult"])); } private string? GetContentType() diff --git a/src/Http/Http.Extensions/gen/StaticRouteHandlerModel/InvocationOperationExtensions.cs b/src/Http/Http.Extensions/gen/StaticRouteHandlerModel/InvocationOperationExtensions.cs index 88af1ffae722..ce8cf875a9ae 100644 --- a/src/Http/Http.Extensions/gen/StaticRouteHandlerModel/InvocationOperationExtensions.cs +++ b/src/Http/Http.Extensions/gen/StaticRouteHandlerModel/InvocationOperationExtensions.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Diagnostics.CodeAnalysis; +using Microsoft.AspNetCore.Analyzers.RouteEmbeddedLanguage.Infrastructure; using Microsoft.AspNetCore.App.Analyzers.Infrastructure; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp.Syntax; @@ -23,13 +24,13 @@ internal static class InvocationOperationExtensions "MapFallback" }; - public static bool IsValidOperation(this IOperation? operation, WellKnownTypes wellKnownTypes, [NotNullWhen(true)] out IInvocationOperation? invocationOperation) + public static bool IsValidOperation(this IOperation? operation, [NotNullWhen(true)] out IInvocationOperation? invocationOperation) { invocationOperation = null; if (operation is IInvocationOperation targetOperation && targetOperation.TryGetRouteHandlerArgument(out var routeHandlerParameter) && routeHandlerParameter is { Parameter.Type: {} delegateType } && - SymbolEqualityComparer.Default.Equals(delegateType, wellKnownTypes.Get(WellKnownTypeData.WellKnownType.System_Delegate))) + delegateType.EqualsByName(["System", "Delegate"])) { invocationOperation = targetOperation; return true; @@ -83,12 +84,12 @@ public static bool TryGetMapMethodName(this SyntaxNode node, out string? methodN IArgumentOperation argument => ResolveMethodFromOperation(argument.Value, semanticModel), IConversionOperation conv => ResolveMethodFromOperation(conv.Operand, semanticModel), IDelegateCreationOperation del => ResolveMethodFromOperation(del.Target, semanticModel), - IFieldReferenceOperation { Field.IsReadOnly: true } f when ResolveDeclarationOperation(f.Field, semanticModel) is IOperation op => - ResolveMethodFromOperation(op, semanticModel), IAnonymousFunctionOperation anon => anon.Symbol, ILocalFunctionOperation local => local.Symbol, IMethodReferenceOperation method => method.Method, IParenthesizedOperation parenthesized => ResolveMethodFromOperation(parenthesized.Operand, semanticModel), + IFieldReferenceOperation { Field.IsReadOnly: true } f when ResolveDeclarationOperation(f.Field, semanticModel) is IOperation op => + ResolveMethodFromOperation(op, semanticModel), _ => null }; diff --git a/src/Shared/RoslynUtils/ParsabilityHelper.cs b/src/Shared/RoslynUtils/ParsabilityHelper.cs index a4e06e44388d..2fe421efe5f5 100644 --- a/src/Shared/RoslynUtils/ParsabilityHelper.cs +++ b/src/Shared/RoslynUtils/ParsabilityHelper.cs @@ -2,25 +2,19 @@ // The .NET Foundation licenses this file to you under the MIT license. using System; -using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis; -using System.Threading; using System.Linq; -using Microsoft.AspNetCore.App.Analyzers.Infrastructure; using Microsoft.AspNetCore.Analyzers.RouteEmbeddedLanguage.Infrastructure; -using System.ComponentModel; using System.Diagnostics.CodeAnalysis; namespace Microsoft.AspNetCore.Analyzers.Infrastructure; -using WellKnownType = WellKnownTypeData.WellKnownType; - internal static class ParsabilityHelper { private static readonly BoundedCacheWithFactory BindabilityCache = new(); private static readonly BoundedCacheWithFactory ParsabilityCache = new(); - private static bool IsTypeAlwaysParsable(ITypeSymbol typeSymbol, WellKnownTypes wellKnownTypes, [NotNullWhen(true)] out ParsabilityMethod? parsabilityMethod) + private static bool IsTypeAlwaysParsable(ITypeSymbol typeSymbol, [NotNullWhen(true)] out ParsabilityMethod? parsabilityMethod) { // Any enum is valid. if (typeSymbol.TypeKind == TypeKind.Enum) @@ -30,7 +24,7 @@ private static bool IsTypeAlwaysParsable(ITypeSymbol typeSymbol, WellKnownTypes } // Uri is valid. - if (SymbolEqualityComparer.Default.Equals(typeSymbol, wellKnownTypes.Get(WellKnownType.System_Uri))) + if (typeSymbol.EqualsByName(["System", "Uri"])) { parsabilityMethod = ParsabilityMethod.Uri; return true; @@ -47,25 +41,25 @@ private static bool IsTypeAlwaysParsable(ITypeSymbol typeSymbol, WellKnownTypes return false; } - internal static Parsability GetParsability(ITypeSymbol typeSymbol, WellKnownTypes wellKnownTypes) + internal static Parsability GetParsability(ITypeSymbol typeSymbol) { - return GetParsability(typeSymbol, wellKnownTypes, out var _); + return GetParsability(typeSymbol, out var _); } - internal static Parsability GetParsability(ITypeSymbol typeSymbol, WellKnownTypes wellKnownTypes, [NotNullWhen(false)] out ParsabilityMethod? parsabilityMethod) + internal static Parsability GetParsability(ITypeSymbol typeSymbol, [NotNullWhen(false)] out ParsabilityMethod? parsabilityMethod) { var parsability = Parsability.NotParsable; parsabilityMethod = null; (parsability, parsabilityMethod) = ParsabilityCache.GetOrCreateValue(typeSymbol, (typeSymbol) => { - if (IsTypeAlwaysParsable(typeSymbol, wellKnownTypes, out var parsabilityMethod)) + if (IsTypeAlwaysParsable(typeSymbol, out var parsabilityMethod)) { return (Parsability.Parsable, parsabilityMethod); } // MyType : IParsable() - if (IsParsableViaIParsable(typeSymbol, wellKnownTypes)) + if (IsParsableViaIParsable(typeSymbol)) { return (Parsability.Parsable, ParsabilityMethod.IParsable); } @@ -75,7 +69,7 @@ internal static Parsability GetParsability(ITypeSymbol typeSymbol, WellKnownType .SelectMany(t => t.GetMembers("TryParse")) .OfType(); - if (tryParseMethods.Any(m => IsTryParseWithFormat(m, wellKnownTypes))) + if (tryParseMethods.Any(m => IsTryParseWithFormat(m))) { return (Parsability.Parsable, ParsabilityMethod.TryParseWithFormatProvider); } @@ -101,65 +95,59 @@ private static bool IsTryParse(IMethodSymbol methodSymbol) methodSymbol.Parameters[1].RefKind == RefKind.Out; } - private static bool IsTryParseWithFormat(IMethodSymbol methodSymbol, WellKnownTypes wellKnownTypes) + private static bool IsTryParseWithFormat(IMethodSymbol methodSymbol) { return methodSymbol.DeclaredAccessibility == Accessibility.Public && methodSymbol.IsStatic && methodSymbol.ReturnType.SpecialType == SpecialType.System_Boolean && methodSymbol.Parameters.Length == 3 && methodSymbol.Parameters[0].Type.SpecialType == SpecialType.System_String && - SymbolEqualityComparer.Default.Equals(methodSymbol.Parameters[1].Type, wellKnownTypes.Get(WellKnownType.System_IFormatProvider)) && + methodSymbol.Parameters[1].Type.EqualsByName(["System", "IFormatProvider"]) && methodSymbol.Parameters[2].RefKind == RefKind.Out; } - internal static bool IsParsableViaIParsable(ITypeSymbol typeSymbol, WellKnownTypes wellKnownTypes) - { - var iParsableTypeSymbol = wellKnownTypes.Get(WellKnownType.System_IParsable_T); - var implementsIParsable = typeSymbol.AllInterfaces.Any( - i => SymbolEqualityComparer.Default.Equals(i.ConstructedFrom, iParsableTypeSymbol) - ); - return implementsIParsable; - } + internal static bool IsParsableViaIParsable(ITypeSymbol typeSymbol) => + typeSymbol.AllInterfaces.Any(i => i.ConstructedFrom.EqualsByName(["System", "IParsable"])); - private static bool IsBindableViaIBindableFromHttpContext(ITypeSymbol typeSymbol, WellKnownTypes wellKnownTypes) + private static bool IsBindableViaIBindableFromHttpContext(ITypeSymbol typeSymbol) { - var iBindableFromHttpContextTypeSymbol = wellKnownTypes.Get(WellKnownType.Microsoft_AspNetCore_Http_IBindableFromHttpContext_T); var constructedTypeSymbol = typeSymbol.AllInterfaces.FirstOrDefault( - i => SymbolEqualityComparer.Default.Equals(i.ConstructedFrom, iBindableFromHttpContextTypeSymbol) + i => i.ConstructedFrom.EqualsByName(["Microsoft", "AspNetCore", "Http", "IBindableFromHttpContext"]) ); return constructedTypeSymbol != null && SymbolEqualityComparer.Default.Equals(constructedTypeSymbol.TypeArguments[0].UnwrapTypeSymbol(unwrapNullable: true), typeSymbol); } - private static bool IsBindAsync(IMethodSymbol methodSymbol, ITypeSymbol typeSymbol, WellKnownTypes wellKnownTypes) + private static bool IsBindAsync(IMethodSymbol methodSymbol, ITypeSymbol typeSymbol) { return methodSymbol.DeclaredAccessibility == Accessibility.Public && methodSymbol.IsStatic && methodSymbol.Parameters.Length == 1 && - SymbolEqualityComparer.Default.Equals(methodSymbol.Parameters[0].Type, wellKnownTypes.Get(WellKnownType.Microsoft_AspNetCore_Http_HttpContext)) && + methodSymbol.Parameters[0].Type.EqualsByName(["Microsoft", "AspNetCore", "Http", "HttpContext"]) && methodSymbol.ReturnType is INamedTypeSymbol returnType && - SymbolEqualityComparer.Default.Equals(returnType.ConstructedFrom, wellKnownTypes.Get(WellKnownType.System_Threading_Tasks_ValueTask_T)) && + returnType.IsGenericType && + returnType.ConstructedFrom.EqualsByName(["System", "Threading", "Tasks", "ValueTask"]) && SymbolEqualityComparer.Default.Equals(returnType.TypeArguments[0], typeSymbol); } - private static bool IsBindAsyncWithParameter(IMethodSymbol methodSymbol, ITypeSymbol typeSymbol, WellKnownTypes wellKnownTypes) + private static bool IsBindAsyncWithParameter(IMethodSymbol methodSymbol, ITypeSymbol typeSymbol) { return methodSymbol.DeclaredAccessibility == Accessibility.Public && methodSymbol.IsStatic && methodSymbol.Parameters.Length == 2 && - SymbolEqualityComparer.Default.Equals(methodSymbol.Parameters[0].Type, wellKnownTypes.Get(WellKnownType.Microsoft_AspNetCore_Http_HttpContext)) && - SymbolEqualityComparer.Default.Equals(methodSymbol.Parameters[1].Type, wellKnownTypes.Get(WellKnownType.System_Reflection_ParameterInfo)) && + methodSymbol.Parameters[0].Type.EqualsByName(["Microsoft", "AspNetCore", "Http", "HttpContext"]) && + methodSymbol.Parameters[1].Type.EqualsByName(["System", "Reflection", "ParameterInfo"]) && methodSymbol.ReturnType is INamedTypeSymbol returnType && - IsReturningValueTaskOfTOrNullableT(returnType, typeSymbol, wellKnownTypes); + IsReturningValueTaskOfTOrNullableT(returnType, typeSymbol); } - private static bool IsReturningValueTaskOfTOrNullableT(INamedTypeSymbol returnType, ITypeSymbol containingType, WellKnownTypes wellKnownTypes) + private static bool IsReturningValueTaskOfTOrNullableT(INamedTypeSymbol returnType, ITypeSymbol containingType) { - return SymbolEqualityComparer.Default.Equals(returnType.ConstructedFrom, wellKnownTypes.Get(WellKnownType.System_Threading_Tasks_ValueTask_T)) && + return returnType.IsGenericType && returnType.ConstructedFrom.EqualsByName(["System", "Threading", "Tasks", "ValueTask"]) && SymbolEqualityComparer.Default.Equals(returnType.TypeArguments[0].UnwrapTypeSymbol(unwrapNullable: true), containingType); } - internal static Bindability GetBindability(ITypeSymbol typeSymbol, WellKnownTypes wellKnownTypes, out BindabilityMethod? bindabilityMethod, out IMethodSymbol? bindMethodSymbol) + internal static Bindability GetBindability(ITypeSymbol typeSymbol, out BindabilityMethod? bindabilityMethod, out IMethodSymbol? bindMethodSymbol) { bindabilityMethod = null; bindMethodSymbol = null; @@ -169,7 +157,7 @@ internal static Bindability GetBindability(ITypeSymbol typeSymbol, WellKnownType { BindabilityMethod? bindabilityMethod = null; IMethodSymbol? bindMethodSymbol = null; - if (IsBindableViaIBindableFromHttpContext(typeSymbol, wellKnownTypes)) + if (IsBindableViaIBindableFromHttpContext(typeSymbol)) { return (BindabilityMethod.IBindableFromHttpContext, null); } @@ -185,13 +173,13 @@ internal static Bindability GetBindability(ITypeSymbol typeSymbol, WellKnownType if (methodSymbolCandidate is IMethodSymbol methodSymbol) { bindAsyncMethod ??= methodSymbol; - if (IsBindAsyncWithParameter(methodSymbol, typeSymbol, wellKnownTypes)) + if (IsBindAsyncWithParameter(methodSymbol, typeSymbol)) { bindabilityMethod = BindabilityMethod.BindAsyncWithParameter; bindMethodSymbol = methodSymbol; break; } - if (IsBindAsync(methodSymbol, typeSymbol, wellKnownTypes)) + if (IsBindAsync(methodSymbol, typeSymbol)) { bindabilityMethod = BindabilityMethod.BindAsync; bindMethodSymbol = methodSymbol; @@ -211,7 +199,7 @@ internal static Bindability GetBindability(ITypeSymbol typeSymbol, WellKnownType // See if we can give better guidance on why the BindAsync method is no good. if (bindAsyncMethod is not null) { - if (bindAsyncMethod.ReturnType is INamedTypeSymbol returnType && !IsReturningValueTaskOfTOrNullableT(returnType, typeSymbol, wellKnownTypes)) + if (bindAsyncMethod.ReturnType is INamedTypeSymbol returnType && !IsReturningValueTaskOfTOrNullableT(returnType, typeSymbol)) { return Bindability.InvalidReturnType; } diff --git a/src/Shared/RoslynUtils/SymbolExtensions.cs b/src/Shared/RoslynUtils/SymbolExtensions.cs index 58ec2bfe36d6..fa32e13671d3 100644 --- a/src/Shared/RoslynUtils/SymbolExtensions.cs +++ b/src/Shared/RoslynUtils/SymbolExtensions.cs @@ -66,6 +66,11 @@ public static bool HasAttribute(this ImmutableArray attributes, I return attributes.TryGetAttribute(attributeType, out _); } + public static bool HasAttribute(this ImmutableArray attributes, string[] attributeType) + { + return attributes.TryGetAttribute(attributeType, out _); + } + public static bool TryGetAttribute(this ImmutableArray attributes, INamedTypeSymbol attributeType, [NotNullWhen(true)] out AttributeData? matchedAttribute) { foreach (var attributeData in attributes) @@ -81,6 +86,21 @@ public static bool TryGetAttribute(this ImmutableArray attributes return false; } + public static bool TryGetAttribute(this ImmutableArray attributes, string[] attributeName, [NotNullWhen(true)] out AttributeData? matchedAttribute) + { + foreach (var attributeData in attributes) + { + if (attributeData.AttributeClass?.EqualsByName(attributeName) == true) + { + matchedAttribute = attributeData; + return true; + } + } + + matchedAttribute = null; + return false; + } + public static bool HasAttributeImplementingInterface(this ISymbol symbol, INamedTypeSymbol interfaceType) { return symbol.TryGetAttributeImplementingInterface(interfaceType, out var _); @@ -106,6 +126,11 @@ public static bool HasAttributeImplementingInterface(this ImmutableArray attributes, string[] interfaceName) + { + return attributes.TryGetAttributeImplementingInterface(interfaceName, out var _); + } + public static bool TryGetAttributeImplementingInterface(this ImmutableArray attributes, INamedTypeSymbol interfaceType, [NotNullWhen(true)] out AttributeData? matchedAttribute) { foreach (var attributeData in attributes) @@ -121,6 +146,61 @@ public static bool TryGetAttributeImplementingInterface(this ImmutableArray attributes, string[] interfaceName, [NotNullWhen(true)] out AttributeData? matchedAttribute) + { + foreach (var attributeData in attributes) + { + if (attributeData.AttributeClass is not null && attributeData.AttributeClass.Implements(interfaceName)) + { + matchedAttribute = attributeData; + return true; + } + } + + matchedAttribute = null; + return false; + } + + public static bool Implements(this ITypeSymbol type, params string[] interfaceName) + { + foreach (var t in type.AllInterfaces) + { + if (t.EqualsByName(interfaceName)) + { + return true; + } + } + return false; + } + + public static bool EqualsByName(this ITypeSymbol type, params string[] name) + { + var length = name.Length; + // Check that the type name matches what we expect + if (type.Name != name[length - 1]) + { + return false; + } + // Enumerate the containing namespaces to ensure they match + var targetNamespace = type.ContainingNamespace; + for (var i = length - 2; i >= 0; i--) + { + if (targetNamespace.Name != name[i]) + { + return false; + } + targetNamespace = targetNamespace.ContainingNamespace; + } + // Once all namespace parts have been enumerated + // we should be in the global namespace + if (targetNamespace.IsGlobalNamespace) + { + return true; + } + + return false; + } + public static bool Implements(this ITypeSymbol type, ITypeSymbol interfaceType) { foreach (var t in type.AllInterfaces)