Skip to content

Commit 494b0d9

Browse files
committed
Filter symbol references to type matches
1 parent 7471401 commit 494b0d9

File tree

4 files changed

+32
-58
lines changed

4 files changed

+32
-58
lines changed

src/PowerShellEditorServices/Services/Symbols/ReferenceTable.cs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,13 @@ public void TagAsChanged()
4242
/// </summary>
4343
private bool IsInitialized => !_symbolReferences.IsEmpty || _isInited;
4444

45-
internal bool TryGetReferences(string command, out ConcurrentBag<SymbolReference>? references)
45+
internal IEnumerable<SymbolReference> TryGetReferences(SymbolReference? symbol)
4646
{
4747
EnsureInitialized();
48-
return _symbolReferences.TryGetValue(command, out references);
48+
return symbol is not null
49+
&& _symbolReferences.TryGetValue(symbol.SymbolName, out ConcurrentBag<SymbolReference>? bag)
50+
? bag.Where(i => SymbolTypeUtils.SymbolTypeMatches(symbol.SymbolType, i.SymbolType))
51+
: Enumerable.Empty<SymbolReference>();
4952
}
5053

5154
internal SymbolReference? TryGetSymbolAtPosition(int line, int column) => GetAllReferences()

src/PowerShellEditorServices/Services/Symbols/SymbolType.cs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,5 +99,15 @@ internal static SymbolKind GetSymbolKind(SymbolType symbolType)
9999
_ => SymbolKind.Variable,
100100
};
101101
}
102+
103+
// Provides a partial equivalence between type constraints and custom types.
104+
internal static bool SymbolTypeMatches(SymbolType left, SymbolType right)
105+
{
106+
return left == right
107+
|| (left is SymbolType.Class or SymbolType.Enum or SymbolType.Type
108+
&& right is SymbolType.Class or SymbolType.Enum or SymbolType.Type)
109+
|| (left is SymbolType.EnumMember or SymbolType.Property
110+
&& right is SymbolType.EnumMember or SymbolType.Property);
111+
}
102112
}
103113
}

src/PowerShellEditorServices/Services/Symbols/SymbolsService.cs

Lines changed: 14 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -188,14 +188,7 @@ public async Task<IEnumerable<SymbolReference>> ScanForReferencesOfSymbolAsync(
188188
{
189189
await Task.Yield();
190190
cancellationToken.ThrowIfCancellationRequested();
191-
192-
_ = file.References.TryGetReferences(targetIdentifier, out ConcurrentBag<SymbolReference>? references);
193-
if (references is null)
194-
{
195-
continue;
196-
}
197-
198-
symbols.AddRange(references);
191+
symbols.AddRange(file.References.TryGetReferences(symbol with { SymbolName = targetIdentifier }));
199192
}
200193
}
201194

@@ -205,19 +198,10 @@ public async Task<IEnumerable<SymbolReference>> ScanForReferencesOfSymbolAsync(
205198
/// <summary>
206199
/// Finds all the occurrences of a symbol in the script given a file location.
207200
/// </summary>
208-
public static IEnumerable<SymbolReference>? FindOccurrencesInFile(
209-
ScriptFile scriptFile, int line, int column)
210-
{
211-
SymbolReference? symbol = FindSymbolAtLocation(scriptFile, line, column);
212-
213-
if (symbol is null)
214-
{
215-
return null;
216-
}
217-
218-
scriptFile.References.TryGetReferences(symbol.SymbolName, out ConcurrentBag<SymbolReference>? references);
219-
return references;
220-
}
201+
public static IEnumerable<SymbolReference> FindOccurrencesInFile(
202+
ScriptFile scriptFile, int line, int column) => scriptFile
203+
.References
204+
.TryGetReferences(FindSymbolAtLocation(scriptFile, line, column));
221205

222206
/// <summary>
223207
/// Finds the symbol at the location and returns it if it's a declaration.
@@ -236,15 +220,9 @@ public async Task<IEnumerable<SymbolReference>> ScanForReferencesOfSymbolAsync(
236220
ScriptFile scriptFile, int line, int column)
237221
{
238222
SymbolReference? symbol = FindSymbolAtLocation(scriptFile, line, column);
239-
if (symbol is null)
240-
{
241-
return Task.FromResult<SymbolDetails?>(null);
242-
}
243-
244-
return SymbolDetails.CreateAsync(
245-
symbol,
246-
_runspaceContext.CurrentRunspace,
247-
_executionService);
223+
return symbol is null
224+
? Task.FromResult<SymbolDetails?>(null)
225+
: SymbolDetails.CreateAsync(symbol, _runspaceContext.CurrentRunspace, _executionService);
248226
}
249227

250228
/// <summary>
@@ -310,34 +288,18 @@ public async Task<IEnumerable<SymbolReference>> GetDefinitionOfSymbolAsync(
310288
CancellationToken cancellationToken = default)
311289
{
312290
List<SymbolReference> declarations = new();
313-
_ = scriptFile.References.TryGetReferences(symbol.SymbolName, out ConcurrentBag<SymbolReference>? symbols);
314-
if (symbols is not null)
315-
{
316-
foreach (SymbolReference foundReference in symbols)
317-
{
318-
if (foundReference.IsDeclaration)
319-
{
320-
_logger.LogDebug($"Found possible declaration in same file ${foundReference}");
321-
declarations.Add(foundReference);
322-
}
323-
}
324-
}
325-
291+
declarations.AddRange(scriptFile.References.TryGetReferences(symbol).Where(i => i.IsDeclaration));
326292
if (declarations.Any())
327293
{
294+
_logger.LogDebug($"Found possible declaration in same file ${declarations}");
328295
return declarations;
329296
}
330297

331-
foreach (SymbolReference foundReference in await ScanForReferencesOfSymbolAsync(
332-
symbol, cancellationToken).ConfigureAwait(false))
333-
{
334-
if (foundReference.IsDeclaration)
335-
{
336-
_logger.LogDebug($"Found possible declaration in workspace ${foundReference}");
337-
declarations.Add(foundReference);
338-
}
339-
}
298+
IEnumerable<SymbolReference> references =
299+
await ScanForReferencesOfSymbolAsync(symbol, cancellationToken).ConfigureAwait(false);
300+
declarations.AddRange(references.Where(i => i.IsDeclaration));
340301

302+
_logger.LogDebug($"Found possible declaration in workspace ${declarations}");
341303
return declarations;
342304
}
343305

test/PowerShellEditorServices.Test/Services/Symbols/AstOperationsTests.cs

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
// Copyright (c) Microsoft Corporation.
22
// Licensed under the MIT License.
33

4-
using System.Collections.Concurrent;
4+
using System.Collections.Generic;
55
using System.Linq;
66
using Microsoft.Extensions.Logging.Abstractions;
77
using Microsoft.PowerShell.EditorServices.Services;
@@ -46,9 +46,8 @@ public void CanFindReferencesOfSymbolAtPosition(int line, int column, Range[] sy
4646
{
4747
SymbolReference symbol = scriptFile.References.TryGetSymbolAtPosition(line, column);
4848

49-
Assert.True(scriptFile.References.TryGetReferences(
50-
symbol.SymbolName,
51-
out ConcurrentBag<SymbolReference> references));
49+
IEnumerable<SymbolReference> references = scriptFile.References.TryGetReferences(symbol);
50+
Assert.NotEmpty(references);
5251

5352
int positionsIndex = 0;
5453
foreach (SymbolReference reference in references.OrderBy((i) => i.ScriptRegion.ToRange().Start))

0 commit comments

Comments
 (0)