Skip to content

Commit cec591f

Browse files
committed
Initial Work to Add COM transformations
1 parent 452bda9 commit cec591f

File tree

4 files changed

+277
-8
lines changed

4 files changed

+277
-8
lines changed

generator.json

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@
22
"Jobs": {
33
"Microsoft": {
44
"Solution": "Silk.NET.sln",
5-
"SourceProject": "TODO.csproj",
6-
"TestProject": "tests/TODO.csproj",
5+
"SourceProject": "sources/Win32/Win32/Silk.NET.Win32.csproj",
6+
"TestProject": "tests/Win32/Win32/Silk.NET.Win32.UnitTests.csproj",
77
"DefaultLicenseHeader": "eng/silktouch/header.txt",
88
"Mods": [
9-
"AddIncludes",
109
"ClangScraper",
11-
"ChangeNamespace"
10+
"ChangeNamespace",
11+
"TransformCOM"
1212
],
1313
"ClangScraper": {
1414
"ClangSharpResponseFiles": [
@@ -24,7 +24,8 @@
2424
"InputTestRoot": "eng/submodules/terrafx.interop.windows/tests/Interop/Windows",
2525
"SkipScrapeIf": [
2626
"!win"
27-
]
27+
],
28+
"CacheOutput": false
2829
},
2930
"ChangeNamespace": {
3031
"Mappings": {

sources/SilkTouch/SilkTouch/Clang/ClangScraper.cs

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Licensed to the .NET Foundation under one or more agreements.
1+
// Licensed to the .NET Foundation under one or more agreements.
22
// The .NET Foundation licenses this file to you under the MIT license.
33

44
using System;
@@ -59,6 +59,12 @@ public record Configuration
5959
/// </summary>
6060
public required string[] ClangSharpResponseFiles { get; init; }
6161

62+
/// <summary>
63+
/// Whether or not to Cache the output when scraping
64+
/// Caching currently fails on large multithreaded jobs
65+
/// </summary>
66+
public bool CacheOutput { get; init; } = true;
67+
6268
/// <summary>
6369
/// Manual overrides for ClangSharp outputs (i.e. manual tweaks of generated output) that should still flow through
6470
/// the SilkTouch pipeline as if it came from ClangSharp.
@@ -300,7 +306,9 @@ static MemoryStream Reopen(MemoryStream ms) =>
300306
.TrimEnd('/');
301307

302308
// Cache the output.
303-
if (cacheKey is not null && !hasErrors)
309+
//TODO: Refactor for better Parallelisation
310+
//Breaks with high concurrency
311+
if (cacheKey is not null && !hasErrors && cfg.CacheOutput)
304312
{
305313
cacheDir ??= (
306314
await cacheProvider!.GetDirectory(
@@ -381,6 +389,7 @@ await tree.GetRootAsync(ct),
381389
filePath: src.FullPath(fname)
382390
)
383391
.Project;
392+
logger.LogDebug($"Add Src Document {fname}");
384393
}
385394

386395
job.SourceProject = src;

sources/SilkTouch/SilkTouch/Mods/Common/ModLoader.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using System;
1+
using System;
22
using Silk.NET.SilkTouch.Clang;
33

44
namespace Silk.NET.SilkTouch.Mods;
@@ -29,6 +29,7 @@ public class ModLoader
2929
nameof(ExtractNestedTyping) => typeof(ExtractNestedTyping),
3030
nameof(TransformProperties) => typeof(TransformProperties),
3131
nameof(ClangScraper) => typeof(ClangScraper),
32+
nameof(TransformCOM) => typeof(TransformCOM),
3233
_ => null,
3334
};
3435
}
Lines changed: 258 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,258 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
4+
using System;
5+
using System.Collections.Generic;
6+
using System.Linq;
7+
using System.Text;
8+
using System.Threading.Tasks;
9+
using Microsoft.CodeAnalysis.CSharp.Syntax;
10+
using Microsoft.CodeAnalysis.CSharp;
11+
using Microsoft.CodeAnalysis;
12+
using Microsoft.CodeAnalysis.FindSymbols;
13+
using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory;
14+
15+
namespace Silk.NET.SilkTouch.Mods
16+
{
17+
/// <summary>
18+
/// A mod to modify COM interface structs into opaque structs that act like ComPtr objects
19+
/// </summary>
20+
[ModConfiguration<Config>]
21+
public class TransformCOM : Mod
22+
{
23+
/// <summary>
24+
/// The configuration for the <see cref="TransformCOM"/> mod.
25+
/// </summary>
26+
public class Config
27+
{
28+
29+
}
30+
31+
/// <inheritdoc />
32+
public override async Task ExecuteAsync(IModContext ctx, CancellationToken ct = default)
33+
{
34+
await base.ExecuteAsync(ctx, ct);
35+
36+
var firstPass = new TypeDiscoverer();
37+
var proj = ctx.SourceProject;
38+
if (proj is null)
39+
{
40+
return;
41+
}
42+
Compilation? comp = await proj.GetCompilationAsync();
43+
IEnumerable<ISymbol>? unknowns = comp?.GetSymbolsWithName("IUnknown");
44+
//foreach (var doc in ctx.SourceProject?.Documents ?? [])
45+
//{
46+
// if (await doc.GetSyntaxRootAsync(ct) is { } root)
47+
// {
48+
// firstPass.Visit(root);
49+
// }
50+
//}
51+
52+
if (unknowns is null || unknowns.Count() == 0)
53+
{
54+
return;
55+
}
56+
57+
IEnumerable<ISymbol>? symbols = await SymbolFinder.FindImplementationsAsync(unknowns.First(), proj.Solution, [proj], ct);
58+
59+
List<(string, bool)> COMTypes = firstPass.FoundCOMTypes;
60+
61+
Dictionary<string, CompilationUnitSyntax> duplicates = new();
62+
63+
var rewriter = new Rewriter(COMTypes);
64+
foreach (var docId in proj?.DocumentIds ?? [])
65+
{
66+
var doc =
67+
proj?.GetDocument(docId) ?? throw new InvalidOperationException("Document missing");
68+
if (await doc.GetSyntaxRootAsync(ct) is not { } root)
69+
{
70+
continue;
71+
}
72+
73+
doc = doc.WithSyntaxRoot(rewriter.Visit(root).NormalizeWhitespace());
74+
75+
proj = doc.Project;
76+
}
77+
78+
ctx.SourceProject = proj;
79+
}
80+
81+
class TypeDiscoverer : CSharpSyntaxWalker
82+
{
83+
private Dictionary<string, List<(string, bool)>> _interfaceParenting = new Dictionary<string, List<(string, bool)>>();
84+
85+
/// <summary>
86+
/// The list of known COM interface types
87+
/// (name of type, is it a struct?)
88+
/// </summary>
89+
public List<(string, bool)> FoundCOMTypes = [];
90+
91+
public override void VisitStructDeclaration(StructDeclarationSyntax node)
92+
{
93+
base.VisitStructDeclaration(node);
94+
95+
var bases = node.BaseList;
96+
97+
if (bases is null)
98+
{
99+
return;
100+
}
101+
102+
var className = $"{node.Identifier}";
103+
104+
CheckBases((className, true), bases);
105+
}
106+
107+
108+
public override void VisitInterfaceDeclaration(InterfaceDeclarationSyntax node)
109+
{
110+
base.VisitInterfaceDeclaration(node);
111+
112+
var bases = node.BaseList;
113+
114+
if (bases is null)
115+
{
116+
return;
117+
}
118+
119+
string name = $"{node.Identifier}";
120+
if (name == "Interface")
121+
{
122+
var parent = node.Parent as StructDeclarationSyntax;
123+
if (parent is not null)
124+
name = $"{parent.Identifier}.{name}";
125+
}
126+
127+
CheckBases((name, false), bases);
128+
}
129+
130+
private void CheckBases((string, bool) className, BaseListSyntax bases)
131+
{
132+
if (bases.Types.Any(baseType => baseType.Type.ToString() == "IUnknown.Interface" || FoundCOMTypes.Any(val => val.Item1 == $"{baseType.Type}")))
133+
{
134+
COMTypeValidated(className);
135+
return;
136+
}
137+
138+
foreach (BaseTypeSyntax baseType in bases.Types)
139+
{
140+
string fullName = $"{baseType.Type}";
141+
if (!_interfaceParenting.ContainsKey(fullName))
142+
_interfaceParenting.Add(fullName, new());
143+
144+
_interfaceParenting[fullName].Add(className);
145+
}
146+
}
147+
148+
private void COMTypeValidated((string, bool) typeName)
149+
{
150+
if (FoundCOMTypes.Contains(typeName))
151+
{
152+
return;
153+
}
154+
155+
FoundCOMTypes.Add(typeName);
156+
157+
if (!_interfaceParenting.TryGetValue(typeName.Item1, out List<(string, bool)>? children))
158+
return;
159+
160+
foreach ((string, bool) childName in children)
161+
{
162+
COMTypeValidated(childName);
163+
}
164+
}
165+
}
166+
167+
class Rewriter(List<(string, bool)> ComTypes)
168+
: CSharpSyntaxRewriter
169+
{
170+
public override SyntaxNode? VisitPointerType(PointerTypeSyntax node)
171+
{
172+
for (int i = 0; i < ComTypes.Count; i++)
173+
{
174+
(string, bool) val = ComTypes[i];
175+
if (val.Item1 == node.ElementType.ToString() && val.Item2)
176+
{
177+
return IdentifierName(val.Item1);
178+
}
179+
}
180+
181+
return base.VisitPointerType(node);
182+
}
183+
184+
public override SyntaxNode VisitGenericName(GenericNameSyntax node) => node;
185+
186+
public override SyntaxNode? VisitIdentifierName(IdentifierNameSyntax node)
187+
{
188+
if (node.Identifier.ToString() == "lpVtbl")
189+
{
190+
return ParenthesizedExpression(PrefixUnaryExpression(SyntaxKind.PointerIndirectionExpression, node));
191+
}
192+
193+
return base.VisitIdentifierName(node);
194+
}
195+
196+
public override SyntaxNode? VisitVariableDeclaration(VariableDeclarationSyntax node)
197+
{
198+
if (node.Type.ToString() == "void**" && node.Variables.First().Identifier.ToString() == "lpVtbl")
199+
{
200+
return VariableDeclaration(PointerType(PointerType(PointerType(IdentifierName("void")))))
201+
.AddVariables(VariableDeclarator("lpVtbl"));
202+
}
203+
204+
return base.VisitVariableDeclaration(node);
205+
}
206+
207+
public override SyntaxNode? VisitInterfaceDeclaration(InterfaceDeclarationSyntax node)
208+
{
209+
var ret = base.VisitInterfaceDeclaration(node);
210+
211+
212+
213+
if (ret is InterfaceDeclarationSyntax inter && inter.BaseList is not null && inter.BaseList.Types.Any(baseType => baseType.Type.ToString().StartsWith("I") && baseType.Type.ToString().EndsWith(".Interface")))
214+
{
215+
List<BaseTypeSyntax> baseTypes = [];
216+
foreach (BaseTypeSyntax baseType in inter.BaseList.Types)
217+
{
218+
if (ComTypes.Any(com => $"{com.Item1}.Interface" == baseType.Type.ToString()))
219+
{
220+
baseTypes.Add(SimpleBaseType(IdentifierName(baseType.Type.ToString())));
221+
}
222+
else
223+
{
224+
baseTypes.Add(baseType);
225+
}
226+
}
227+
228+
ret = inter.WithBaseList(BaseList(SeparatedList(baseTypes)));
229+
}
230+
231+
return ret;
232+
}
233+
234+
public override SyntaxNode? VisitCastExpression(CastExpressionSyntax node)
235+
{
236+
var castType = node.Type.ToString();
237+
238+
var expression = node.Expression.ToString();
239+
240+
if (expression != "Unsafe.AsPointer(ref this)")
241+
{
242+
return base.VisitCastExpression(node);
243+
}
244+
245+
for (int i = 0; i < ComTypes.Count; i++)
246+
{
247+
(string, bool) val = ComTypes[i];
248+
if (castType == $"{val.Item1}*" && val.Item2)
249+
{
250+
return ThisExpression();
251+
}
252+
}
253+
254+
return base.VisitCastExpression(node);
255+
}
256+
}
257+
}
258+
}

0 commit comments

Comments
 (0)