Skip to content

Commit ed2115c

Browse files
Fix type inference in foreach statements (#14)
Determine `GetEnumerable` and `get_Current` method by detecting `IEnumerable<>`, `IDictionary` or `IEnumerable` implementations.
1 parent f2b331b commit ed2115c

File tree

6 files changed

+202
-25
lines changed

6 files changed

+202
-25
lines changed

src/PSLambda/CompileVisitor.cs

Lines changed: 79 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -431,20 +431,36 @@ public object VisitForEachStatement(ForEachStatementAst forEachStatementAst)
431431
{
432432
using (_loops.NewScope())
433433
{
434-
var enumerator = Call(
435-
ReflectionCache.LanguagePrimitives_GetEnumerator,
436-
forEachStatementAst.Condition.Compile(this));
434+
var condition = forEachStatementAst.Condition.Compile(this);
435+
var canEnumerate = TryGetEnumeratorMethod(
436+
condition.Type,
437+
out MethodInfo getEnumeratorMethod,
438+
out MethodInfo getCurrentMethod);
439+
440+
if (!canEnumerate)
441+
{
442+
Errors.ReportParseError(
443+
forEachStatementAst.Condition.Extent,
444+
nameof(ErrorStrings.ForEachInvalidEnumerable),
445+
string.Format(
446+
CultureInfo.CurrentCulture,
447+
ErrorStrings.ForEachInvalidEnumerable,
448+
condition.Type));
449+
450+
Errors.ThrowIfAnyErrors();
451+
return Empty();
452+
}
437453

438454
using (_scopeStack.NewScope())
439455
{
440456
var enumeratorRef = _scopeStack.GetVariable(
441457
Strings.ForEachVariableName,
442-
typeof(IEnumerator));
458+
getEnumeratorMethod.ReturnType);
443459
try
444460
{
445461
return Block(
446462
_scopeStack.GetVariables(),
447-
Assign(enumeratorRef, enumerator),
463+
Assign(enumeratorRef, Call(condition, getEnumeratorMethod)),
448464
Loop(
449465
IfThenElse(
450466
test: Call(enumeratorRef, ReflectionCache.IEnumerator_MoveNext),
@@ -454,8 +470,8 @@ public object VisitForEachStatement(ForEachStatementAst forEachStatementAst)
454470
Assign(
455471
_scopeStack.GetVariable(
456472
forEachStatementAst.Variable.VariablePath.UserPath,
457-
typeof(object)),
458-
Property(enumeratorRef, ReflectionCache.IEnumerator_Current)),
473+
getCurrentMethod.ReturnType),
474+
Call(enumeratorRef, getCurrentMethod)),
459475
forEachStatementAst.Body.Compile(this)
460476
}),
461477
ifFalse: Break(_loops.Break)),
@@ -1906,5 +1922,61 @@ private MemberBinder GetBinder(Ast ast)
19061922
_binder = new MemberBinder(BindingFlags.Public, namespaces.ToArray());
19071923
return _binder;
19081924
}
1925+
1926+
private bool TryGetEnumeratorMethod(
1927+
Type type,
1928+
out MethodInfo getEnumeratorMethod,
1929+
out MethodInfo getCurrentMethod)
1930+
{
1931+
var canFallbackToEnumerable = false;
1932+
var canFallbackToIDictionary = false;
1933+
var interfaces = type.GetInterfaces();
1934+
for (var i = 0; i < interfaces.Length; i++)
1935+
{
1936+
if (interfaces[i] == typeof(IEnumerable))
1937+
{
1938+
canFallbackToEnumerable = true;
1939+
continue;
1940+
}
1941+
1942+
if (interfaces[i] == typeof(IDictionary))
1943+
{
1944+
canFallbackToIDictionary = true;
1945+
continue;
1946+
}
1947+
1948+
if (interfaces[i].IsConstructedGenericType &&
1949+
interfaces[i].GetGenericTypeDefinition() == typeof(IEnumerable<>))
1950+
{
1951+
getEnumeratorMethod = interfaces[i].GetMethod(
1952+
Strings.GetEnumeratorMethodName,
1953+
Type.EmptyTypes);
1954+
1955+
getCurrentMethod =
1956+
getEnumeratorMethod.ReturnType.GetMethod(
1957+
Strings.EnumeratorGetCurrentMethodName,
1958+
Type.EmptyTypes);
1959+
return true;
1960+
}
1961+
}
1962+
1963+
if (canFallbackToIDictionary)
1964+
{
1965+
getEnumeratorMethod = ReflectionCache.IDictionary_GetEnumerator;
1966+
getCurrentMethod = ReflectionCache.IDictionaryEnumerator_get_Entry;
1967+
return true;
1968+
}
1969+
1970+
if (canFallbackToEnumerable)
1971+
{
1972+
getEnumeratorMethod = ReflectionCache.IEnumerable_GetEnumerator;
1973+
getCurrentMethod = ReflectionCache.IEnumerator_get_Current;
1974+
return true;
1975+
}
1976+
1977+
getEnumeratorMethod = null;
1978+
getCurrentMethod = null;
1979+
return false;
1980+
}
19091981
}
19101982
}

