Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/Abstractions/DurableTaskAttribute.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@ namespace Microsoft.DurableTask;
/// </summary>
/// <remarks>
/// This attribute is meant to be used on class definitions that derive from
/// <see cref="TaskOrchestrator{TInput, TOutput}"/> or <see cref="TaskActivity{TInput, TOutput}"/>.
/// <see cref="TaskOrchestrator{TInput, TOutput}"/>, <see cref="TaskActivity{TInput, TOutput}"/>,
/// or TaskEntity{TState} from the Microsoft.DurableTask.Entities namespace.
/// It is used specifically by build-time source generators to generate type-safe methods for invoking
/// orchestrations or activities.
/// orchestrations, activities, or registering entities.
/// </remarks>
[AttributeUsage(AttributeTargets.Class, AllowMultiple = false, Inherited = false)]
public sealed class DurableTaskAttribute : Attribute
Expand Down
3 changes: 2 additions & 1 deletion src/Generators/AzureFunctions/DurableFunction.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ public enum DurableFunctionKind
{
Unknown,
Orchestration,
Activity
Activity,
Entity
}

public class DurableFunction
Expand Down
9 changes: 8 additions & 1 deletion src/Generators/AzureFunctions/SyntaxNodeUtility.cs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,12 @@ public static bool TryGetFunctionKind(MethodDeclarationSyntax method, out Durabl
kind = DurableFunctionKind.Activity;
return true;
}

if (attribute.ToString().Equals("EntityTrigger", StringComparison.Ordinal))
{
kind = DurableFunctionKind.Entity;
return true;
}
}
}

Expand Down Expand Up @@ -125,7 +131,8 @@ public static bool TryGetParameter(
{
string attributeName = attribute.Name.ToString();
if ((kind == DurableFunctionKind.Activity && attributeName == "ActivityTrigger") ||
(kind == DurableFunctionKind.Orchestration && attributeName == "OrchestratorTrigger"))
(kind == DurableFunctionKind.Orchestration && attributeName == "OrchestratorTrigger") ||
(kind == DurableFunctionKind.Entity && attributeName == "EntityTrigger"))
{
TypeInfo info = model.GetTypeInfo(methodParam.Type);
if (info.Type is INamedTypeSymbol named)
Expand Down
99 changes: 80 additions & 19 deletions src/Generators/DurableTaskSourceGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context)

string className = classType.ToDisplayString();
INamedTypeSymbol? taskType = null;
bool isActivity = false;
DurableTaskKind kind = DurableTaskKind.Orchestrator;

INamedTypeSymbol? baseType = classType.BaseType;
while (baseType != null)
Expand All @@ -105,27 +105,51 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
if (baseType.Name == "TaskActivity")
{
taskType = baseType;
isActivity = true;
kind = DurableTaskKind.Activity;
break;
}
else if (baseType.Name == "TaskOrchestrator")
{
taskType = baseType;
isActivity = false;
kind = DurableTaskKind.Orchestrator;
break;
}
else if (baseType.Name == "TaskEntity")
{
taskType = baseType;
kind = DurableTaskKind.Entity;
break;
}
}

baseType = baseType.BaseType;
}

if (taskType == null || taskType.TypeParameters.Length <= 1)
// TaskEntity has 1 type parameter (TState), while TaskActivity and TaskOrchestrator have 2 (TInput, TOutput)
if (taskType == null)
{
return null;
}

ITypeSymbol inputType = taskType.TypeArguments.First();
ITypeSymbol outputType = taskType.TypeArguments.Last();
if (kind == DurableTaskKind.Entity)
{
// Entity only has a single TState type parameter
if (taskType.TypeParameters.Length < 1)
{
return null;
}
}
else
{
// Orchestrator and Activity have TInput and TOutput type parameters
if (taskType.TypeParameters.Length <= 1)
{
return null;
}
}

ITypeSymbol? inputType = kind == DurableTaskKind.Entity ? null : taskType.TypeArguments.First();
ITypeSymbol? outputType = kind == DurableTaskKind.Entity ? null : taskType.TypeArguments.Last();

string taskName = classType.Name;
if (attribute.ArgumentList?.Arguments.Count > 0)
Expand All @@ -134,7 +158,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
taskName = context.SemanticModel.GetConstantValue(expression).ToString();
}

