Skip to content

Commit a674cdf

Browse files
Generate dependency injection extension methods (Azure#50351)
* Generate dependency injection extension methods * refactor * copilot typo fixes * fix tests * Add multiple clients test case
1 parent 18ef7f7 commit a674cdf

File tree

177 files changed

+2211
-166
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

177 files changed

+2211
-166
lines changed

eng/Packages.Data.props

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -440,7 +440,7 @@
440440

441441
<PropertyGroup>
442442
<TestProxyVersion>1.0.0-dev.20250501.1</TestProxyVersion>
443-
<UnbrandedGeneratorVersion>1.0.0-alpha.20250528.1</UnbrandedGeneratorVersion>
443+
<UnbrandedGeneratorVersion>1.0.0-alpha.20250602.1</UnbrandedGeneratorVersion>
444444
<AzureGeneratorVersion>1.0.0-alpha.20250523.1</AzureGeneratorVersion>
445445
</PropertyGroup>
446446
</Project>

eng/packages/http-client-csharp/generator/Azure.Generator/src/AzureClientGenerator.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,5 +53,6 @@ protected override void Configure()
5353
AddVisitor(new DistributedTracingVisitor());
5454
AddVisitor(new PipelinePropertyVisitor());
5555
AddVisitor(new LroVisitor());
56+
AddVisitor(new ModelFactoryVisitor());
5657
}
5758
}

eng/packages/http-client-csharp/generator/Azure.Generator/src/AzureOutputLibrary.cs

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
// Copyright (c) Microsoft Corporation. All rights reserved.
22
// Licensed under the MIT License.
33

4+
using System;
5+
using System.Linq;
46
using Azure.Generator.Providers;
57
using Microsoft.TypeSpec.Generator.ClientModel;
8+
using Microsoft.TypeSpec.Generator.ClientModel.Providers;
69
using Microsoft.TypeSpec.Generator.Providers;
710

811
namespace Azure.Generator
@@ -12,6 +15,15 @@ public class AzureOutputLibrary : ScmOutputLibrary
1215
{
1316
/// <inheritdoc/>
1417
protected override TypeProvider[] BuildTypeProviders()
15-
=> [.. base.BuildTypeProviders(), new RequestContextExtensionsDefinition()];
18+
{
19+
var types = base.BuildTypeProviders();
20+
var clients = types.OfType<ClientProvider>().ToList();
21+
return
22+
[
23+
.. types,
24+
new RequestContextExtensionsDefinition(),
25+
.. clients.Count > 0 ? [new ClientBuilderExtensionsDefinition(clients)] : Array.Empty<TypeProvider>()
26+
];
27+
}
1628
}
1729
}

eng/packages/http-client-csharp/generator/Azure.Generator/src/Primitives/NewAzureProjectScaffolding.cs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
using System;
88
using System.Collections.Generic;
99
using System.Linq;
10-
using System.Numerics;
1110

