Skip to content

Commit c5e5d17

Browse files
Client request id visitor (Azure#50755)
* Special headers handling * Fix return parameter handling * revert * comment * Use service method params
1 parent 76b7396 commit c5e5d17

File tree

10 files changed

+246
-21
lines changed

10 files changed

+246
-21
lines changed

eng/Packages.Data.props

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -450,7 +450,7 @@
450450

451451
<PropertyGroup>
452452
<TestProxyVersion>1.0.0-dev.20250501.1</TestProxyVersion>
453-
<UnbrandedGeneratorVersion>1.0.0-alpha.20250619.2</UnbrandedGeneratorVersion>
453+
<UnbrandedGeneratorVersion>1.0.0-alpha.20250620.1</UnbrandedGeneratorVersion>
454454
<AzureGeneratorVersion>1.0.0-alpha.20250527.2</AzureGeneratorVersion>
455455
</PropertyGroup>
456456
</Project>

eng/packages/http-client-csharp/generator/Azure.Generator/src/AzureClientGenerator.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,5 +54,6 @@ protected override void Configure()
5454
AddVisitor(new PipelinePropertyVisitor());
5555
AddVisitor(new LroVisitor());
5656
AddVisitor(new ModelFactoryVisitor());
57+
AddVisitor(new SpecialHeadersVisitor());
5758
}
5859
}

eng/packages/http-client-csharp/generator/Azure.Generator/src/AzureTypeFactory.cs

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -86,18 +86,6 @@ public class AzureTypeFactory : ScmTypeFactory
8686
return base.CreateCSharpTypeCore(inputType);
8787
}
8888

