@@ -431,20 +431,36 @@ public object VisitForEachStatement(ForEachStatementAst forEachStatementAst)
431431 {
432432 using ( _loops . NewScope ( ) )
433433 {
434- var enumerator = Call (
435- ReflectionCache . LanguagePrimitives_GetEnumerator ,
436- forEachStatementAst . Condition . Compile ( this ) ) ;
434+ var condition = forEachStatementAst . Condition . Compile ( this ) ;
435+ var canEnumerate = TryGetEnumeratorMethod (
436+ condition . Type ,
437+ out MethodInfo getEnumeratorMethod ,
438+ out MethodInfo getCurrentMethod ) ;
439+
440+ if ( ! canEnumerate )
441+ {
442+ Errors . ReportParseError (
443+ forEachStatementAst . Condition . Extent ,
444+ nameof ( ErrorStrings . ForEachInvalidEnumerable ) ,
445+ string . Format (
446+ CultureInfo . CurrentCulture ,
447+ ErrorStrings . ForEachInvalidEnumerable ,
448+ condition . Type ) ) ;
449+
450+ Errors . ThrowIfAnyErrors ( ) ;
451+ return Empty ( ) ;
452+ }
437453
438454 using ( _scopeStack . NewScope ( ) )
439455 {
440456 var enumeratorRef = _scopeStack . GetVariable (
441457 Strings . ForEachVariableName ,
442- typeof ( IEnumerator ) ) ;
458+ getEnumeratorMethod . ReturnType ) ;
443459 try
444460 {
445461 return Block (
446462 _scopeStack . GetVariables ( ) ,
447- Assign ( enumeratorRef , enumerator ) ,
463+ Assign ( enumeratorRef , Call ( condition , getEnumeratorMethod ) ) ,
448464 Loop (
449465 IfThenElse (
450466 test : Call ( enumeratorRef , ReflectionCache . IEnumerator_MoveNext ) ,
@@ -454,8 +470,8 @@ public object VisitForEachStatement(ForEachStatementAst forEachStatementAst)
454470 Assign (
455471 _scopeStack . GetVariable (
456472 forEachStatementAst . Variable . VariablePath . UserPath ,
457- typeof ( object ) ) ,
458- Property ( enumeratorRef , ReflectionCache . IEnumerator_Current ) ) ,
473+ getCurrentMethod . ReturnType ) ,
474+ Call ( enumeratorRef , getCurrentMethod ) ) ,
459475 forEachStatementAst . Body . Compile ( this )
460476 } ) ,
461477 ifFalse : Break ( _loops . Break ) ) ,
@@ -1906,5 +1922,61 @@ private MemberBinder GetBinder(Ast ast)
19061922 _binder = new MemberBinder ( BindingFlags . Public , namespaces . ToArray ( ) ) ;
19071923 return _binder ;
19081924 }
1925+
1926+ private bool TryGetEnumeratorMethod (
1927+ Type type ,
1928+ out MethodInfo getEnumeratorMethod ,
1929+ out MethodInfo getCurrentMethod )
1930+ {
1931+ var canFallbackToEnumerable = false ;
1932+ var canFallbackToIDictionary = false ;
1933+ var interfaces = type . GetInterfaces ( ) ;
1934+ for ( var i = 0 ; i < interfaces . Length ; i ++ )
1935+ {
1936+ if ( interfaces [ i ] == typeof ( IEnumerable ) )
1937+ {
1938+ canFallbackToEnumerable = true ;
1939+ continue ;
1940+ }
1941+
1942+ if ( interfaces [ i ] == typeof ( IDictionary ) )
1943+ {
1944+ canFallbackToIDictionary = true ;
1945+ continue ;
1946+ }
1947+
1948+ if ( interfaces [ i ] . IsConstructedGenericType &&
1949+ interfaces [ i ] . GetGenericTypeDefinition ( ) == typeof ( IEnumerable < > ) )
1950+ {
1951+ getEnumeratorMethod = interfaces [ i ] . GetMethod (
1952+ Strings . GetEnumeratorMethodName ,
1953+ Type . EmptyTypes ) ;
1954+
1955+ getCurrentMethod =
1956+ getEnumeratorMethod . ReturnType . GetMethod (
1957+ Strings . EnumeratorGetCurrentMethodName ,
1958+ Type . EmptyTypes ) ;
1959+ return true ;
1960+ }
1961+ }
1962+
1963+ if ( canFallbackToIDictionary )
1964+ {
1965+ getEnumeratorMethod = ReflectionCache . IDictionary_GetEnumerator ;
1966+ getCurrentMethod = ReflectionCache . IDictionaryEnumerator_get_Entry ;
1967+ return true ;
1968+ }
1969+
1970+ if ( canFallbackToEnumerable )
1971+ {
1972+ getEnumeratorMethod = ReflectionCache . IEnumerable_GetEnumerator ;
1973+ getCurrentMethod = ReflectionCache . IEnumerator_get_Current ;
1974+ return true ;
1975+ }
1976+
1977+ getEnumeratorMethod = null ;
1978+ getCurrentMethod = null ;
1979+ return false ;
1980+ }
19091981 }
19101982}
0 commit comments