Skip to content

Commit 5f17c4e

Browse files
committed
TransformCOM update
-Setup to run using new mod system properly -Added pointer member accessor to member accessor conversion -Added Logging
1 parent 5593b03 commit 5f17c4e

File tree

1 file changed

+75
-29
lines changed

1 file changed

+75
-29
lines changed

sources/SilkTouch/SilkTouch/Mods/TransformCOM.cs

Lines changed: 75 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,21 @@
1111
using Microsoft.CodeAnalysis;
1212
using Microsoft.CodeAnalysis.FindSymbols;
1313
using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory;
14+
using System.Diagnostics;
15+
using System.Reflection.Metadata;
16+
using Microsoft.CodeAnalysis.Editing;
17+
using System.Xml.Linq;
18+
using Microsoft.Extensions.Logging;
19+
using Silk.NET.SilkTouch.Clang;
1420

1521
namespace Silk.NET.SilkTouch.Mods
1622
{
1723
/// <summary>
1824
/// A mod to modify COM interface structs into opaque structs that act like ComPtr objects
1925
/// </summary>
26+
/// <param name="logger">The logger to use.</param>
2027
[ModConfiguration<Config>]
21-
public class TransformCOM : Mod
28+
public class TransformCOM(ILogger<TransformCOM> logger) : Mod
2229
{
2330
/// <summary>
2431
/// The configuration for the <see cref="TransformCOM"/> mod.
@@ -39,40 +46,79 @@ public override async Task ExecuteAsync(IModContext ctx, CancellationToken ct =
3946
{
4047
return;
4148
}
42-
Compilation? comp = await proj.GetCompilationAsync();
43-
IEnumerable<ISymbol>? unknowns = comp?.GetSymbolsWithName("IUnknown");
44-
//foreach (var doc in ctx.SourceProject?.Documents ?? [])
45-
//{
46-
// if (await doc.GetSyntaxRootAsync(ct) is { } root)
47-
// {
48-
// firstPass.Visit(root);
49-
// }
50-
//}
5149

52-
if (unknowns is null || unknowns.Count() == 0)
50+
logger.LogInformation("Starting COM Object Collection");
51+
foreach (var docId in proj?.DocumentIds ?? [])
5352
{
54-
return;
53+
var doc =
54+
proj?.GetDocument(docId) ?? throw new InvalidOperationException("Document missing");
55+
if (await doc.GetSyntaxRootAsync(ct) is not { } root)
56+
{
57+
continue;
58+
}
59+
60+
firstPass.Visit(root);
5561
}
5662

57-
IEnumerable<ISymbol>? symbols = await SymbolFinder.FindImplementationsAsync(unknowns.First(), proj.Solution, [proj], ct);
58-
59-
List<(string, bool)> COMTypes = firstPass.FoundCOMTypes;
63+
firstPass.FoundCOMTypes = firstPass.FoundCOMTypes.Where(val => val.Item2).ToList();
64+
65+
var rewriter = new Rewriter(firstPass.FoundCOMTypes);
66+
int index = 0;
67+
int count = proj?.DocumentIds.Count ?? 0;
68+
logger.LogInformation("Starting COM Object Rewrite");
69+
foreach (var docId in proj?.DocumentIds ?? [])
70+
{
71+
index++;
72+
var doc =
73+
proj?.GetDocument(docId) ?? throw new InvalidOperationException("Document missing");
74+
if (await doc.GetSyntaxRootAsync(ct) is not { } root)
75+
{
76+
continue;
77+
}
78+
79+
doc = doc.WithSyntaxRoot(rewriter.Visit(root).NormalizeWhitespace());
80+
81+
proj = doc.Project;
6082

61-
Dictionary<string, CompilationUnitSyntax> duplicates = new();
83+
logger.LogInformation("COM Object Rewrite for {0} Complete ({1}/{2})", doc.Name, index, count);
84+
}
6285

63-
var rewriter = new Rewriter(COMTypes);
86+
index = 0;
87+
logger.LogInformation("Starting COM Object Usage Update");
6488
foreach (var docId in proj?.DocumentIds ?? [])
6589
{
90+
index++;
6691
var doc =
6792
proj?.GetDocument(docId) ?? throw new InvalidOperationException("Document missing");
6893
if (await doc.GetSyntaxRootAsync(ct) is not { } root)
6994
{
7095
continue;
7196
}
7297

98+
var semanticModel = await doc.GetSemanticModelAsync();
99+
var editor = new SyntaxEditor(root, proj.Solution.Workspace.Services);
100+
// Replace pointer member access -> with regular member access .
101+
var memberAccesses = root.DescendantNodes()
102+
.OfType<MemberAccessExpressionSyntax>()
103+
.Where(m => m.Expression is PrefixUnaryExpressionSyntax pues && pues.IsKind(SyntaxKind.PointerMemberAccessExpression));
104+
105+
foreach (var memberAccess in memberAccesses)
106+
{
107+
var pointerIndirection = (PrefixUnaryExpressionSyntax)memberAccess.Expression;
108+
var typeInfo = semanticModel.GetTypeInfo(pointerIndirection.Operand);
109+
// Check if the type is a ComType
110+
if (typeInfo.Type != null && firstPass.FoundCOMTypes.Any(type => type.Item1 == typeInfo.Type.ToDisplayString()))
111+
{
112+
var newMemberAccess = SyntaxFactory.MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, pointerIndirection.Operand, memberAccess.Name);
113+
editor.ReplaceNode(memberAccess, newMemberAccess);
114+
}
115+
}
116+
73117
doc = doc.WithSyntaxRoot(rewriter.Visit(root).NormalizeWhitespace());
74118

75119
proj = doc.Project;
120+
121+
logger.LogInformation("COM Object Usage Update for {0} Complete ({1}/{2})", doc.Name, index, count);
76122
}
77123

78124
ctx.SourceProject = proj;
@@ -172,7 +218,8 @@ class Rewriter(List<(string, bool)> ComTypes)
172218
for (int i = 0; i < ComTypes.Count; i++)
173219
{
174220
(string, bool) val = ComTypes[i];
175-
if (val.Item1 == node.ElementType.ToString() && val.Item2)
221+
222+
if (val.Item1 == node.ElementType.ToString())
176223
{
177224
return IdentifierName(val.Item1);
178225
}
@@ -183,15 +230,15 @@ class Rewriter(List<(string, bool)> ComTypes)
183230

184231
public override SyntaxNode VisitGenericName(GenericNameSyntax node) => node;
185232

186-
public override SyntaxNode? VisitIdentifierName(IdentifierNameSyntax node)
187-
{
188-
if (node.Identifier.ToString() == "lpVtbl")
189-
{
190-
return ParenthesizedExpression(PrefixUnaryExpression(SyntaxKind.PointerIndirectionExpression, node));
191-
}
233+
//public override SyntaxNode? VisitIdentifierName(IdentifierNameSyntax node)
234+
//{
235+
// if (node.Identifier.ToString() == "lpVtbl")
236+
// {
237+
// return ParenthesizedExpression(PrefixUnaryExpression(SyntaxKind.PointerIndirectionExpression, node));
238+
// }
192239

193-
return base.VisitIdentifierName(node);
194-
}
240+
// return base.VisitIdentifierName(node);
241+
//}
195242

196243
public override SyntaxNode? VisitVariableDeclaration(VariableDeclarationSyntax node)
197244
{
@@ -208,8 +255,6 @@ class Rewriter(List<(string, bool)> ComTypes)
208255
{
209256
var ret = base.VisitInterfaceDeclaration(node);
210257

211-
212-
213258
if (ret is InterfaceDeclarationSyntax inter && inter.BaseList is not null && inter.BaseList.Types.Any(baseType => baseType.Type.ToString().StartsWith("I") && baseType.Type.ToString().EndsWith(".Interface")))
214259
{
215260
List<BaseTypeSyntax> baseTypes = [];
@@ -245,7 +290,8 @@ class Rewriter(List<(string, bool)> ComTypes)
245290
for (int i = 0; i < ComTypes.Count; i++)
246291
{
247292
(string, bool) val = ComTypes[i];
248-
if (castType == $"{val.Item1}*" && val.Item2)
293+
294+
if (castType == $"{val.Item1}*")
249295
{
250296
return ThisExpression();
251297
}

0 commit comments

Comments
 (0)