diff --git a/components/Extensions.DependencyInjection/CommunityToolkit.Extensions.DependencyInjection.SourceGenerators/Models/RegisteredServiceInfo.cs b/components/Extensions.DependencyInjection/CommunityToolkit.Extensions.DependencyInjection.SourceGenerators/Models/RegisteredServiceInfo.cs
index c25e2896d..b8d57112a 100644
--- a/components/Extensions.DependencyInjection/CommunityToolkit.Extensions.DependencyInjection.SourceGenerators/Models/RegisteredServiceInfo.cs
+++ b/components/Extensions.DependencyInjection/CommunityToolkit.Extensions.DependencyInjection.SourceGenerators/Models/RegisteredServiceInfo.cs
@@ -10,11 +10,13 @@ namespace CommunityToolkit.Extensions.DependencyInjection.SourceGenerators.Model
/// A model for a singleton service registration.
///
/// The registration kind for the service.
+/// The type name of the implementation type.
/// The fully qualified type name of the implementation type.
/// The fully qualified type names of dependent services for .
/// The fully qualified type names for the services to register for .
internal sealed record RegisteredServiceInfo(
ServiceRegistrationKind RegistrationKind,
+ string ImplementationTypeName,
string ImplementationFullyQualifiedTypeName,
EquatableArray RequiredServiceFullyQualifiedTypeNames,
EquatableArray ServiceFullyQualifiedTypeNames);
diff --git a/components/Extensions.DependencyInjection/CommunityToolkit.Extensions.DependencyInjection.SourceGenerators/ServiceProviderGenerator.Execute.cs b/components/Extensions.DependencyInjection/CommunityToolkit.Extensions.DependencyInjection.SourceGenerators/ServiceProviderGenerator.Execute.cs
index efce7ff24..de4cccf86 100644
--- a/components/Extensions.DependencyInjection/CommunityToolkit.Extensions.DependencyInjection.SourceGenerators/ServiceProviderGenerator.Execute.cs
+++ b/components/Extensions.DependencyInjection/CommunityToolkit.Extensions.DependencyInjection.SourceGenerators/ServiceProviderGenerator.Execute.cs
@@ -143,6 +143,7 @@ public static bool IsSyntaxTarget(SyntaxNode syntaxNode, CancellationToken token
// Create the model fully describing the current service registration
serviceInfo.Add(new RegisteredServiceInfo(
RegistrationKind: registrationKind,
+ ImplementationTypeName: implementationType.Name,
ImplementationFullyQualifiedTypeName: implementationTypeName,
ServiceFullyQualifiedTypeNames: serviceTypeNames,
RequiredServiceFullyQualifiedTypeNames: constructorArgumentTypes));
@@ -166,8 +167,11 @@ public static bool IsSyntaxTarget(SyntaxNode syntaxNode, CancellationToken token
/// A instance with the gathered info.
public static CompilationUnitSyntax GetSyntax(ServiceCollectionInfo info)
{
+ using ImmutableArrayBuilder localFunctions = ImmutableArrayBuilder.Rent();
using ImmutableArrayBuilder registrationStatements = ImmutableArrayBuilder.Rent();
+ int index = -1;
+
foreach (RegisteredServiceInfo serviceInfo in info.Services)
{
// The first service type always acts as "main" registration, and should always be present
@@ -176,6 +180,9 @@ public static CompilationUnitSyntax GetSyntax(ServiceCollectionInfo info)
continue;
}
+ // Increment the index we use to disambiguate the generated local function names (starting from 0)
+ index++;
+
using ImmutableArrayBuilder constructorArguments = ImmutableArrayBuilder.Rent();
// Prepare the dependent services for the implementation type
@@ -199,54 +206,55 @@ public static CompilationUnitSyntax GetSyntax(ServiceCollectionInfo info)
// Prepare the method name, either AddSingleton or AddTransient
string registrationMethod = $"Add{serviceInfo.RegistrationKind}";
- // Special case when the service is a singleton and no dependent services are present, just use eager instantiation instead:
+ // Prepare the name of the factory local function
+ string factoryMethod = $"Get{serviceInfo.ImplementationTypeName}_{index}";
+
+ // Prepare the local function for the registration (to improve lambda caching):
//
- // global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddSingleton(, typeof(), new ());
- if (serviceInfo.RegistrationKind == ServiceRegistrationKind.Singleton && constructorArguments.Count == 0)
- {
- registrationStatements.Add(
- ExpressionStatement(
- InvocationExpression(
- MemberAccessExpression(
- SyntaxKind.SimpleMemberAccessExpression,
- IdentifierName("global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions"),
- IdentifierName("AddSingleton")))
- .AddArgumentListArguments(
- Argument(IdentifierName(info.Method.ServiceCollectionParameterName)),
- Argument(TypeOfExpression(IdentifierName(rootServiceTypeName))),
- Argument(
- ObjectCreationExpression(IdentifierName(serviceInfo.ImplementationFullyQualifiedTypeName))
- .WithArgumentList(ArgumentList())))));
- }
- else
- {
- // Register the main implementation type when at least a dependent service is needed:
- //
- // global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.(, typeof(), static services => new ());
- registrationStatements.Add(
- ExpressionStatement(
- InvocationExpression(
- MemberAccessExpression(
- SyntaxKind.SimpleMemberAccessExpression,
- IdentifierName("global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions"),
- IdentifierName(registrationMethod)))
- .AddArgumentListArguments(
- Argument(IdentifierName(info.Method.ServiceCollectionParameterName)),
- Argument(TypeOfExpression(IdentifierName(rootServiceTypeName))),
- Argument(
- SimpleLambdaExpression(Parameter(Identifier("services")))
- .AddModifiers(Token(SyntaxKind.StaticKeyword))
- .WithExpressionBody(
- ObjectCreationExpression(IdentifierName(serviceInfo.ImplementationFullyQualifiedTypeName))
- .AddArgumentListArguments(constructorArguments.ToArray()))))));
- }
+ // static object (global::System.IServiceProvider services)
+ // {
+ // return new ();
+ // }
+ localFunctions.Add(
+ LocalFunctionStatement(
+ PredefinedType(Token(SyntaxKind.ObjectKeyword)),
+ Identifier(factoryMethod))
+ .AddModifiers(Token(SyntaxKind.StaticKeyword))
+ .AddParameterListParameters(
+ Parameter(Identifier("services"))
+ .WithType(IdentifierName("global::System.IServiceProvider")))
+ .AddBodyStatements(
+ ReturnStatement(
+ ObjectCreationExpression(IdentifierName(serviceInfo.ImplementationFullyQualifiedTypeName))
+ .AddArgumentListArguments(constructorArguments.ToArray()))));
+
+ // Register the main implementation type:
+ //
+ // global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.(, typeof(), new global::Func());
+ registrationStatements.Add(
+ ExpressionStatement(
+ InvocationExpression(
+ MemberAccessExpression(
+ SyntaxKind.SimpleMemberAccessExpression,
+ IdentifierName("global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions"),
+ IdentifierName(registrationMethod)))
+ .AddArgumentListArguments(
+ Argument(IdentifierName(info.Method.ServiceCollectionParameterName)),
+ Argument(TypeOfExpression(IdentifierName(rootServiceTypeName))),
+ Argument(
+ ObjectCreationExpression(
+ GenericName(Identifier("global::System.Func"))
+ .AddTypeArgumentListArguments(
+ IdentifierName("global::System.IServiceProvider"),
+ PredefinedType(Token(SyntaxKind.ObjectKeyword))))
+ .AddArgumentListArguments(Argument(IdentifierName(factoryMethod)))))));
// Register all secondary services, if any
foreach (string dependentServiceType in dependentServiceTypeNames)
{
// Register the main implementation type:
//
- // global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.(, typeof(), static services => global::Microsoft.Extensions.DependencyInjection.ServiceProviderServiceExtensions.GetRequiredServices(services));
+ // global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.(, typeof(), new global::System.Func(global::Microsoft.Extensions.DependencyInjection.ServiceProviderServiceExtensions.GetRequiredServices));
registrationStatements.Add(
ExpressionStatement(
InvocationExpression(
@@ -258,16 +266,17 @@ public static CompilationUnitSyntax GetSyntax(ServiceCollectionInfo info)
Argument(IdentifierName(info.Method.ServiceCollectionParameterName)),
Argument(TypeOfExpression(IdentifierName(dependentServiceType))),
Argument(
- SimpleLambdaExpression(Parameter(Identifier("services")))
- .AddModifiers(Token(SyntaxKind.StaticKeyword))
- .WithExpressionBody(
- InvocationExpression(
- MemberAccessExpression(
- SyntaxKind.SimpleMemberAccessExpression,
- IdentifierName("global::Microsoft.Extensions.DependencyInjection.ServiceProviderServiceExtensions"),
- GenericName(Identifier("GetRequiredService"))
- .AddTypeArgumentListArguments(IdentifierName(rootServiceTypeName))))
- .AddArgumentListArguments(Argument(IdentifierName("services"))))))));
+ ObjectCreationExpression(
+ GenericName("global::System.Func")
+ .AddTypeArgumentListArguments(
+ IdentifierName("global::System.IServiceProvider"),
+ PredefinedType(Token(SyntaxKind.ObjectKeyword))))
+ .AddArgumentListArguments(Argument(
+ MemberAccessExpression(
+ SyntaxKind.SimpleMemberAccessExpression,
+ IdentifierName("global::Microsoft.Extensions.DependencyInjection.ServiceProviderServiceExtensions"),
+ GenericName(Identifier("GetRequiredService"))
+ .AddTypeArgumentListArguments(IdentifierName(rootServiceTypeName)))))))));
}
}
@@ -294,6 +303,7 @@ public static CompilationUnitSyntax GetSyntax(ServiceCollectionInfo info)
// [global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage]
// (global::Microsoft.Extensions.DependencyInjection.IServiceCollection )
// {
+ //
//
// }
MethodDeclarationSyntax configureServicesMethodDeclaration =
@@ -302,6 +312,7 @@ public static CompilationUnitSyntax GetSyntax(ServiceCollectionInfo info)
.AddParameterListParameters(
Parameter(Identifier(info.Method.ServiceCollectionParameterName))
.WithType(IdentifierName("global::Microsoft.Extensions.DependencyInjection.IServiceCollection")))
+ .AddBodyStatements(localFunctions.ToArray())
.AddBodyStatements(registrationStatements.ToArray())
.AddAttributeLists(
AttributeList(SingletonSeparatedList(