Skip to content

Commit e3b230e

Browse files
CopilotYunchuWang
andauthored
Fix source generator for void-returning activity functions (#554)
* Initial plan * Fix source generator to handle void-returning activity functions Co-authored-by: YunchuWang <[email protected]> * Address code review feedback - use robust type comparisons Co-authored-by: YunchuWang <[email protected]> * Address code review comments - improve null safety and avoid fragile semantics Co-authored-by: YunchuWang <[email protected]> * Remove unnecessary cast - GetSpecialType returns INamedTypeSymbol Co-authored-by: YunchuWang <[email protected]> * Address review feedback - add TypeKind.Error check and make returnType nullable Co-authored-by: YunchuWang <[email protected]> --------- Co-authored-by: copilot-swe-agent[bot] <[email protected]> Co-authored-by: YunchuWang <[email protected]> Co-authored-by: wangbill <[email protected]>
1 parent d836f18 commit e3b230e

File tree

4 files changed

+160
-11
lines changed

4 files changed

+160
-11
lines changed

src/Generators/AzureFunctions/DurableFunction.cs

Lines changed: 60 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,21 +22,24 @@ public class DurableFunction
2222
public DurableFunctionKind Kind { get; }
2323
public TypedParameter Parameter { get; }
2424
public string ReturnType { get; }
25+
public bool ReturnsVoid { get; }
2526

2627
public DurableFunction(
2728
string fullTypeName,
2829
string name,
2930
DurableFunctionKind kind,
3031
TypedParameter parameter,
31-
ITypeSymbol returnType,
32+
ITypeSymbol? returnType,
33+
bool returnsVoid,
3234
HashSet<string> requiredNamespaces)
3335
{
3436
this.FullTypeName = fullTypeName;
3537
this.RequiredNamespaces = requiredNamespaces;
3638
this.Name = name;
3739
this.Kind = kind;
3840
this.Parameter = parameter;
39-
this.ReturnType = SyntaxNodeUtility.GetRenderedTypeExpression(returnType, false);
41+
this.ReturnType = returnType != null ? SyntaxNodeUtility.GetRenderedTypeExpression(returnType, false) : string.Empty;
42+
this.ReturnsVoid = returnsVoid;
4043
}
4144

4245
public static bool TryParse(SemanticModel model, MethodDeclarationSyntax method, out DurableFunction? function)
@@ -59,12 +62,54 @@ public static bool TryParse(SemanticModel model, MethodDeclarationSyntax method,
5962
return false;
6063
}
6164

62-
INamedTypeSymbol taskSymbol = model.Compilation.GetTypeByMetadataName("System.Threading.Tasks.Task`1")!;
63-
INamedTypeSymbol returnSymbol = (INamedTypeSymbol)model.GetTypeInfo(returnType).Type!;
64-
if (SymbolEqualityComparer.Default.Equals(returnSymbol.OriginalDefinition, taskSymbol))
65+
ITypeSymbol? returnTypeSymbol = model.GetTypeInfo(returnType).Type;
66+
if (returnTypeSymbol == null || returnTypeSymbol.TypeKind == TypeKind.Error)
6567
{
66-
// this is a Task<T> return value, lets pull out the generic.
67-
returnSymbol = (INamedTypeSymbol)returnSymbol.TypeArguments[0];
68+
function = null;
69+
return false;
70+
}
71+
72+
bool returnsVoid = false;
73+
INamedTypeSymbol? returnSymbol = null;
74+
75+
// Check if it's a void return type
76+
if (returnTypeSymbol.SpecialType == SpecialType.System_Void)
77+
{
78+
returnsVoid = true;
79+
// returnSymbol is left as null since void has no type to track
80+
}
81+
// Check if it's Task (non-generic)
82+
else if (returnTypeSymbol is INamedTypeSymbol namedReturn)
83+
{
84+
INamedTypeSymbol? nonGenericTaskSymbol = model.Compilation.GetTypeByMetadataName("System.Threading.Tasks.Task");
85+
if (nonGenericTaskSymbol != null && SymbolEqualityComparer.Default.Equals(namedReturn, nonGenericTaskSymbol))
86+
{
87+
returnsVoid = true;
88+
// returnSymbol is left as null since Task (non-generic) has no return type to track
89+
}
90+
// Check if it's Task<T>
91+
else
92+
{
93+
INamedTypeSymbol? taskSymbol = model.Compilation.GetTypeByMetadataName("System.Threading.Tasks.Task`1");
94+
returnSymbol = namedReturn;
95+
if (taskSymbol != null && SymbolEqualityComparer.Default.Equals(returnSymbol.OriginalDefinition, taskSymbol))
96+
{
97+
// this is a Task<T> return value, lets pull out the generic.
98+
ITypeSymbol typeArg = returnSymbol.TypeArguments[0];
99+
if (typeArg is not INamedTypeSymbol namedTypeArg)
100+
{
101+
function = null;
102+
return false;
103+
}
104+
returnSymbol = namedTypeArg;
105+
}
106+
}
107+
}
108+
else
109+
{
110+
// returnTypeSymbol is not INamedTypeSymbol, which is unexpected
111+
function = null;
112+
return false;
68113
}
69114

70115
if (!SyntaxNodeUtility.TryGetParameter(model, method, kind, out TypedParameter? parameter) || parameter == null)
@@ -79,12 +124,18 @@ public static bool TryParse(SemanticModel model, MethodDeclarationSyntax method,
79124
return false;
80125
}
81126

127+
// Build list of types used for namespace resolution
82128
List<INamedTypeSymbol> usedTypes = new()
83129
{
84-
returnSymbol,
85130
parameter.Type
86131
};
87132

133+
// Only include return type if it's not void
134+
if (returnSymbol != null)
135+
{
136+
usedTypes.Add(returnSymbol);
137+
}
138+
88139
if (!SyntaxNodeUtility.TryGetRequiredNamespaces(usedTypes, out HashSet<string>? requiredNamespaces))
89140
{
90141
function = null;
@@ -93,7 +144,7 @@ public static bool TryParse(SemanticModel model, MethodDeclarationSyntax method,
93144

94145
requiredNamespaces!.UnionWith(GetRequiredGlobalNamespaces());
95146

96-
function = new DurableFunction(fullTypeName!, name, kind, parameter, returnSymbol, requiredNamespaces);
147+
function = new DurableFunction(fullTypeName!, name, kind, parameter, returnSymbol, returnsVoid, requiredNamespaces);
97148
return true;
98149
}
99150

src/Generators/AzureFunctions/TypedParameter.cs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,17 @@ public TypedParameter(INamedTypeSymbol type, string name)
1919

2020
public override string ToString()
2121
{
22-
return $"{SyntaxNodeUtility.GetRenderedTypeExpression(this.Type, false)} {this.Name}";
22+
// Use the type as-is, preserving the nullability annotation from the source
23+
string typeExpression = SyntaxNodeUtility.GetRenderedTypeExpression(this.Type, false);
24+
25+
// Special case: if the type is exactly System.Object (not a nullable object), make it nullable
26+
// This is because object parameters are typically nullable in the context of Durable Functions
27+
if (this.Type.SpecialType == SpecialType.System_Object && this.Type.NullableAnnotation != NullableAnnotation.Annotated)
28+
{
29+
typeExpression = "object?";
30+
}
31+
32+
return $"{typeExpression} {this.Name}";
2333
}
2434
}
2535
}

