@@ -95,7 +95,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
9595
9696 string className = classType . ToDisplayString ( ) ;
9797 INamedTypeSymbol ? taskType = null ;
98- bool isActivity = false ;
98+ DurableTaskKind kind = DurableTaskKind . Orchestrator ;
9999
100100 INamedTypeSymbol ? baseType = classType . BaseType ;
101101 while ( baseType != null )
@@ -105,27 +105,51 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
105105 if ( baseType . Name == "TaskActivity" )
106106 {
107107 taskType = baseType ;
108- isActivity = true ;
108+ kind = DurableTaskKind . Activity ;
109109 break ;
110110 }
111111 else if ( baseType . Name == "TaskOrchestrator" )
112112 {
113113 taskType = baseType ;
114- isActivity = false ;
114+ kind = DurableTaskKind . Orchestrator ;
115+ break ;
116+ }
117+ else if ( baseType . Name == "TaskEntity" )
118+ {
119+ taskType = baseType ;
120+ kind = DurableTaskKind . Entity ;
115121 break ;
116122 }
117123 }
118124
119125 baseType = baseType . BaseType ;
120126 }
121127
122- if ( taskType == null || taskType . TypeParameters . Length <= 1 )
128+ // TaskEntity has 1 type parameter (TState), while TaskActivity and TaskOrchestrator have 2 (TInput, TOutput)
129+ if ( taskType == null )
123130 {
124131 return null ;
125132 }
126133
127- ITypeSymbol inputType = taskType . TypeArguments . First ( ) ;
128- ITypeSymbol outputType = taskType . TypeArguments . Last ( ) ;
134+ if ( kind == DurableTaskKind . Entity )
135+ {
136+ // Entity only has a single TState type parameter
137+ if ( taskType . TypeParameters . Length < 1 )
138+ {
139+ return null ;
140+ }
141+ }
142+ else
143+ {
144+ // Orchestrator and Activity have TInput and TOutput type parameters
145+ if ( taskType . TypeParameters . Length <= 1 )
146+ {
147+ return null ;
148+ }
149+ }
150+
151+ ITypeSymbol ? inputType = kind == DurableTaskKind . Entity ? null : taskType . TypeArguments . First ( ) ;
152+ ITypeSymbol ? outputType = kind == DurableTaskKind . Entity ? null : taskType . TypeArguments . Last ( ) ;
129153
130154 string taskName = classType . Name ;
131155 if ( attribute . ArgumentList ? . Arguments . Count > 0 )
@@ -134,7 +158,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
134158 taskName = context . SemanticModel . GetConstantValue ( expression ) . ToString ( ) ;
135159 }
136160
137- return new DurableTaskTypeInfo ( className , taskName , inputType , outputType , isActivity ) ;
161+ return new DurableTaskTypeInfo ( className , taskName , inputType , outputType , kind ) ;
138162 }
139163
140164 static DurableFunction ? GetDurableFunction ( GeneratorSyntaxContext context )
@@ -165,23 +189,28 @@ static void Execute(
165189 bool isDurableFunctions = compilation . ReferencedAssemblyNames . Any (
166190 assembly => assembly . Name . Equals ( "Microsoft.Azure.Functions.Worker.Extensions.DurableTask" , StringComparison . OrdinalIgnoreCase ) ) ;
167191
168- // Separate tasks into orchestrators and activities
192+ // Separate tasks into orchestrators, activities, and entities
169193 List < DurableTaskTypeInfo > orchestrators = new ( ) ;
170194 List < DurableTaskTypeInfo > activities = new ( ) ;
195+ List < DurableTaskTypeInfo > entities = new ( ) ;
171196
172197 foreach ( DurableTaskTypeInfo task in allTasks )
173198 {
174199 if ( task . IsActivity )
175200 {
176201 activities . Add ( task ) ;
177202 }
203+ else if ( task . IsEntity )
204+ {
205+ entities . Add ( task ) ;
206+ }
178207 else
179208 {
180209 orchestrators . Add ( task ) ;
181210 }
182211 }
183212
184- int found = activities . Count + orchestrators . Count + allFunctions . Length ;
213+ int found = activities . Count + orchestrators . Count + entities . Count + allFunctions . Length ;
185214 if ( found == 0 )
186215 {
187216 return ;
@@ -264,7 +293,8 @@ public static class GeneratedDurableTaskExtensions
264293 AddRegistrationMethodForAllTasks (
265294 sourceBuilder ,
266295 orchestrators ,
267- activities ) ;
296+ activities ,
297+ entities ) ;
268298 }
269299
270300 sourceBuilder . AppendLine ( " }" ) . AppendLine ( "}" ) ;
@@ -368,7 +398,8 @@ public GeneratedActivityContext(TaskName name, string instanceId)
368398 static void AddRegistrationMethodForAllTasks (
369399 StringBuilder sourceBuilder ,
370400 IEnumerable < DurableTaskTypeInfo > orchestrators ,
371- IEnumerable < DurableTaskTypeInfo > activities )
401+ IEnumerable < DurableTaskTypeInfo > activities ,
402+ IEnumerable < DurableTaskTypeInfo > entities )
372403 {
373404 // internal so it does not conflict with other projects with this generated file.
374405 sourceBuilder . Append ( $@ "
@@ -387,39 +418,69 @@ internal static DurableTaskRegistry AddAllGeneratedTasks(this DurableTaskRegistr
387418 builder.AddActivity<{ taskInfo . TypeName } >();" ) ;
388419 }
389420
421+ foreach ( DurableTaskTypeInfo taskInfo in entities )
422+ {
423+ sourceBuilder . Append ( $@ "
424+ builder.AddEntity<{ taskInfo . TypeName } >();" ) ;
425+ }
426+
390427 sourceBuilder . AppendLine ( $@ "
391428 return builder;
392429 }}" ) ;
393430 }
394431
432+ enum DurableTaskKind
433+ {
434+ Orchestrator ,
435+ Activity ,
436+ Entity
437+ }
438+
395439 class DurableTaskTypeInfo
396440 {
397441 public DurableTaskTypeInfo (
398442 string taskType ,
399443 string taskName ,
400444 ITypeSymbol ? inputType ,
401445 ITypeSymbol ? outputType ,
402- bool isActivity )
446+ DurableTaskKind kind )
403447 {
404448 this . TypeName = taskType ;
405449 this . TaskName = taskName ;
406- this . InputType = GetRenderedTypeExpression ( inputType ) ;
407- this . InputParameter = this . InputType + " input" ;
408- if ( this . InputType [ this . InputType . Length - 1 ] == '?' )
450+ this . Kind = kind ;
451+
452+ // Entities only have a state type parameter, not input/output
453+ if ( kind == DurableTaskKind . Entity )
409454 {
410- this . InputParameter += " = default" ;
455+ this . InputType = string . Empty ;
456+ this . InputParameter = string . Empty ;
457+ this . OutputType = string . Empty ;
411458 }
459+ else
460+ {
461+ this . InputType = GetRenderedTypeExpression ( inputType ) ;
462+ this . InputParameter = this . InputType + " input" ;
463+ if ( this . InputType [ this . InputType . Length - 1 ] == '?' )
464+ {
465+ this . InputParameter += " = default" ;
466+ }
412467
413- this . OutputType = GetRenderedTypeExpression ( outputType ) ;
414- this . IsActivity = isActivity ;
468+ this . OutputType = GetRenderedTypeExpression ( outputType ) ;
469+ }
415470 }
416471
417472 public string TypeName { get ; }
418473 public string TaskName { get ; }
419474 public string InputType { get ; }
420475 public string InputParameter { get ; }
421476 public string OutputType { get ; }
422- public bool IsActivity { get ; }
477+ public DurableTaskKind Kind { get ; }
478+
479+ public bool IsActivity => this . Kind == DurableTaskKind . Activity ;
480+
481+ public bool IsOrchestrator => this . Kind == DurableTaskKind . Orchestrator ;
482+
483+ public bool IsEntity => this . Kind == DurableTaskKind . Entity ;
423484
424485 static string GetRenderedTypeExpression ( ITypeSymbol ? symbol )
425486 {
0 commit comments