1211
namespace Azure.Generator.Primitives
1312
{
@@ -26,8 +25,8 @@ protected override string GetSourceProjectFileContent()
2625
{
2726
var builder = new CSharpProjectWriter()
2827
{
29-
Description = $"This is the {AzureClientGenerator.Instance.TypeFactory.PrimaryNamespace} client library for developing .NET applications with rich experience.",
30-
AssemblyTitle = $"SDK Code Generation {AzureClientGenerator.Instance.TypeFactory.PrimaryNamespace}",
28+
Description = $"This is the {AzureClientGenerator.Instance.Configuration.PackageName} client library for developing .NET applications with rich experience.",
29+
AssemblyTitle = $"SDK Code Generation {AzureClientGenerator.Instance.Configuration.PackageName}",
3130
Version = "1.0.0-beta.1",
3231
PackageTags = AzureClientGenerator.Instance.TypeFactory.PrimaryNamespace,
3332
GenerateDocumentationFile = true,
Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
using System;
5+
using System.Collections.Generic;
6+
using System.Diagnostics.CodeAnalysis;
7+
using System.IO;
8+
using System.Linq;
9+
using Azure.Core;
10+
using Azure.Core.Extensions;
11+
using Azure.Generator.Utilities;
12+
using Microsoft.TypeSpec.Generator.ClientModel.Providers;
13+
using Microsoft.TypeSpec.Generator.Expressions;
14+
using Microsoft.TypeSpec.Generator.Primitives;
15+
using Microsoft.TypeSpec.Generator.Providers;
16+
using Microsoft.TypeSpec.Generator.Statements;
17+
using static Microsoft.TypeSpec.Generator.Snippets.Snippet;
18+
19+
namespace Azure.Generator.Providers
20+
{
21+
/// <summary>
22+
/// Defines the client builder extension methods for integration with Microsoft.Extensions.Azure.
23+
/// </summary>
24+
internal class ClientBuilderExtensionsDefinition : TypeProvider
25+
{
26+
private readonly IEnumerable<ClientProvider> _clients;
27+
private readonly string _resourceProviderName;
28+
29+
public ClientBuilderExtensionsDefinition(IEnumerable<ClientProvider> clients)
30+
{
31+
_clients = clients;
32+
_resourceProviderName = TypeNameUtilities.GetResourceProviderName();
33+
AzureClientGenerator.Instance.AddTypeToKeep(this);
34+
}
35+
36+
protected override string BuildRelativeFilePath() => Path.Combine("src", "Generated", $"{Name}.cs");
37+
38+
protected override string BuildName() => $"{_resourceProviderName}ClientBuilderExtensions";
39+
40+
protected override string BuildNamespace() => "Microsoft.Extensions.Azure";
41+
42+
protected override TypeSignatureModifiers BuildDeclarationModifiers() =>
43+
TypeSignatureModifiers.Public | TypeSignatureModifiers.Static | TypeSignatureModifiers.Partial;
44+
45+
protected override FormattableString Description => $"Extension methods to add clients to <see cref=\"IAzureClientBuilder{{TClient, TOptions}}\"/>.";
46+
47+
protected override MethodProvider[] BuildMethods()
48+
{
49+
var methods = new List<MethodProvider>();
50+
foreach (var client in _clients)
51+
{
52+
if (client.ClientOptionsParameter == null)
53+
{
54+
continue;
55+
}
56+
57+
var tBuilder = typeof(BuilderType<,>).GetGenericArguments()[0];
58+
var tConfiguration = typeof(BuilderType<,>).GetGenericArguments()[1];
59+
var builderParameter = new ParameterProvider("builder", $"The builder to register with.", tBuilder);
60+
var configurationParameter = new ParameterProvider(
61+
"configuration",
62+
$"The configuration to use for the client.",
63+
tConfiguration);
64+
var methodName = $"Add{client.Name}";
65+
FormattableString methodDescription =
66+
$"Registers a <see cref=\"{client.Name}\"/> client with the specified <see cref=\"IAzureClientBuilder{{TClient, TOptions}}\"/>.";
67+
var methodModifiers = MethodSignatureModifiers.Public | MethodSignatureModifiers.Static |
68+
MethodSignatureModifiers.Extension;
69+
var methodReturnType = new CSharpType(typeof(IAzureClientBuilder<,>), client.Type,
70+
client.ClientOptionsParameter.Type);
71+
72+
foreach (var constructor in client.Constructors)
73+
{
74+
if (!constructor.Signature.Modifiers.HasFlag(MethodSignatureModifiers.Public))
75+
{
76+
continue;
77+
}
78+
79+
// only add overloads for the full constructors that include the client options parameter
80+
if (constructor.Signature.Parameters.LastOrDefault()?.Type.Equals(client.ClientOptionsParameter.Type) != true)
81+
{
82+
continue;
83+
}
84+
85+
// get the second to last parameter, which is the location of the auth credential parameter if there is one
86+
var authParameter = constructor.Signature.Parameters[^2];
87+
var isTokenCredential = authParameter?.Type.Equals(typeof(TokenCredential)) == true;
88+
var parameters = new List<ParameterProvider>(constructor.Signature.Parameters.Count + 1);
89+
parameters.Add(builderParameter);
90+
parameters.AddRange(isTokenCredential ? constructor.Signature.Parameters.SkipLast(2) : constructor.Signature.Parameters.SkipLast(1));
91+
var method = new MethodProvider(
92+
new MethodSignature(
93+
methodName,
94+
methodDescription,
95+
methodModifiers,
96+
methodReturnType,
97+
null,
98+
parameters,
99+
GenericArguments: [tBuilder],
100+
GenericParameterConstraints: [Where.Implements(tBuilder, isTokenCredential ?
101+
typeof(IAzureClientFactoryBuilderWithCredential) :
102+
typeof(IAzureClientFactoryBuilder))]),
103+
bodyStatements:
104+
Return(builderParameter.Invoke(
105+
"RegisterClientFactory",
106+
args: [BuildFuncExpression(client, constructor.Signature, isTokenCredential)],
107+
typeArgs: [client.Type, client.ClientOptionsParameter.Type])),
108+
enclosingType: this);
109+
methods.Add(method);
110+
}
111+
112+
// Add the configuration overload
113+
var requiresUnreferencedCodeMessage = Literal("Requires unreferenced code until we opt into EnableConfigurationBindingGenerator.");
114+
methods.Add(new MethodProvider(
115+
new MethodSignature(
116+
methodName,
117+
methodDescription,
118+
methodModifiers,
119+
methodReturnType,
120+
null,
121+
[builderParameter, configurationParameter],
122+
Attributes:
123+
[
124+
new AttributeStatement(
125+
typeof(RequiresUnreferencedCodeAttribute),
126+
requiresUnreferencedCodeMessage),
127+
new AttributeStatement(
128+
typeof(RequiresDynamicCodeAttribute),
129+
requiresUnreferencedCodeMessage)
130+
],
131+
GenericArguments: [tBuilder, tConfiguration],
132+
GenericParameterConstraints:
133+
[
134+
Where.Implements(
135+
tBuilder,
136+
new CSharpType(typeof(IAzureClientFactoryBuilderWithConfiguration<>), tConfiguration))
137+
]),
138+
bodyStatements:
139+
Return(builderParameter.Invoke(
140+
"RegisterClientFactory",
141+
args: [configurationParameter],
142+
typeArgs: [client.Type, client.ClientOptionsParameter.Type])),
143+
enclosingType: this));
144+
}
145+
146+
return [.. methods];
147+
}
148+
149+
private static FuncExpression BuildFuncExpression(ClientProvider client, ConstructorSignature constructorSignature, bool isTokenCredential)
150+
{
151+
var options = constructorSignature.Parameters.Last();
152+
var token = new VariableExpression(typeof(TokenCredential), "credential");
153+
154+
ValueExpression[] ctorArgs = isTokenCredential ?
155+
[.. constructorSignature.Parameters.SkipLast(2), token, options] :
156+
[.. constructorSignature.Parameters];
157+
158+
return new FuncExpression(
159+
isTokenCredential ? [options.AsExpression().Declaration, token.Declaration] : [options.AsExpression().Declaration],
160+
New.Instance(client.Type, ctorArgs));
161+
}
162+
163+
private class BuilderType<TBuilder, TConfiguration>
164+
{
165+
}
166+
}
167+
}
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
using System.Linq;
5+
6+
namespace Azure.Generator.Utilities
7+
{
8+
internal static class TypeNameUtilities
9+
{
10+
private const string AzurePackageNamespacePrefix = "Azure.";
11+
private const string AzureMgmtPackageNamespacePrefix = "Azure.ResourceManager.";
12+
13+
/// <summary>
14+
/// Returns the name of the RP from the package name using the following:
15+
/// If the package name starts with `Azure.ResourceManager`, returns every segment concatenating after the `Azure.ResourceManager` prefix.
16+
/// If the package name starts with `Azure`, returns every segment concatenating together after the `Azure` prefix.
17+
/// Returns the package name as the RP name if nothing matches.
18+
/// </summary>
19+
public static string GetResourceProviderName()
20+
{
21+
var packageName = AzureClientGenerator.Instance.Configuration.PackageName;
22+
var segments = packageName.Split('.');
23+
if (packageName.StartsWith(AzurePackageNamespacePrefix))
24+
{
25+
if (packageName.StartsWith(AzureMgmtPackageNamespacePrefix))
26+
{
27+
return string.Join("", segments.Skip(2)); // skips "Azure" and "ResourceManager"
28+
}
29+
30+
return string.Join("", segments.Skip(1));
31+
}
32+
return string.Join("", segments);
33+
}
34+
}
35+
}
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
using Azure.Generator.Utilities;
5+
using Microsoft.TypeSpec.Generator.ClientModel;
6+
using Microsoft.TypeSpec.Generator.Providers;
7+
8+
namespace Azure.Generator.Visitors
9+
{
10+
internal class ModelFactoryVisitor : ScmLibraryVisitor
11+
{
12+
protected override TypeProvider? VisitType(TypeProvider type)
13+
{
14+
if (type is ModelFactoryProvider && type.CustomCodeView == null)
15+
{
16+
type.Type.Update(name: $"{TypeNameUtilities.GetResourceProviderName()}ModelFactory");
17+
}
18+
19+
return type;
20+
}
21+
}
22+
}

eng/packages/http-client-csharp/generator/Azure.Generator/src/Visitors/NamespaceVisitor.cs

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,7 @@ internal class NamespaceVisitor : ScmLibraryVisitor
4646
{
4747
UpdateModelsNamespace(type);
4848
}
49-
else
50-
{
51-
type.Type.Update(@namespace: AzureClientGenerator.Instance.TypeFactory.PrimaryNamespace);
52-
}
49+
5350
return type;
5451
}
5552

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
using System;
5+
using System.Linq;
6+
using Azure.Generator.Providers;
7+
using Azure.Generator.Tests.Common;
8+
using Azure.Generator.Tests.TestHelpers;
9+
using Microsoft.TypeSpec.Generator.Input;
10+
using Microsoft.TypeSpec.Generator.Primitives;
11+
using NUnit.Framework;
12+
13+
namespace Azure.Generator.Tests.Providers.ClientBuilderExtensionsDefinitions
14+
{
15+
public class ClientBuilderExtensionsTests
16+
{
17+
[Test]
18+
public void AddsClientExtensionForApiKeyAuth()
19+
{
20+
var client = InputFactory.Client("TestClient", "Samples", "");
21+
var plugin = MockHelpers.LoadMockPlugin(
22+
apiKeyAuth: () => new InputApiKeyAuth("mock", null),
23+
clients: () => [client]);
24+
25+
var builderExtensions = plugin.Object.OutputLibrary.TypeProviders
26+
.OfType<ClientBuilderExtensionsDefinition>().SingleOrDefault();
27+
28+
Assert.IsNotNull(builderExtensions);
29+
var writer = new TypeProviderWriter(builderExtensions!);
30+
var file = writer.Write();
31+
Assert.AreEqual(Helpers.GetExpectedFromFile(), file.Content);
32+
}
33+
34+
[Test]
35+
public void AddsClientExtensionForOAuth()
36+
{
37+
var client = InputFactory.Client("TestClient", "Samples", "");
38+
var plugin = MockHelpers.LoadMockPlugin(
39+
oauth2Auth: ()=> new InputOAuth2Auth(["mock"]),
40+
clients: () => [client]);
41+
42+
var builderExtensions = plugin.Object.OutputLibrary.TypeProviders
43+
.OfType<ClientBuilderExtensionsDefinition>().SingleOrDefault();
44+
45+
Assert.IsNotNull(builderExtensions);
46+
var writer = new TypeProviderWriter(builderExtensions!);
47+
var file = writer.Write();
48+
Assert.AreEqual(Helpers.GetExpectedFromFile(), file.Content);
49+
}
50+
51+
[Test]
52+
public void AddsClientExtensionForEachAuthMethod()
53+
{
54+
var client = InputFactory.Client("TestClient", "Samples", "");
55+
var plugin = MockHelpers.LoadMockPlugin(
56+
apiKeyAuth: () => new InputApiKeyAuth("mock", null),
57+
oauth2Auth: ()=> new InputOAuth2Auth(["mock"]),
58+
clients: () => [client]);
59+
60+
var builderExtensions = plugin.Object.OutputLibrary.TypeProviders
61+
.OfType<ClientBuilderExtensionsDefinition>().SingleOrDefault();
62+
63+
Assert.IsNotNull(builderExtensions);
64+
var writer = new TypeProviderWriter(builderExtensions!);
65+
var file = writer.Write();
66+
Assert.AreEqual(Helpers.GetExpectedFromFile(), file.Content);
67+
}
68+
69+
[Test]
70+
public void AddsClientExtensionForEachAuthMethodMultipleClients()
71+
{
72+
var client1 = InputFactory.Client("TestClient", "Samples", "");
73+
var client2 = InputFactory.Client("TestClient2", "Samples", "");
74+
var plugin = MockHelpers.LoadMockPlugin(
75+
apiKeyAuth: () => new InputApiKeyAuth("mock", null),
76+
oauth2Auth: ()=> new InputOAuth2Auth(["mock"]),
77+
clients: () => [client1, client2]);
78+
79+
var builderExtensions = plugin.Object.OutputLibrary.TypeProviders
80+
.OfType<ClientBuilderExtensionsDefinition>().SingleOrDefault();
81+
82+
Assert.IsNotNull(builderExtensions);
83+
var writer = new TypeProviderWriter(builderExtensions!);
84+
var file = writer.Write();
85+
Assert.AreEqual(Helpers.GetExpectedFromFile(), file.Content);
86+
}
87+
}
88+
}

0 commit comments

Comments
 (0)