return new DurableTaskTypeInfo(className, taskName, inputType, outputType, isActivity);
return new DurableTaskTypeInfo(className, taskName, inputType, outputType, kind);
}

static DurableFunction? GetDurableFunction(GeneratorSyntaxContext context)
Expand Down Expand Up @@ -165,23 +189,28 @@ static void Execute(
bool isDurableFunctions = compilation.ReferencedAssemblyNames.Any(
assembly => assembly.Name.Equals("Microsoft.Azure.Functions.Worker.Extensions.DurableTask", StringComparison.OrdinalIgnoreCase));

// Separate tasks into orchestrators and activities
// Separate tasks into orchestrators, activities, and entities
List<DurableTaskTypeInfo> orchestrators = new();
List<DurableTaskTypeInfo> activities = new();
List<DurableTaskTypeInfo> entities = new();

foreach (DurableTaskTypeInfo task in allTasks)
{
if (task.IsActivity)
{
activities.Add(task);
}
else if (task.IsEntity)
{
entities.Add(task);
}
else
{
orchestrators.Add(task);
}
}

int found = activities.Count + orchestrators.Count + allFunctions.Length;
int found = activities.Count + orchestrators.Count + entities.Count + allFunctions.Length;
if (found == 0)
{
return;
Expand Down Expand Up @@ -264,7 +293,8 @@ public static class GeneratedDurableTaskExtensions
AddRegistrationMethodForAllTasks(
sourceBuilder,
orchestrators,
activities);
activities,
entities);
}

sourceBuilder.AppendLine(" }").AppendLine("}");
Expand Down Expand Up @@ -368,7 +398,8 @@ public GeneratedActivityContext(TaskName name, string instanceId)
static void AddRegistrationMethodForAllTasks(
StringBuilder sourceBuilder,
IEnumerable<DurableTaskTypeInfo> orchestrators,
IEnumerable<DurableTaskTypeInfo> activities)
IEnumerable<DurableTaskTypeInfo> activities,
IEnumerable<DurableTaskTypeInfo> entities)
{
// internal so it does not conflict with other projects with this generated file.
sourceBuilder.Append($@"
Expand All @@ -387,39 +418,69 @@ internal static DurableTaskRegistry AddAllGeneratedTasks(this DurableTaskRegistr
builder.AddActivity<{taskInfo.TypeName}>();");
}

foreach (DurableTaskTypeInfo taskInfo in entities)
{
sourceBuilder.Append($@"
builder.AddEntity<{taskInfo.TypeName}>();");
}

sourceBuilder.AppendLine($@"
return builder;
}}");
}

enum DurableTaskKind
{
Orchestrator,
Activity,
Entity
}

class DurableTaskTypeInfo
{
public DurableTaskTypeInfo(
string taskType,
string taskName,
ITypeSymbol? inputType,
ITypeSymbol? outputType,
bool isActivity)
DurableTaskKind kind)
{
this.TypeName = taskType;
this.TaskName = taskName;
this.InputType = GetRenderedTypeExpression(inputType);
this.InputParameter = this.InputType + " input";
if (this.InputType[this.InputType.Length - 1] == '?')
this.Kind = kind;

// Entities only have a state type parameter, not input/output
if (kind == DurableTaskKind.Entity)
{
this.InputParameter += " = default";
this.InputType = string.Empty;
this.InputParameter = string.Empty;
this.OutputType = string.Empty;
}
else
{
this.InputType = GetRenderedTypeExpression(inputType);
this.InputParameter = this.InputType + " input";
if (this.InputType[this.InputType.Length - 1] == '?')
{
this.InputParameter += " = default";
}

this.OutputType = GetRenderedTypeExpression(outputType);
this.IsActivity = isActivity;
this.OutputType = GetRenderedTypeExpression(outputType);
}
}

public string TypeName { get; }
public string TaskName { get; }
public string InputType { get; }
public string InputParameter { get; }
public string OutputType { get; }
public bool IsActivity { get; }
public DurableTaskKind Kind { get; }

public bool IsActivity => this.Kind == DurableTaskKind.Activity;

public bool IsOrchestrator => this.Kind == DurableTaskKind.Orchestrator;

public bool IsEntity => this.Kind == DurableTaskKind.Entity;

static string GetRenderedTypeExpression(ITypeSymbol? symbol)
{
Expand Down
Loading
Loading