11
11
using Microsoft . CodeAnalysis ;
12
12
using Microsoft . CodeAnalysis . FindSymbols ;
13
13
using static Microsoft . CodeAnalysis . CSharp . SyntaxFactory ;
14
+ using System . Diagnostics ;
15
+ using System . Reflection . Metadata ;
16
+ using Microsoft . CodeAnalysis . Editing ;
17
+ using System . Xml . Linq ;
18
+ using Microsoft . Extensions . Logging ;
19
+ using Silk . NET . SilkTouch . Clang ;
14
20
15
21
namespace Silk . NET . SilkTouch . Mods
16
22
{
17
23
/// <summary>
18
24
/// A mod to modify COM interface structs into opaque structs that act like ComPtr objects
19
25
/// </summary>
26
+ /// <param name="logger">The logger to use.</param>
20
27
[ ModConfiguration < Config > ]
21
- public class TransformCOM : Mod
28
+ public class TransformCOM ( ILogger < TransformCOM > logger ) : Mod
22
29
{
23
30
/// <summary>
24
31
/// The configuration for the <see cref="TransformCOM"/> mod.
@@ -39,40 +46,79 @@ public override async Task ExecuteAsync(IModContext ctx, CancellationToken ct =
39
46
{
40
47
return ;
41
48
}
42
- Compilation ? comp = await proj . GetCompilationAsync ( ) ;
43
- IEnumerable < ISymbol > ? unknowns = comp ? . GetSymbolsWithName ( "IUnknown" ) ;
44
- //foreach (var doc in ctx.SourceProject?.Documents ?? [])
45
- //{
46
- // if (await doc.GetSyntaxRootAsync(ct) is { } root)
47
- // {
48
- // firstPass.Visit(root);
49
- // }
50
- //}
51
49
52
- if ( unknowns is null || unknowns . Count ( ) == 0 )
50
+ logger . LogInformation ( "Starting COM Object Collection" ) ;
51
+ foreach ( var docId in proj ? . DocumentIds ?? [ ] )
53
52
{
54
- return ;
53
+ var doc =
54
+ proj ? . GetDocument ( docId ) ?? throw new InvalidOperationException ( "Document missing" ) ;
55
+ if ( await doc . GetSyntaxRootAsync ( ct ) is not { } root )
56
+ {
57
+ continue ;
58
+ }
59
+
60
+ firstPass . Visit ( root ) ;
55
61
}
56
62
57
- IEnumerable < ISymbol > ? symbols = await SymbolFinder . FindImplementationsAsync ( unknowns . First ( ) , proj . Solution , [ proj ] , ct ) ;
58
-
59
- List < ( string , bool ) > COMTypes = firstPass . FoundCOMTypes ;
63
+ firstPass . FoundCOMTypes = firstPass . FoundCOMTypes . Where ( val => val . Item2 ) . ToList ( ) ;
64
+
65
+ var rewriter = new Rewriter ( firstPass . FoundCOMTypes ) ;
66
+ int index = 0 ;
67
+ int count = proj ? . DocumentIds . Count ?? 0 ;
68
+ logger . LogInformation ( "Starting COM Object Rewrite" ) ;
69
+ foreach ( var docId in proj ? . DocumentIds ?? [ ] )
70
+ {
71
+ index ++ ;
72
+ var doc =
73
+ proj ? . GetDocument ( docId ) ?? throw new InvalidOperationException ( "Document missing" ) ;
74
+ if ( await doc . GetSyntaxRootAsync ( ct ) is not { } root )
75
+ {
76
+ continue ;
77
+ }
78
+
79
+ doc = doc . WithSyntaxRoot ( rewriter . Visit ( root ) . NormalizeWhitespace ( ) ) ;
80
+
81
+ proj = doc . Project ;
60
82
61
- Dictionary < string , CompilationUnitSyntax > duplicates = new ( ) ;
83
+ logger . LogInformation ( "COM Object Rewrite for {0} Complete ({1}/{2})" , doc . Name , index , count ) ;
84
+ }
62
85
63
- var rewriter = new Rewriter ( COMTypes ) ;
86
+ index = 0 ;
87
+ logger . LogInformation ( "Starting COM Object Usage Update" ) ;
64
88
foreach ( var docId in proj ? . DocumentIds ?? [ ] )
65
89
{
90
+ index ++ ;
66
91
var doc =
67
92
proj ? . GetDocument ( docId ) ?? throw new InvalidOperationException ( "Document missing" ) ;
68
93
if ( await doc . GetSyntaxRootAsync ( ct ) is not { } root )
69
94
{
70
95
continue ;
71
96
}
72
97
98
+ var semanticModel = await doc . GetSemanticModelAsync ( ) ;
99
+ var editor = new SyntaxEditor ( root , proj . Solution . Workspace . Services ) ;
100
+ // Replace pointer member access -> with regular member access .
101
+ var memberAccesses = root . DescendantNodes ( )
102
+ . OfType < MemberAccessExpressionSyntax > ( )
103
+ . Where ( m => m . Expression is PrefixUnaryExpressionSyntax pues && pues . IsKind ( SyntaxKind . PointerMemberAccessExpression ) ) ;
104
+
105
+ foreach ( var memberAccess in memberAccesses )
106
+ {
107
+ var pointerIndirection = ( PrefixUnaryExpressionSyntax ) memberAccess . Expression ;
108
+ var typeInfo = semanticModel . GetTypeInfo ( pointerIndirection . Operand ) ;
109
+ // Check if the type is a ComType
110
+ if ( typeInfo . Type != null && firstPass . FoundCOMTypes . Any ( type => type . Item1 == typeInfo . Type . ToDisplayString ( ) ) )
111
+ {
112
+ var newMemberAccess = SyntaxFactory . MemberAccessExpression ( SyntaxKind . SimpleMemberAccessExpression , pointerIndirection . Operand , memberAccess . Name ) ;
113
+ editor . ReplaceNode ( memberAccess , newMemberAccess ) ;
114
+ }
115
+ }
116
+
73
117
doc = doc . WithSyntaxRoot ( rewriter . Visit ( root ) . NormalizeWhitespace ( ) ) ;
74
118
75
119
proj = doc . Project ;
120
+
121
+ logger . LogInformation ( "COM Object Usage Update for {0} Complete ({1}/{2})" , doc . Name , index , count ) ;
76
122
}
77
123
78
124
ctx . SourceProject = proj ;
@@ -172,7 +218,8 @@ class Rewriter(List<(string, bool)> ComTypes)
172
218
for ( int i = 0 ; i < ComTypes . Count ; i ++ )
173
219
{
174
220
( string , bool ) val = ComTypes [ i ] ;
175
- if ( val . Item1 == node . ElementType . ToString ( ) && val . Item2 )
221
+
222
+ if ( val . Item1 == node . ElementType . ToString ( ) )
176
223
{
177
224
return IdentifierName ( val . Item1 ) ;
178
225
}
@@ -183,15 +230,15 @@ class Rewriter(List<(string, bool)> ComTypes)
183
230
184
231
public override SyntaxNode VisitGenericName ( GenericNameSyntax node ) => node ;
185
232
186
- public override SyntaxNode ? VisitIdentifierName ( IdentifierNameSyntax node )
187
- {
188
- if ( node . Identifier . ToString ( ) == "lpVtbl" )
189
- {
190
- return ParenthesizedExpression ( PrefixUnaryExpression ( SyntaxKind . PointerIndirectionExpression , node ) ) ;
191
- }
233
+ // public override SyntaxNode? VisitIdentifierName(IdentifierNameSyntax node)
234
+ // {
235
+ // if (node.Identifier.ToString() == "lpVtbl")
236
+ // {
237
+ // return ParenthesizedExpression(PrefixUnaryExpression(SyntaxKind.PointerIndirectionExpression, node));
238
+ // }
192
239
193
- return base . VisitIdentifierName ( node ) ;
194
- }
240
+ // return base.VisitIdentifierName(node);
241
+ // }
195
242
196
243
public override SyntaxNode ? VisitVariableDeclaration ( VariableDeclarationSyntax node )
197
244
{
@@ -208,8 +255,6 @@ class Rewriter(List<(string, bool)> ComTypes)
208
255
{
209
256
var ret = base . VisitInterfaceDeclaration ( node ) ;
210
257
211
-
212
-
213
258
if ( ret is InterfaceDeclarationSyntax inter && inter . BaseList is not null && inter . BaseList . Types . Any ( baseType => baseType . Type . ToString ( ) . StartsWith ( "I" ) && baseType . Type . ToString ( ) . EndsWith ( ".Interface" ) ) )
214
259
{
215
260
List < BaseTypeSyntax > baseTypes = [ ] ;
@@ -245,7 +290,8 @@ class Rewriter(List<(string, bool)> ComTypes)
245
290
for ( int i = 0 ; i < ComTypes . Count ; i ++ )
246
291
{
247
292
( string , bool ) val = ComTypes [ i ] ;
248
- if ( castType == $ "{ val . Item1 } *" && val . Item2 )
293
+
294
+ if ( castType == $ "{ val . Item1 } *")
249
295
{
250
296
return ThisExpression ( ) ;
251
297
}
0 commit comments