Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions src/Injectio.Generators/ServiceRegistrationGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ private static (EquatableArray<Diagnostic> diagnostics, bool hasServiceCollectio
// no implementation type set, use class attribute is on
if (implementationType.IsNullOrWhiteSpace())
{
implementationType = classSymbol.ToDisplayString(_fullyQualifiedNullableFormat);
implementationType = ToNamedTypeWithoutPlaceholders(classSymbol).ToDisplayString(_fullyQualifiedNullableFormat);
}

// add implemented interfaces
Expand All @@ -345,7 +345,7 @@ private static (EquatableArray<Diagnostic> diagnostics, bool hasServiceCollectio
{
// This interface is typically not injected into services and, more specifically, record types auto-implement it.
if(implementedInterface.ConstructedFrom.ToString() == "System.IEquatable<T>") continue;
var interfaceName = implementedInterface.ToDisplayString(_fullyQualifiedNullableFormat);
var interfaceName = ToNamedTypeWithoutPlaceholders(implementedInterface).ToDisplayString(_fullyQualifiedNullableFormat);
serviceTypes.Add(interfaceName);
}
}
Expand All @@ -366,6 +366,26 @@ private static (EquatableArray<Diagnostic> diagnostics, bool hasServiceCollectio
tags.ToArray());
}

private static INamedTypeSymbol ToNamedTypeWithoutPlaceholders(INamedTypeSymbol typeSymbol)
{
if (!typeSymbol.IsGenericType
|| typeSymbol.IsUnboundGenericType)
{
return typeSymbol;
}

foreach (var typeArgument in typeSymbol.TypeArguments)
{
// If TypeKind is TypeParameter, it's actually the name of a locally declared type-parameter -> placeholder
if (typeArgument.TypeKind != TypeKind.TypeParameter)
{
return typeSymbol;
}
}

return typeSymbol.ConstructUnboundGenericType();
}

private static bool IsKnownAttribute(AttributeData attribute, out string serviceLifetime)
{
if (IsSingletonAttribute(attribute))
Expand Down
1 change: 1 addition & 0 deletions tests/Injectio.Tests.Console/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
var module = provider.GetRequiredService<IModuleService>();

var generic = provider.GetRequiredService<IOpenGeneric<string>>();
var generic2 = provider.GetRequiredService<IOpenGeneric2<string>>();
var tagService = provider.GetService<IServiceTag>();

var alpaService = provider.GetKeyedService<IServiceKeyed>("Alpha");
Expand Down
9 changes: 9 additions & 0 deletions tests/Injectio.Tests.Library/Service.cs
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,15 @@ public class OpenGeneric<T> : IOpenGeneric<T>
{
}

public interface IOpenGeneric2<T>
{
}

[RegisterSingleton]
public class OpenGeneric2<T> : IOpenGeneric2<T>
{
}

public interface IServiceTag
{
}
Expand Down
52 changes: 52 additions & 0 deletions tests/Injectio.Tests/ServiceRegistrationGeneratorTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,58 @@ public class OpenGeneric<T> : IOpenGeneric<T>
.ScrubLinesContaining("GeneratedCodeAttribute");
}

[Fact]
public Task GenerateRegisterSingletonSelfAsOpenGeneric()
{
var source = @"
using Injectio.Attributes;

namespace Injectio.Sample;

public interface IOpenGeneric<T>
{ }

[RegisterSingleton]
public class OpenGeneric<T> : IOpenGeneric<T>
{ }
";

var (diagnostics, output) = GetGeneratedOutput<ServiceRegistrationGenerator>(source);

diagnostics.Should().BeEmpty();

return Verifier
.Verify(output)
.UseDirectory("Snapshots")
.ScrubLinesContaining("GeneratedCodeAttribute");
}

[Fact]
public Task GenerateRegisterSingletonSelfAsClosedGeneric()
{
var source = @"
using Injectio.Attributes;

namespace Injectio.Sample;

public interface IClosedGeneric<T>
{ }

[RegisterSingleton]
public class Service : IClosedGeneric<object>
{ }
";

var (diagnostics, output) = GetGeneratedOutput<ServiceRegistrationGenerator>(source);

diagnostics.Should().BeEmpty();

return Verifier
.Verify(output)
.UseDirectory("Snapshots")
.ScrubLinesContaining("GeneratedCodeAttribute");
}

[Fact]
public Task GenerateRegisterSingletonTags()
{
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// <auto-generated />
#nullable enable

namespace Microsoft.Extensions.DependencyInjection
{
/// <summary>
/// Extension methods for discovered service registrations
/// </summary>
public static class DiscoveredServicesExtensions
{
/// <summary>
/// Adds discovered services from Test.Generator to the specified service collection
/// </summary>
/// <param name="serviceCollection">The service collection.</param>
/// <param name="tags">The service registration tags to include.</param>
/// <returns>The service collection</returns>
public static global::Microsoft.Extensions.DependencyInjection.IServiceCollection AddTestGenerator(this global::Microsoft.Extensions.DependencyInjection.IServiceCollection serviceCollection, params string[]? tags)
{
var tagSet = new global::System.Collections.Generic.HashSet<string>(tags ?? global::System.Linq.Enumerable.Empty<string>());

global::Microsoft.Extensions.DependencyInjection.Extensions.ServiceCollectionDescriptorExtensions.TryAdd(
serviceCollection,
global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor.Describe(
typeof(global::Injectio.Sample.IClosedGeneric<object>),
typeof(global::Injectio.Sample.Service),
global::Microsoft.Extensions.DependencyInjection.ServiceLifetime.Singleton
)
);

global::Microsoft.Extensions.DependencyInjection.Extensions.ServiceCollectionDescriptorExtensions.TryAdd(
serviceCollection,
global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor.Describe(
typeof(global::Injectio.Sample.Service),
typeof(global::Injectio.Sample.Service),
global::Microsoft.Extensions.DependencyInjection.ServiceLifetime.Singleton
)
);

return serviceCollection;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// <auto-generated />
#nullable enable

namespace Microsoft.Extensions.DependencyInjection
{
/// <summary>
/// Extension methods for discovered service registrations
/// </summary>
public static class DiscoveredServicesExtensions
{
/// <summary>
/// Adds discovered services from Test.Generator to the specified service collection
/// </summary>
/// <param name="serviceCollection">The service collection.</param>
/// <param name="tags">The service registration tags to include.</param>
/// <returns>The service collection</returns>
public static global::Microsoft.Extensions.DependencyInjection.IServiceCollection AddTestGenerator(this global::Microsoft.Extensions.DependencyInjection.IServiceCollection serviceCollection, params string[]? tags)
{
var tagSet = new global::System.Collections.Generic.HashSet<string>(tags ?? global::System.Linq.Enumerable.Empty<string>());

global::Microsoft.Extensions.DependencyInjection.Extensions.ServiceCollectionDescriptorExtensions.TryAdd(
serviceCollection,
global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor.Describe(
typeof(global::Injectio.Sample.IOpenGeneric<>),
typeof(global::Injectio.Sample.OpenGeneric<>),
global::Microsoft.Extensions.DependencyInjection.ServiceLifetime.Singleton
)
);

global::Microsoft.Extensions.DependencyInjection.Extensions.ServiceCollectionDescriptorExtensions.TryAdd(
serviceCollection,
global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor.Describe(
typeof(global::Injectio.Sample.OpenGeneric<>),
typeof(global::Injectio.Sample.OpenGeneric<>),
global::Microsoft.Extensions.DependencyInjection.ServiceLifetime.Singleton
)
);

return serviceCollection;
}
}
}