@@ -12,7 +12,11 @@ namespace AsyncGenerator.Core.Extensions.Internal
1212{
1313 internal static class SymbolExtensions
1414 {
15- private static readonly Func < IMethodSymbol , IEnumerable > GetHiddenMembersFunc ;
15+ private static readonly Func < object , IEnumerable > GetHiddenMembersFunc ;
16+ #if ! LEGACY
17+ private static readonly Func < object , ISymbol > GetPublicSymbolFunc ;
18+ private static readonly Func < ISymbol , object > GetInternalSymbolFunc ;
19+ #endif
1620
1721 static SymbolExtensions ( )
1822 {
@@ -22,6 +26,7 @@ static SymbolExtensions()
2226 {
2327 throw new InvalidOperationException ( $ "Type { methodSymbolFullName } does not exist") ;
2428 }
29+
2530 var overriddenOrHiddenMembersGetter = type . GetProperty ( "OverriddenOrHiddenMembers" , BindingFlags . NonPublic | BindingFlags . Instance ) ? . GetMethod ;
2631 if ( overriddenOrHiddenMembersGetter == null )
2732 {
@@ -34,22 +39,92 @@ static SymbolExtensions()
3439 throw new InvalidOperationException ( $ "Property HiddenMembers of type { overriddenOrHiddenMembersGetter . ReturnType } does not exist.") ;
3540 }
3641
37- var param1 = Expression . Parameter ( typeof ( IMethodSymbol ) , "methodSymbol" ) ;
38- var convertToMethodSymbol = Expression . Convert ( param1 , type ) ;
39- var callOverriddenOrHiddenMembersGetter = Expression . Call ( convertToMethodSymbol , overriddenOrHiddenMembersGetter ) ;
42+ var symbolParameter = Expression . Parameter ( typeof ( object ) ) ;
43+ var callOverriddenOrHiddenMembersGetter = Expression . Call (
44+ Expression . Convert ( symbolParameter , type ) ,
45+ overriddenOrHiddenMembersGetter ) ;
4046 var callHiddenMembersGetter = Expression . Call ( callOverriddenOrHiddenMembersGetter , hiddenMembersGetter ) ;
4147 var lambdaParams = new List < ParameterExpression >
4248 {
43- param1
49+ symbolParameter
4450 } ;
45- GetHiddenMembersFunc = Expression . Lambda < Func < IMethodSymbol , IEnumerable > > (
51+ GetHiddenMembersFunc = Expression . Lambda < Func < object , IEnumerable > > (
4652 Expression . Convert ( callHiddenMembersGetter , typeof ( IEnumerable ) ) , lambdaParams )
4753 . Compile ( ) ;
54+ #if ! LEGACY
55+ GetPublicSymbolFunc = CreateGetPublicSymbolFunction ( ) ;
56+ GetInternalSymbolFunc = CreateGetInternalSymbolFunction ( ) ;
57+ #endif
58+ }
59+
60+ #if ! LEGACY
61+ private static Func < object , ISymbol > CreateGetPublicSymbolFunction ( )
62+ {
63+ const string symbolInternalFullName = "Microsoft.CodeAnalysis.Symbols.ISymbolInternal, Microsoft.CodeAnalysis" ;
64+ var symbolInternalType = Type . GetType ( symbolInternalFullName ) ;
65+ if ( symbolInternalType == null )
66+ {
67+ throw new InvalidOperationException ( $ "Type { symbolInternalFullName } does not exist") ;
68+ }
69+
70+ var getISymbolMethod = symbolInternalType . GetMethod ( "GetISymbol" ) ;
71+ if ( getISymbolMethod == null )
72+ {
73+ throw new InvalidOperationException ( $ "Method GetISymbol of type { symbolInternalFullName } does not exist.") ;
74+ }
75+
76+ var wrapperParameter = Expression . Parameter ( typeof ( object ) ) ;
77+ return Expression . Lambda < Func < object , ISymbol > > (
78+ Expression . Call (
79+ Expression . Convert ( wrapperParameter , symbolInternalType ) ,
80+ getISymbolMethod
81+ ) ,
82+ wrapperParameter
83+ ) . Compile ( ) ;
84+ }
85+
86+ private static Func < ISymbol , object > CreateGetInternalSymbolFunction ( )
87+ {
88+ const string publicSymbolFullName = "Microsoft.CodeAnalysis.CSharp.Symbols.PublicModel.Symbol, Microsoft.CodeAnalysis.CSharp" ;
89+ var publicSymbolType = Type . GetType ( publicSymbolFullName ) ;
90+ if ( publicSymbolType == null )
91+ {
92+ throw new InvalidOperationException ( $ "Type { publicSymbolFullName } does not exist") ;
93+ }
94+
95+ var underlyingSymbolProperty = publicSymbolType . GetProperty ( "UnderlyingSymbol" , BindingFlags . NonPublic | BindingFlags . Instance ) ;
96+ if ( underlyingSymbolProperty == null )
97+ {
98+ throw new InvalidOperationException ( $ "Property UnderlyingSymbol of type { publicSymbolFullName } does not exist.") ;
99+ }
100+
101+ var symbolParameter = Expression . Parameter ( typeof ( ISymbol ) ) ;
102+ return Expression . Lambda < Func < ISymbol , object > > (
103+ Expression . Property (
104+ Expression . Convert ( symbolParameter , publicSymbolType ) ,
105+ underlyingSymbolProperty
106+ ) ,
107+ symbolParameter
108+ ) . Compile ( ) ;
48109 }
110+ #endif
49111
50112 internal static IEnumerable < IMethodSymbol > GetHiddenMethods ( this IMethodSymbol methodSymbol )
51113 {
114+ #if LEGACY
52115 return GetHiddenMembersFunc ( methodSymbol ) . OfType < IMethodSymbol > ( ) ;
116+ #else
117+ foreach ( var item in GetHiddenMembersFunc ( GetInternalSymbolFunc ( methodSymbol ) ) )
118+ {
119+ var hiddenMethod = GetPublicSymbolFunc ( item ) as IMethodSymbol ;
120+ if ( hiddenMethod == null )
121+ {
122+ continue ;
123+ }
124+
125+ yield return hiddenMethod ;
126+ }
127+ #endif
53128 }
54129
55130 /// <summary>
@@ -146,6 +221,15 @@ private static bool AreEqual(ImmutableArray<ITypeSymbol> types, ImmutableArray<I
146221 return true ;
147222 }
148223
224+ internal static bool EqualTo ( this ISymbol symbol , ISymbol symbolToCompare )
225+ {
226+ #if LEGACY
227+ return symbol . Equals ( symbolToCompare ) ;
228+ #else
229+ return symbol . Equals ( symbolToCompare , SymbolEqualityComparer . Default ) ;
230+ #endif
231+ }
232+
149233 /// <summary>
150234 /// Check if the return type matches, valid cases: <see cref="Void"/> to <see cref="System.Threading.Tasks.Task"/> Task, TResult to <see cref="System.Threading.Tasks.Task{TResult}"/> and
151235 /// also equals return types are ok when there is at least one delegate that can be converted to async (eg. Task.Run(<see cref="Action"/>) and Task.Run(<see cref="Func{Task}"/>))
@@ -156,7 +240,7 @@ private static bool AreEqual(ImmutableArray<ITypeSymbol> types, ImmutableArray<I
156240 internal static bool IsAsyncCandidateForReturnType ( this IMethodSymbol syncMethod , IMethodSymbol candidateAsyncMethod )
157241 {
158242 // Original definition is used for matching generic types
159- if ( syncMethod . ReturnType . OriginalDefinition . Equals ( candidateAsyncMethod . ReturnType . OriginalDefinition ) )
243+ if ( syncMethod . ReturnType . OriginalDefinition . EqualTo ( candidateAsyncMethod . ReturnType . OriginalDefinition ) )
160244 {
161245 return true ;
162246 }
@@ -220,7 +304,7 @@ internal static List<int> GetAsyncDelegateArgumentIndexes(this IMethodSymbol syn
220304 continue ;
221305 }
222306 var candidateDelegate = candidateTypeSymbol . DelegateInvokeMethod ;
223- if ( origDelegate . Equals ( candidateDelegate ) )
307+ if ( origDelegate . EqualTo ( candidateDelegate ) )
224308 {
225309 continue ;
226310 }
@@ -239,7 +323,7 @@ internal static List<int> GetAsyncDelegateArgumentIndexes(this IMethodSymbol syn
239323 /// <returns></returns>
240324 internal static bool AreEqual ( this ITypeSymbol type , ITypeSymbol toCompare , ITypeSymbol canBeDerivedFromType = null )
241325 {
242- if ( type . Equals ( toCompare ) )
326+ if ( type . EqualTo ( toCompare ) )
243327 {
244328 return true ;
245329 }
@@ -267,10 +351,10 @@ internal static bool AreEqual(this ITypeSymbol type, ITypeSymbol toCompare, ITyp
267351 return false ;
268352 }
269353 }
270- var equals = typeNamedType . OriginalDefinition . Equals ( toCompareNamedType . OriginalDefinition ) ;
354+ var equals = typeNamedType . OriginalDefinition . EqualTo ( toCompareNamedType . OriginalDefinition ) ;
271355 if ( ! equals && canBeDerivedFromType != null )
272356 {
273- equals = new [ ] { canBeDerivedFromType } . Concat ( canBeDerivedFromType . AllInterfaces ) . Any ( o => toCompareNamedType . OriginalDefinition . Equals ( o . OriginalDefinition ) ) ;
357+ equals = new [ ] { canBeDerivedFromType } . Concat ( canBeDerivedFromType . AllInterfaces ) . Any ( o => toCompareNamedType . OriginalDefinition . EqualTo ( o . OriginalDefinition ) ) ;
274358 }
275359 return equals ;
276360 }
@@ -286,22 +370,22 @@ internal static bool InheritsFromOrEquals(this ITypeSymbol type, ITypeSymbol bas
286370 return InheritsFromOrEquals ( type , baseType ) ;
287371 }
288372
289- return type . GetBaseTypesAndThis ( ) . Concat ( type . AllInterfaces ) . Any ( t => t . Equals ( baseType ) ) ;
373+ return type . GetBaseTypesAndThis ( ) . Concat ( type . AllInterfaces ) . Any ( t => t . EqualTo ( baseType ) ) ;
290374 }
291375
292376 // Determine if "type" inherits from "baseType", ignoring constructed types and interfaces, dealing
293377 // only with original types.
294378 internal static bool InheritsFromOrEquals ( this ITypeSymbol type , ITypeSymbol baseType )
295379 {
296- return type . GetBaseTypesAndThis ( ) . Any ( t => t . Equals ( baseType ) ) ;
380+ return type . GetBaseTypesAndThis ( ) . Any ( t => t . EqualTo ( baseType ) ) ;
297381 }
298382
299383 // Determine if "type" inherits from "baseType", ignoring constructed types, and dealing
300384 // only with original types.
301385 internal static bool InheritsFromOrEqualsIgnoringConstruction ( this ITypeSymbol type , ITypeSymbol baseType )
302386 {
303387 var originalBaseType = baseType . OriginalDefinition ;
304- return type . GetBaseTypesAndThis ( ) . Any ( t => t . OriginalDefinition . Equals ( originalBaseType ) ) ;
388+ return type . GetBaseTypesAndThis ( ) . Any ( t => t . OriginalDefinition . EqualTo ( originalBaseType ) ) ;
305389 }
306390
307391 internal static IEnumerable < ITypeSymbol > GetBaseTypesAndThis ( this ITypeSymbol type )
0 commit comments