|
6 | 6 |
|
7 | 7 | namespace Hyperbee.XS.Core; |
8 | 8 |
|
9 | | -public sealed class TypeResolver |
| 9 | +public interface ITypeResolver |
| 10 | +{ |
| 11 | + Type ResolveType( string typeName ); |
| 12 | + MethodInfo ResolveMethod( Type type, string methodName, IReadOnlyList<Type> typeArgs, IReadOnlyList<Expression> args ); |
| 13 | + MemberInfo ResolveMember( Type type, string memberName ); |
| 14 | +} |
| 15 | +public interface ITypeRewriter |
| 16 | +{ |
| 17 | + Expression RewriteIndexerExpression( Expression targetExpression, IReadOnlyList<Expression> indexes ); |
| 18 | + Expression RewriteMemberExpression( Expression targetExpression, string name, IReadOnlyList<Type> typeArgs, IReadOnlyList<Expression> args ); |
| 19 | +} |
| 20 | + |
| 21 | +public class TypeResolver : ITypeResolver, ITypeRewriter |
10 | 22 | { |
11 | 23 | public ReferenceManager ReferenceManager { get; } |
12 | 24 |
|
@@ -53,7 +65,7 @@ public TypeResolver( ReferenceManager referenceManager ) |
53 | 65 | ReferenceManager = referenceManager; |
54 | 66 | } |
55 | 67 |
|
56 | | - public Type ResolveType( string typeName ) |
| 68 | + public virtual Type ResolveType( string typeName ) |
57 | 69 | { |
58 | 70 | return _typeCache.GetOrAdd( typeName, _ => |
59 | 71 | { |
@@ -93,7 +105,7 @@ static Type GetTypeFromKeyword( string typeName ) |
93 | 105 | } |
94 | 106 | } |
95 | 107 |
|
96 | | - public MethodInfo ResolveMethod( Type type, string methodName, IReadOnlyList<Type> typeArgs, IReadOnlyList<Expression> args ) |
| 108 | + public virtual MethodInfo ResolveMethod( Type type, string methodName, IReadOnlyList<Type> typeArgs, IReadOnlyList<Expression> args ) |
97 | 109 | { |
98 | 110 | var candidateMethods = GetCandidateMethods( methodName, type ); |
99 | 111 | var callerTypes = GetCallerTypes( type, args ); |
@@ -153,6 +165,106 @@ public MethodInfo ResolveMethod( Type type, string methodName, IReadOnlyList<Typ |
153 | 165 | return bestMatch; |
154 | 166 | } |
155 | 167 |
|
| 168 | + public virtual MemberInfo ResolveMember( Type type, string memberName ) |
| 169 | + { |
| 170 | + const BindingFlags BindingAttr = BindingFlags.Public | BindingFlags.Instance | BindingFlags.Static; |
| 171 | + return type.GetMember( memberName, BindingAttr ).FirstOrDefault(); |
| 172 | + } |
| 173 | + |
| 174 | + public virtual Expression RewriteIndexerExpression( Expression targetExpression, IReadOnlyList<Expression> indexes ) |
| 175 | + { |
| 176 | + var indexers = targetExpression.Type.GetProperties() |
| 177 | + .Where( p => p.GetIndexParameters().Length == indexes.Count ) |
| 178 | + .ToArray(); |
| 179 | + |
| 180 | + if ( indexers.Length == 0 ) |
| 181 | + return Expression.ArrayAccess( targetExpression, indexes ); |
| 182 | + |
| 183 | + // Find the best match based on parameter types |
| 184 | + var indexer = indexers.FirstOrDefault( p => |
| 185 | + p.GetIndexParameters() |
| 186 | + .Select( param => param.ParameterType ) |
| 187 | + .SequenceEqual( indexes.Select( i => i.Type ) ) ); |
| 188 | + |
| 189 | + if ( indexer == null ) |
| 190 | + { |
| 191 | + throw new InvalidOperationException( |
| 192 | + $"No matching indexer found on type '{targetExpression.Type}' with parameter types: " + |
| 193 | + $"{string.Join( ", ", indexes.Select( i => i.Type.Name ) )}." ); |
| 194 | + } |
| 195 | + |
| 196 | + return Expression.Property( targetExpression, indexer, indexes.ToArray() ); |
| 197 | + } |
| 198 | + |
| 199 | + public virtual Expression RewriteMemberExpression( Expression targetExpression, string name, IReadOnlyList<Type> typeArgs, IReadOnlyList<Expression> args ) |
| 200 | + { |
| 201 | + var type = TypeOf( targetExpression ); |
| 202 | + |
| 203 | + // method |
| 204 | + |
| 205 | + if ( args != null ) |
| 206 | + { |
| 207 | + var method = ResolveMethod( type, name, typeArgs, args ); |
| 208 | + |
| 209 | + if ( method == null ) |
| 210 | + throw new InvalidOperationException( $"Method '{name}' not found on type '{type}'." ); |
| 211 | + |
| 212 | + var arguments = GetArgumentsWithDefaults( method, targetExpression, args ); |
| 213 | + |
| 214 | + return method.IsStatic |
| 215 | + ? Expression.Call( method, arguments ) |
| 216 | + : Expression.Call( targetExpression, method, arguments ); |
| 217 | + } |
| 218 | + |
| 219 | + // property or field |
| 220 | + |
| 221 | + var member = ResolveMember( type, name ); |
| 222 | + |
| 223 | + if ( member == null ) |
| 224 | + throw new InvalidOperationException( $"Member '{name}' not found on type '{type}'." ); |
| 225 | + |
| 226 | + return member switch |
| 227 | + { |
| 228 | + PropertyInfo property => Expression.Property( targetExpression, property ), |
| 229 | + FieldInfo field => Expression.Field( targetExpression, field ), |
| 230 | + _ => throw new InvalidOperationException( $"Member '{name}' is not a property or field." ) |
| 231 | + }; |
| 232 | + |
| 233 | + static IReadOnlyList<Expression> GetArgumentsWithDefaults( MethodInfo method, Expression targetExpression, IReadOnlyList<Expression> providedArgs ) |
| 234 | + { |
| 235 | + var parameters = method.GetParameters(); |
| 236 | + var isExtension = method.IsDefined( typeof( ExtensionAttribute ), false ); |
| 237 | + |
| 238 | + var providedOffset = isExtension ? 1 : 0; |
| 239 | + var providedCount = providedArgs.Count; |
| 240 | + var totalParameters = parameters.Length; |
| 241 | + |
| 242 | + if ( providedCount == totalParameters ) |
| 243 | + return providedArgs; |
| 244 | + |
| 245 | + var methodArgs = new Expression[totalParameters]; |
| 246 | + |
| 247 | + // add provided arguments |
| 248 | + if ( isExtension ) |
| 249 | + methodArgs[0] = targetExpression; |
| 250 | + |
| 251 | + for ( var i = 0; i < providedCount; i++ ) |
| 252 | + { |
| 253 | + methodArgs[i + providedOffset] = providedArgs[i]; |
| 254 | + } |
| 255 | + |
| 256 | + // add missing optional parameters |
| 257 | + for ( var i = providedCount + providedOffset; i < totalParameters; i++ ) |
| 258 | + { |
| 259 | + methodArgs[i] = parameters[i].HasDefaultValue |
| 260 | + ? Expression.Constant( parameters[i].DefaultValue, parameters[i].ParameterType ) |
| 261 | + : throw new ArgumentException( $"Missing required parameter: {parameters[i].Name}" ); |
| 262 | + } |
| 263 | + |
| 264 | + return methodArgs; |
| 265 | + } |
| 266 | + } |
| 267 | + |
156 | 268 | public void RegisterExtensionMethods( IEnumerable<Assembly> assemblies ) |
157 | 269 | { |
158 | 270 | Parallel.ForEach( assemblies, assembly => |
@@ -386,5 +498,18 @@ private static int CompareMethods( MethodInfo m1, MethodInfo m2 ) |
386 | 498 |
|
387 | 499 | return p1.Length.CompareTo( p2.Length ); |
388 | 500 | } |
| 501 | + |
| 502 | + [MethodImpl( MethodImplOptions.AggressiveInlining )] |
| 503 | + private static Type TypeOf( Expression expression ) |
| 504 | + { |
| 505 | + ArgumentNullException.ThrowIfNull( expression, nameof( expression ) ); |
| 506 | + |
| 507 | + return expression switch |
| 508 | + { |
| 509 | + ConstantExpression ce => ce.Value as Type ?? ce.Type, |
| 510 | + _ => expression.Type |
| 511 | + }; |
| 512 | + } |
| 513 | + |
389 | 514 | } |
390 | 515 |
|
0 commit comments