89-
/// <inheritdoc/>
90-
protected override ParameterProvider? CreateParameterCore(InputParameter parameter)
91-
{
92-
// Skip the x-ms-client-request-id parameter as it is handled as part of the Azure.Core pipeline.
93-
if (parameter.NameInRequest == "x-ms-client-request-id")
94-
{
95-
return null;
96-
}
97-
98-
return base.CreateParameterCore(parameter);
99-
}
100-
10189
private CSharpType? CreateKnownPrimitiveType(InputPrimitiveType inputType)
10290
{
10391
InputPrimitiveType? primitiveType = inputType;

eng/packages/http-client-csharp/generator/Azure.Generator/src/Providers/Abstraction/HttpRequestProvider.cs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
using Microsoft.TypeSpec.Generator.Expressions;
77
using Microsoft.TypeSpec.Generator.Statements;
88
using System;
9-
using System.ClientModel.Primitives;
109
using System.Collections.Generic;
1110
using static Microsoft.TypeSpec.Generator.Snippets.Snippet;
1211

@@ -30,10 +29,10 @@ public override HttpRequestApi FromExpression(ValueExpression original)
3029
=> new HttpRequestProvider(original);
3130

3231
public override MethodBodyStatement SetHeaders(IReadOnlyList<ValueExpression> arguments)
33-
=> Original.Property(nameof(PipelineRequest.Headers)).Invoke(nameof(RequestHeaders.SetValue), arguments).Terminate();
32+
=> Original.Property(nameof(Request.Headers)).Invoke(nameof(RequestHeaders.SetValue), arguments).Terminate();
3433

3534
public override MethodBodyStatement SetMethod(string httpMethod)
36-
=> Original.Property(nameof(PipelineRequest.Method)).Assign(CreateRequestMethod(httpMethod)).Terminate();
35+
=> Original.Property(nameof(Request.Method)).Assign(CreateRequestMethod(httpMethod)).Terminate();
3736

3837
public override MethodBodyStatement SetUri(ValueExpression value)
3938
=> Original.Property("Uri").Assign(value).Terminate();
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
using System.ClientModel.Primitives;
5+
using Azure.Core;
6+
using Microsoft.TypeSpec.Generator.Expressions;
7+
using Microsoft.TypeSpec.Generator.Snippets;
8+
using Microsoft.TypeSpec.Generator.Statements;
9+
using static Microsoft.TypeSpec.Generator.Snippets.Snippet;
10+
11+
namespace Azure.Generator.Snippets
12+
{
13+
internal static class RequestSnippets
14+
{
15+
public static MethodBodyStatement SetHeaderValue(this ScopedApi<Request> request, string name, ValueExpression value)
16+
=> request.Property(nameof(PipelineRequest.Headers)).Invoke(nameof(RequestHeaders.SetValue), Literal(name), value).Terminate();
17+
}
18+
}
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
using System.Collections.Generic;
5+
using System.Linq;
6+
using Azure.Core;
7+
using Azure.Generator.Snippets;
8+
using Microsoft.TypeSpec.Generator.ClientModel;
9+
using Microsoft.TypeSpec.Generator.ClientModel.Providers;
10+
using Microsoft.TypeSpec.Generator.Expressions;
11+
using Microsoft.TypeSpec.Generator.Input;
12+
using Microsoft.TypeSpec.Generator.Statements;
13+
using static Microsoft.TypeSpec.Generator.Snippets.Snippet;
14+
15+
namespace Azure.Generator.Visitors
16+
{
17+
/// <summary>
18+
/// Visitor to handle removing special header parameters from service methods and adding them to the request. Note,
19+
/// "x-ms-client-request-id" is not added to the request as it is handled by the Azure.Core pipeline.
20+
/// </summary>
21+
internal class SpecialHeadersVisitor : ScmLibraryVisitor
22+
{
23+
protected override ScmMethodProviderCollection? Visit(
24+
InputServiceMethod serviceMethod,
25+
ClientProvider client,
26+
ScmMethodProviderCollection? methods)
27+
{
28+
var clientRequestIdParameter =
29+
serviceMethod.Parameters.FirstOrDefault(p => p.NameInRequest == "client-request-id");
30+
var returnClientRequestIdParameter =
31+
serviceMethod.Parameters.FirstOrDefault(p => p.NameInRequest == "return-client-request-id");
32+
var xMsClientRequestIdParameter =
33+
serviceMethod.Parameters.FirstOrDefault(p => p.NameInRequest == "x-ms-client-request-id");
34+
35+
if (clientRequestIdParameter != null || returnClientRequestIdParameter != null || xMsClientRequestIdParameter != null)
36+
{
37+
serviceMethod.Update(parameters: serviceMethod.Parameters
38+
.Where(p => p != clientRequestIdParameter && p != returnClientRequestIdParameter && p != xMsClientRequestIdParameter)
39+
.ToList());
40+
serviceMethod.Operation.Update(parameters: serviceMethod.Parameters);
41+
42+
// Create a new method collection with the updated service method
43+
methods = new ScmMethodProviderCollection(serviceMethod, client);
44+
45+
// Reset the rest client so that its methods are rebuilt.
46+
client.RestClient.Reset();
47+
var createRequestMethod = client.RestClient.GetCreateRequestMethod(serviceMethod.Operation);
48+
49+
var originalBodyStatements = createRequestMethod.BodyStatements!.ToList();
50+
51+
// Exclude the last statement which is the return statement. We will add it back later.
52+
var newStatements = new List<MethodBodyStatement>(originalBodyStatements[..^1]);
53+
54+
// Find the request variable
55+
VariableExpression? requestVariable = null;
56+
foreach (var statement in newStatements)
57+
{
58+
if (statement is ExpressionStatement
59+
{
60+
Expression: AssignmentExpression { Variable: DeclarationExpression declaration }
61+
})
62+
{
63+
var variable = declaration.Variable;
64+
if (variable.Type.Equals(typeof(Request)))
65+
{
66+
requestVariable = variable;
67+
}
68+
}
69+
}
70+
71+
if (clientRequestIdParameter != null)
72+
{
73+
// Set the client-request-id header
74+
newStatements.Add(requestVariable!.As<Request>().SetHeaderValue(
75+
clientRequestIdParameter.NameInRequest, requestVariable.Property(nameof(Request.ClientRequestId))));
76+
}
77+
78+
if (returnClientRequestIdParameter?.DefaultValue?.Value != null)
79+
{
80+
if (bool.TryParse(returnClientRequestIdParameter.DefaultValue.Value.ToString(), out bool value))
81+
{
82+
// Set the return-client-request-id header
83+
newStatements.Add(requestVariable!.As<Request>().SetHeaderValue(
84+
returnClientRequestIdParameter.NameInRequest,
85+
Literal(value.ToString().ToLowerInvariant())));
86+
}
87+
}
88+
89+
// Add the return statement back
90+
newStatements.Add(originalBodyStatements[^1]);
91+
92+
createRequestMethod.Update(bodyStatements: newStatements);
93+
}
94+
95+
return methods;
96+
}
97+
}
98+
}
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
using System.Collections.Generic;
5+
using System.Linq;
6+
using Azure.Generator.Tests.Common;
7+
using Azure.Generator.Tests.TestHelpers;
8+
using Azure.Generator.Visitors;
9+
using Microsoft.TypeSpec.Generator.ClientModel.Providers;
10+
using Microsoft.TypeSpec.Generator.Input;
11+
using Microsoft.TypeSpec.Generator.Primitives;
12+
using NUnit.Framework;
13+
14+
namespace Azure.Generator.Tests.Visitors
15+
{
16+
public class SpecialHeadersVisitorTests
17+
{
18+
[Test]
19+
public void RemovesSpecialHeaderParametersFromServiceMethods()
20+
{
21+
var visitor = new TestSpecialHeadersVisitor();
22+
List<InputParameter> parameters =
23+
[
24+
InputFactory.Parameter(
25+
"client-request-id",
26+
type: InputPrimitiveType.String,
27+
nameInRequest: "client-request-id",
28+
location: InputRequestLocation.Header),
29+
InputFactory.Parameter(
30+
"return-client-request-id",
31+
type: new InputLiteralType("return-client-request-id", "ns", InputPrimitiveType.Boolean, true),
32+
defaultValue: new InputConstant(true, InputPrimitiveType.Boolean),
33+
nameInRequest: "return-client-request-id",
34+
location: InputRequestLocation.Header),
35+
InputFactory.Parameter(
36+
"x-ms-client-request-id",
37+
type: InputPrimitiveType.String,
38+
nameInRequest: "x-ms-client-request-id",
39+
location: InputRequestLocation.Header),
40+
];
41+
var responseModel = InputFactory.Model("foo");
42+
var operation = InputFactory.Operation(
43+
"foo",
44+
parameters: parameters,
45+
responses: [InputFactory.OperationResponse(bodytype: responseModel)]);
46+
var serviceMethod = InputFactory.LongRunningServiceMethod(
47+
"foo",
48+
operation,
49+
parameters: parameters,
50+
response: InputFactory.ServiceMethodResponse(responseModel, ["result"]));
51+
var inputClient = InputFactory.Client("TestClient", methods: [serviceMethod]);
52+
MockHelpers.LoadMockPlugin(clients: () => [inputClient]);
53+
54+
var clientProvider = AzureClientGenerator.Instance.TypeFactory.CreateClient(inputClient);
55+
Assert.IsNotNull(clientProvider);
56+
57+
var responseModelProvider = AzureClientGenerator.Instance.TypeFactory.CreateModel(responseModel);
58+
Assert.IsNotNull(responseModelProvider);
59+
60+
var methodCollection = new ScmMethodProviderCollection(serviceMethod, clientProvider!);
61+
methodCollection = visitor.InvokeVisitServiceMethod(serviceMethod, clientProvider!, methodCollection);
62+
63+
foreach (var method in methodCollection!)
64+
{
65+
Assert.IsFalse(method.Signature.Parameters.Any(p => p.Name == "client-request-id"));
66+
Assert.IsFalse(method.Signature.Parameters.Any(p => p.Name == "return-client-request-id"));
67+
// This header should not be added in the rest method because it is added by the Core pipeline.
68+
Assert.IsFalse(method.Signature.Parameters.Any(p => p.Name == "x-ms-client-request-id"));
69+
}
70+
71+
var writer = new TypeProviderWriter(clientProvider!.RestClient);
72+
var file = writer.Write();
73+
74+
Assert.AreEqual(Helpers.GetExpectedFromFile(), file.Content);
75+
}
76+
77+
private class TestSpecialHeadersVisitor : SpecialHeadersVisitor
78+
{
79+
public ScmMethodProviderCollection? InvokeVisitServiceMethod(
80+
InputServiceMethod serviceMethod,
81+
ClientProvider client,
82+
ScmMethodProviderCollection? methodCollection)
83+
{
84+
return base.Visit(serviceMethod, client, methodCollection);
85+
}
86+
}
87+
}
88+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
// <auto-generated/>
5+
6+
#nullable disable
7+
8+
using Azure;
9+
using Azure.Core;
10+
11+
namespace Samples
12+
{
13+
/// <summary></summary>
14+
public partial class TestClient
15+
{
16+
private static global::Azure.Core.ResponseClassifier _pipelineMessageClassifier200;
17+
18+
private static global::Azure.Core.ResponseClassifier PipelineMessageClassifier200 => _pipelineMessageClassifier200 = new global::Azure.Core.StatusCodeClassifier(stackalloc ushort[] { 200 });
19+
20+
internal global::Azure.Core.HttpMessage CreateFooRequest(global::Azure.RequestContext context)
21+
{
22+
global::Azure.Core.HttpMessage message = Pipeline.CreateMessage(context, PipelineMessageClassifier200);
23+
global::Azure.Core.Request request = message.Request;
24+
request.Method = global::Azure.Core.RequestMethod.Get;
25+
global::Azure.Core.RawRequestUriBuilder uri = new global::Azure.Core.RawRequestUriBuilder();
26+
uri.Reset(_endpoint);
27+
request.Uri = uri;
28+
request.Headers.SetValue("client-request-id", request.ClientRequestId);
29+
request.Headers.SetValue("return-client-request-id", "true");
30+
return message;
31+
}
32+
}
33+
}

eng/packages/http-client-csharp/package-lock.json

Lines changed: 4 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

eng/packages/http-client-csharp/package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
"dist/generator/**"
3939
],
4040
"dependencies": {
41-
"@typespec/http-client-csharp": "1.0.0-alpha.20250619.2"
41+
"@typespec/http-client-csharp": "1.0.0-alpha.20250620.1"
4242
},
4343
"devDependencies": {
4444
"@azure-tools/azure-http-specs": "0.1.0-alpha.19",

0 commit comments

Comments
 (0)