Skip to content

Commit 3d4e294

Browse files
authored
Improve type inference for $_ (PowerShell#17716)
1 parent e8c66a3 commit 3d4e294

File tree

2 files changed

+156
-52
lines changed

2 files changed

+156
-52
lines changed

src/System.Management.Automation/engine/parser/TypeInferenceVisitor.cs

Lines changed: 71 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1840,82 +1840,101 @@ private void InferTypeFrom(VariableExpressionAst variableExpressionAst, List<PST
18401840
(SpecialVariables.IsUnderbar(astVariablePath.UserPath)
18411841
|| astVariablePath.UserPath.EqualsOrdinalIgnoreCase(SpecialVariables.PSItem)))
18421842
{
1843-
// $_ is special, see if we're used in a script block in some pipeline.
1844-
while (parent != null)
1845-
{
1846-
if (parent is ScriptBlockExpressionAst || parent is CatchClauseAst)
1843+
// The automatic variable $_ is assigned a value in scriptblocks, Switch loops and Catch/Trap statements
1844+
// This loop will find whichever Ast that determines the value of $_
1845+
// The value in scriptblocks is determined by the parents of that scriptblock, the only interesting scenarios are:
1846+
// 1: MemberInvocation like: $Collection.Where({$_})
1847+
// 2: Command pipelines like: dir | where {$_}
1848+
// The value in a Switch loop is whichever item is in the condition part of the statement.
1849+
// The value in Catch/Trap statements is always an error record.
1850+
bool hasSeenScriptBlock = false;
1851+
while (parent is not null)
1852+
{
1853+
if (parent is CatchClauseAst or TrapStatementAst)
18471854
{
18481855
break;
18491856
}
1850-
1851-
parent = parent.Parent;
1852-
}
1853-
1854-
if (parent != null)
1855-
{
1856-
if (parent.Parent is CommandExpressionAst && parent.Parent.Parent is PipelineAst)
1857+
else if (parent is SwitchStatementAst switchStatement)
18571858
{
1858-
// Script block in a hash table, could be something like:
1859-
// dir | ft @{ Expression = { $_ } }
1860-
1861-
if (parent.Parent.Parent.Parent is HashtableAst)
1862-
{
1863-
parent = parent.Parent.Parent.Parent;
1864-
}
1865-
else if (parent.Parent.Parent.Parent is ArrayLiteralAst && parent.Parent.Parent.Parent.Parent is HashtableAst)
1859+
parent = switchStatement.Condition;
1860+
break;
1861+
}
1862+
else if (parent is ErrorStatementAst switchErrorStatement && switchErrorStatement.Kind?.Kind == TokenKind.Switch)
1863+
{
1864+
if (switchErrorStatement.Conditions?.Count > 0)
18661865
{
1867-
parent = parent.Parent.Parent.Parent.Parent;
1866+
parent = switchErrorStatement.Conditions[0];
18681867
}
1868+
break;
18691869
}
1870-
1871-
if (parent.Parent is CommandParameterAst)
1870+
else if (parent is ScriptBlockExpressionAst)
18721871
{
1873-
parent = parent.Parent;
1872+
hasSeenScriptBlock = true;
18741873
}
1875-
1876-
if (parent is CatchClauseAst catchBlock)
1874+
else if (hasSeenScriptBlock)
18771875
{
1878-
if (catchBlock.CatchTypes.Count > 0)
1876+
if (parent is InvokeMemberExpressionAst invokeMember)
18791877
{
1880-
foreach (TypeConstraintAst catchType in catchBlock.CatchTypes)
1878+
parent = invokeMember.Expression;
1879+
break;
1880+
}
1881+
else if (parent is CommandAst cmdAst && cmdAst.Parent is PipelineAst pipeline && pipeline.PipelineElements.Count > 1)
1882+
{
1883+
// We've found a pipeline with multiple commands, now we need to determine what command came before the command with the scriptblock:
1884+
// eg Get-Partition in this example: Get-Disk | Get-Partition | Where {$_}
1885+
var indexOfPreviousCommand = pipeline.PipelineElements.IndexOf(cmdAst) - 1;
1886+
if (indexOfPreviousCommand >= 0)
18811887
{
1882-
Type exceptionType = catchType.TypeName.GetReflectionType();
1883-
if (exceptionType != null && typeof(Exception).IsAssignableFrom(exceptionType))
1884-
{
1885-
inferredTypes.Add(new PSTypeName(typeof(ErrorRecord<>).MakeGenericType(exceptionType)));
1886-
}
1888+
parent = pipeline.PipelineElements[indexOfPreviousCommand];
1889+
break;
18871890
}
18881891
}
1889-
else
1892+
}
1893+
1894+
parent = parent.Parent;
1895+
}
1896+
1897+
if (parent is CatchClauseAst catchBlock)
1898+
{
1899+
if (catchBlock.CatchTypes.Count > 0)
1900+
{
1901+
foreach (TypeConstraintAst catchType in catchBlock.CatchTypes)
18901902
{
1891-
inferredTypes.Add(new PSTypeName(typeof(ErrorRecord)));
1903+
Type exceptionType = catchType.TypeName.GetReflectionType();
1904+
if (typeof(Exception).IsAssignableFrom(exceptionType))
1905+
{
1906+
inferredTypes.Add(new PSTypeName(typeof(ErrorRecord<>).MakeGenericType(exceptionType)));
1907+
}
18921908
}
1893-
1894-
return;
18951909
}
18961910

1897-
if (parent.Parent is CommandAst commandAst)
1911+
// Either no type constraint was specified, or all the specified catch types were unavailable but we still know it's an error record.
1912+
if (inferredTypes.Count == 0)
1913+
{
1914+
inferredTypes.Add(new PSTypeName(typeof(ErrorRecord)));
1915+
}
1916+
}
1917+
else if (parent is TrapStatementAst trap)
1918+
{
1919+
if (trap.TrapType is not null)
18981920
{
1899-
// We found a command, see if there is a previous command in the pipeline.
1900-
PipelineAst pipelineAst = (PipelineAst)commandAst.Parent;
1901-
var previousCommandIndex = pipelineAst.PipelineElements.IndexOf(commandAst) - 1;
1902-
if (previousCommandIndex < 0)
1921+
Type exceptionType = trap.TrapType.TypeName.GetReflectionType();
1922+
if (typeof(Exception).IsAssignableFrom(exceptionType))
19031923
{
1904-
return;
1924+
inferredTypes.Add(new PSTypeName(typeof(ErrorRecord<>).MakeGenericType(exceptionType)));
19051925
}
1906-
1907-
AddInferredTypesForDollarUnderbar(pipelineAst.PipelineElements[0], inferredTypes);
1908-
1909-
return;
19101926
}
1911-
1912-
if (parent.Parent is InvokeMemberExpressionAst memberExpression)
1927+
if (inferredTypes.Count == 0)
19131928
{
1914-
AddInferredTypesForDollarUnderbar(memberExpression.Expression, inferredTypes);
1915-
1916-
return;
1929+
inferredTypes.Add(new PSTypeName(typeof(ErrorRecord)));
19171930
}
19181931
}
1932+
else if (parent is not null)
1933+
{
1934+
AddInferredTypesForDollarUnderbar(parent, inferredTypes);
1935+
}
1936+
1937+
return;
19191938
}
19201939

19211940
// For certain variables, we always know their type, well at least we can assume we know.
@@ -2072,7 +2091,7 @@ private void AddInferredTypesForDollarUnderbar(Ast parentExpression, List<PSType
20722091
continue;
20732092
}
20742093

2075-
if (typeof(IEnumerable).IsAssignableFrom(result.Type))
2094+
if (result.Type != typeof(string) && typeof(IEnumerable).IsAssignableFrom(result.Type))
20762095
{
20772096
// We can't deduce much from IEnumerable, but we can if it's generic.
20782097
var enumerableInterfaces = result.Type.GetInterfaces();

test/powershell/engine/Api/TypeInference.Tests.ps1

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1079,6 +1079,42 @@ Describe "Type inference Tests" -tags "CI" {
10791079
$res.Name | Should -Be System.Exception
10801080
}
10811081

1082+
It 'Infers type of variable $_ in pipeline with more than one element' {
1083+
$memberAst = { Get-Date | New-Guid | Select-Object -Property {$_} }.Ast.Find({ param($a) $a -is [System.Management.Automation.Language.VariableExpressionAst] }, $true)
1084+
$res = [AstTypeInference]::InferTypeOf($memberAst)
1085+
1086+
$res | Should -HaveCount 1
1087+
$res.Name | Should -Be System.Guid
1088+
}
1089+
It 'Infers type of variable $_ in array of calculated properties' {
1090+
$variableAst = { New-TimeSpan | Select-Object -Property Day,@{n="min";e={$_}} }.Ast.Find({ param($a) $a -is [System.Management.Automation.Language.VariableExpressionAst] }, $true)
1091+
$res = [AstTypeInference]::InferTypeOf($variableAst)
1092+
1093+
$res | Should -HaveCount 1
1094+
$res.Name | Should -Be System.TimeSpan
1095+
}
1096+
1097+
It 'Infers type of variable $_ in switch statement' {
1098+
$variableAst = {
1099+
switch ("Hello","World")
1100+
{
1101+
'Hello'
1102+
{
1103+
$_
1104+
}
1105+
} }.Ast.Find({ param($a) $a -is [System.Management.Automation.Language.VariableExpressionAst] }, $true)
1106+
$res = [AstTypeInference]::InferTypeOf($variableAst)
1107+
1108+
$res | Should -HaveCount 1
1109+
$res.Name | Should -Be System.String
1110+
}
1111+
1112+
It 'Does not infer string in pipeline as char' {
1113+
$variableAst = { "Hello" | Select-Object -Property @{n="min";e={$_}} }.Ast.Find({ param($a) $a -is [System.Management.Automation.Language.VariableExpressionAst] }, $true)
1114+
$res = [AstTypeInference]::InferTypeOf($variableAst)
1115+
$res.Name | Should -Be System.String
1116+
}
1117+
10821118
$catchClauseTypes = @(
10831119
@{ Type = 'System.ArgumentException' }
10841120
@{ Type = 'System.ArgumentNullException' }
@@ -1147,6 +1183,55 @@ Describe "Type inference Tests" -tags "CI" {
11471183
$res[1].Name | Should -Be System.Exception
11481184
}
11491185

1186+
It 'falls back to a generic ErrorRecord if catch exception type is invalid' {
1187+
$VariableAst = {
1188+
try {}
1189+
catch [ThisTypeDoesNotExist] { $_ }
1190+
}.Ast.Find(
1191+
{ param($a) $a -is [System.Management.Automation.Language.VariableExpressionAst] },
1192+
$true
1193+
)
1194+
$res = [AstTypeInference]::InferTypeOf($VariableAst)
1195+
1196+
$res.Name | Should -Be System.Management.Automation.ErrorRecord
1197+
}
1198+
1199+
It 'Infers type of trap statement' {
1200+
$VariableAst = {
1201+
trap { $_ }
1202+
}.Ast.Find(
1203+
{ param($a) $a -is [System.Management.Automation.Language.VariableExpressionAst] },
1204+
$true
1205+
)
1206+
$res = [AstTypeInference]::InferTypeOf($VariableAst)
1207+
1208+
$res.Name | Should -Be System.Management.Automation.ErrorRecord
1209+
}
1210+
1211+
It 'Infers type of exception in typed trap statement' {
1212+
$memberAst = {
1213+
trap [System.DivideByZeroException] { $_.Exception }
1214+
}.Ast.Find(
1215+
{ param($a) $a -is [System.Management.Automation.Language.MemberExpressionAst] },
1216+
$true
1217+
)
1218+
$res = [AstTypeInference]::InferTypeOf($memberAst)
1219+
1220+
$res.Name | Should -Be System.DivideByZeroException
1221+
}
1222+
1223+
It 'falls back to a generic ErrorRecord if trap exception type is invalid' {
1224+
$VariableAst = {
1225+
trap [ThisTypeDoesNotExist] { $_ }
1226+
}.Ast.Find(
1227+
{ param($a) $a -is [System.Management.Automation.Language.VariableExpressionAst] },
1228+
$true
1229+
)
1230+
$res = [AstTypeInference]::InferTypeOf($VariableAst)
1231+
1232+
$res.Name | Should -Be System.Management.Automation.ErrorRecord
1233+
}
1234+
11501235
It 'Infers type of function member' {
11511236
$res = [AstTypeInference]::InferTypeOf( {
11521237
class X {

0 commit comments

Comments
 (0)