src/Generators/DurableTaskSourceGenerator.cs

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -435,7 +435,21 @@ static void AddActivityCallMethod(StringBuilder sourceBuilder, DurableTaskTypeIn
435435

436436
static void AddActivityCallMethod(StringBuilder sourceBuilder, DurableFunction activity)
437437
{
438-
sourceBuilder.AppendLine($@"
438+
if (activity.ReturnsVoid)
439+
{
440+
sourceBuilder.AppendLine($@"
441+
/// <summary>
442+
/// Calls the <see cref=""{activity.FullTypeName}""/> activity.
443+
/// </summary>
444+
/// <inheritdoc cref=""TaskOrchestrationContext.CallActivityAsync(TaskName, object?, TaskOptions?)""/>
445+
public static Task Call{activity.Name}Async(this TaskOrchestrationContext ctx, {activity.Parameter}, TaskOptions? options = null)
446+
{{
447+
return ctx.CallActivityAsync(""{activity.Name}"", {activity.Parameter.Name}, options);
448+
}}");
449+
}
450+
else
451+
{
452+
sourceBuilder.AppendLine($@"
439453
/// <summary>
440454
/// Calls the <see cref=""{activity.FullTypeName}""/> activity.
441455
/// </summary>
@@ -444,6 +458,7 @@ static void AddActivityCallMethod(StringBuilder sourceBuilder, DurableFunction a
444458
{{
445459
return ctx.CallActivityAsync<{activity.ReturnType}>(""{activity.Name}"", {activity.Parameter.Name}, options);
446460
}}");
461+
}
447462
}
448463

449464
static void AddEventWaitMethod(StringBuilder sourceBuilder, DurableEventTypeInfo eventInfo)

test/Generators.Tests/AzureFunctionsTests.cs

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,79 @@ await TestHelpers.RunTestAsync<DurableTaskSourceGenerator>(
117117
isDurableFunctions: true);
118118
}
119119

120+
[Fact]
121+
public async Task Activities_SimpleFunctionTrigger_VoidReturn()
122+
{
123+
string code = @"
124+
using Microsoft.Azure.Functions.Worker;
125+
using Microsoft.DurableTask;
126+
127+
public class Activities
128+
{
129+
[Function(nameof(FlakeyActivity))]
130+
public static void FlakeyActivity([ActivityTrigger] object _)
131+
{
132+
throw new System.ApplicationException(""Kah-BOOOOM!!!"");
133+
}
134+
}";
135+
136+
string expectedOutput = TestHelpers.WrapAndFormat(
137+
GeneratedClassName,
138+
methodList: @"
139+
/// <summary>
140+
/// Calls the <see cref=""Activities.FlakeyActivity""/> activity.
141+
/// </summary>
142+
/// <inheritdoc cref=""TaskOrchestrationContext.CallActivityAsync(TaskName, object?, TaskOptions?)""/>
143+
public static Task CallFlakeyActivityAsync(this TaskOrchestrationContext ctx, object? _, TaskOptions? options = null)
144+
{
145+
return ctx.CallActivityAsync(""FlakeyActivity"", _, options);
146+
}",
147+
isDurableFunctions: true);
148+
149+
await TestHelpers.RunTestAsync<DurableTaskSourceGenerator>(
150+
GeneratedFileName,
151+
code,
152+
expectedOutput,
153+
isDurableFunctions: true);
154+
}
155+
156+
[Fact]
157+
public async Task Activities_SimpleFunctionTrigger_TaskReturn()
158+
{
159+
string code = @"
160+
using System.Threading.Tasks;
161+
using Microsoft.Azure.Functions.Worker;
162+
using Microsoft.DurableTask;
163+
164+
public class Activities
165+
{
166+
[Function(nameof(FlakeyActivity))]
167+
public static Task FlakeyActivity([ActivityTrigger] object _)
168+
{
169+
throw new System.ApplicationException(""Kah-BOOOOM!!!"");
170+
}
171+
}";
172+
173+
string expectedOutput = TestHelpers.WrapAndFormat(
174+
GeneratedClassName,
175+
methodList: @"
176+
/// <summary>
177+
/// Calls the <see cref=""Activities.FlakeyActivity""/> activity.
178+
/// </summary>
179+
/// <inheritdoc cref=""TaskOrchestrationContext.CallActivityAsync(TaskName, object?, TaskOptions?)""/>
180+
public static Task CallFlakeyActivityAsync(this TaskOrchestrationContext ctx, object? _, TaskOptions? options = null)
181+
{
182+
return ctx.CallActivityAsync(""FlakeyActivity"", _, options);
183+
}",
184+
isDurableFunctions: true);
185+
186+
await TestHelpers.RunTestAsync<DurableTaskSourceGenerator>(
187+
GeneratedFileName,
188+
code,
189+
expectedOutput,
190+
isDurableFunctions: true);
191+
}
192+
120193
/// <summary>
121194
/// Verifies that using the class-based activity syntax generates a <see cref="TaskOrchestrationContext"/>
122195
/// extension method as well as an <see cref="ActivityTriggerAttribute"/> function definition.

0 commit comments

Comments
 (0)