Skip to content

Commit 7977507

Browse files
committed
fixup metadata query param
1 parent 7970bdf commit 7977507

File tree

6 files changed

+714
-11
lines changed

6 files changed

+714
-11
lines changed

codegen/generator/src/OpenAILibraryGenerator.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ protected override void Configure()
3939
AddVisitor(new ExperimentalAttributeVisitor());
4040
AddVisitor(new ModelDirectoryVisitor());
4141
AddVisitor(new PaginationVisitor());
42+
AddVisitor(new MetadataQueryParamVisitor());
4243
}
4344
}
4445
}
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Linq;
4+
using Microsoft.TypeSpec.Generator.ClientModel;
5+
using Microsoft.TypeSpec.Generator.Expressions;
6+
using Microsoft.TypeSpec.Generator.Primitives;
7+
using Microsoft.TypeSpec.Generator.Providers;
8+
using Microsoft.TypeSpec.Generator.Snippets;
9+
using Microsoft.TypeSpec.Generator.Statements;
10+
using static OpenAILibraryPlugin.Visitors.VisitorHelpers;
11+
12+
namespace OpenAILibraryPlugin.Visitors;
13+
14+
/// <summary>
15+
/// This visitor modifies GetRawPagesAsync methods to consider HasMore in addition to LastId when deciding whether to continue pagination.
16+
/// It also replaces specific parameters with an options type for pagination methods.
17+
/// </summary>
18+
public class MetadataQueryParamVisitor : ScmLibraryVisitor
19+
{
20+
21+
private static readonly string[] _chatParamsToReplace = ["after", "before", "limit", "order", "model", "metadata"];
22+
private static readonly Dictionary<string, string> _paramReplacementMap = new()
23+
{
24+
{ "after", "AfterId" },
25+
{ "before", "LastId" },
26+
{ "limit", "PageSizeLimit" },
27+
{ "order", "Order" },
28+
{ "model", "Model" },
29+
{ "metadata", "Metadata" }
30+
};
31+
private static readonly Dictionary<string, (string ReturnType, string OptionsType, string[] ParamsToReplace)> _optionsReplacements = new()
32+
{
33+
{
34+
"GetChatCompletions",
35+
("ChatCompletion", "ChatCompletionCollectionOptions", _chatParamsToReplace)
36+
},
37+
{
38+
"GetChatCompletionsAsync",
39+
("ChatCompletion", "ChatCompletionCollectionOptions", _chatParamsToReplace)
40+
}
41+
};
42+
43+
/// <summary>
44+
/// Visits Create*Request methods to modify how metadata query parameters are handled.
45+
/// It replaces the following statements:
46+
/// <code>
47+
/// List<object> list = new List<object>();
48+
/// foreach (var @param in metadata)
49+
/// {
50+
/// uri.AppendQuery($"metadata[{@param.Key}]", @param.Value, true);
51+
/// list.Add(@param.Key);
52+
/// list.Add(@param.Value);
53+
/// }
54+
/// uri.AppendQueryDelimited("metadata", list, ",", null, true);
55+
/// </code>
56+
/// with:
57+
/// <code>
58+
/// foreach (var @param in metadata)
59+
/// {
60+
/// uri.AppendQuery($"metadata[{@param.Key}]", @param.Value, true);
61+
/// }
62+
/// </summary>
63+
/// <param name="method"></param>
64+
/// <returns></returns>
65+
protected override MethodProvider? VisitMethod(MethodProvider method)
66+
{
67+
// Check if the method is one of the Create*Request methods and has a signature that takes a metadata parameter like IDictionary<string, string> metadata
68+
if (method.Signature.Name.StartsWith("Create") && method.Signature.Name.EndsWith("Request") &&
69+
method.Signature.Parameters.Any(p => p.Type.IsDictionary && p.Name == "metadata"))
70+
{
71+
ValueExpression? uri = null;
72+
var statements = method.BodyStatements?.ToList() ?? new List<MethodBodyStatement>();
73+
VisitExplodedMethodBodyStatements(
74+
statements!,
75+
statement =>
76+
{
77+
// Check if the statement is an assignment to a variable named "uri"
78+
// Capture it if so
79+
if (statement is ExpressionStatement expressionStatement &&
80+
expressionStatement.Expression is AssignmentExpression assignmentExpression &&
81+
assignmentExpression.Variable is DeclarationExpression declarationExpression &&
82+
declarationExpression.Variable is VariableExpression variableExpression &&
83+
variableExpression.Declaration.RequestedName == "uri")
84+
{
85+
uri = variableExpression;
86+
}
87+
// Try to remove the unnecessary list declaration
88+
if (statement is ExpressionStatement expressionStatement2 &&
89+
expressionStatement2.Expression is AssignmentExpression assignmentExpression2 &&
90+
assignmentExpression2.Variable is DeclarationExpression declarationExpression2 &&
91+
declarationExpression2.Variable is VariableExpression variableExpression2 &&
92+
variableExpression2.Declaration.RequestedName == "list" &&
93+
variableExpression2.Type.IsCollection && variableExpression2.Type.IsGenericType)
94+
{
95+
// Remove the list declaration
96+
return new SingleLineCommentStatement("Plugin customization: remove unnecessary list declaration");
97+
}
98+
99+
if (uri is not null &&
100+
statement is ForEachStatement foreachStatement &&
101+
foreachStatement.Enumerable is DictionaryExpression dictionaryExpression &&
102+
dictionaryExpression.Original is VariableExpression variable &&
103+
variable.Declaration.RequestedName == "metadata")
104+
{
105+
var formatString = new FormattableStringExpression("metadata[{0}]", [foreachStatement.ItemVariable.Property("Key")]);
106+
var appendQueryStatement = uri.Invoke("AppendQuery", [formatString, foreachStatement.ItemVariable.Property("Value"), new KeywordExpression("true", null)]);
107+
foreachStatement.Body.Clear();
108+
foreachStatement.Body.Add(new SingleLineCommentStatement("Plugin customization: Properly handle metadata query parameters"));
109+
foreachStatement.Body.Add(new ExpressionStatement(appendQueryStatement));
110+
}
111+
112+
// Remove the call to AppendQueryDelimited for metadata
113+
if (statement is ExpressionStatement expressionStatement3 &&
114+
expressionStatement3.Expression is InvokeMethodExpression invokeMethodExpression &&
115+
invokeMethodExpression.MethodName == "AppendQueryDelimited" &&
116+
invokeMethodExpression.Arguments.Count == 5 &&
117+
invokeMethodExpression.Arguments[0].ToDisplayString() == "\"metadata\"")
118+
{
119+
return new SingleLineCommentStatement("Plugin customization: remove unnecessary AppendQueryDelimited for metadata");
120+
}
121+
return statement;
122+
});
123+
124+
// Rebuild the method body with the modified statements
125+
method.Update(bodyStatements: statements);
126+
}
127+
128+
return base.VisitMethod(method);
129+
}
130+
}

