diff --git a/components/Extensions.DependencyInjection/CommunityToolkit.Extensions.DependencyInjection.SourceGenerators/Models/RegisteredServiceInfo.cs b/components/Extensions.DependencyInjection/CommunityToolkit.Extensions.DependencyInjection.SourceGenerators/Models/RegisteredServiceInfo.cs index c25e2896d..b8d57112a 100644 --- a/components/Extensions.DependencyInjection/CommunityToolkit.Extensions.DependencyInjection.SourceGenerators/Models/RegisteredServiceInfo.cs +++ b/components/Extensions.DependencyInjection/CommunityToolkit.Extensions.DependencyInjection.SourceGenerators/Models/RegisteredServiceInfo.cs @@ -10,11 +10,13 @@ namespace CommunityToolkit.Extensions.DependencyInjection.SourceGenerators.Model /// A model for a singleton service registration. /// /// The registration kind for the service. +/// The type name of the implementation type. /// The fully qualified type name of the implementation type. /// The fully qualified type names of dependent services for . /// The fully qualified type names for the services to register for . internal sealed record RegisteredServiceInfo( ServiceRegistrationKind RegistrationKind, + string ImplementationTypeName, string ImplementationFullyQualifiedTypeName, EquatableArray RequiredServiceFullyQualifiedTypeNames, EquatableArray ServiceFullyQualifiedTypeNames); diff --git a/components/Extensions.DependencyInjection/CommunityToolkit.Extensions.DependencyInjection.SourceGenerators/ServiceProviderGenerator.Execute.cs b/components/Extensions.DependencyInjection/CommunityToolkit.Extensions.DependencyInjection.SourceGenerators/ServiceProviderGenerator.Execute.cs index efce7ff24..de4cccf86 100644 --- a/components/Extensions.DependencyInjection/CommunityToolkit.Extensions.DependencyInjection.SourceGenerators/ServiceProviderGenerator.Execute.cs +++ b/components/Extensions.DependencyInjection/CommunityToolkit.Extensions.DependencyInjection.SourceGenerators/ServiceProviderGenerator.Execute.cs @@ -143,6 +143,7 @@ public static bool IsSyntaxTarget(SyntaxNode syntaxNode, CancellationToken token // Create the model fully describing the current service registration serviceInfo.Add(new RegisteredServiceInfo( RegistrationKind: registrationKind, + ImplementationTypeName: implementationType.Name, ImplementationFullyQualifiedTypeName: implementationTypeName, ServiceFullyQualifiedTypeNames: serviceTypeNames, RequiredServiceFullyQualifiedTypeNames: constructorArgumentTypes)); @@ -166,8 +167,11 @@ public static bool IsSyntaxTarget(SyntaxNode syntaxNode, CancellationToken token /// A instance with the gathered info. public static CompilationUnitSyntax GetSyntax(ServiceCollectionInfo info) { + using ImmutableArrayBuilder localFunctions = ImmutableArrayBuilder.Rent(); using ImmutableArrayBuilder registrationStatements = ImmutableArrayBuilder.Rent(); + int index = -1; + foreach (RegisteredServiceInfo serviceInfo in info.Services) { // The first service type always acts as "main" registration, and should always be present @@ -176,6 +180,9 @@ public static CompilationUnitSyntax GetSyntax(ServiceCollectionInfo info) continue; } + // Increment the index we use to disambiguate the generated local function names (starting from 0) + index++; + using ImmutableArrayBuilder constructorArguments = ImmutableArrayBuilder.Rent(); // Prepare the dependent services for the implementation type @@ -199,54 +206,55 @@ public static CompilationUnitSyntax GetSyntax(ServiceCollectionInfo info) // Prepare the method name, either AddSingleton or AddTransient string registrationMethod = $"Add{serviceInfo.RegistrationKind}"; - // Special case when the service is a singleton and no dependent services are present, just use eager instantiation instead: + // Prepare the name of the factory local function + string factoryMethod = $"Get{serviceInfo.ImplementationTypeName}_{index}"; + + // Prepare the local function for the registration (to improve lambda caching): // - // global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddSingleton(, typeof(), new ()); - if (serviceInfo.RegistrationKind == ServiceRegistrationKind.Singleton && constructorArguments.Count == 0) - { - registrationStatements.Add( - ExpressionStatement( - InvocationExpression( - MemberAccessExpression( - SyntaxKind.SimpleMemberAccessExpression, - IdentifierName("global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions"), - IdentifierName("AddSingleton"))) - .AddArgumentListArguments( - Argument(IdentifierName(info.Method.ServiceCollectionParameterName)), - Argument(TypeOfExpression(IdentifierName(rootServiceTypeName))), - Argument( - ObjectCreationExpression(IdentifierName(serviceInfo.ImplementationFullyQualifiedTypeName)) - .WithArgumentList(ArgumentList()))))); - } - else - { - // Register the main implementation type when at least a dependent service is needed: - // - // global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.(, typeof(), static services => new ()); - registrationStatements.Add( - ExpressionStatement( - InvocationExpression( - MemberAccessExpression( - SyntaxKind.SimpleMemberAccessExpression, - IdentifierName("global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions"), - IdentifierName(registrationMethod))) - .AddArgumentListArguments( - Argument(IdentifierName(info.Method.ServiceCollectionParameterName)), - Argument(TypeOfExpression(IdentifierName(rootServiceTypeName))), - Argument( - SimpleLambdaExpression(Parameter(Identifier("services"))) - .AddModifiers(Token(SyntaxKind.StaticKeyword)) - .WithExpressionBody( - ObjectCreationExpression(IdentifierName(serviceInfo.ImplementationFullyQualifiedTypeName)) - .AddArgumentListArguments(constructorArguments.ToArray())))))); - } + // static object (global::System.IServiceProvider services) + // { + // return new (); + // } + localFunctions.Add( + LocalFunctionStatement( + PredefinedType(Token(SyntaxKind.ObjectKeyword)), + Identifier(factoryMethod)) + .AddModifiers(Token(SyntaxKind.StaticKeyword)) + .AddParameterListParameters( + Parameter(Identifier("services")) + .WithType(IdentifierName("global::System.IServiceProvider"))) + .AddBodyStatements( + ReturnStatement( + ObjectCreationExpression(IdentifierName(serviceInfo.ImplementationFullyQualifiedTypeName)) + .AddArgumentListArguments(constructorArguments.ToArray())))); + + // Register the main implementation type: + // + // global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.(, typeof(), new global::Func()); + registrationStatements.Add( + ExpressionStatement( + InvocationExpression( + MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + IdentifierName("global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions"), + IdentifierName(registrationMethod))) + .AddArgumentListArguments( + Argument(IdentifierName(info.Method.ServiceCollectionParameterName)), + Argument(TypeOfExpression(IdentifierName(rootServiceTypeName))), + Argument( + ObjectCreationExpression( + GenericName(Identifier("global::System.Func")) + .AddTypeArgumentListArguments( + IdentifierName("global::System.IServiceProvider"), + PredefinedType(Token(SyntaxKind.ObjectKeyword)))) + .AddArgumentListArguments(Argument(IdentifierName(factoryMethod))))))); // Register all secondary services, if any foreach (string dependentServiceType in dependentServiceTypeNames) { // Register the main implementation type: // - // global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.(, typeof(), static services => global::Microsoft.Extensions.DependencyInjection.ServiceProviderServiceExtensions.GetRequiredServices(services)); + // global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.(, typeof(), new global::System.Func(global::Microsoft.Extensions.DependencyInjection.ServiceProviderServiceExtensions.GetRequiredServices)); registrationStatements.Add( ExpressionStatement( InvocationExpression( @@ -258,16 +266,17 @@ public static CompilationUnitSyntax GetSyntax(ServiceCollectionInfo info) Argument(IdentifierName(info.Method.ServiceCollectionParameterName)), Argument(TypeOfExpression(IdentifierName(dependentServiceType))), Argument( - SimpleLambdaExpression(Parameter(Identifier("services"))) - .AddModifiers(Token(SyntaxKind.StaticKeyword)) - .WithExpressionBody( - InvocationExpression( - MemberAccessExpression( - SyntaxKind.SimpleMemberAccessExpression, - IdentifierName("global::Microsoft.Extensions.DependencyInjection.ServiceProviderServiceExtensions"), - GenericName(Identifier("GetRequiredService")) - .AddTypeArgumentListArguments(IdentifierName(rootServiceTypeName)))) - .AddArgumentListArguments(Argument(IdentifierName("services")))))))); + ObjectCreationExpression( + GenericName("global::System.Func") + .AddTypeArgumentListArguments( + IdentifierName("global::System.IServiceProvider"), + PredefinedType(Token(SyntaxKind.ObjectKeyword)))) + .AddArgumentListArguments(Argument( + MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + IdentifierName("global::Microsoft.Extensions.DependencyInjection.ServiceProviderServiceExtensions"), + GenericName(Identifier("GetRequiredService")) + .AddTypeArgumentListArguments(IdentifierName(rootServiceTypeName))))))))); } } @@ -294,6 +303,7 @@ public static CompilationUnitSyntax GetSyntax(ServiceCollectionInfo info) // [global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage] // (global::Microsoft.Extensions.DependencyInjection.IServiceCollection ) // { + // // // } MethodDeclarationSyntax configureServicesMethodDeclaration = @@ -302,6 +312,7 @@ public static CompilationUnitSyntax GetSyntax(ServiceCollectionInfo info) .AddParameterListParameters( Parameter(Identifier(info.Method.ServiceCollectionParameterName)) .WithType(IdentifierName("global::Microsoft.Extensions.DependencyInjection.IServiceCollection"))) + .AddBodyStatements(localFunctions.ToArray()) .AddBodyStatements(registrationStatements.ToArray()) .AddAttributeLists( AttributeList(SingletonSeparatedList(