Skip to content
203 changes: 194 additions & 9 deletions src/Analyzers/Activities/FunctionNotFoundAnalyzer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,22 @@ public sealed class FunctionNotFoundAnalyzer : DiagnosticAnalyzer
/// </summary>
public const string SubOrchestrationNotFoundDiagnosticId = "DURABLE2004";

// System assemblies to skip when scanning referenced assemblies for performance
static readonly HashSet<string> SystemAssemblyNames =
[
"mscorlib",
"System",
"netstandard"
];

static readonly string[] SystemAssemblyPrefixes =
[
"System.",
"Microsoft.CodeAnalysis",
"Microsoft.CSharp",
"Microsoft.VisualBasic"
];

static readonly LocalizableString ActivityNotFoundTitle = new LocalizableResourceString(nameof(Resources.ActivityNotFoundAnalyzerTitle), Resources.ResourceManager, typeof(Resources));
static readonly LocalizableString ActivityNotFoundMessageFormat = new LocalizableResourceString(nameof(Resources.ActivityNotFoundAnalyzerMessageFormat), Resources.ResourceManager, typeof(Resources));

Expand Down Expand Up @@ -129,19 +145,13 @@ public override void Initialize(AnalysisContext context)
}

// Check for Activity defined via [ActivityTrigger]
if (knownSymbols.ActivityTriggerAttribute != null &&
methodSymbol.ContainsAttributeInAnyMethodArguments(knownSymbols.ActivityTriggerAttribute) &&
knownSymbols.FunctionNameAttribute != null &&
methodSymbol.TryGetSingleValueFromAttribute(knownSymbols.FunctionNameAttribute, out string functionName))
if (IsActivityMethod(methodSymbol, knownSymbols, out string functionName))
{
activityNames.Add(functionName);
}

// Check for Orchestrator defined via [OrchestrationTrigger]
if (knownSymbols.FunctionOrchestrationAttribute != null &&
methodSymbol.ContainsAttributeInAnyMethodArguments(knownSymbols.FunctionOrchestrationAttribute) &&
knownSymbols.FunctionNameAttribute != null &&
methodSymbol.TryGetSingleValueFromAttribute(knownSymbols.FunctionNameAttribute, out string orchestratorFunctionName))
if (IsOrchestratorMethod(methodSymbol, knownSymbols, out string orchestratorFunctionName))
{
orchestratorNames.Add(orchestratorFunctionName);
}
Expand Down Expand Up @@ -173,7 +183,7 @@ public override void Initialize(AnalysisContext context)

