Skip to content

Commit 34d7295

Browse files
committed
Don't emit registrations with generics that don't satisfy constraints
We were emitting instantiations of generic interfaces (implemented by the service) for covariant type parameters that didn't satisfy the generic type parameter constraints.
1 parent d211847 commit 34d7295

File tree

4 files changed

+124
-2
lines changed

4 files changed

+124
-2
lines changed

src/DependencyInjection.Tests/DependencyInjection.Tests.csproj

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010

1111
<ItemGroup>
1212
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.9.0" />
13+
<PackageReference Include="Moq" Version="4.20.72" />
14+
<PackageReference Include="Spectre.Console.Cli" Version="0.50.0" />
1315
<PackageReference Include="xunit" Version="2.7.0" />
1416
<PackageReference Include="xunit.runner.visualstudio" Version="2.5.7" />
1517
<PackageReference Include="Microsoft.Extensions.DependencyInjection" Version="8.0.0" />
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Linq;
4+
using System.Text;
5+
using System.Threading.Tasks;
6+
using Microsoft.Extensions.DependencyInjection;
7+
using Moq;
8+
using Spectre.Console.Cli;
9+
10+
namespace Tests.Regressions;
11+
12+
public class Regressions
13+
{
14+
[Fact]
15+
public void CovariantRegistrationSatisfiesIntefaceConstraints()
16+
{
17+
var collection = new ServiceCollection();
18+
collection.AddServices(typeof(ICommand));
19+
20+
var provider = collection.BuildServiceProvider();
21+
22+
var command = provider.GetRequiredService<MyCommand>();
23+
24+
Assert.Equal(0, command.Execute(new CommandContext([], Mock.Of<IRemainingArguments>(), "my", null),
25+
new MySetting { Base = "", Name = "" }));
26+
}
27+
}
28+
29+
public interface ISetting
30+
{
31+
string Name { get; set; }
32+
}
33+
34+
public class BaseSetting : CommandSettings
35+
{
36+
[CommandArgument(0, "<BASE>")]
37+
public required string Base { get; init; }
38+
}
39+
40+
public class MySetting : BaseSetting, ISetting
41+
{
42+
[CommandOption("--name")]
43+
public required string Name { get; set; }
44+
}
45+
46+
public class MyCommand : BaseCommand<MySetting> { }
47+
48+
public abstract class BaseCommand<TSettings> : Command<TSettings> where TSettings : BaseSetting, ISetting
49+
{
50+
public override int Execute(CommandContext context, TSettings settings)
51+
{
52+
Console.WriteLine($"Base: {settings.Base}, Name: {settings.Name}");
53+
return 0;
54+
}
55+
}
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Linq;
4+
using System.Reflection.Metadata;
5+
using System.Text;
6+
using Microsoft.CodeAnalysis;
7+
8+
namespace Devlooped.Extensions.DependencyInjection;
9+
10+
static class ConstraintsChecker
11+
{
12+
public static bool SatisfiesConstraints(this ITypeSymbol typeArgument, ITypeParameterSymbol typeParameter)
13+
{
14+
// Check reference type constraint
15+
if (typeParameter.HasReferenceTypeConstraint && !typeArgument.IsReferenceType)
16+
return false;
17+
18+
// Check value type constraint
19+
if (typeParameter.HasValueTypeConstraint && !typeArgument.IsValueType)
20+
return false;
21+
22+
// Check base class and interface constraints
23+
foreach (var constraint in typeParameter.ConstraintTypes)
24+
{
25+
if (constraint.TypeKind == TypeKind.Class)
26+
{
27+
if (!typeArgument.GetBaseTypes().Any(baseType => SymbolEqualityComparer.Default.Equals(baseType, constraint)))
28+
return false;
29+
}
30+
else if (constraint.TypeKind == TypeKind.Interface)
31+
{
32+
if (!typeArgument.AllInterfaces.Any(interfaceSymbol => SymbolEqualityComparer.Default.Equals(interfaceSymbol, constraint)))
33+
return false;
34+
}
35+
}
36+
37+
// Constructor constraint (optional, not typically needed here)
38+
if (typeParameter.HasConstructorConstraint)
39+
{
40+
// Check for parameterless constructor (simplified)
41+
var hasParameterlessConstructor = typeArgument.GetMembers(".ctor")
42+
.OfType<IMethodSymbol>()
43+
.Any(ctor => ctor.Parameters.Length == 0);
44+
if (!hasParameterlessConstructor)
45+
return false;
46+
}
47+
48+
return true;
49+
}
50+
51+
static IEnumerable<ITypeSymbol> GetBaseTypes(this ITypeSymbol typeSymbol)
52+
{
53+
var currentType = typeSymbol.BaseType;
54+
while (currentType != null && currentType.SpecialType != SpecialType.System_Object)
55+
{
56+
yield return currentType;
57+
currentType = currentType.BaseType;
58+
}
59+
}
60+
}

src/DependencyInjection/IncrementalGenerator.cs

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -452,6 +452,9 @@ void AddServices(IEnumerable<INamedTypeSymbol> services, Compilation compilation
452452

453453
foreach (var iface in type.AllInterfaces)
454454
{
455+
if (!compilation.HasImplicitConversion(type, iface))
456+
continue;
457+
455458
var ifaceName = iface.ToFullName(compilation);
456459
if (!registered.Contains(ifaceName))
457460
{
@@ -476,9 +479,11 @@ void AddServices(IEnumerable<INamedTypeSymbol> services, Compilation compilation
476479
baseType = baseType.BaseType;
477480
}
478481

479-
foreach (var candidate in candidates.Select(x => iface.ConstructedFrom.Construct(x))
482+
foreach (var candidate in candidates
483+
.Where(x => x.SatisfiesConstraints(iface.TypeParameters[0]))
484+
.Select(x => iface.ConstructedFrom.Construct(x))
485+
.Where(x => x != null && compilation.HasImplicitConversion(type, x))
480486
.ToImmutableHashSet(SymbolEqualityComparer.Default)
481-
.Where(x => x != null)
482487
.Select(x => x!.ToFullName(compilation)))
483488
{
484489
if (!registered.Contains(candidate))

0 commit comments

Comments
 (0)