diff --git a/src/Injectio.Attributes/RegistrationStrategy.cs b/src/Injectio.Attributes/RegistrationStrategy.cs index f249f97..3355135 100644 --- a/src/Injectio.Attributes/RegistrationStrategy.cs +++ b/src/Injectio.Attributes/RegistrationStrategy.cs @@ -16,5 +16,11 @@ public enum RegistrationStrategy /// /// Registers each matching concrete type as all of its implemented interfaces and itself /// - SelfWithInterfaces = 2 + SelfWithInterfaces = 2, + /// + /// Registers each matching concrete type as all of its implemented interfaces and itself. + /// For the interfaces a proxy-factory resolves the service from its type-name, so only one instance is created per lifetime + /// + /// For open-generic registrations, this behaves like + SelfWithProxyFactory = 3 } diff --git a/src/Injectio.Generators/KnownTypes.cs b/src/Injectio.Generators/KnownTypes.cs index d3dcc34..77efad4 100644 --- a/src/Injectio.Generators/KnownTypes.cs +++ b/src/Injectio.Generators/KnownTypes.cs @@ -37,29 +37,33 @@ public static class KnownTypes public const string ServiceLifetimeTransientFullName = "Microsoft.Extensions.DependencyInjection.ServiceLifetime.Transient"; - public static readonly int DuplicateStrategySkipValue = 0; + public const int DuplicateStrategySkipValue = 0; public const string DuplicateStrategySkipShortName = "Skip"; public const string DuplicateStrategySkipTypeName = $"DuplicateStrategy.{DuplicateStrategySkipShortName}"; - public static readonly int DuplicateStrategyReplaceValue = 1; + public const int DuplicateStrategyReplaceValue = 1; public const string DuplicateStrategyReplaceShortName = "Replace"; public const string DuplicateStrategyReplaceTypeName = $"DuplicateStrategy.{DuplicateStrategyReplaceShortName}"; - public static readonly int DuplicateStrategyAppendValue = 2; + public const int DuplicateStrategyAppendValue = 2; public const string DuplicateStrategyAppendShortName = "Append"; public const string DuplicateStrategyAppendTypeName = $"DuplicateStrategy.{DuplicateStrategyAppendShortName}"; - public static readonly int RegistrationStrategySelfValue = 0; + public const int RegistrationStrategySelfValue = 0; public const string RegistrationStrategySelfShortName = "Self"; public const string RegistrationStrategySelfTypeName = $"RegistrationStrategy.{RegistrationStrategySelfShortName}"; - public static readonly int RegistrationStrategyImplementedInterfacesValue = 1; + public const int RegistrationStrategyImplementedInterfacesValue = 1; public const string RegistrationStrategyImplementedInterfacesShortName = "ImplementedInterfaces"; public const string RegistrationStrategyImplementedInterfacesTypeName = $"RegistrationStrategy.{RegistrationStrategyImplementedInterfacesShortName}"; - public static readonly int RegistrationStrategySelfWithInterfacesValue = 2; + public const int RegistrationStrategySelfWithInterfacesValue = 2; public const string RegistrationStrategySelfWithInterfacesShortName = "SelfWithInterfaces"; public const string RegistrationStrategySelfWithInterfacesTypeName = $"RegistrationStrategy.{RegistrationStrategySelfWithInterfacesShortName}"; + public const int RegistrationStrategySelfWithProxyFactoryValue = 3; + public const string RegistrationStrategySelfWithProxyFactoryShortName = "SelfWithProxyFactory"; + public const string RegistrationStrategySelfWithProxyFactoryTypeName = $"RegistrationStrategy.{RegistrationStrategySelfWithProxyFactoryShortName}"; + } diff --git a/src/Injectio.Generators/ServiceRegistrationGenerator.cs b/src/Injectio.Generators/ServiceRegistrationGenerator.cs index cc86341..3d3b250 100644 --- a/src/Injectio.Generators/ServiceRegistrationGenerator.cs +++ b/src/Injectio.Generators/ServiceRegistrationGenerator.cs @@ -336,7 +336,7 @@ private static (EquatableArray diagnostics, bool hasServiceCollectio && implementationType == null && serviceTypes.Count == 0) { - registrationStrategy = KnownTypes.RegistrationStrategySelfWithInterfacesShortName; + registrationStrategy = KnownTypes.RegistrationStrategySelfWithProxyFactoryShortName; } // no implementation type set, use class attribute is on @@ -348,7 +348,9 @@ private static (EquatableArray diagnostics, bool hasServiceCollectio } // add implemented interfaces - bool includeInterfaces = registrationStrategy is KnownTypes.RegistrationStrategyImplementedInterfacesShortName or KnownTypes.RegistrationStrategySelfWithInterfacesShortName; + bool includeInterfaces = registrationStrategy is KnownTypes.RegistrationStrategyImplementedInterfacesShortName + or KnownTypes.RegistrationStrategySelfWithInterfacesShortName + or KnownTypes.RegistrationStrategySelfWithProxyFactoryShortName; if (includeInterfaces) { foreach (var implementedInterface in classSymbol.AllInterfaces) @@ -366,10 +368,18 @@ private static (EquatableArray diagnostics, bool hasServiceCollectio } // add class attribute is on; default service type if not set - bool includeSelf = registrationStrategy is KnownTypes.RegistrationStrategySelfShortName or KnownTypes.RegistrationStrategySelfWithInterfacesShortName; + bool includeSelf = registrationStrategy is KnownTypes.RegistrationStrategySelfShortName + or KnownTypes.RegistrationStrategySelfWithInterfacesShortName + or KnownTypes.RegistrationStrategySelfWithProxyFactoryShortName; if (includeSelf || serviceTypes.Count == 0) serviceTypes.Add(implementationType!); + if (registrationStrategy is null && serviceTypes.Contains(implementationType!)) + registrationStrategy = KnownTypes.RegistrationStrategySelfWithProxyFactoryShortName; + + if (registrationStrategy is KnownTypes.RegistrationStrategySelfWithProxyFactoryShortName && isOpenGeneric) + registrationStrategy = KnownTypes.RegistrationStrategySelfWithInterfacesShortName; + return new ServiceRegistration( Lifetime: serviceLifetime, ImplementationType: implementationType!, @@ -529,9 +539,9 @@ private static string ResolveDuplicateStrategy(object? value) { int v => v switch { - 0 => KnownTypes.DuplicateStrategySkipShortName, - 1 => KnownTypes.DuplicateStrategyReplaceShortName, - 2 => KnownTypes.DuplicateStrategyAppendShortName, + KnownTypes.DuplicateStrategySkipValue => KnownTypes.DuplicateStrategySkipShortName, + KnownTypes.DuplicateStrategyReplaceValue => KnownTypes.DuplicateStrategyReplaceShortName, + KnownTypes.DuplicateStrategyAppendValue => KnownTypes.DuplicateStrategyAppendShortName, _ => KnownTypes.DuplicateStrategySkipShortName }, string text => text, @@ -545,13 +555,14 @@ private static string ResolveRegistrationStrategy(object? value) { int v => v switch { - 0 => KnownTypes.RegistrationStrategySelfShortName, - 1 => KnownTypes.RegistrationStrategyImplementedInterfacesShortName, - 2 => KnownTypes.RegistrationStrategySelfWithInterfacesShortName, - _ => KnownTypes.RegistrationStrategySelfWithInterfacesShortName + KnownTypes.RegistrationStrategySelfValue => KnownTypes.RegistrationStrategySelfShortName, + KnownTypes.RegistrationStrategyImplementedInterfacesValue => KnownTypes.RegistrationStrategyImplementedInterfacesShortName, + KnownTypes.RegistrationStrategySelfWithInterfacesValue => KnownTypes.RegistrationStrategySelfWithInterfacesShortName, + KnownTypes.RegistrationStrategySelfWithProxyFactoryValue => KnownTypes.RegistrationStrategySelfWithProxyFactoryShortName, + _ => KnownTypes.RegistrationStrategySelfWithProxyFactoryShortName }, string text => text, - _ => KnownTypes.RegistrationStrategySelfWithInterfacesShortName + _ => KnownTypes.RegistrationStrategySelfWithProxyFactoryShortName }; } } diff --git a/src/Injectio.Generators/ServiceRegistrationWriter.cs b/src/Injectio.Generators/ServiceRegistrationWriter.cs index bbe0d37..f945c57 100644 --- a/src/Injectio.Generators/ServiceRegistrationWriter.cs +++ b/src/Injectio.Generators/ServiceRegistrationWriter.cs @@ -291,6 +291,23 @@ private static void WriteServiceGeneric( .AppendIf(".", !hasNamespace) .Append(serviceRegistration.Factory); } + else if (serviceRegistration.Registration == KnownTypes.RegistrationStrategySelfWithProxyFactoryShortName + && serviceRegistration.ImplementationType != serviceType) + { + codeBuilder + .AppendIf(", ", serviceRegistration.ServiceKey.HasValue()) + .Append("(serviceProvider") + .AppendIf(", key", serviceRegistration.ServiceKey.HasValue()) + .Append(") => global::Microsoft.Extensions.DependencyInjection.ServiceProvider") + .Append(serviceRegistration.ServiceKey.HasValue() + ? "KeyedServiceExtensions.GetRequiredKeyedService<" + : "ServiceExtensions.GetRequiredService<") + .AppendIf("global::", !serviceRegistration.ImplementationType.StartsWith("global::")) + .Append(serviceRegistration.ImplementationType) + .Append(">(serviceProvider") + .AppendIf(", key", serviceRegistration.ServiceKey.HasValue()) + .Append(")"); + } codeBuilder .AppendLine(")") diff --git a/tests/Injectio.Tests/Snapshots/ServiceRegistrationGeneratorTests.GenerateRegisterScopedSelfWithInterfaces.verified.txt b/tests/Injectio.Tests/Snapshots/ServiceRegistrationGeneratorTests.GenerateRegisterScopedSelfWithInterfaces.verified.txt index 1f85549..2944137 100644 --- a/tests/Injectio.Tests/Snapshots/ServiceRegistrationGeneratorTests.GenerateRegisterScopedSelfWithInterfaces.verified.txt +++ b/tests/Injectio.Tests/Snapshots/ServiceRegistrationGeneratorTests.GenerateRegisterScopedSelfWithInterfaces.verified.txt @@ -20,7 +20,7 @@ namespace Microsoft.Extensions.DependencyInjection global::Microsoft.Extensions.DependencyInjection.Extensions.ServiceCollectionDescriptorExtensions.TryAdd( serviceCollection, - global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor.Scoped() + global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor.Scoped((serviceProvider) => global::Microsoft.Extensions.DependencyInjection.ServiceProviderServiceExtensions.GetRequiredService(serviceProvider)) ); global::Microsoft.Extensions.DependencyInjection.Extensions.ServiceCollectionDescriptorExtensions.TryAdd( diff --git a/tests/Injectio.Tests/Snapshots/ServiceRegistrationGeneratorTests.GenerateRegisterSingletonSelfAsClosedGeneric.verified.txt b/tests/Injectio.Tests/Snapshots/ServiceRegistrationGeneratorTests.GenerateRegisterSingletonSelfAsClosedGeneric.verified.txt index f94cdba..0079f09 100644 --- a/tests/Injectio.Tests/Snapshots/ServiceRegistrationGeneratorTests.GenerateRegisterSingletonSelfAsClosedGeneric.verified.txt +++ b/tests/Injectio.Tests/Snapshots/ServiceRegistrationGeneratorTests.GenerateRegisterSingletonSelfAsClosedGeneric.verified.txt @@ -20,7 +20,7 @@ namespace Microsoft.Extensions.DependencyInjection global::Microsoft.Extensions.DependencyInjection.Extensions.ServiceCollectionDescriptorExtensions.TryAdd( serviceCollection, - global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor.Singleton, global::Injectio.Sample.Service>() + global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor.Singleton, global::Injectio.Sample.Service>((serviceProvider) => global::Microsoft.Extensions.DependencyInjection.ServiceProviderServiceExtensions.GetRequiredService(serviceProvider)) ); global::Microsoft.Extensions.DependencyInjection.Extensions.ServiceCollectionDescriptorExtensions.TryAdd( diff --git a/tests/Injectio.Tests/Snapshots/ServiceRegistrationGeneratorTests.GenerateRegisterSingletonTags.verified.txt b/tests/Injectio.Tests/Snapshots/ServiceRegistrationGeneratorTests.GenerateRegisterSingletonTags.verified.txt index 0699b5e..83d1c10 100644 --- a/tests/Injectio.Tests/Snapshots/ServiceRegistrationGeneratorTests.GenerateRegisterSingletonTags.verified.txt +++ b/tests/Injectio.Tests/Snapshots/ServiceRegistrationGeneratorTests.GenerateRegisterSingletonTags.verified.txt @@ -22,7 +22,7 @@ namespace Microsoft.Extensions.DependencyInjection { global::Microsoft.Extensions.DependencyInjection.Extensions.ServiceCollectionDescriptorExtensions.TryAdd( serviceCollection, - global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor.Singleton() + global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor.Singleton((serviceProvider) => global::Microsoft.Extensions.DependencyInjection.ServiceProviderServiceExtensions.GetRequiredService(serviceProvider)) ); global::Microsoft.Extensions.DependencyInjection.Extensions.ServiceCollectionDescriptorExtensions.TryAdd( diff --git a/tests/Injectio.Tests/Snapshots/ServiceRegistrationGeneratorTests.GenerateRegisterTransientSelfWithInterfaces.verified.txt b/tests/Injectio.Tests/Snapshots/ServiceRegistrationGeneratorTests.GenerateRegisterTransientSelfWithInterfaces.verified.txt index d00106b..a8d610f 100644 --- a/tests/Injectio.Tests/Snapshots/ServiceRegistrationGeneratorTests.GenerateRegisterTransientSelfWithInterfaces.verified.txt +++ b/tests/Injectio.Tests/Snapshots/ServiceRegistrationGeneratorTests.GenerateRegisterTransientSelfWithInterfaces.verified.txt @@ -20,7 +20,7 @@ namespace Microsoft.Extensions.DependencyInjection global::Microsoft.Extensions.DependencyInjection.Extensions.ServiceCollectionDescriptorExtensions.TryAdd( serviceCollection, - global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor.Transient() + global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor.Transient((serviceProvider) => global::Microsoft.Extensions.DependencyInjection.ServiceProviderServiceExtensions.GetRequiredService(serviceProvider)) ); global::Microsoft.Extensions.DependencyInjection.Extensions.ServiceCollectionDescriptorExtensions.TryAdd(