codegen/generator/src/Visitors/PaginationVisitor.cs

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,32 @@ public class PaginationVisitor : ScmLibraryVisitor
4141
};
4242

4343
protected override MethodProvider? VisitMethod(MethodProvider method)
44+
{
45+
// Try to handle pagination methods with options replacement
46+
if (TryHandlePaginationMethodWithOptions(method))
47+
{
48+
return method;
49+
}
50+
51+
// Try to handle GetRawPagesAsync methods for hasMore checks
52+
if (TryHandleGetRawPagesAsyncMethod(method))
53+
{
54+
return method;
55+
}
56+
57+
return base.VisitMethod(method);
58+
}
59+
60+
/// <summary>
61+
/// Handles pagination methods that need their parameters replaced with an options type.
62+
/// </summary>
63+
/// <param name="method">The method to potentially handle. Will be modified in place if handling is successful.</param>
64+
/// <returns>True if the method was handled, false otherwise.</returns>
65+
private bool TryHandlePaginationMethodWithOptions(MethodProvider method)
4466
{
4567
// Check if the method is one of the pagination methods we want to modify.
4668
// If so, we will update its parameters to replace the specified parameters with the options type.
47-
if (
48-
method.Signature.ReturnType is not null &&
69+
if (method.Signature.ReturnType is not null &&
4970
method.Signature.ReturnType.Name.EndsWith("CollectionResult") &&
5071
_optionsReplacements.TryGetValue(method.Signature.Name, out var options) &&
5172
method.Signature.ReturnType.IsGenericType &&
@@ -111,7 +132,6 @@ method.Signature.ReturnType is not null &&
111132
if (_paramReplacementMap.TryGetValue(varExpr.Declaration.RequestedName, out var replacement))
112133
{
113134
newParameters.Add(optionsParam.NullConditional().Property(replacement));
114-
var foo = optionsParam.NullConditional().Property(replacement).NullConditional().Invoke("ToString", Array.Empty<ValueExpression>());
115135
}
116136
}
117137
else if (param is InvokeMethodExpression invokeMethod && invokeMethod.MethodName == "ToString" &&
@@ -149,10 +169,21 @@ nullConditional.Inner is VariableExpression varExpr2 &&
149169
});
150170

151171
method.Update(signature: newSignature, bodyStatements: statements);
172+
return true;
152173
}
153174
}
154175
}
155176

