Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions codegen/generator/src/Visitors/PaginationVisitor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
namespace OpenAILibraryPlugin.Visitors;

/// <summary>
/// This visitor modifies GetRawPagesAsync methods to consider HasMore in addition to LastId when deciding whether to continue pagination.
/// This visitor modifies GetRawPagesAsync and GetRawPages methods to consider HasMore in addition to LastId when deciding whether to continue pagination.
/// It also replaces specific parameters with an options type for pagination methods.
/// </summary>
public class PaginationVisitor : ScmLibraryVisitor
Expand All @@ -37,6 +37,14 @@ public class PaginationVisitor : ScmLibraryVisitor
{
"GetChatCompletionsAsync",
("ChatCompletion", "ChatCompletionCollectionOptions", _chatParamsToReplace)
},
{
"GetChatCompletionMessages",
("ChatCompletionMessageListDatum", "ChatCompletionCollectionOptions", _chatParamsToReplace)
},
{
"GetChatCompletionMessagesAsync",
("ChatCompletionMessageListDatum", "ChatCompletionMessageCollectionOptions", _chatParamsToReplace)
}
};

Expand Down Expand Up @@ -174,9 +182,9 @@ nullConditional.Inner is VariableExpression varExpr2 &&
/// <returns>True if the method was handled, false otherwise.</returns>
private bool TryHandleGetRawPagesAsyncMethod(MethodProvider method)
{
// If the method is GetRawPagesAsync and is internal, we will modify the body statements to add a check for hasMore == false.
// If the method is GetRawPagesAsync or GetRawPages and is internal, we will modify the body statements to add a check for hasMore == false.
// This is to ensure that pagination stops when hasMore is false, in addition to checking LastId.
if (method.Signature.Name == "GetRawPagesAsync" && method.EnclosingType.DeclarationModifiers.HasFlag(TypeSignatureModifiers.Internal))
if ((method.Signature.Name == "GetRawPagesAsync" || method.Signature.Name == "GetRawPages") && method.EnclosingType.DeclarationModifiers.HasFlag(TypeSignatureModifiers.Internal))
{
var statements = method.BodyStatements?.ToList() ?? new List<MethodBodyStatement>();
VisitExplodedMethodBodyStatements(
Expand Down
22 changes: 22 additions & 0 deletions specification/client/models/chat.models.tsp
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,26 @@ model ChatCompletionCollectionOptions {
@query `model`?: string,
}

alias ChatCompletionMessageCollectionOrderQueryParameter = {
/**
* Sort order by the `created_at` timestamp of the objects. `asc` for ascending order and`desc`
* for descending order.
*/
@query order?: ChatCompletionCollectionOrder;
};

union ChatCompletionMessageCollectionOrder {
string,
Ascending: "asc",
Descending: "desc",
}

@access(Access.public)
@usage(Usage.input)
model ChatCompletionMessageCollectionOptions {
...CollectionAfterQueryParameter,
...CollectionLimitQueryParameter,
...ChatCompletionMessageCollectionOrderQueryParameter,
}


2 changes: 0 additions & 2 deletions src/Custom/Chat/ChatClient.Protocol.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
namespace OpenAI.Chat;

/// <summary> The service client for the OpenAI Chat Completions endpoint. </summary>
[CodeGenSuppress("GetChatCompletionMessagesAsync", typeof(string), typeof(string), typeof(int?), typeof(string), typeof(RequestOptions))]
[CodeGenSuppress("GetChatCompletionMessages", typeof(string), typeof(string), typeof(int?), typeof(string), typeof(RequestOptions))]
[CodeGenSuppress("UpdateChatCompletionAsync", typeof(string), typeof(BinaryContent), typeof(RequestOptions))]
[CodeGenSuppress("UpdateChatCompletion", typeof(string), typeof(BinaryContent), typeof(RequestOptions))]
public partial class ChatClient
Expand Down
2 changes: 0 additions & 2 deletions src/Custom/Chat/ChatClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@ namespace OpenAI.Chat;
[CodeGenSuppress("ChatClient", typeof(ClientPipeline), typeof(Uri))]
[CodeGenSuppress("CompleteChat", typeof(ChatCompletionOptions), typeof(CancellationToken))]
[CodeGenSuppress("CompleteChatAsync", typeof(ChatCompletionOptions), typeof(CancellationToken))]
[CodeGenSuppress("GetChatCompletionMessages", typeof(string), typeof(string), typeof(int?), typeof(OpenAI.VectorStores.VectorStoreCollectionOrder?), typeof(CancellationToken))]
[CodeGenSuppress("GetChatCompletionMessagesAsync", typeof(string), typeof(string), typeof(int?), typeof(OpenAI.VectorStores.VectorStoreCollectionOrder?), typeof(CancellationToken))]
[CodeGenSuppress("UpdateChatCompletion", typeof(string), typeof(IDictionary<string, string>), typeof(CancellationToken))]
[CodeGenSuppress("UpdateChatCompletionAsync", typeof(string), typeof(IDictionary<string, string>), typeof(CancellationToken))]
public partial class ChatClient
Expand Down
6 changes: 6 additions & 0 deletions src/Custom/Chat/ChatCompletionMessageCollectionOptions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@

namespace OpenAI.Chat;

// CUSTOM: Use the correct namespace.
[CodeGenType("ChatCompletionMessageCollectionOptions")]
public partial class ChatCompletionMessageCollectionOptions {}
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
namespace OpenAI.Chat;

[CodeGenType("ChatCompletionMessageListDatum")]
internal partial class InternalChatCompletionMessageListDatum
public partial class ChatCompletionMessageListDatum
{
// CUSTOM: Ensure enumerated value is used.
[CodeGenMember("Role")]
Expand Down
56 changes: 56 additions & 0 deletions src/Generated/ChatClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -90,5 +90,61 @@ public virtual async Task<ClientResult> CompleteChatAsync(BinaryContent content,
using PipelineMessage message = CreateCompleteChatRequest(content, options);
return ClientResult.FromResponse(await Pipeline.ProcessMessageAsync(message, options).ConfigureAwait(false));
}

[Experimental("OPENAI001")]
public virtual CollectionResult GetChatCompletionMessages(string completionId, string after, int? limit, string order, RequestOptions options)
{
Argument.AssertNotNullOrEmpty(completionId, nameof(completionId));

return new ChatClientGetChatCompletionMessagesCollectionResult(
this,
completionId,
after,
limit,
order,
options);
}

[Experimental("OPENAI001")]
public virtual AsyncCollectionResult GetChatCompletionMessagesAsync(string completionId, string after, int? limit, string order, RequestOptions options)
{
Argument.AssertNotNullOrEmpty(completionId, nameof(completionId));

return new ChatClientGetChatCompletionMessagesAsyncCollectionResult(
this,
completionId,
after,
limit,
order,
options);
}

[Experimental("OPENAI001")]
public virtual CollectionResult<ChatCompletionMessageListDatum> GetChatCompletionMessages(string completionId, ChatCompletionCollectionOptions options = default, CancellationToken cancellationToken = default)
{
Argument.AssertNotNullOrEmpty(completionId, nameof(completionId));

return new ChatClientGetChatCompletionMessagesCollectionResultOfT(
this,
completionId,
options?.AfterId,
options?.PageSizeLimit,
options?.Order?.ToString(),
cancellationToken.CanBeCanceled ? new RequestOptions { CancellationToken = cancellationToken } : null);
}

[Experimental("OPENAI001")]
public virtual AsyncCollectionResult<ChatCompletionMessageListDatum> GetChatCompletionMessagesAsync(string completionId, ChatCompletionMessageCollectionOptions options = default, CancellationToken cancellationToken = default)
{
Argument.AssertNotNullOrEmpty(completionId, nameof(completionId));

return new ChatClientGetChatCompletionMessagesAsyncCollectionResultOfT(
this,
completionId,
options?.AfterId,
options?.PageSizeLimit,
options?.Order?.ToString(),
cancellationToken.CanBeCanceled ? new RequestOptions { CancellationToken = cancellationToken } : null);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
// <auto-generated/>

#nullable disable

using System;
using System.ClientModel;
using System.ClientModel.Primitives;
using System.Collections.Generic;
using OpenAI;

namespace OpenAI.Chat
{
internal partial class ChatClientGetChatCompletionMessagesAsyncCollectionResult : AsyncCollectionResult
{
private readonly ChatClient _client;
private readonly string _completionId;
private readonly string _after;
private readonly int? _limit;
private readonly string _order;
private readonly RequestOptions _options;

public ChatClientGetChatCompletionMessagesAsyncCollectionResult(ChatClient client, string completionId, string after, int? limit, string order, RequestOptions options)
{
Argument.AssertNotNullOrEmpty(completionId, nameof(completionId));

_client = client;
_completionId = completionId;
_after = after;
_limit = limit;
_order = order;
_options = options;
}

public override async IAsyncEnumerable<ClientResult> GetRawPagesAsync()
{
PipelineMessage message = _client.CreateGetChatCompletionMessagesRequest(_completionId, _after, _limit, _order, _options);
string nextToken = null;
while (true)
{
ClientResult result = ClientResult.FromResponse(await _client.Pipeline.ProcessMessageAsync(message, _options).ConfigureAwait(false));
yield return result;

// Plugin customization: add hasMore assignment
bool hasMore = ((InternalChatCompletionMessageList)result).HasMore;
nextToken = ((InternalChatCompletionMessageList)result).LastId;
// Plugin customization: add hasMore == false check to pagination condition
if (nextToken == null || hasMore == false)
{
yield break;
}
message = _client.CreateGetChatCompletionMessagesRequest(_completionId, nextToken, _limit, _order, _options);
}
}

public override ContinuationToken GetContinuationToken(ClientResult page)
{
string nextPage = ((InternalChatCompletionMessageList)page).LastId;
if (nextPage != null)
{
return ContinuationToken.FromBytes(BinaryData.FromString(nextPage));
}
else
{
return null;
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
// <auto-generated/>

#nullable disable

using System;
using System.ClientModel;
using System.ClientModel.Primitives;
using System.Collections.Generic;
using System.Threading.Tasks;
using OpenAI;

namespace OpenAI.Chat
{
internal partial class ChatClientGetChatCompletionMessagesAsyncCollectionResultOfT : AsyncCollectionResult<ChatCompletionMessageListDatum>
{
private readonly ChatClient _client;
private readonly string _completionId;
private readonly string _after;
private readonly int? _limit;
private readonly string _order;
private readonly RequestOptions _options;

public ChatClientGetChatCompletionMessagesAsyncCollectionResultOfT(ChatClient client, string completionId, string after, int? limit, string order, RequestOptions options)
{
Argument.AssertNotNullOrEmpty(completionId, nameof(completionId));

_client = client;
_completionId = completionId;
_after = after;
_limit = limit;
_order = order;
_options = options;
}

public override async IAsyncEnumerable<ClientResult> GetRawPagesAsync()
{
PipelineMessage message = _client.CreateGetChatCompletionMessagesRequest(_completionId, _after, _limit, _order, _options);
string nextToken = null;
while (true)
{
ClientResult result = ClientResult.FromResponse(await _client.Pipeline.ProcessMessageAsync(message, _options).ConfigureAwait(false));
yield return result;

// Plugin customization: add hasMore assignment
bool hasMore = ((InternalChatCompletionMessageList)result).HasMore;
nextToken = ((InternalChatCompletionMessageList)result).LastId;
// Plugin customization: add hasMore == false check to pagination condition
if (nextToken == null || hasMore == false)
{
yield break;
}
message = _client.CreateGetChatCompletionMessagesRequest(_completionId, nextToken, _limit, _order, _options);
}
}

public override ContinuationToken GetContinuationToken(ClientResult page)
{
string nextPage = ((InternalChatCompletionMessageList)page).LastId;
if (nextPage != null)
{
return ContinuationToken.FromBytes(BinaryData.FromString(nextPage));
}
else
{
return null;
}
}

protected override async IAsyncEnumerable<ChatCompletionMessageListDatum> GetValuesFromPageAsync(ClientResult page)
{
foreach (ChatCompletionMessageListDatum item in ((InternalChatCompletionMessageList)page).Data)
{
yield return item;
await Task.Yield();
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
// <auto-generated/>

#nullable disable

using System;
using System.ClientModel;
using System.ClientModel.Primitives;
using System.Collections.Generic;
using OpenAI;

namespace OpenAI.Chat
{
internal partial class ChatClientGetChatCompletionMessagesCollectionResult : CollectionResult
{
private readonly ChatClient _client;
private readonly string _completionId;
private readonly string _after;
private readonly int? _limit;
private readonly string _order;
private readonly RequestOptions _options;

public ChatClientGetChatCompletionMessagesCollectionResult(ChatClient client, string completionId, string after, int? limit, string order, RequestOptions options)
{
Argument.AssertNotNullOrEmpty(completionId, nameof(completionId));

_client = client;
_completionId = completionId;
_after = after;
_limit = limit;
_order = order;
_options = options;
}

public override IEnumerable<ClientResult> GetRawPages()
{
PipelineMessage message = _client.CreateGetChatCompletionMessagesRequest(_completionId, _after, _limit, _order, _options);
string nextToken = null;
while (true)
{
ClientResult result = ClientResult.FromResponse(_client.Pipeline.ProcessMessage(message, _options));
yield return result;

// Plugin customization: add hasMore assignment
bool hasMore = ((InternalChatCompletionMessageList)result).HasMore;
nextToken = ((InternalChatCompletionMessageList)result).LastId;
// Plugin customization: add hasMore == false check to pagination condition
if (nextToken == null || hasMore == false)
{
yield break;
}
message = _client.CreateGetChatCompletionMessagesRequest(_completionId, nextToken, _limit, _order, _options);
}
}

public override ContinuationToken GetContinuationToken(ClientResult page)
{
string nextPage = ((InternalChatCompletionMessageList)page).LastId;
if (nextPage != null)
{
return ContinuationToken.FromBytes(BinaryData.FromString(nextPage));
}
else
{
return null;
}
}
}
}
Loading