17
17
using System . Xml . Linq ;
18
18
using Microsoft . Extensions . Logging ;
19
19
using Silk . NET . SilkTouch . Clang ;
20
+ using Microsoft . Extensions . Options ;
20
21
21
22
namespace Silk . NET . SilkTouch . Mods
22
23
{
23
24
/// <summary>
24
25
/// A mod to modify COM interface structs into opaque structs that act like ComPtr objects
25
26
/// </summary>
26
27
/// <param name="logger">The logger to use.</param>
27
- [ ModConfiguration < Config > ]
28
- public class TransformCOM ( ILogger < TransformCOM > logger ) : Mod
28
+ /// <param name="config">The configuration to use.</param>
29
+ [ ModConfiguration < Configuration > ]
30
+ public class TransformCOM (
31
+ ILogger < TransformCOM > logger ,
32
+ IOptionsSnapshot < TransformCOM . Configuration > config ) : Mod
29
33
{
30
34
/// <summary>
31
35
/// The configuration for the <see cref="TransformCOM"/> mod.
32
36
/// </summary>
33
- public class Config
37
+ public class Configuration
34
38
{
35
-
39
+ /// <summary>
40
+ /// The base type to consider the base of the com tree
41
+ /// Usually this is IUknown.Interface
42
+ /// </summary>
43
+ public string ? BaseType { get ; init ; }
36
44
}
37
45
38
46
/// <inheritdoc />
39
47
public override async Task ExecuteAsync ( IModContext ctx , CancellationToken ct = default )
40
48
{
41
49
await base . ExecuteAsync ( ctx , ct ) ;
42
50
43
- var firstPass = new TypeDiscoverer ( ) ;
51
+ var firstPass = new TypeDiscoverer ( config . Value . BaseType ?? "IUnknown.Interface" ) ;
44
52
var proj = ctx . SourceProject ;
45
53
if ( proj is null )
46
54
{
@@ -101,7 +109,11 @@ public override async Task ExecuteAsync(IModContext ctx, CancellationToken ct =
101
109
. OfType < MemberAccessExpressionSyntax > ( )
102
110
. Where ( m => m . Expression is PrefixUnaryExpressionSyntax pues && pues . IsKind ( SyntaxKind . PointerMemberAccessExpression ) ) ;
103
111
104
- if ( memberAccesses . Count ( ) == 0 )
112
+ var nullAssignments = root . DescendantNodes ( )
113
+ . OfType < AssignmentExpressionSyntax > ( )
114
+ . Where ( aes => aes . Right . IsKind ( SyntaxKind . NullLiteralExpression ) ) ;
115
+
116
+ if ( memberAccesses . Count ( ) == 0 && nullAssignments . Count ( ) == 0 )
105
117
{
106
118
continue ;
107
119
}
@@ -120,6 +132,16 @@ public override async Task ExecuteAsync(IModContext ctx, CancellationToken ct =
120
132
}
121
133
}
122
134
135
+ foreach ( var nullAssignment in nullAssignments )
136
+ {
137
+ var typeInfo = semanticModel . GetTypeInfo ( nullAssignment . Left ) ;
138
+ // Check if the type is a ComType
139
+ if ( typeInfo . Type != null && firstPass . FoundCOMTypes . Any ( type => type . Item1 == typeInfo . Type . ToDisplayString ( ) ) )
140
+ {
141
+ var newNullAssignment = SyntaxFactory . AssignmentExpression ( nullAssignment . Kind ( ) , nullAssignment . Left , LiteralExpression ( SyntaxKind . DefaultExpression ) ) ;
142
+ editor . ReplaceNode ( nullAssignment , newNullAssignment ) ;
143
+ }
144
+ }
123
145
proj = doc . Project ;
124
146
125
147
logger . LogInformation ( "COM Object Usage Update for {0} Complete ({1}/{2})" , doc . Name , index , count ) ;
@@ -128,7 +150,7 @@ public override async Task ExecuteAsync(IModContext ctx, CancellationToken ct =
128
150
ctx . SourceProject = proj ;
129
151
}
130
152
131
- class TypeDiscoverer : CSharpSyntaxWalker
153
+ class TypeDiscoverer ( string BaseType ) : CSharpSyntaxWalker
132
154
{
133
155
private Dictionary < string , List < ( string , bool ) > > _interfaceParenting = new Dictionary < string , List < ( string , bool ) > > ( ) ;
134
156
@@ -179,7 +201,7 @@ public override void VisitInterfaceDeclaration(InterfaceDeclarationSyntax node)
179
201
180
202
private void CheckBases ( ( string , bool ) className , BaseListSyntax bases )
181
203
{
182
- if ( bases . Types . Any ( baseType => baseType . Type . ToString ( ) == "IUnknown.Interface" || FoundCOMTypes . Any ( val => val . Item1 == $ "{ baseType . Type } ") ) )
204
+ if ( bases . Types . Any ( baseType => baseType . Type . ToString ( ) == BaseType || FoundCOMTypes . Any ( val => val . Item1 == $ "{ baseType . Type } ") ) )
183
205
{
184
206
COMTypeValidated ( className ) ;
185
207
return ;
@@ -303,6 +325,20 @@ class Rewriter(List<(string, bool)> ComTypes)
303
325
304
326
return base . VisitCastExpression ( node ) ;
305
327
}
328
+
329
+ public override SyntaxNode ? VisitParameter ( ParameterSyntax node )
330
+ {
331
+ var visited = base . VisitParameter ( node ) ;
332
+ var visitedParameter = visited as ParameterSyntax ;
333
+ if ( visitedParameter is null || visitedParameter . Default is null || visitedParameter . Type is null ||
334
+ visitedParameter . Default . Value . IsKind ( SyntaxKind . NullLiteralExpression ) ||
335
+ ! ComTypes . Any ( com => visitedParameter . Type ? . ToString ( ) == com . Item1 ) )
336
+ {
337
+ return visited ;
338
+ }
339
+
340
+ return Parameter ( visitedParameter . AttributeLists , visitedParameter . Modifiers , visitedParameter . Type , visitedParameter . Identifier , EqualsValueClause ( LiteralExpression ( SyntaxKind . DefaultExpression ) ) ) ;
341
+ }
306
342
}
307
343
}
308
344
}
0 commit comments