// Check for ITaskOrchestrator implementations (class-based orchestrators)
if (knownSymbols.TaskOrchestratorInterface != null &&
classSymbol.AllInterfaces.Any(i => SymbolEqualityComparer.Default.Equals(i, knownSymbols.TaskOrchestratorInterface)))
ImplementsInterface(classSymbol, knownSymbols.TaskOrchestratorInterface))
{
orchestratorNames.Add(classSymbol.Name);
}
Expand Down Expand Up @@ -222,6 +232,17 @@ public override void Initialize(AnalysisContext context)
// At the end of the compilation, we correlate the invocations with the definitions
context.RegisterCompilationEndAction(ctx =>
{
ctx.CancellationToken.ThrowIfCancellationRequested();

// Scan referenced assemblies for activities and orchestrators
ScanReferencedAssemblies(
ctx.Compilation,
knownSymbols,
taskActivityRunAsync,
activityNames,
orchestratorNames,
ctx.CancellationToken);

// Create lookup sets for faster searching
HashSet<string> definedActivities = new(activityNames);
HashSet<string> definedOrchestrators = new(orchestratorNames);
Expand Down Expand Up @@ -270,6 +291,49 @@ public override void Initialize(AnalysisContext context)
return constant.Value?.ToString();
}

static bool IsActivityMethod(IMethodSymbol methodSymbol, KnownTypeSymbols knownSymbols, out string functionName)
{
functionName = string.Empty;

if (knownSymbols.ActivityTriggerAttribute == null ||
!methodSymbol.ContainsAttributeInAnyMethodArguments(knownSymbols.ActivityTriggerAttribute))
{
return false;
}

if (knownSymbols.FunctionNameAttribute == null ||
!methodSymbol.TryGetSingleValueFromAttribute(knownSymbols.FunctionNameAttribute, out functionName))
{
return false;
}

return true;
}

static bool IsOrchestratorMethod(IMethodSymbol methodSymbol, KnownTypeSymbols knownSymbols, out string functionName)
{
functionName = string.Empty;

if (knownSymbols.FunctionOrchestrationAttribute == null ||
!methodSymbol.ContainsAttributeInAnyMethodArguments(knownSymbols.FunctionOrchestrationAttribute))
{
return false;
}

if (knownSymbols.FunctionNameAttribute == null ||
!methodSymbol.TryGetSingleValueFromAttribute(knownSymbols.FunctionNameAttribute, out functionName))
{
return false;
}

return true;
}

static bool ImplementsInterface(INamedTypeSymbol typeSymbol, INamedTypeSymbol interfaceSymbol)
{
return typeSymbol.AllInterfaces.Any(i => SymbolEqualityComparer.Default.Equals(i, interfaceSymbol));
}

static bool ClassOverridesMethod(INamedTypeSymbol classSymbol, IMethodSymbol methodToFind)
{
INamedTypeSymbol? baseType = classSymbol;
Expand All @@ -287,6 +351,127 @@ static bool ClassOverridesMethod(INamedTypeSymbol classSymbol, IMethodSymbol met
return false;
}

static void ScanReferencedAssemblies(
Compilation compilation,
KnownTypeSymbols knownSymbols,
IMethodSymbol? taskActivityRunAsync,
ConcurrentBag<string> activityNames,
ConcurrentBag<string> orchestratorNames,
CancellationToken cancellationToken)
{
// Scan all referenced assemblies for activities and orchestrators
// Skip system assemblies for performance
foreach (MetadataReference reference in compilation.References)
{
cancellationToken.ThrowIfCancellationRequested();

if (compilation.GetAssemblyOrModuleSymbol(reference) is not IAssemblySymbol assembly)
{
continue;
}

if (IsSystemAssembly(assembly))
{
continue;
}

// Scan this assembly - if it doesn't contain Durable functions, nothing will be added
ScanNamespaceForFunctions(
assembly.GlobalNamespace,
knownSymbols,
taskActivityRunAsync,
activityNames,
orchestratorNames,
cancellationToken);
}
}

static bool IsSystemAssembly(IAssemblySymbol assembly)
{
// Skip well-known system assemblies to improve performance
string assemblyName = assembly.Name;

if (SystemAssemblyNames.Contains(assemblyName))
{
return true;
}

if (SystemAssemblyPrefixes.Any(prefix => assemblyName.StartsWith(prefix, StringComparison.Ordinal)))
{
return true;
}

return false;
}

static void ScanNamespaceForFunctions(
INamespaceSymbol namespaceSymbol,
KnownTypeSymbols knownSymbols,
IMethodSymbol? taskActivityRunAsync,
ConcurrentBag<string> activityNames,
ConcurrentBag<string> orchestratorNames,
CancellationToken cancellationToken)
{
cancellationToken.ThrowIfCancellationRequested();

// Scan types in this namespace
foreach (INamedTypeSymbol typeSymbol in namespaceSymbol.GetTypeMembers())
{
cancellationToken.ThrowIfCancellationRequested();

// Check for TaskActivity<TInput, TOutput> derived classes
if (knownSymbols.TaskActivityBase != null &&
taskActivityRunAsync != null &&
!typeSymbol.IsAbstract &&
ClassOverridesMethod(typeSymbol, taskActivityRunAsync))
{
activityNames.Add(typeSymbol.Name);
}

// Check for ITaskOrchestrator implementations (class-based orchestrators)
if (knownSymbols.TaskOrchestratorInterface != null &&
ImplementsInterface(typeSymbol, knownSymbols.TaskOrchestratorInterface))
{
orchestratorNames.Add(typeSymbol.Name);
}

// Check methods for [Function] + [ActivityTrigger] or [OrchestrationTrigger]
foreach (ISymbol member in typeSymbol.GetMembers())
{
cancellationToken.ThrowIfCancellationRequested();

if (member is not IMethodSymbol methodSymbol)
{
continue;
}

// Check for Activity defined via [ActivityTrigger]
if (IsActivityMethod(methodSymbol, knownSymbols, out string functionName))
{
activityNames.Add(functionName);
}

// Check for Orchestrator defined via [OrchestrationTrigger]
if (IsOrchestratorMethod(methodSymbol, knownSymbols, out string orchestratorFunctionName))
{
orchestratorNames.Add(orchestratorFunctionName);
}
}
}

// Recursively scan nested namespaces
foreach (INamespaceSymbol nestedNamespace in namespaceSymbol.GetNamespaceMembers())
{
ScanNamespaceForFunctions(
nestedNamespace,
knownSymbols,
taskActivityRunAsync,
activityNames,
orchestratorNames,
cancellationToken);
}
}

readonly struct FunctionInvocation(string name, SyntaxNode invocationSyntaxNode)
{
public string Name { get; } = name;
Expand Down
97 changes: 97 additions & 0 deletions test/Analyzers.Tests/Activities/FunctionNotFoundAnalyzerTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,103 @@ async Task ExistingOrchestrator([OrchestrationTrigger] TaskOrchestrationContext
await VerifyCS.VerifyDurableTaskAnalyzerAsync(code, expected);
}

[Fact]
public async Task ActivityInvocationWithActivityDefinedInReferencedAssembly_NoDiagnostic()
{
// Arrange - Orchestrator in main project
string orchestratorCode = Wrapper.WrapDurableFunctionOrchestration(@"
async Task Method(TaskOrchestrationContext context)
{
await context.CallActivityAsync(""SayHello"", ""Tokyo"");
}
");

// Activity in a separate source file (simulates cross-assembly scenario)
string activityCode = @"
using Microsoft.Azure.Functions.Worker;
using Microsoft.DurableTask;

class ActivityFunctions
{
[Function(""SayHello"")]
void SayHello([ActivityTrigger] string name)
{
}
}
";

void configureTest(VerifyCS.Test test) => test.TestState.Sources.Add(activityCode);

// Act & Assert
await VerifyCS.VerifyDurableTaskAnalyzerAsync(orchestratorCode, configureTest);
}

[Fact]
public async Task SubOrchestrationInvocationWithOrchestratorDefinedInReferencedAssembly_NoDiagnostic()
{
// Arrange - Parent orchestrator in main project
string parentOrchestratorCode = Wrapper.WrapDurableFunctionOrchestration(@"
async Task Method(TaskOrchestrationContext context)
{
await context.CallSubOrchestratorAsync(""ChildOrchestration"", ""input"");
}
");

// Child orchestrator in a separate source file (simulates cross-assembly scenario)
string childOrchestratorCode = @"
using System.Threading.Tasks;
using Microsoft.Azure.Functions.Worker;
using Microsoft.DurableTask;

class ChildOrchestrators
{
[Function(""ChildOrchestration"")]
async Task ChildOrchestration([OrchestrationTrigger] TaskOrchestrationContext context)
{
await Task.CompletedTask;
}
}
";

void configureTest(VerifyCS.Test test) => test.TestState.Sources.Add(childOrchestratorCode);

// Act & Assert
await VerifyCS.VerifyDurableTaskAnalyzerAsync(parentOrchestratorCode, configureTest);
}

[Fact]
public async Task ClassBasedActivityInvocationWithActivityDefinedInReferencedAssembly_NoDiagnostic()
{
// Arrange - Orchestrator in main project
string orchestratorCode = Wrapper.WrapTaskOrchestrator(@"
public class Caller {
async Task Method(TaskOrchestrationContext context)
{
await context.CallActivityAsync<string>(nameof(MyActivity), ""Tokyo"");
}
}
");

// Class-based activity in a separate source file (simulates cross-assembly scenario)
string activityCode = @"
using System.Threading.Tasks;
using Microsoft.DurableTask;

public class MyActivity : TaskActivity<string, string>
{
public override Task<string> RunAsync(TaskActivityContext context, string cityName)
{
return Task.FromResult(cityName);
}
}
";

void configureTest(VerifyCS.Test test) => test.TestState.Sources.Add(activityCode);

// Act & Assert
await VerifyCS.VerifyDurableTaskAnalyzerAsync(orchestratorCode, configureTest);
}

static DiagnosticResult BuildActivityNotFoundDiagnostic()
{
return VerifyCS.Diagnostic(FunctionNotFoundAnalyzer.ActivityNotFoundDiagnosticId);
Expand Down
Loading