src/PSLambda/ExpressionUtils.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ public static Expression PSConvertAllTo<T>(Expression source)
185185
collectionVar,
186186
Strings.AddMethodName,
187187
Type.EmptyTypes,
188-
PSConvertTo<T>(Property(enumeratorVar, ReflectionCache.IEnumerator_Current))),
188+
PSConvertTo<T>(Call(enumeratorVar, ReflectionCache.IEnumerator_get_Current))),
189189
Break(breakLabel)),
190190
breakLabel),
191191
Call(collectionVar, Strings.ToArrayMethodName, Type.EmptyTypes));
@@ -305,7 +305,7 @@ public static Expression PSIsIn(Expression item, Expression items, bool isCaseSe
305305
IfThen(
306306
PSEquals(
307307
item,
308-
Property(enumeratorVar, ReflectionCache.IEnumerator_Current),
308+
Call(enumeratorVar, ReflectionCache.IEnumerator_get_Current),
309309
isCaseSensitive),
310310
Return(returnLabel, SpecialVariables.Constants[Strings.TrueVariableName])),
311311
Return(returnLabel, SpecialVariables.Constants[Strings.FalseVariableName]))),

src/PSLambda/ReflectionCache.cs

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using System;
22
using System.Collections;
3+
using System.Collections.Generic;
34
using System.Linq;
45
using System.Management.Automation;
56
using System.Reflection;
@@ -189,10 +190,16 @@ internal static class ReflectionCache
189190
typeof(Hashtable).GetConstructor(new[] { typeof(int), typeof(IEqualityComparer) });
190191

191192
/// <summary>
192-
/// Resolves to <see cref="IEnumerator.Current" />
193+
/// Resolves to <see cref="IEnumerator.get_Current" />
193194
/// </summary>
194-
public static readonly PropertyInfo IEnumerator_Current =
195-
typeof(IEnumerator).GetProperty("Current");
195+
public static readonly MethodInfo IEnumerator_get_Current =
196+
typeof(IEnumerator).GetMethod(Strings.EnumeratorGetCurrentMethodName, Type.EmptyTypes);
197+
198+
/// <summary>
199+
/// Resolves to <see cref="IEnumerable.GetEnumerator" />
200+
/// </summary>
201+
public static readonly MethodInfo IEnumerable_GetEnumerator =
202+
typeof(IEnumerable).GetMethod(Strings.GetEnumeratorMethodName, Type.EmptyTypes);
196203

197204
/// <summary>
198205
/// Resolves to <see cref="StringComparer.CurrentCultureIgnoreCase" />.
@@ -225,5 +232,29 @@ internal static class ReflectionCache
225232
/// </summary>
226233
public static readonly MethodInfo Monitor_Exit =
227234
typeof(System.Threading.Monitor).GetMethod("Exit", new[] { typeof(object) });
235+
236+
/// <summary>
237+
/// Resolves to <see cref="IEnumerable{T}.GetEnumerator" />.
238+
/// </summary>
239+
public static readonly MethodInfo IEnumerable_T_GetEnumerator =
240+
typeof(IEnumerable<>).GetMethod(Strings.GetEnumeratorMethodName, Type.EmptyTypes);
241+
242+
/// <summary>
243+
/// Resolves to <see cref="IEnumerator{T}.get_Current" />.
244+
/// </summary>
245+
public static readonly MethodInfo IEnumerator_T_get_Current =
246+
typeof(IEnumerator<>).GetMethod(Strings.EnumeratorGetCurrentMethodName, Type.EmptyTypes);
247+
248+
/// <summary>
249+
/// Resolves to <see cref="IDictionary.GetEnumerator" />.
250+
/// </summary>
251+
public static readonly MethodInfo IDictionary_GetEnumerator =
252+
typeof(IDictionary).GetMethod(Strings.GetEnumeratorMethodName, Type.EmptyTypes);
253+
254+
/// <summary>
255+
/// Resolves to <see cref="IDictionaryEnumerator.get_Entry" />.
256+
/// </summary>
257+
public static readonly MethodInfo IDictionaryEnumerator_get_Entry =
258+
typeof(IDictionaryEnumerator).GetMethod("get_Entry", Type.EmptyTypes);
228259
}
229260
}

