diff --git a/src/Dapr.Actors.Generators/ActorClientGenerator.cs b/src/Dapr.Actors.Generators/ActorClientGenerator.cs index 0f064e801..282dc682b 100644 --- a/src/Dapr.Actors.Generators/ActorClientGenerator.cs +++ b/src/Dapr.Actors.Generators/ActorClientGenerator.cs @@ -164,6 +164,65 @@ private static void GenerateActorClientCode(SourceProductionContext context, Act var actorClientClassTypeParameters = descriptor.InterfaceType.TypeParameters .Select(x => SyntaxFactory.TypeParameter(x.ToString())); + // Create constraint clauses for type parameters + var constraintClauses = new List(); + + // For each type parameter, create constraint clauses based on the interface's constraints + foreach (var typeParam in descriptor.InterfaceType.TypeParameters) + { + if (typeParam.HasReferenceTypeConstraint || + typeParam.HasValueTypeConstraint || + typeParam.HasUnmanagedTypeConstraint || + typeParam.HasNotNullConstraint || + typeParam.ConstraintTypes.Length > 0) + { + var constraints = new List(); + + // Add class/struct constraints + if (typeParam.HasReferenceTypeConstraint) + { + constraints.Add(SyntaxFactory.ClassOrStructConstraint(SyntaxKind.ClassConstraint)); + } + else if (typeParam.HasValueTypeConstraint) + { + constraints.Add(SyntaxFactory.ClassOrStructConstraint(SyntaxKind.StructConstraint)); + else if (typeParam.HasValueTypeConstraint && !typeParam.HasUnmanagedTypeConstraint) + { + constraints.Add(SyntaxFactory.ClassOrStructConstraint(SyntaxKind.StructConstraint)); + } + + // Add unmanaged constraint + if (typeParam.HasUnmanagedTypeConstraint) + { + constraints.Add(SyntaxFactory.TypeConstraint(SyntaxFactory.IdentifierName("unmanaged"))); + } + + // Add type constraints (e.g., where T : IInterface) + foreach (var constraintType in typeParam.ConstraintTypes) + { + constraints.Add(SyntaxFactory.TypeConstraint( + SyntaxFactory.ParseTypeName(constraintType.ToString()))); + } + + // Add notnull constraint + if (typeParam.HasNotNullConstraint) + { + constraints.Add(SyntaxFactory.TypeConstraint(SyntaxFactory.IdentifierName("notnull"))); + } + + // Add new() constraint - must be last + if (typeParam.HasConstructorConstraint) + { + constraints.Add(SyntaxFactory.ConstructorConstraint()); + } + + constraintClauses.Add( + SyntaxFactory.TypeParameterConstraintClause( + SyntaxFactory.IdentifierName(typeParam.Name), + SyntaxFactory.SeparatedList(constraints))); + } + } + var actorClientClassDeclaration = (actorClientClassTypeParameters.Count() == 0) ? SyntaxFactory.ClassDeclaration(descriptor.ClientTypeName) .WithModifiers(SyntaxFactory.TokenList(actorClientClassModifiers)) @@ -174,6 +233,7 @@ private static void GenerateActorClientCode(SourceProductionContext context, Act : SyntaxFactory.ClassDeclaration(descriptor.ClientTypeName) .WithModifiers(SyntaxFactory.TokenList(actorClientClassModifiers)) .WithTypeParameterList(SyntaxFactory.TypeParameterList(SyntaxFactory.SeparatedList(actorClientClassTypeParameters))) + .WithConstraintClauses(SyntaxFactory.List(constraintClauses)) // Add constraint clauses to the class .WithMembers(SyntaxFactory.List(actorMembers)) .WithBaseList(SyntaxFactory.BaseList( SyntaxFactory.Token(SyntaxKind.ColonToken), diff --git a/test/Dapr.Actors.Generators.Test/ActorClientGeneratorTests.cs b/test/Dapr.Actors.Generators.Test/ActorClientGeneratorTests.cs index 3515bc8b0..2841f7b68 100644 --- a/test/Dapr.Actors.Generators.Test/ActorClientGeneratorTests.cs +++ b/test/Dapr.Actors.Generators.Test/ActorClientGeneratorTests.cs @@ -877,4 +877,120 @@ public interface ITestActor await test.RunAsync(); } + + [Fact] + public async Task TestGenericWithConstraints() + { + var originalSource = @" +using Dapr.Actors.Generators; +using System.Threading.Tasks; + +namespace Test +{ + public interface ITrait + { + string GetName(); + } + + [GenerateActorClient] + public interface ITestActor where TTrait : ITrait + { + Task SetTrait(TTrait trait); + Task GetTrait(string name); + } +}"; + + var generatedSource = @"// +#nullable enable +namespace Test +{ + public sealed class TestActorClient : Test.ITestActor where TTrait : Test.ITrait + { + private readonly Dapr.Actors.Client.ActorProxy actorProxy; + public TestActorClient(Dapr.Actors.Client.ActorProxy actorProxy) + { + if (actorProxy is null) + { + throw new System.ArgumentNullException(nameof(actorProxy)); + } + + this.actorProxy = actorProxy; + } + + public System.Threading.Tasks.Task GetTrait(string name) + { + return this.actorProxy.InvokeMethodAsync(""GetTrait"", name); + } + + public System.Threading.Tasks.Task SetTrait(TTrait trait) + { + return this.actorProxy.InvokeMethodAsync(""SetTrait"", trait); + } + } +}"; + + await CreateTest(originalSource, "Test.TestActorClient.g.cs", generatedSource).RunAsync(); + } + + [Fact] + public async Task TestGenericWithMultipleTypeParametersAndConstraints() + { + var originalSource = @" +using Dapr.Actors.Generators; +using System.Threading.Tasks; + +namespace Test +{ + public interface ITrait + { + string GetName(); + } + + public interface IValidator + { + bool Validate(T item); + } + + [GenerateActorClient] + public interface ITestActor + where TTrait : ITrait, new() + where TValidator : class, IValidator + { + Task SetTrait(TTrait trait); + Task GetValidatedTrait(string name); + } +}"; + + var generatedSource = @"// +#nullable enable +namespace Test +{ + public sealed class TestActorClient : Test.ITestActor where TTrait : Test.ITrait, new() + where TValidator : class, Test.IValidator + { + private readonly Dapr.Actors.Client.ActorProxy actorProxy; + public TestActorClient(Dapr.Actors.Client.ActorProxy actorProxy) + { + if (actorProxy is null) + { + throw new System.ArgumentNullException(nameof(actorProxy)); + } + + this.actorProxy = actorProxy; + } + + public System.Threading.Tasks.Task GetValidatedTrait(string name) + { + return this.actorProxy.InvokeMethodAsync(""GetValidatedTrait"", name); + } + + public System.Threading.Tasks.Task SetTrait(TTrait trait) + { + return this.actorProxy.InvokeMethodAsync(""SetTrait"", trait); + } + } +}"; + + await CreateTest(originalSource, "Test.TestActorClient.g.cs", generatedSource).RunAsync(); + } }