Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions api/OpenAI.net8.0.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1430,6 +1430,14 @@ public class ChatClient {
public virtual Task<ClientResult> GetChatCompletionAsync(string completionId, RequestOptions options);
[Experimental("OPENAI001")]
public virtual Task<ClientResult<ChatCompletion>> GetChatCompletionAsync(string completionId, CancellationToken cancellationToken = default);
[Experimental("OPENAI001")]
public virtual CollectionResult<ChatCompletion> GetChatCompletions(ChatCompletionCollectionOptions options = null, CancellationToken cancellationToken = default);
[Experimental("OPENAI001")]
public virtual CollectionResult GetChatCompletions(string after, int? limit, string order, IDictionary<string, string> metadata, string model, RequestOptions options);
[Experimental("OPENAI001")]
public virtual AsyncCollectionResult<ChatCompletion> GetChatCompletionsAsync(ChatCompletionCollectionOptions options = null, CancellationToken cancellationToken = default);
[Experimental("OPENAI001")]
public virtual AsyncCollectionResult GetChatCompletionsAsync(string after, int? limit, string order, IDictionary<string, string> metadata, string model, RequestOptions options);
}
public class ChatCompletion : IJsonModel<ChatCompletion>, IPersistableModel<ChatCompletion> {
[Experimental("OPENAI001")]
Expand Down Expand Up @@ -1460,6 +1468,33 @@ public class ChatCompletion : IJsonModel<ChatCompletion>, IPersistableModel<Chat
protected virtual BinaryData PersistableModelWriteCore(ModelReaderWriterOptions options);
}
[Experimental("OPENAI001")]
public class ChatCompletionCollectionOptions : IJsonModel<ChatCompletionCollectionOptions>, IPersistableModel<ChatCompletionCollectionOptions> {
public string AfterId { get; set; }
public IDictionary<string, string> Metadata { get; }
public string Model { get; set; }
public ChatCompletionCollectionOrder? Order { get; set; }
public int? PageSizeLimit { get; set; }
protected virtual ChatCompletionCollectionOptions JsonModelCreateCore(ref Utf8JsonReader reader, ModelReaderWriterOptions options);
protected virtual void JsonModelWriteCore(Utf8JsonWriter writer, ModelReaderWriterOptions options);
protected virtual ChatCompletionCollectionOptions PersistableModelCreateCore(BinaryData data, ModelReaderWriterOptions options);
protected virtual BinaryData PersistableModelWriteCore(ModelReaderWriterOptions options);
}
[Experimental("OPENAI001")]
public readonly partial struct ChatCompletionCollectionOrder : IEquatable<ChatCompletionCollectionOrder> {
public ChatCompletionCollectionOrder(string value);
public static ChatCompletionCollectionOrder Ascending { get; }
public static ChatCompletionCollectionOrder Descending { get; }
public readonly bool Equals(ChatCompletionCollectionOrder other);
[EditorBrowsable(EditorBrowsableState.Never)]
public override readonly bool Equals(object obj);
[EditorBrowsable(EditorBrowsableState.Never)]
public override readonly int GetHashCode();
public static bool operator ==(ChatCompletionCollectionOrder left, ChatCompletionCollectionOrder right);
public static implicit operator ChatCompletionCollectionOrder(string value);
public static bool operator !=(ChatCompletionCollectionOrder left, ChatCompletionCollectionOrder right);
public override readonly string ToString();
}
[Experimental("OPENAI001")]
public class ChatCompletionDeletionResult : IJsonModel<ChatCompletionDeletionResult>, IPersistableModel<ChatCompletionDeletionResult> {
public string ChatCompletionId { get; }
public bool Deleted { get; }
Expand Down
29 changes: 29 additions & 0 deletions api/OpenAI.netstandard2.0.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1285,6 +1285,10 @@ public class ChatClient {
public virtual ClientResult<ChatCompletion> GetChatCompletion(string completionId, CancellationToken cancellationToken = default);
public virtual Task<ClientResult> GetChatCompletionAsync(string completionId, RequestOptions options);
public virtual Task<ClientResult<ChatCompletion>> GetChatCompletionAsync(string completionId, CancellationToken cancellationToken = default);
public virtual CollectionResult<ChatCompletion> GetChatCompletions(ChatCompletionCollectionOptions options = null, CancellationToken cancellationToken = default);
public virtual CollectionResult GetChatCompletions(string after, int? limit, string order, IDictionary<string, string> metadata, string model, RequestOptions options);
public virtual AsyncCollectionResult<ChatCompletion> GetChatCompletionsAsync(ChatCompletionCollectionOptions options = null, CancellationToken cancellationToken = default);
public virtual AsyncCollectionResult GetChatCompletionsAsync(string after, int? limit, string order, IDictionary<string, string> metadata, string model, RequestOptions options);
}
public class ChatCompletion : IJsonModel<ChatCompletion>, IPersistableModel<ChatCompletion> {
public IReadOnlyList<ChatMessageAnnotation> Annotations { get; }
Expand All @@ -1308,6 +1312,31 @@ public class ChatCompletion : IJsonModel<ChatCompletion>, IPersistableModel<Chat
protected virtual ChatCompletion PersistableModelCreateCore(BinaryData data, ModelReaderWriterOptions options);
protected virtual BinaryData PersistableModelWriteCore(ModelReaderWriterOptions options);
}
public class ChatCompletionCollectionOptions : IJsonModel<ChatCompletionCollectionOptions>, IPersistableModel<ChatCompletionCollectionOptions> {
public string AfterId { get; set; }
public IDictionary<string, string> Metadata { get; }
public string Model { get; set; }
public ChatCompletionCollectionOrder? Order { get; set; }
public int? PageSizeLimit { get; set; }
protected virtual ChatCompletionCollectionOptions JsonModelCreateCore(ref Utf8JsonReader reader, ModelReaderWriterOptions options);
protected virtual void JsonModelWriteCore(Utf8JsonWriter writer, ModelReaderWriterOptions options);
protected virtual ChatCompletionCollectionOptions PersistableModelCreateCore(BinaryData data, ModelReaderWriterOptions options);
protected virtual BinaryData PersistableModelWriteCore(ModelReaderWriterOptions options);
}
public readonly partial struct ChatCompletionCollectionOrder : IEquatable<ChatCompletionCollectionOrder> {
public ChatCompletionCollectionOrder(string value);
public static ChatCompletionCollectionOrder Ascending { get; }
public static ChatCompletionCollectionOrder Descending { get; }
public readonly bool Equals(ChatCompletionCollectionOrder other);
[EditorBrowsable(EditorBrowsableState.Never)]
public override readonly bool Equals(object obj);
[EditorBrowsable(EditorBrowsableState.Never)]
public override readonly int GetHashCode();
public static bool operator ==(ChatCompletionCollectionOrder left, ChatCompletionCollectionOrder right);
public static implicit operator ChatCompletionCollectionOrder(string value);
public static bool operator !=(ChatCompletionCollectionOrder left, ChatCompletionCollectionOrder right);
public override readonly string ToString();
}
public class ChatCompletionDeletionResult : IJsonModel<ChatCompletionDeletionResult>, IPersistableModel<ChatCompletionDeletionResult> {
public string ChatCompletionId { get; }
public bool Deleted { get; }
Expand Down
23 changes: 23 additions & 0 deletions codegen/README.MD
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Debugging the generator

To configure VS Code for debugging the generator, specifically visitors, add the following to your `launch.json` in the root of the workspace

```json
{
"version": "0.2.0",
"configurations": [
{
"name": "Debug OpenAI Library Plugin",
"type": "coreclr",
"request": "launch",
"program": "dotnet",
"args": [
"${workspaceFolder}/codegen/dist/generator/Microsoft.TypeSpec.Generator.dll",
"${workspaceFolder}",
"-g",
"OpenAILibraryGenerator"
],
}
]
}
```
2 changes: 2 additions & 0 deletions codegen/generator/src/OpenAILibraryGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ protected override void Configure()
AddVisitor(new ModelSerializationVisitor());
AddVisitor(new ExperimentalAttributeVisitor());
AddVisitor(new ModelDirectoryVisitor());
AddVisitor(new PaginationVisitor());
AddVisitor(new MetadataQueryParamVisitor());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ protected override MethodProvider VisitMethod(MethodProvider method)
if (method.Signature.Modifiers.HasFlag(MethodSignatureModifiers.Explicit) &&
method.Signature.Modifiers.HasFlag(MethodSignatureModifiers.Operator) &&
method.Signature.Parameters.Count == 1 &&
method.Signature.Parameters[0].Type.Name == nameof(ClientResult))
method.Signature.Parameters[0].Type.Name == nameof(ClientResult) &&
!method.EnclosingType.DeclarationModifiers.HasFlag(TypeSignatureModifiers.Internal))
{
return null;
}
Expand Down
130 changes: 130 additions & 0 deletions codegen/generator/src/Visitors/MetadataQueryParamVisitor.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.TypeSpec.Generator.ClientModel;
using Microsoft.TypeSpec.Generator.Expressions;
using Microsoft.TypeSpec.Generator.Primitives;
using Microsoft.TypeSpec.Generator.Providers;
using Microsoft.TypeSpec.Generator.Snippets;
using Microsoft.TypeSpec.Generator.Statements;
using static OpenAILibraryPlugin.Visitors.VisitorHelpers;

namespace OpenAILibraryPlugin.Visitors;

/// <summary>
/// This visitor modifies GetRawPagesAsync methods to consider HasMore in addition to LastId when deciding whether to continue pagination.
/// It also replaces specific parameters with an options type for pagination methods.
/// </summary>
public class MetadataQueryParamVisitor : ScmLibraryVisitor
{

private static readonly string[] _chatParamsToReplace = ["after", "before", "limit", "order", "model", "metadata"];
private static readonly Dictionary<string, string> _paramReplacementMap = new()
{
{ "after", "AfterId" },
{ "before", "LastId" },
{ "limit", "PageSizeLimit" },
{ "order", "Order" },
{ "model", "Model" },
{ "metadata", "Metadata" }
};
private static readonly Dictionary<string, (string ReturnType, string OptionsType, string[] ParamsToReplace)> _optionsReplacements = new()
{
{
"GetChatCompletions",
("ChatCompletion", "ChatCompletionCollectionOptions", _chatParamsToReplace)
},
{
"GetChatCompletionsAsync",
("ChatCompletion", "ChatCompletionCollectionOptions", _chatParamsToReplace)
}
};

/// <summary>
/// Visits Create*Request methods to modify how metadata query parameters are handled.
/// It replaces the following statements:
/// <code>
/// List<object> list = new List<object>();
/// foreach (var @param in metadata)
/// {
/// uri.AppendQuery($"metadata[{@param.Key}]", @param.Value, true);
/// list.Add(@param.Key);
/// list.Add(@param.Value);
/// }
/// uri.AppendQueryDelimited("metadata", list, ",", null, true);
/// </code>
/// with:
/// <code>
/// foreach (var @param in metadata)
/// {
/// uri.AppendQuery($"metadata[{@param.Key}]", @param.Value, true);
/// }
/// </summary>
/// <param name="method"></param>
/// <returns></returns>
protected override MethodProvider? VisitMethod(MethodProvider method)
{
// Check if the method is one of the Create*Request methods and has a signature that takes a metadata parameter like IDictionary<string, string> metadata
if (method.Signature.Name.StartsWith("Create") && method.Signature.Name.EndsWith("Request") &&
method.Signature.Parameters.Any(p => p.Type.IsDictionary && p.Name == "metadata"))
{
ValueExpression? uri = null;
var statements = method.BodyStatements?.ToList() ?? new List<MethodBodyStatement>();
VisitExplodedMethodBodyStatements(
statements!,
statement =>
{
// Check if the statement is an assignment to a variable named "uri"
// Capture it if so
if (statement is ExpressionStatement expressionStatement &&
expressionStatement.Expression is AssignmentExpression assignmentExpression &&
assignmentExpression.Variable is DeclarationExpression declarationExpression &&
declarationExpression.Variable is VariableExpression variableExpression &&
variableExpression.Declaration.RequestedName == "uri")
{
uri = variableExpression;
}
// Try to remove the unnecessary list declaration
if (statement is ExpressionStatement expressionStatement2 &&
expressionStatement2.Expression is AssignmentExpression assignmentExpression2 &&
assignmentExpression2.Variable is DeclarationExpression declarationExpression2 &&
declarationExpression2.Variable is VariableExpression variableExpression2 &&
variableExpression2.Declaration.RequestedName == "list" &&
variableExpression2.Type.IsCollection && variableExpression2.Type.IsGenericType)
{
// Remove the list declaration
return new SingleLineCommentStatement("Plugin customization: remove unnecessary list declaration");
}

if (uri is not null &&
statement is ForEachStatement foreachStatement &&
foreachStatement.Enumerable is DictionaryExpression dictionaryExpression &&
dictionaryExpression.Original is VariableExpression variable &&
variable.Declaration.RequestedName == "metadata")
{
var formatString = new FormattableStringExpression("metadata[{0}]", [foreachStatement.ItemVariable.Property("Key")]);
var appendQueryStatement = uri.Invoke("AppendQuery", [formatString, foreachStatement.ItemVariable.Property("Value"), Snippet.True]);
foreachStatement.Body.Clear();
foreachStatement.Body.Add(new SingleLineCommentStatement("Plugin customization: Properly handle metadata query parameters"));
foreachStatement.Body.Add(appendQueryStatement.Terminate());
}

// Remove the call to AppendQueryDelimited for metadata
if (statement is ExpressionStatement expressionStatement3 &&
expressionStatement3.Expression is InvokeMethodExpression invokeMethodExpression &&
invokeMethodExpression.MethodName == "AppendQueryDelimited" &&
invokeMethodExpression.Arguments.Count == 5 &&
invokeMethodExpression.Arguments[0].ToDisplayString() == "\"metadata\"")
{
return new SingleLineCommentStatement("Plugin customization: remove unnecessary AppendQueryDelimited for metadata");
}
return statement;
});

// Rebuild the method body with the modified statements
method.Update(bodyStatements: statements);
}

return base.VisitMethod(method);
}
}
Loading