177+
return false;
178+
}
179+
180+
/// <summary>
181+
/// Handles GetRawPagesAsync methods to add hasMore == false checks for pagination.
182+
/// </summary>
183+
/// <param name="method">The method to potentially handle. Will be modified in place if handling is successful.</param>
184+
/// <returns>True if the method was handled, false otherwise.</returns>
185+
private bool TryHandleGetRawPagesAsyncMethod(MethodProvider method)
186+
{
156187
// If the method is GetRawPagesAsync and is internal, we will modify the body statements to add a check for hasMore == false.
157188
// This is to ensure that pagination stops when hasMore is false, in addition to checking LastId.
158189
if (method.Signature.Name == "GetRawPagesAsync" && method.EnclosingType.DeclarationModifiers.HasFlag(TypeSignatureModifiers.Internal))
@@ -197,6 +228,7 @@ binaryExpr.Right is KeywordExpression rightKeyword &&
197228
.SelectMany(bodyStatement => bodyStatement)
198229
.ToList();
199230

231+
// Check for the assignment of nextToken and add hasMore assignment
200232
for (int i = 0; i < statementList.Count; i++)
201233
{
202234
if (statementList[i] is ExpressionStatement expressionStatement &&
@@ -227,9 +259,9 @@ assignmentExpression.Value is MemberExpression memberExpression &&
227259
});
228260

229261
method.Update(bodyStatements: statements);
230-
return method;
262+
return true;
231263
}
232264

233-
return base.VisitMethod(method);
265+
return false;
234266
}
235267
}

codegen/generator/src/Visitors/VisitorHelpers.cs

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ internal static void VisitExplodedMethodBodyStatements(
1616
for (int i = 0; i < statements.Count; i++)
1717
{
1818
statements[i] = visitorFunc.Invoke(statements[i]);
19-
19+
2020
if (statements[i] is ForEachStatement foreachStatement)
2121
{
2222
List<MethodBodyStatement> foreachBodyStatements
@@ -29,7 +29,17 @@ List<MethodBodyStatement> foreachBodyStatements
2929
}
3030
else if (statements[i] is IfStatement ifStatement)
3131
{
32-
// To do: traverse inside of "if"
32+
List<MethodBodyStatement> ifBodyStatements
33+
= ifStatement.Body
34+
.SelectMany(bodyStatement => bodyStatement)
35+
.ToList();
36+
VisitExplodedMethodBodyStatements(ifBodyStatements!, visitorFunc);
37+
var newIfStatement = new IfStatement(ifStatement.Condition);
38+
foreach (MethodBodyStatement bodyStatement in ifBodyStatements)
39+
{
40+
newIfStatement.Add(bodyStatement);
41+
}
42+
statements[i] = newIfStatement;
3343
}
3444
else if (statements[i] is ForStatement forStatement)
3545
{

src/Generated/ChatClient.RestClient.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,13 @@ internal virtual PipelineMessage CreateGetChatCompletionsRequest(string after, i
3838
}
3939
if (metadata != null && !(metadata is ChangeTrackingDictionary<string, string> changeTrackingDictionary && changeTrackingDictionary.IsUndefined))
4040
{
41-
List<object> list = new List<object>();
41+
// Plugin customization: remove unnecessary list declaration
4242
foreach (var @param in metadata)
4343
{
44-
list.Add(@param.Key);
45-
list.Add(@param.Value);
44+
// Plugin customization: Properly handle metadata query parameters
45+
uri.AppendQuery($"metadata[{@param.Key}]", @param.Value, true);
4646
}
47-
uri.AppendQueryDelimited("metadata", list, ",", null, true);
47+
// Plugin customization: remove unnecessary AppendQueryDelimited for metadata
4848
}
4949
if (model != null)
5050
{

0 commit comments

Comments
 (0)