@@ -22,21 +22,24 @@ public class DurableFunction
2222 public DurableFunctionKind Kind { get ; }
2323 public TypedParameter Parameter { get ; }
2424 public string ReturnType { get ; }
25+ public bool ReturnsVoid { get ; }
2526
2627 public DurableFunction (
2728 string fullTypeName ,
2829 string name ,
2930 DurableFunctionKind kind ,
3031 TypedParameter parameter ,
31- ITypeSymbol returnType ,
32+ ITypeSymbol ? returnType ,
33+ bool returnsVoid ,
3234 HashSet < string > requiredNamespaces )
3335 {
3436 this . FullTypeName = fullTypeName ;
3537 this . RequiredNamespaces = requiredNamespaces ;
3638 this . Name = name ;
3739 this . Kind = kind ;
3840 this . Parameter = parameter ;
39- this . ReturnType = SyntaxNodeUtility . GetRenderedTypeExpression ( returnType , false ) ;
41+ this . ReturnType = returnType != null ? SyntaxNodeUtility . GetRenderedTypeExpression ( returnType , false ) : string . Empty ;
42+ this . ReturnsVoid = returnsVoid ;
4043 }
4144
4245 public static bool TryParse ( SemanticModel model , MethodDeclarationSyntax method , out DurableFunction ? function )
@@ -59,12 +62,54 @@ public static bool TryParse(SemanticModel model, MethodDeclarationSyntax method,
5962 return false ;
6063 }
6164
62- INamedTypeSymbol taskSymbol = model . Compilation . GetTypeByMetadataName ( "System.Threading.Tasks.Task`1" ) ! ;
63- INamedTypeSymbol returnSymbol = ( INamedTypeSymbol ) model . GetTypeInfo ( returnType ) . Type ! ;
64- if ( SymbolEqualityComparer . Default . Equals ( returnSymbol . OriginalDefinition , taskSymbol ) )
65+ ITypeSymbol ? returnTypeSymbol = model . GetTypeInfo ( returnType ) . Type ;
66+ if ( returnTypeSymbol == null || returnTypeSymbol . TypeKind == TypeKind . Error )
6567 {
66- // this is a Task<T> return value, lets pull out the generic.
67- returnSymbol = ( INamedTypeSymbol ) returnSymbol . TypeArguments [ 0 ] ;
68+ function = null ;
69+ return false ;
70+ }
71+
72+ bool returnsVoid = false ;
73+ INamedTypeSymbol ? returnSymbol = null ;
74+
75+ // Check if it's a void return type
76+ if ( returnTypeSymbol . SpecialType == SpecialType . System_Void )
77+ {
78+ returnsVoid = true ;
79+ // returnSymbol is left as null since void has no type to track
80+ }
81+ // Check if it's Task (non-generic)
82+ else if ( returnTypeSymbol is INamedTypeSymbol namedReturn )
83+ {
84+ INamedTypeSymbol ? nonGenericTaskSymbol = model . Compilation . GetTypeByMetadataName ( "System.Threading.Tasks.Task" ) ;
85+ if ( nonGenericTaskSymbol != null && SymbolEqualityComparer . Default . Equals ( namedReturn , nonGenericTaskSymbol ) )
86+ {
87+ returnsVoid = true ;
88+ // returnSymbol is left as null since Task (non-generic) has no return type to track
89+ }
90+ // Check if it's Task<T>
91+ else
92+ {
93+ INamedTypeSymbol ? taskSymbol = model . Compilation . GetTypeByMetadataName ( "System.Threading.Tasks.Task`1" ) ;
94+ returnSymbol = namedReturn ;
95+ if ( taskSymbol != null && SymbolEqualityComparer . Default . Equals ( returnSymbol . OriginalDefinition , taskSymbol ) )
96+ {
97+ // this is a Task<T> return value, lets pull out the generic.
98+ ITypeSymbol typeArg = returnSymbol . TypeArguments [ 0 ] ;
99+ if ( typeArg is not INamedTypeSymbol namedTypeArg )
100+ {
101+ function = null ;
102+ return false ;
103+ }
104+ returnSymbol = namedTypeArg ;
105+ }
106+ }
107+ }
108+ else
109+ {
110+ // returnTypeSymbol is not INamedTypeSymbol, which is unexpected
111+ function = null ;
112+ return false ;
68113 }
69114
70115 if ( ! SyntaxNodeUtility . TryGetParameter ( model , method , kind , out TypedParameter ? parameter ) || parameter == null )
@@ -79,12 +124,18 @@ public static bool TryParse(SemanticModel model, MethodDeclarationSyntax method,
79124 return false ;
80125 }
81126
127+ // Build list of types used for namespace resolution
82128 List < INamedTypeSymbol > usedTypes = new ( )
83129 {
84- returnSymbol ,
85130 parameter . Type
86131 } ;
87132
133+ // Only include return type if it's not void
134+ if ( returnSymbol != null )
135+ {
136+ usedTypes . Add ( returnSymbol ) ;
137+ }
138+
88139 if ( ! SyntaxNodeUtility . TryGetRequiredNamespaces ( usedTypes , out HashSet < string > ? requiredNamespaces ) )
89140 {
90141 function = null ;
@@ -93,7 +144,7 @@ public static bool TryParse(SemanticModel model, MethodDeclarationSyntax method,
93144
94145 requiredNamespaces ! . UnionWith ( GetRequiredGlobalNamespaces ( ) ) ;
95146
96- function = new DurableFunction ( fullTypeName ! , name , kind , parameter , returnSymbol , requiredNamespaces ) ;
147+ function = new DurableFunction ( fullTypeName ! , name , kind , parameter , returnSymbol , returnsVoid , requiredNamespaces ) ;
97148 return true ;
98149 }
99150
0 commit comments