src/PSLambda/Strings.cs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,16 @@ namespace PSLambda
55
/// </summary>
66
internal class Strings
77
{
8+
/// <summary>
9+
/// Constant containing a string similar to "GetEnumerator".
10+
/// </summary>
11+
public const string GetEnumeratorMethodName = "GetEnumerator";
12+
13+
/// <summary>
14+
/// Constant containing a string similar to "get_Current".
15+
/// </summary>
16+
public const string EnumeratorGetCurrentMethodName = "get_Current";
17+
818
/// <summary>
919
/// Constant containing a string similar to "psdelegate".
1020
/// </summary>

src/PSLambda/resources/ErrorStrings.resx

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,4 +168,7 @@
168168
<data name="NoMemberArgumentMatch" xml:space="preserve">
169169
<value>'{0}' does not contain a definition for a method named '{1}' that takes the specified arguments.</value>
170170
</data>
171+
<data name="ForEachInvalidEnumerable" xml:space="preserve">
172+
<value>The foreach statement cannot operate on variables of type '{0}' because '{0}' does not contain a public definition for 'GetEnumerator'</value>
173+
</data>
171174
</root>

test/Loops.Tests.ps1

Lines changed: 74 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,32 +4,93 @@ $manifestPath = "$PSScriptRoot\..\Release\$moduleName\*\$moduleName.psd1"
44
Import-Module $manifestPath -Force
55

66
Describe 'basic loop functionality' {
7-
It 'for statement' {
8-
$delegate = New-PSDelegate {
9-
[int] $total = 0
10-
for ([int] $i = 0; $i -lt 10; $i++) {
11-
$total = $i + $total
7+
Context 'foreach statement tests' {
8+
It 'can enumerate IEnumerable<>' {
9+
$delegate = New-PSDelegate {
10+
$total = 0
11+
foreach($item in 0..10) {
12+
$total = $item + $total
13+
}
14+
15+
return $total
1216
}
1317

14-
return $total
18+
$delegate.Invoke() | Should -Be 55
1519
}
1620

17-
$delegate.Invoke() | Should -Be 45
21+
It 'can enumerate IDictionary' {
22+
$delegate = New-PSDelegate {
23+
$hashtable = @{
24+
one = 'two'
25+
three = 'four'
26+
}
27+
28+
$sb = [System.Text.StringBuilder]::new()
29+
foreach($item in $hashtable) {
30+
$sb.Append($item.Value.ToString())
31+
}
32+
33+
return $sb.ToString()
34+
}
35+
36+
$delegate.Invoke() | Should -Be twofour
37+
}
38+
39+
It 'prioritizes IEnumerable<> over IDictionary' {
40+
$delegate = New-PSDelegate {
41+
$map = [System.Collections.Generic.Dictionary[string, int]]::new()
42+
$map.Add('test', 10)
43+
$map.Add('test2', 30)
44+
45+
$results = [System.Collections.Generic.List[int]]::new()
46+
foreach ($item in $map) {
47+
$results.Add($item.Value)
48+
}
49+
50+
return $results
51+
}
52+
53+
$delegate.Invoke() | Should -Be 10, 30
54+
}
55+
56+
It 'can enumerable IEnumerable' {
57+
$delegate = New-PSDelegate {
58+
$list = [System.Collections.ArrayList]::new()
59+
$list.Add([object]10)
60+
$list.Add('test2')
61+
62+
$results = [System.Collections.Generic.List[string]]::new()
63+
foreach ($item in $list) {
64+
$results.Add($item.ToString())
65+
}
66+
67+
return $results
68+
}
69+
70+
$delegate.Invoke() | Should -Be 10, test2
71+
}
72+
73+
It 'throws the correct message when target is not IEnumerable' {
74+
$expectedMsg =
75+
"The foreach statement cannot operate on variables of type " +
76+
"'System.Int32' because 'System.Int32' does not contain a " +
77+
"public definition for 'GetEnumerator'"
78+
79+
{ New-PSDelegate { foreach ($a in 10) {}}} | Should -Throw $expectedMsg
80+
}
1881
}
1982

20-
It 'foreach statement' {
83+
It 'for statement' {
2184
$delegate = New-PSDelegate {
22-
[int[]] $numbers = 1, 2, 3, 4
2385
[int] $total = 0
24-
25-
foreach($item in $numbers) {
26-
$total = [int]$item + [int]$total
86+
for ([int] $i = 0; $i -lt 10; $i++) {
87+
$total = $i + $total
2788
}
2889

2990
return $total
3091
}
3192

32-
$delegate.Invoke() | Should -Be 10
93+
$delegate.Invoke() | Should -Be 45
3394
}
3495

3596
It 'while statement' {

0 commit comments

Comments
 (0)