Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@ namespace CommunityToolkit.Extensions.DependencyInjection.SourceGenerators.Model
/// A model for a singleton service registration.
/// </summary>
/// <param name="RegistrationKind">The registration kind for the service.</param>
/// <param name="ImplementationTypeName">The type name of the implementation type.</param>
/// <param name="ImplementationFullyQualifiedTypeName">The fully qualified type name of the implementation type.</param>
/// <param name="RequiredServiceFullyQualifiedTypeNames">The fully qualified type names of dependent services for <paramref name="ImplementationFullyQualifiedTypeName"/>.</param>
/// <param name="ServiceFullyQualifiedTypeNames">The fully qualified type names for the services to register for <paramref name="ImplementationFullyQualifiedTypeName"/>.</param>
internal sealed record RegisteredServiceInfo(
ServiceRegistrationKind RegistrationKind,
string ImplementationTypeName,
string ImplementationFullyQualifiedTypeName,
EquatableArray<string> RequiredServiceFullyQualifiedTypeNames,
EquatableArray<string> ServiceFullyQualifiedTypeNames);
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ public static bool IsSyntaxTarget(SyntaxNode syntaxNode, CancellationToken token
// Create the model fully describing the current service registration
serviceInfo.Add(new RegisteredServiceInfo(
RegistrationKind: registrationKind,
ImplementationTypeName: implementationType.Name,
ImplementationFullyQualifiedTypeName: implementationTypeName,
ServiceFullyQualifiedTypeNames: serviceTypeNames,
RequiredServiceFullyQualifiedTypeNames: constructorArgumentTypes));
Expand All @@ -166,8 +167,11 @@ public static bool IsSyntaxTarget(SyntaxNode syntaxNode, CancellationToken token
/// <returns>A <see cref="CompilationUnitSyntax"/> instance with the gathered info.</returns>
public static CompilationUnitSyntax GetSyntax(ServiceCollectionInfo info)
{
using ImmutableArrayBuilder<LocalFunctionStatementSyntax> localFunctions = ImmutableArrayBuilder<LocalFunctionStatementSyntax>.Rent();
using ImmutableArrayBuilder<StatementSyntax> registrationStatements = ImmutableArrayBuilder<StatementSyntax>.Rent();

int index = -1;

foreach (RegisteredServiceInfo serviceInfo in info.Services)
{
// The first service type always acts as "main" registration, and should always be present
Expand All @@ -176,6 +180,9 @@ public static CompilationUnitSyntax GetSyntax(ServiceCollectionInfo info)
continue;
}

// Increment the index we use to disambiguate the generated local function names (starting from 0)
index++;

using ImmutableArrayBuilder<ArgumentSyntax> constructorArguments = ImmutableArrayBuilder<ArgumentSyntax>.Rent();

// Prepare the dependent services for the implementation type
Expand All @@ -199,54 +206,55 @@ public static CompilationUnitSyntax GetSyntax(ServiceCollectionInfo info)
// Prepare the method name, either AddSingleton or AddTransient
string registrationMethod = $"Add{serviceInfo.RegistrationKind}";

// Special case when the service is a singleton and no dependent services are present, just use eager instantiation instead:
// Prepare the name of the factory local function
string factoryMethod = $"Get{serviceInfo.ImplementationTypeName}_{index}";

// Prepare the local function for the registration (to improve lambda caching):
//
// global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddSingleton(<PARAMETER_NAME>, typeof(<ROOT_SERVICE_TYPE>), new <IMPLEMENTATION_TYPE>());
if (serviceInfo.RegistrationKind == ServiceRegistrationKind.Singleton && constructorArguments.Count == 0)
{
registrationStatements.Add(
ExpressionStatement(
InvocationExpression(
MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
IdentifierName("global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions"),
IdentifierName("AddSingleton")))
.AddArgumentListArguments(
Argument(IdentifierName(info.Method.ServiceCollectionParameterName)),
Argument(TypeOfExpression(IdentifierName(rootServiceTypeName))),
Argument(
ObjectCreationExpression(IdentifierName(serviceInfo.ImplementationFullyQualifiedTypeName))
.WithArgumentList(ArgumentList())))));
}
else
{
// Register the main implementation type when at least a dependent service is needed:
//
// global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.<REGISTRATION_METHOD>(<PARAMETER_NAME>, typeof(<ROOT_SERVICE_TYPE>), static services => new <IMPLEMENTATION_TYPE>(<CONSTRUCTOR_ARGUMENTS>));
registrationStatements.Add(
ExpressionStatement(
InvocationExpression(
MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
IdentifierName("global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions"),
IdentifierName(registrationMethod)))
.AddArgumentListArguments(
Argument(IdentifierName(info.Method.ServiceCollectionParameterName)),
Argument(TypeOfExpression(IdentifierName(rootServiceTypeName))),
Argument(
SimpleLambdaExpression(Parameter(Identifier("services")))
.AddModifiers(Token(SyntaxKind.StaticKeyword))
.WithExpressionBody(
ObjectCreationExpression(IdentifierName(serviceInfo.ImplementationFullyQualifiedTypeName))
.AddArgumentListArguments(constructorArguments.ToArray()))))));
}
// static object <FACTORY_METHOD>(global::System.IServiceProvider services)
// {
// return new <IMPLEMENTATION_TYPE>(<CONSTRUCTOR_ARGUMENTS>);
// }
localFunctions.Add(
LocalFunctionStatement(
PredefinedType(Token(SyntaxKind.ObjectKeyword)),
Identifier(factoryMethod))
.AddModifiers(Token(SyntaxKind.StaticKeyword))
.AddParameterListParameters(
Parameter(Identifier("services"))
.WithType(IdentifierName("global::System.IserviceProvider")))
.AddBodyStatements(
ReturnStatement(
ObjectCreationExpression(IdentifierName(serviceInfo.ImplementationFullyQualifiedTypeName))
.AddArgumentListArguments(constructorArguments.ToArray()))));

// Register the main implementation type:
//
// global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.<REGISTRATION_METHOD>(<PARAMETER_NAME>, typeof(<ROOT_SERVICE_TYPE>), new global::Func<global::System.IServiceProvider, object>(<FACTORY_METHOD>));
registrationStatements.Add(
ExpressionStatement(
InvocationExpression(
MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
IdentifierName("global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions"),
IdentifierName(registrationMethod)))
.AddArgumentListArguments(
Argument(IdentifierName(info.Method.ServiceCollectionParameterName)),
Argument(TypeOfExpression(IdentifierName(rootServiceTypeName))),
Argument(
ObjectCreationExpression(
GenericName(Identifier("global::System.Func"))
.AddTypeArgumentListArguments(
IdentifierName("global::System.IServiceProvider"),
PredefinedType(Token(SyntaxKind.ObjectKeyword))))
.AddArgumentListArguments(Argument(IdentifierName(factoryMethod)))))));

