@@ -51,14 +51,44 @@ record ServiceRegistration(int Lifetime, TypeSyntax? AssignableTo, string? FullN
5151
5252 public void Initialize ( IncrementalGeneratorInitializationContext context )
5353 {
54- var types = context . CompilationProvider . SelectMany ( ( x , c ) =>
54+ var compilation = context . CompilationProvider . Select ( ( compilation , _ ) =>
5555 {
56- var visitor = new TypesVisitor ( s => x . IsSymbolAccessible ( s ) , c ) ;
57- x . GlobalNamespace . Accept ( visitor ) ;
56+ // Add missing types as needed since we depend on the static generator potentially and can't
57+ // rely on its sources being added.
58+ var parse = ( CSharpParseOptions ) compilation . SyntaxTrees . FirstOrDefault ( ) . Options ;
59+
60+ if ( compilation . GetTypeByMetadataName ( "Microsoft.Extensions.DependencyInjection.AddServicesNoReflectionExtension" ) is null )
61+ {
62+ compilation = compilation . AddSyntaxTrees (
63+ CSharpSyntaxTree . ParseText ( ThisAssembly . Resources . AddServicesNoReflectionExtension . Text , parse ) ) ;
64+ }
65+
66+ if ( compilation . GetTypeByMetadataName ( "Microsoft.Extensions.DependencyInjection.ServiceAttribute" ) is null )
67+ {
68+ compilation = compilation . AddSyntaxTrees (
69+ CSharpSyntaxTree . ParseText ( ThisAssembly . Resources . ServiceAttribute . Text , parse ) ,
70+ CSharpSyntaxTree . ParseText ( ThisAssembly . Resources . ServiceAttribute_1 . Text , parse ) ) ;
71+ }
72+
73+ return compilation ;
74+ } ) ;
75+
76+ var types = compilation . Combine ( context . AnalyzerConfigOptionsProvider ) . SelectMany ( ( x , c ) =>
77+ {
78+ ( var compilation , var options ) = x ;
79+
80+ // We won't add any registrations in this case.
81+ if ( ! options . GlobalOptions . TryGetValue ( "build_property.AddServicesExtension" , out var value ) ||
82+ ! bool . TryParse ( value , out var addServices ) || ! addServices )
83+ return Enumerable . Empty < INamedTypeSymbol > ( ) ;
84+
85+ var visitor = new TypesVisitor ( s => compilation . IsSymbolAccessible ( s ) , c ) ;
86+ compilation . GlobalNamespace . Accept ( visitor ) ;
87+
5888 // Also visit aliased references, which will not become part of the global:: namespace
59- foreach ( var symbol in x . References
89+ foreach ( var symbol in compilation . References
6090 . Where ( r => ! r . Properties . Aliases . IsDefaultOrEmpty )
61- . Select ( r => x . GetAssemblyOrModuleSymbol ( r ) ) )
91+ . Select ( r => compilation . GetAssemblyOrModuleSymbol ( r ) ) )
6292 {
6393 symbol ? . Accept ( visitor ) ;
6494 }
@@ -152,8 +182,6 @@ bool IsExport(AttributeData attr)
152182 } )
153183 . Where ( x => x != null ) ;
154184
155- var options = context . AnalyzerConfigOptionsProvider . Combine ( context . CompilationProvider ) ;
156-
157185 // Only requisite is that we define Scoped = 0, Singleton = 1 and Transient = 2.
158186 // This matches https://learn.microsoft.com/en-us/dotnet/api/microsoft.extensions.dependencyinjection.servicelifetime?view=dotnet-plat-ext-6.0#fields
159187
@@ -164,11 +192,18 @@ bool IsExport(AttributeData attr)
164192 . CreateSyntaxProvider (
165193 predicate : static ( node , _ ) => node is InvocationExpressionSyntax invocation && invocation . ArgumentList . Arguments . Count != 0 && GetInvokedMethodName ( invocation ) == nameof ( AddServicesNoReflectionExtension . AddServices ) ,
166194 transform : static ( ctx , _ ) => GetServiceRegistration ( ( InvocationExpressionSyntax ) ctx . Node , ctx . SemanticModel ) )
167- . Where ( details => details != null )
195+ . Combine ( context . AnalyzerConfigOptionsProvider )
196+ . Where ( x =>
197+ {
198+ ( var registration , var options ) = x ;
199+ return options . GlobalOptions . TryGetValue ( "build_property.AddServicesExtension" , out var value ) &&
200+ bool . TryParse ( value , out var addServices ) && addServices && registration is not null ;
201+ } )
202+ . Select ( ( x , _ ) => x . Left )
168203 . Collect ( ) ;
169204
170205 // Project matching service types to register with the given lifetime.
171- var conventionServices = types . Combine ( methodInvocations . Combine ( context . CompilationProvider ) ) . SelectMany ( ( pair , cancellationToken ) =>
206+ var conventionServices = types . Combine ( methodInvocations . Combine ( compilation ) ) . SelectMany ( ( pair , cancellationToken ) =>
172207 {
173208 var ( typeSymbol , ( registrations , compilation ) ) = pair ;
174209 var results = ImmutableArray . CreateBuilder < ServiceSymbol > ( ) ;
@@ -196,33 +231,33 @@ bool IsExport(AttributeData attr)
196231 . SelectMany ( ( tuple , _ ) => ImmutableArray . CreateRange ( [ tuple . Item1 , tuple . Item2 ] ) )
197232 . SelectMany ( ( items , _ ) => items . Distinct ( ) . ToImmutableArray ( ) ) ;
198233
199- RegisterServicesOutput ( context , finalServices , options ) ;
234+ RegisterServicesOutput ( context , finalServices , compilation ) ;
200235 }
201236
202- void RegisterServicesOutput ( IncrementalGeneratorInitializationContext context , IncrementalValuesProvider < ServiceSymbol > services , IncrementalValueProvider < ( AnalyzerConfigOptionsProvider Left , Compilation Right ) > options )
237+ void RegisterServicesOutput ( IncrementalGeneratorInitializationContext context , IncrementalValuesProvider < ServiceSymbol > services , IncrementalValueProvider < Compilation > compilation )
203238 {
204239 context . RegisterImplementationSourceOutput (
205- services . Where ( x => x ! . Lifetime == 0 && x . Key is null ) . Select ( ( x , _ ) => new KeyedService ( x ! . Type , null ) ) . Collect ( ) . Combine ( options ) ,
240+ services . Where ( x => x ! . Lifetime == 0 && x . Key is null ) . Select ( ( x , _ ) => new KeyedService ( x ! . Type , null ) ) . Collect ( ) . Combine ( compilation ) ,
206241 ( ctx , data ) => AddPartial ( "AddSingleton" , ctx , data ) ) ;
207242
208243 context . RegisterImplementationSourceOutput (
209- services . Where ( x => x ! . Lifetime == 1 && x . Key is null ) . Select ( ( x , _ ) => new KeyedService ( x ! . Type , null ) ) . Collect ( ) . Combine ( options ) ,
244+ services . Where ( x => x ! . Lifetime == 1 && x . Key is null ) . Select ( ( x , _ ) => new KeyedService ( x ! . Type , null ) ) . Collect ( ) . Combine ( compilation ) ,
210245 ( ctx , data ) => AddPartial ( "AddScoped" , ctx , data ) ) ;
211246
212247 context . RegisterImplementationSourceOutput (
213- services . Where ( x => x ! . Lifetime == 2 && x . Key is null ) . Select ( ( x , _ ) => new KeyedService ( x ! . Type , null ) ) . Collect ( ) . Combine ( options ) ,
248+ services . Where ( x => x ! . Lifetime == 2 && x . Key is null ) . Select ( ( x , _ ) => new KeyedService ( x ! . Type , null ) ) . Collect ( ) . Combine ( compilation ) ,
214249 ( ctx , data ) => AddPartial ( "AddTransient" , ctx , data ) ) ;
215250
216251 context . RegisterImplementationSourceOutput (
217- services . Where ( x => x ! . Lifetime == 0 && x . Key is not null ) . Select ( ( x , _ ) => new KeyedService ( x ! . Type , x . Key ! ) ) . Collect ( ) . Combine ( options ) ,
252+ services . Where ( x => x ! . Lifetime == 0 && x . Key is not null ) . Select ( ( x , _ ) => new KeyedService ( x ! . Type , x . Key ! ) ) . Collect ( ) . Combine ( compilation ) ,
218253 ( ctx , data ) => AddPartial ( "AddKeyedSingleton" , ctx , data ) ) ;
219254
220255 context . RegisterImplementationSourceOutput (
221- services . Where ( x => x ! . Lifetime == 1 && x . Key is not null ) . Select ( ( x , _ ) => new KeyedService ( x ! . Type , x . Key ! ) ) . Collect ( ) . Combine ( options ) ,
256+ services . Where ( x => x ! . Lifetime == 1 && x . Key is not null ) . Select ( ( x , _ ) => new KeyedService ( x ! . Type , x . Key ! ) ) . Collect ( ) . Combine ( compilation ) ,
222257 ( ctx , data ) => AddPartial ( "AddKeyedScoped" , ctx , data ) ) ;
223258
224259 context . RegisterImplementationSourceOutput (
225- services . Where ( x => x ! . Lifetime == 2 && x . Key is not null ) . Select ( ( x , _ ) => new KeyedService ( x ! . Type , x . Key ! ) ) . Collect ( ) . Combine ( options ) ,
260+ services . Where ( x => x ! . Lifetime == 2 && x . Key is not null ) . Select ( ( x , _ ) => new KeyedService ( x ! . Type , x . Key ! ) ) . Collect ( ) . Combine ( compilation ) ,
226261 ( ctx , data ) => AddPartial ( "AddKeyedTransient" , ctx , data ) ) ;
227262 }
228263
@@ -240,13 +275,22 @@ void RegisterServicesOutput(IncrementalGeneratorInitializationContext context, I
240275
241276 var options = ( CSharpParseOptions ) invocation . SyntaxTree . Options ;
242277
243- // NOTE: we need to add the sources that *another* generator emits (the static files)
244- // because otherwise all invocations will basically have no semantic info since it wasn't there
245- // when the source generations invocations started.
246- var compilation = semanticModel . Compilation . AddSyntaxTrees (
247- CSharpSyntaxTree . ParseText ( ThisAssembly . Resources . ServiceAttribute . Text , options ) ,
248- CSharpSyntaxTree . ParseText ( ThisAssembly . Resources . ServiceAttribute_1 . Text , options ) ,
249- CSharpSyntaxTree . ParseText ( ThisAssembly . Resources . AddServicesNoReflectionExtension . Text , options ) ) ;
278+ var compilation = semanticModel . Compilation ;
279+
280+ // Add missing types as needed since we depend on the static generator potentially and can't
281+ // rely on its sources being added.
282+ if ( compilation . GetTypeByMetadataName ( "Microsoft.Extensions.DependencyInjection.AddServicesNoReflectionExtension" ) is null )
283+ {
284+ compilation = compilation . AddSyntaxTrees (
285+ CSharpSyntaxTree . ParseText ( ThisAssembly . Resources . AddServicesNoReflectionExtension . Text , options ) ) ;
286+ }
287+
288+ if ( compilation . GetTypeByMetadataName ( "Microsoft.Extensions.DependencyInjection.ServiceAttribute" ) is null )
289+ {
290+ compilation = compilation . AddSyntaxTrees (
291+ CSharpSyntaxTree . ParseText ( ThisAssembly . Resources . ServiceAttribute . Text , options ) ,
292+ CSharpSyntaxTree . ParseText ( ThisAssembly . Resources . ServiceAttribute_1 . Text , options ) ) ;
293+ }
250294
251295 var model = compilation . GetSemanticModel ( invocation . SyntaxTree ) ;
252296
@@ -292,46 +336,37 @@ void RegisterServicesOutput(IncrementalGeneratorInitializationContext context, I
292336 return null ;
293337 }
294338
295- void AddPartial ( string methodName , SourceProductionContext ctx , ( ImmutableArray < KeyedService > Types , ( AnalyzerConfigOptionsProvider Config , Compilation Compilation ) Options ) data )
339+ void AddPartial ( string methodName , SourceProductionContext ctx , ( ImmutableArray < KeyedService > Types , Compilation Compilation ) data )
296340 {
297341 var builder = new StringBuilder ( )
298342 . AppendLine ( "// <auto-generated />" ) ;
299343
300- var rootNs = data . Options . Config . GlobalOptions . TryGetValue ( "build_property.AddServicesNamespace" , out var value ) && ! string . IsNullOrEmpty ( value )
301- ? value
302- : "Microsoft.Extensions.DependencyInjection" ;
303-
304- var className = data . Options . Config . GlobalOptions . TryGetValue ( "build_property.AddServicesClassName" , out value ) && ! string . IsNullOrEmpty ( value ) ?
305- value : "AddServicesNoReflectionExtension" ;
306-
307- foreach ( var alias in data . Options . Compilation . References . SelectMany ( r => r . Properties . Aliases ) )
344+ foreach ( var alias in data . Compilation . References . SelectMany ( r => r . Properties . Aliases ) )
308345 {
309346 builder . AppendLine ( $ "extern alias { alias } ;") ;
310347 }
311348
312349 builder . AppendLine (
313350 $$ """
314- #if DDI_ADDSERVICES
315351 using Microsoft.Extensions.DependencyInjection.Extensions;
316352 using System;
317353
318- namespace {{ rootNs }}
354+ namespace Microsoft.Extensions.DependencyInjection
319355 {
320- static partial class {{ className }}
356+ static partial class AddServicesNoReflectionExtension
321357 {
322358 static partial void {{ methodName }} Services(IServiceCollection services)
323359 {
324360 """ ) ;
325361
326- AddServices ( data . Types . Where ( x => x . Key is null ) . Select ( x => x . Type ) , data . Options . Compilation , methodName , builder ) ;
327- AddKeyedServices ( data . Types . Where ( x => x . Key is not null ) , data . Options . Compilation , methodName , builder ) ;
362+ AddServices ( data . Types . Where ( x => x . Key is null ) . Select ( x => x . Type ) , data . Compilation , methodName , builder ) ;
363+ AddKeyedServices ( data . Types . Where ( x => x . Key is not null ) , data . Compilation , methodName , builder ) ;
328364
329365 builder . AppendLine (
330366 """
331367 }
332368 }
333369 }
334- #endif
335370 """ ) ;
336371
337372 ctx . AddSource ( methodName + ".g" , builder . ToString ( ) . Replace ( "\r \n " , "\n " ) . Replace ( "\n " , Environment . NewLine ) ) ;
0 commit comments