diff --git a/src/Generators/AzureFunctions/DurableFunction.cs b/src/Generators/AzureFunctions/DurableFunction.cs index 2f1af7dc..17c7417e 100644 --- a/src/Generators/AzureFunctions/DurableFunction.cs +++ b/src/Generators/AzureFunctions/DurableFunction.cs @@ -22,13 +22,15 @@ public class DurableFunction public DurableFunctionKind Kind { get; } public TypedParameter Parameter { get; } public string ReturnType { get; } + public bool ReturnsVoid { get; } public DurableFunction( string fullTypeName, string name, DurableFunctionKind kind, TypedParameter parameter, - ITypeSymbol returnType, + ITypeSymbol? returnType, + bool returnsVoid, HashSet requiredNamespaces) { this.FullTypeName = fullTypeName; @@ -36,7 +38,8 @@ public DurableFunction( this.Name = name; this.Kind = kind; this.Parameter = parameter; - this.ReturnType = SyntaxNodeUtility.GetRenderedTypeExpression(returnType, false); + this.ReturnType = returnType != null ? SyntaxNodeUtility.GetRenderedTypeExpression(returnType, false) : string.Empty; + this.ReturnsVoid = returnsVoid; } public static bool TryParse(SemanticModel model, MethodDeclarationSyntax method, out DurableFunction? function) @@ -59,12 +62,54 @@ public static bool TryParse(SemanticModel model, MethodDeclarationSyntax method, return false; } - INamedTypeSymbol taskSymbol = model.Compilation.GetTypeByMetadataName("System.Threading.Tasks.Task`1")!; - INamedTypeSymbol returnSymbol = (INamedTypeSymbol)model.GetTypeInfo(returnType).Type!; - if (SymbolEqualityComparer.Default.Equals(returnSymbol.OriginalDefinition, taskSymbol)) + ITypeSymbol? returnTypeSymbol = model.GetTypeInfo(returnType).Type; + if (returnTypeSymbol == null || returnTypeSymbol.TypeKind == TypeKind.Error) { - // this is a Task return value, lets pull out the generic. - returnSymbol = (INamedTypeSymbol)returnSymbol.TypeArguments[0]; + function = null; + return false; + } + + bool returnsVoid = false; + INamedTypeSymbol? returnSymbol = null; + + // Check if it's a void return type + if (returnTypeSymbol.SpecialType == SpecialType.System_Void) + { + returnsVoid = true; + // returnSymbol is left as null since void has no type to track + } + // Check if it's Task (non-generic) + else if (returnTypeSymbol is INamedTypeSymbol namedReturn) + { + INamedTypeSymbol? nonGenericTaskSymbol = model.Compilation.GetTypeByMetadataName("System.Threading.Tasks.Task"); + if (nonGenericTaskSymbol != null && SymbolEqualityComparer.Default.Equals(namedReturn, nonGenericTaskSymbol)) + { + returnsVoid = true; + // returnSymbol is left as null since Task (non-generic) has no return type to track + } + // Check if it's Task + else + { + INamedTypeSymbol? taskSymbol = model.Compilation.GetTypeByMetadataName("System.Threading.Tasks.Task`1"); + returnSymbol = namedReturn; + if (taskSymbol != null && SymbolEqualityComparer.Default.Equals(returnSymbol.OriginalDefinition, taskSymbol)) + { + // this is a Task return value, lets pull out the generic. + ITypeSymbol typeArg = returnSymbol.TypeArguments[0]; + if (typeArg is not INamedTypeSymbol namedTypeArg) + { + function = null; + return false; + } + returnSymbol = namedTypeArg; + } + } + } + else + { + // returnTypeSymbol is not INamedTypeSymbol, which is unexpected + function = null; + return false; } if (!SyntaxNodeUtility.TryGetParameter(model, method, kind, out TypedParameter? parameter) || parameter == null) @@ -79,12 +124,18 @@ public static bool TryParse(SemanticModel model, MethodDeclarationSyntax method, return false; } + // Build list of types used for namespace resolution List usedTypes = new() { - returnSymbol, parameter.Type }; + // Only include return type if it's not void + if (returnSymbol != null) + { + usedTypes.Add(returnSymbol); + } + if (!SyntaxNodeUtility.TryGetRequiredNamespaces(usedTypes, out HashSet? requiredNamespaces)) { function = null; @@ -93,7 +144,7 @@ public static bool TryParse(SemanticModel model, MethodDeclarationSyntax method, requiredNamespaces!.UnionWith(GetRequiredGlobalNamespaces()); - function = new DurableFunction(fullTypeName!, name, kind, parameter, returnSymbol, requiredNamespaces); + function = new DurableFunction(fullTypeName!, name, kind, parameter, returnSymbol, returnsVoid, requiredNamespaces); return true; } diff --git a/src/Generators/AzureFunctions/TypedParameter.cs b/src/Generators/AzureFunctions/TypedParameter.cs index f6bc7ed8..4860905a 100644 --- a/src/Generators/AzureFunctions/TypedParameter.cs +++ b/src/Generators/AzureFunctions/TypedParameter.cs @@ -19,7 +19,17 @@ public TypedParameter(INamedTypeSymbol type, string name) public override string ToString() { - return $"{SyntaxNodeUtility.GetRenderedTypeExpression(this.Type, false)} {this.Name}"; + // Use the type as-is, preserving the nullability annotation from the source + string typeExpression = SyntaxNodeUtility.GetRenderedTypeExpression(this.Type, false); + + // Special case: if the type is exactly System.Object (not a nullable object), make it nullable + // This is because object parameters are typically nullable in the context of Durable Functions + if (this.Type.SpecialType == SpecialType.System_Object && this.Type.NullableAnnotation != NullableAnnotation.Annotated) + { + typeExpression = "object?"; + } + + return $"{typeExpression} {this.Name}"; } } } diff --git a/src/Generators/DurableTaskSourceGenerator.cs b/src/Generators/DurableTaskSourceGenerator.cs index 0b4e717e..29de835e 100644 --- a/src/Generators/DurableTaskSourceGenerator.cs +++ b/src/Generators/DurableTaskSourceGenerator.cs @@ -435,7 +435,21 @@ static void AddActivityCallMethod(StringBuilder sourceBuilder, DurableTaskTypeIn static void AddActivityCallMethod(StringBuilder sourceBuilder, DurableFunction activity) { - sourceBuilder.AppendLine($@" + if (activity.ReturnsVoid) + { + sourceBuilder.AppendLine($@" + /// + /// Calls the activity. + /// + /// + public static Task Call{activity.Name}Async(this TaskOrchestrationContext ctx, {activity.Parameter}, TaskOptions? options = null) + {{ + return ctx.CallActivityAsync(""{activity.Name}"", {activity.Parameter.Name}, options); + }}"); + } + else + { + sourceBuilder.AppendLine($@" /// /// Calls the activity. /// @@ -444,6 +458,7 @@ static void AddActivityCallMethod(StringBuilder sourceBuilder, DurableFunction a {{ return ctx.CallActivityAsync<{activity.ReturnType}>(""{activity.Name}"", {activity.Parameter.Name}, options); }}"); + } } static void AddEventWaitMethod(StringBuilder sourceBuilder, DurableEventTypeInfo eventInfo) diff --git a/test/Generators.Tests/AzureFunctionsTests.cs b/test/Generators.Tests/AzureFunctionsTests.cs index d9d7fad0..3a02eeee 100644 --- a/test/Generators.Tests/AzureFunctionsTests.cs +++ b/test/Generators.Tests/AzureFunctionsTests.cs @@ -117,6 +117,79 @@ await TestHelpers.RunTestAsync( isDurableFunctions: true); } + [Fact] + public async Task Activities_SimpleFunctionTrigger_VoidReturn() + { + string code = @" +using Microsoft.Azure.Functions.Worker; +using Microsoft.DurableTask; + +public class Activities +{ + [Function(nameof(FlakeyActivity))] + public static void FlakeyActivity([ActivityTrigger] object _) + { + throw new System.ApplicationException(""Kah-BOOOOM!!!""); + } +}"; + + string expectedOutput = TestHelpers.WrapAndFormat( + GeneratedClassName, + methodList: @" +/// +/// Calls the activity. +/// +/// +public static Task CallFlakeyActivityAsync(this TaskOrchestrationContext ctx, object? _, TaskOptions? options = null) +{ + return ctx.CallActivityAsync(""FlakeyActivity"", _, options); +}", + isDurableFunctions: true); + + await TestHelpers.RunTestAsync( + GeneratedFileName, + code, + expectedOutput, + isDurableFunctions: true); + } + + [Fact] + public async Task Activities_SimpleFunctionTrigger_TaskReturn() + { + string code = @" +using System.Threading.Tasks; +using Microsoft.Azure.Functions.Worker; +using Microsoft.DurableTask; + +public class Activities +{ + [Function(nameof(FlakeyActivity))] + public static Task FlakeyActivity([ActivityTrigger] object _) + { + throw new System.ApplicationException(""Kah-BOOOOM!!!""); + } +}"; + + string expectedOutput = TestHelpers.WrapAndFormat( + GeneratedClassName, + methodList: @" +/// +/// Calls the activity. +/// +/// +public static Task CallFlakeyActivityAsync(this TaskOrchestrationContext ctx, object? _, TaskOptions? options = null) +{ + return ctx.CallActivityAsync(""FlakeyActivity"", _, options); +}", + isDurableFunctions: true); + + await TestHelpers.RunTestAsync( + GeneratedFileName, + code, + expectedOutput, + isDurableFunctions: true); + } + /// /// Verifies that using the class-based activity syntax generates a /// extension method as well as an function definition.