// Register all secondary services, if any
foreach (string dependentServiceType in dependentServiceTypeNames)
{
// Register the main implementation type:
//
// global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.<REGISTRATION_METHOD>(<PARAMETER_NAME>, typeof(<DEPENDENT_SERVICE_TYPE>), static services => global::Microsoft.Extensions.DependencyInjection.ServiceProviderServiceExtensions.GetRequiredServices<ROOT_SERVICE_TYPE>(services));
// global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.<REGISTRATION_METHOD>(<PARAMETER_NAME>, typeof(<DEPENDENT_SERVICE_TYPE>), new global::System.Action<global::Microsoft.Extensions.DependencyInjection.IServiceCollection, object>(global::Microsoft.Extensions.DependencyInjection.ServiceProviderServiceExtensions.GetRequiredServices<ROOT_SERVICE_TYPE>));
registrationStatements.Add(
ExpressionStatement(
InvocationExpression(
Expand All @@ -258,16 +266,17 @@ public static CompilationUnitSyntax GetSyntax(ServiceCollectionInfo info)
Argument(IdentifierName(info.Method.ServiceCollectionParameterName)),
Argument(TypeOfExpression(IdentifierName(dependentServiceType))),
Argument(
SimpleLambdaExpression(Parameter(Identifier("services")))
.AddModifiers(Token(SyntaxKind.StaticKeyword))
.WithExpressionBody(
InvocationExpression(
MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
IdentifierName("global::Microsoft.Extensions.DependencyInjection.ServiceProviderServiceExtensions"),
GenericName(Identifier("GetRequiredService"))
.AddTypeArgumentListArguments(IdentifierName(rootServiceTypeName))))
.AddArgumentListArguments(Argument(IdentifierName("services"))))))));
ObjectCreationExpression(
GenericName("global::System.Action")
.AddTypeArgumentListArguments(
IdentifierName("global::Microsoft.Extensions.DependencyInjection.IServiceCollection"),
PredefinedType(Token(SyntaxKind.ObjectKeyword))))
.AddArgumentListArguments(Argument(
MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
IdentifierName("global::Microsoft.Extensions.DependencyInjection.ServiceProviderServiceExtensions"),
GenericName(Identifier("GetRequiredService"))
.AddTypeArgumentListArguments(IdentifierName(rootServiceTypeName)))))))));
}
}

Expand All @@ -294,6 +303,7 @@ public static CompilationUnitSyntax GetSyntax(ServiceCollectionInfo info)
// [global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage]
// <MODIFIERS> <RETURN_TYPE> <METHOD_NAME>(global::Microsoft.Extensions.DependencyInjection.IServiceCollection <PARAMETER_NAME>)
// {
// <LOCAL_FUNCTIONS>
// <REGISTRATION_STATEMENTS>
// }
MethodDeclarationSyntax configureServicesMethodDeclaration =
Expand All @@ -302,6 +312,7 @@ public static CompilationUnitSyntax GetSyntax(ServiceCollectionInfo info)
.AddParameterListParameters(
Parameter(Identifier(info.Method.ServiceCollectionParameterName))
.WithType(IdentifierName("global::Microsoft.Extensions.DependencyInjection.IServiceCollection")))
.AddBodyStatements(localFunctions.ToArray())
.AddBodyStatements(registrationStatements.ToArray())
.AddAttributeLists(
AttributeList(SingletonSeparatedList(
Expand Down
Loading