From 4358abe27e47ca0ed6449f8a4c3304835d41e1a0 Mon Sep 17 00:00:00 2001 From: Devis Lucato Date: Sun, 17 Nov 2024 12:24:24 -0800 Subject: [PATCH 01/10] Optimize pipeline for non-retriable errors --- .../Program.cs | 4 +- .../MyHandler.cs | 4 +- .../AzureAIDocIntel/AzureAIDocIntelEngine.cs | 17 +++-- extensions/AzureQueues/AzureQueuesPipeline.cs | 40 ++++++++---- .../RabbitMQ.TestApplication/Program.cs | 3 +- .../RabbitMQ/RabbitMQ/RabbitMQPipeline.cs | 64 +++++++++++-------- .../Pipeline/IPipelineStepHandler.cs | 2 +- .../Pipeline/NonRetriableException.cs | 20 ++++++ service/Abstractions/Pipeline/Queue/IQueue.cs | 2 +- service/Abstractions/Pipeline/ResultType.cs | 10 +++ .../Core/Handlers/DeleteDocumentHandler.cs | 4 +- .../Handlers/DeleteGeneratedFilesHandler.cs | 4 +- service/Core/Handlers/DeleteIndexHandler.cs | 4 +- .../Handlers/GenerateEmbeddingsHandler.cs | 6 +- .../GenerateEmbeddingsParallelHandler.cs | 6 +- service/Core/Handlers/SaveRecordsHandler.cs | 4 +- service/Core/Handlers/SummarizationHandler.cs | 4 +- .../Handlers/SummarizationParallelHandler.cs | 4 +- .../Core/Handlers/TextExtractionHandler.cs | 4 +- .../Core/Handlers/TextPartitioningHandler.cs | 6 +- service/Core/Pipeline/BaseOrchestrator.cs | 6 +- .../DistributedPipelineOrchestrator.cs | 56 ++++++++-------- .../Pipeline/InProcessPipelineOrchestrator.cs | 33 ++++++---- .../Pipeline/Queue/DevTools/SimpleQueues.cs | 63 +++++++++++------- service/Service.AspNetCore/WebAPIEndpoints.cs | 3 +- service/Service/Program.cs | 23 ++++++- 26 files changed, 254 insertions(+), 142 deletions(-) create mode 100644 service/Abstractions/Pipeline/NonRetriableException.cs create mode 100644 service/Abstractions/Pipeline/ResultType.cs diff --git a/examples/201-dotnet-serverless-custom-handler/Program.cs b/examples/201-dotnet-serverless-custom-handler/Program.cs index 743ecac55..cd35e3aed 100644 --- a/examples/201-dotnet-serverless-custom-handler/Program.cs +++ b/examples/201-dotnet-serverless-custom-handler/Program.cs @@ -47,7 +47,7 @@ public MyHandler( public string StepName { get; } /// - public async Task<(bool success, DataPipeline updatedPipeline)> InvokeAsync( + public async Task<(ResultType resultType, DataPipeline updatedPipeline)> InvokeAsync( DataPipeline pipeline, CancellationToken cancellationToken = default) { /* ... your custom ... @@ -64,6 +64,6 @@ public MyHandler( // Remove this - here only to avoid build errors await Task.Delay(0, cancellationToken).ConfigureAwait(false); - return (true, pipeline); + return (ResultType.Success, pipeline); } } diff --git a/examples/202-dotnet-custom-handler-as-a-service/MyHandler.cs b/examples/202-dotnet-custom-handler-as-a-service/MyHandler.cs index 5ebfbe948..e3c7fd89d 100644 --- a/examples/202-dotnet-custom-handler-as-a-service/MyHandler.cs +++ b/examples/202-dotnet-custom-handler-as-a-service/MyHandler.cs @@ -38,7 +38,7 @@ public Task StopAsync(CancellationToken cancellationToken = default) } /// - public async Task<(bool success, DataPipeline updatedPipeline)> InvokeAsync(DataPipeline pipeline, CancellationToken cancellationToken = default) + public async Task<(ResultType resultType, DataPipeline updatedPipeline)> InvokeAsync(DataPipeline pipeline, CancellationToken cancellationToken = default) { /* ... your custom ... * ... handler ... @@ -49,6 +49,6 @@ public Task StopAsync(CancellationToken cancellationToken = default) // Remove this - here only to avoid build errors await Task.Delay(0, cancellationToken).ConfigureAwait(false); - return (true, pipeline); + return (ResultType.Success, pipeline); } } diff --git a/extensions/AzureAIDocIntel/AzureAIDocIntelEngine.cs b/extensions/AzureAIDocIntel/AzureAIDocIntelEngine.cs index 49f3c2a7f..09920256d 100644 --- a/extensions/AzureAIDocIntel/AzureAIDocIntelEngine.cs +++ b/extensions/AzureAIDocIntel/AzureAIDocIntelEngine.cs @@ -58,12 +58,19 @@ public AzureAIDocIntelEngine( /// public async Task ExtractTextFromImageAsync(Stream imageContent, CancellationToken cancellationToken = default) { - // Start the OCR operation - var operation = await this._recognizerClient.AnalyzeDocumentAsync(WaitUntil.Completed, "prebuilt-read", imageContent, cancellationToken: cancellationToken).ConfigureAwait(false); + try + { + // Start the OCR operation + var operation = await this._recognizerClient.AnalyzeDocumentAsync(WaitUntil.Completed, "prebuilt-read", imageContent, cancellationToken: cancellationToken).ConfigureAwait(false); - // Wait for the result - Response operationResponse = await operation.WaitForCompletionAsync(cancellationToken).ConfigureAwait(false); + // Wait for the result + Response operationResponse = await operation.WaitForCompletionAsync(cancellationToken).ConfigureAwait(false); - return operationResponse.Value.Content; + return operationResponse.Value.Content; + } + catch (RequestFailedException e) when (e.Status is >= 400 and < 500) + { + throw new NonRetriableException(e.Message, e); + } } } diff --git a/extensions/AzureQueues/AzureQueuesPipeline.cs b/extensions/AzureQueues/AzureQueuesPipeline.cs index f04771b55..93b8a80dc 100644 --- a/extensions/AzureQueues/AzureQueuesPipeline.cs +++ b/extensions/AzureQueues/AzureQueuesPipeline.cs @@ -14,6 +14,7 @@ using Microsoft.Extensions.Logging; using Microsoft.KernelMemory.Diagnostics; using Microsoft.KernelMemory.DocumentStorage; +using Microsoft.KernelMemory.Pipeline; using Microsoft.KernelMemory.Pipeline.Queue; using Timer = System.Timers.Timer; @@ -180,7 +181,7 @@ public async Task EnqueueAsync(string message, CancellationToken cancellationTok } /// - public void OnDequeue(Func> processMessageAction) + public void OnDequeue(Func> processMessageAction) { this.Received += async (object sender, MessageEventArgs args) => { @@ -191,20 +192,30 @@ public void OnDequeue(Func> processMessageAction) try { + ResultType resultType = await processMessageAction.Invoke(message.MessageText).ConfigureAwait(false); if (message.DequeueCount <= this._config.MaxRetriesBeforePoisonQueue) { - bool success = await processMessageAction.Invoke(message.MessageText).ConfigureAwait(false); - if (success) + switch (resultType) { - this._log.LogTrace("Message '{0}' successfully processed, deleting message", message.MessageId); - await this.DeleteMessageAsync(message, cancellationToken: default).ConfigureAwait(false); - } - else - { - var backoffDelay = TimeSpan.FromSeconds(1 * message.DequeueCount); - this._log.LogWarning("Message '{0}' failed to process, putting message back in the queue with a delay of {1} msecs", - message.MessageId, backoffDelay.TotalMilliseconds); - await this.UnlockMessageAsync(message, backoffDelay, cancellationToken: default).ConfigureAwait(false); + case ResultType.Success: + this._log.LogTrace("Message '{0}' successfully processed, deleting message", message.MessageId); + await this.DeleteMessageAsync(message, cancellationToken: default).ConfigureAwait(false); + break; + + case ResultType.RetriableError: + var backoffDelay = TimeSpan.FromSeconds(1 * message.DequeueCount); + this._log.LogWarning("Message '{0}' failed to process, putting message back in the queue with a delay of {1} msecs", + message.MessageId, backoffDelay.TotalMilliseconds); + await this.UnlockMessageAsync(message, backoffDelay, cancellationToken: default).ConfigureAwait(false); + break; + + case ResultType.NonRetriableError: + this._log.LogError("Message '{0}' failed to process due to a non-recoverable error, moving to poison queue", message.MessageId); + await this.MoveMessageToPoisonQueueAsync(message, cancellationToken: default).ConfigureAwait(false); + break; + + default: + throw new ArgumentOutOfRangeException($"Unknown {resultType:G} result"); } } else @@ -213,6 +224,11 @@ public void OnDequeue(Func> processMessageAction) await this.MoveMessageToPoisonQueueAsync(message, cancellationToken: default).ConfigureAwait(false); } } + catch (NonRetriableException e) + { + this._log.LogError(e, "Message '{0}' failed to process due to a non-recoverable error, moving to poison queue", message.MessageId); + await this.MoveMessageToPoisonQueueAsync(message, cancellationToken: default).ConfigureAwait(false); + } #pragma warning disable CA1031 // Must catch all to handle queue properly catch (Exception e) { diff --git a/extensions/RabbitMQ/RabbitMQ.TestApplication/Program.cs b/extensions/RabbitMQ/RabbitMQ.TestApplication/Program.cs index 10ece1b13..84c4a2e0d 100644 --- a/extensions/RabbitMQ/RabbitMQ.TestApplication/Program.cs +++ b/extensions/RabbitMQ/RabbitMQ.TestApplication/Program.cs @@ -4,6 +4,7 @@ using Microsoft.KernelMemory; using Microsoft.KernelMemory.Diagnostics; using Microsoft.KernelMemory.Orchestration.RabbitMQ; +using Microsoft.KernelMemory.Pipeline; using Microsoft.KernelMemory.Pipeline.Queue; using RabbitMQ.Client; using RabbitMQ.Client.Events; @@ -38,7 +39,7 @@ public static async Task Main() { Console.WriteLine($"{++counter} Received message: {msg}"); await Task.Delay(0); - return false; + return ResultType.RetriableError; }); await pipeline.ConnectToQueueAsync(QueueName, QueueOptions.PubSub); diff --git a/extensions/RabbitMQ/RabbitMQ/RabbitMQPipeline.cs b/extensions/RabbitMQ/RabbitMQ/RabbitMQPipeline.cs index 6fcc816bb..c660eb350 100644 --- a/extensions/RabbitMQ/RabbitMQ/RabbitMQPipeline.cs +++ b/extensions/RabbitMQ/RabbitMQ/RabbitMQPipeline.cs @@ -8,6 +8,7 @@ using System.Threading.Tasks; using Microsoft.Extensions.Logging; using Microsoft.KernelMemory.Diagnostics; +using Microsoft.KernelMemory.Pipeline; using Microsoft.KernelMemory.Pipeline.Queue; using RabbitMQ.Client; using RabbitMQ.Client.Events; @@ -68,7 +69,7 @@ public RabbitMQPipeline(RabbitMQConfig config, ILoggerFactory? loggerFactory = n } /// - /// About dead letters, see https://www.rabbitmq.com/docs/dlx + /// About posion queue and dead letters, see https://www.rabbitmq.com/docs/dlx public Task ConnectToQueueAsync(string queueName, QueueOptions options = default, CancellationToken cancellationToken = default) { ArgumentNullExceptionEx.ThrowIfNullOrWhiteSpace(queueName, nameof(queueName), "The queue name is empty"); @@ -82,7 +83,7 @@ public Task ConnectToQueueAsync(string queueName, QueueOptions options = ArgumentExceptionEx.ThrowIf((Encoding.UTF8.GetByteCount(poisonExchangeName) > 255), nameof(poisonExchangeName), $"The exchange name '{poisonExchangeName}' is too long, max 255 UTF8 bytes allowed, try using a shorter queue name"); ArgumentExceptionEx.ThrowIf((Encoding.UTF8.GetByteCount(poisonQueueName) > 255), nameof(poisonQueueName), - $"The dead letter queue name '{poisonQueueName}' is too long, max 255 UTF8 bytes allowed, try using a shorter queue name"); + $"The poison queue name '{poisonQueueName}' is too long, max 255 UTF8 bytes allowed, try using a shorter queue name"); if (!string.IsNullOrEmpty(this._queueName)) { @@ -173,7 +174,7 @@ public Task EnqueueAsync(string message, CancellationToken cancellationToken = d } /// - public void OnDequeue(Func> processMessageAction) + public void OnDequeue(Func> processMessageAction) { this._consumer.Received += async (object sender, BasicDeliverEventArgs args) => { @@ -192,33 +193,47 @@ public void OnDequeue(Func> processMessageAction) byte[] body = args.Body.ToArray(); string message = Encoding.UTF8.GetString(body); - bool success = await processMessageAction.Invoke(message).ConfigureAwait(false); - if (success) + var resultType = await processMessageAction.Invoke(message).ConfigureAwait(false); + switch (resultType) { - this._log.LogTrace("Message '{0}' successfully processed, deleting message", args.BasicProperties?.MessageId); - this._channel.BasicAck(args.DeliveryTag, multiple: false); - } - else - { - if (attemptNumber < this._maxAttempts) - { - this._log.LogWarning("Message '{0}' failed to process (attempt {1} of {2}), putting message back in the queue", - args.BasicProperties?.MessageId, attemptNumber, this._maxAttempts); - if (this._delayBeforeRetryingMsecs > 0) + case ResultType.Success: + this._log.LogTrace("Message '{0}' successfully processed, deleting message", args.BasicProperties?.MessageId); + this._channel.BasicAck(args.DeliveryTag, multiple: false); + break; + + case ResultType.RetriableError: + if (attemptNumber < this._maxAttempts) { - await Task.Delay(TimeSpan.FromMilliseconds(this._delayBeforeRetryingMsecs)).ConfigureAwait(false); + this._log.LogWarning("Message '{0}' failed to process (attempt {1} of {2}), putting message back in the queue", + args.BasicProperties?.MessageId, attemptNumber, this._maxAttempts); + if (this._delayBeforeRetryingMsecs > 0) + { + await Task.Delay(TimeSpan.FromMilliseconds(this._delayBeforeRetryingMsecs)).ConfigureAwait(false); + } } - } - else - { - this._log.LogError("Message '{0}' failed to process (attempt {1} of {2}), moving message to dead letter queue", - args.BasicProperties?.MessageId, attemptNumber, this._maxAttempts); - } + else + { + this._log.LogError("Message '{0}' failed to process (attempt {1} of {2}), moving message to poison queue", + args.BasicProperties?.MessageId, attemptNumber, this._maxAttempts); + } + + this._channel.BasicNack(args.DeliveryTag, multiple: false, requeue: true); + break; - // Note: if "requeue == false" the message would be moved to the dead letter exchange - this._channel.BasicNack(args.DeliveryTag, multiple: false, requeue: true); + case ResultType.NonRetriableError: + this._log.LogError("Message '{0}' failed to process due to a non-recoverable error, moving to poison queue", args.BasicProperties?.MessageId); + this._channel.BasicNack(args.DeliveryTag, multiple: false, requeue: false); + break; + + default: + throw new ArgumentOutOfRangeException($"Unknown {resultType:G} result"); } } + catch (NonRetriableException e) + { + this._log.LogError(e, "Message '{0}' failed to process due to a non-recoverable error, moving to poison queue", args.BasicProperties?.MessageId); + this._channel.BasicNack(args.DeliveryTag, multiple: false, requeue: false); + } #pragma warning disable CA1031 // Must catch all to handle queue properly catch (Exception e) { @@ -243,7 +258,6 @@ public void OnDequeue(Func> processMessageAction) } // TODO: verify and document what happens if this fails. RabbitMQ should automatically unlock messages. - // Note: if "requeue == false" the message would be moved to the dead letter exchange this._channel.BasicNack(args.DeliveryTag, multiple: false, requeue: true); } #pragma warning restore CA1031 diff --git a/service/Abstractions/Pipeline/IPipelineStepHandler.cs b/service/Abstractions/Pipeline/IPipelineStepHandler.cs index ce5296321..f785c4323 100644 --- a/service/Abstractions/Pipeline/IPipelineStepHandler.cs +++ b/service/Abstractions/Pipeline/IPipelineStepHandler.cs @@ -20,5 +20,5 @@ public interface IPipelineStepHandler /// Pipeline status /// Async task cancellation token /// Whether the pipeline step has been processed successfully, and the new pipeline status to use moving forward - Task<(bool success, DataPipeline updatedPipeline)> InvokeAsync(DataPipeline pipeline, CancellationToken cancellationToken = default); + Task<(ResultType resultType, DataPipeline updatedPipeline)> InvokeAsync(DataPipeline pipeline, CancellationToken cancellationToken = default); } diff --git a/service/Abstractions/Pipeline/NonRetriableException.cs b/service/Abstractions/Pipeline/NonRetriableException.cs new file mode 100644 index 000000000..4905f89af --- /dev/null +++ b/service/Abstractions/Pipeline/NonRetriableException.cs @@ -0,0 +1,20 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using Microsoft.KernelMemory.Pipeline; + +#pragma warning disable IDE0130 // reduce number of "using" statements +// ReSharper disable once CheckNamespace - reduce number of "using" statements +namespace Microsoft.KernelMemory; + +public class NonRetriableException : OrchestrationException +{ + /// + public NonRetriableException() { } + + /// + public NonRetriableException(string message) : base(message) { } + + /// + public NonRetriableException(string message, Exception? innerException) : base(message, innerException) { } +} diff --git a/service/Abstractions/Pipeline/Queue/IQueue.cs b/service/Abstractions/Pipeline/Queue/IQueue.cs index b4bc7c671..ed8f05824 100644 --- a/service/Abstractions/Pipeline/Queue/IQueue.cs +++ b/service/Abstractions/Pipeline/Queue/IQueue.cs @@ -28,5 +28,5 @@ public interface IQueue : IDisposable /// Define the logic to execute when a new message is in the queue. /// /// Async action to execute - void OnDequeue(Func> processMessageAction); + void OnDequeue(Func> processMessageAction); } diff --git a/service/Abstractions/Pipeline/ResultType.cs b/service/Abstractions/Pipeline/ResultType.cs new file mode 100644 index 000000000..3207e0578 --- /dev/null +++ b/service/Abstractions/Pipeline/ResultType.cs @@ -0,0 +1,10 @@ +// Copyright (c) Microsoft. All rights reserved. + +namespace Microsoft.KernelMemory.Pipeline; + +public enum ResultType +{ + Success = 0, + RetriableError = 1, + NonRetriableError = 2, +} diff --git a/service/Core/Handlers/DeleteDocumentHandler.cs b/service/Core/Handlers/DeleteDocumentHandler.cs index 6324ed447..8e438511a 100644 --- a/service/Core/Handlers/DeleteDocumentHandler.cs +++ b/service/Core/Handlers/DeleteDocumentHandler.cs @@ -34,7 +34,7 @@ public DeleteDocumentHandler( } /// - public async Task<(bool success, DataPipeline updatedPipeline)> InvokeAsync( + public async Task<(ResultType resultType, DataPipeline updatedPipeline)> InvokeAsync( DataPipeline pipeline, CancellationToken cancellationToken = default) { this._log.LogDebug("Deleting document, pipeline '{0}/{1}'", pipeline.Index, pipeline.DocumentId); @@ -60,6 +60,6 @@ await this._documentStorage.EmptyDocumentDirectoryAsync( documentId: pipeline.DocumentId, cancellationToken).ConfigureAwait(false); - return (true, pipeline); + return (ResultType.Success, pipeline); } } diff --git a/service/Core/Handlers/DeleteGeneratedFilesHandler.cs b/service/Core/Handlers/DeleteGeneratedFilesHandler.cs index 006eb58df..5792aad17 100644 --- a/service/Core/Handlers/DeleteGeneratedFilesHandler.cs +++ b/service/Core/Handlers/DeleteGeneratedFilesHandler.cs @@ -29,7 +29,7 @@ public DeleteGeneratedFilesHandler( } /// - public async Task<(bool success, DataPipeline updatedPipeline)> InvokeAsync( + public async Task<(ResultType resultType, DataPipeline updatedPipeline)> InvokeAsync( DataPipeline pipeline, CancellationToken cancellationToken = default) { this._log.LogDebug("Deleting generated files, pipeline '{0}/{1}'", pipeline.Index, pipeline.DocumentId); @@ -40,6 +40,6 @@ await this._documentStorage.EmptyDocumentDirectoryAsync( documentId: pipeline.DocumentId, cancellationToken).ConfigureAwait(false); - return (true, pipeline); + return (ResultType.Success, pipeline); } } diff --git a/service/Core/Handlers/DeleteIndexHandler.cs b/service/Core/Handlers/DeleteIndexHandler.cs index a2897d489..a18dbab60 100644 --- a/service/Core/Handlers/DeleteIndexHandler.cs +++ b/service/Core/Handlers/DeleteIndexHandler.cs @@ -34,7 +34,7 @@ public DeleteIndexHandler( } /// - public async Task<(bool success, DataPipeline updatedPipeline)> InvokeAsync( + public async Task<(ResultType resultType, DataPipeline updatedPipeline)> InvokeAsync( DataPipeline pipeline, CancellationToken cancellationToken = default) { this._log.LogDebug("Deleting index, pipeline '{0}/{1}'", pipeline.Index, pipeline.DocumentId); @@ -50,6 +50,6 @@ await this._documentStorage.DeleteIndexDirectoryAsync( index: pipeline.Index, cancellationToken).ConfigureAwait(false); - return (true, pipeline); + return (ResultType.Success, pipeline); } } diff --git a/service/Core/Handlers/GenerateEmbeddingsHandler.cs b/service/Core/Handlers/GenerateEmbeddingsHandler.cs index 30f725b5b..0bdb1a970 100644 --- a/service/Core/Handlers/GenerateEmbeddingsHandler.cs +++ b/service/Core/Handlers/GenerateEmbeddingsHandler.cs @@ -58,13 +58,13 @@ public GenerateEmbeddingsHandler( } /// - public async Task<(bool success, DataPipeline updatedPipeline)> InvokeAsync( + public async Task<(ResultType resultType, DataPipeline updatedPipeline)> InvokeAsync( DataPipeline pipeline, CancellationToken cancellationToken = default) { if (!this._embeddingGenerationEnabled) { this._log.LogTrace("Embedding generation is disabled, skipping - pipeline '{0}/{1}'", pipeline.Index, pipeline.DocumentId); - return (true, pipeline); + return (ResultType.Success, pipeline); } foreach (ITextEmbeddingGenerator generator in this._embeddingGenerators) @@ -83,7 +83,7 @@ public GenerateEmbeddingsHandler( } } - return (true, pipeline); + return (ResultType.Success, pipeline); } protected override IPipelineStepHandler ActualInstance => this; diff --git a/service/Core/Handlers/GenerateEmbeddingsParallelHandler.cs b/service/Core/Handlers/GenerateEmbeddingsParallelHandler.cs index 360c82874..8150a126a 100644 --- a/service/Core/Handlers/GenerateEmbeddingsParallelHandler.cs +++ b/service/Core/Handlers/GenerateEmbeddingsParallelHandler.cs @@ -58,13 +58,13 @@ public GenerateEmbeddingsParallelHandler( } /// - public async Task<(bool success, DataPipeline updatedPipeline)> InvokeAsync( + public async Task<(ResultType resultType, DataPipeline updatedPipeline)> InvokeAsync( DataPipeline pipeline, CancellationToken cancellationToken = default) { if (!this._embeddingGenerationEnabled) { this._log.LogTrace("Embedding generation is disabled, skipping - pipeline '{0}/{1}'", pipeline.Index, pipeline.DocumentId); - return (true, pipeline); + return (ResultType.Success, pipeline); } foreach (ITextEmbeddingGenerator generator in this._embeddingGenerators) @@ -83,7 +83,7 @@ public GenerateEmbeddingsParallelHandler( } } - return (true, pipeline); + return (ResultType.Success, pipeline); } protected override IPipelineStepHandler ActualInstance => this; diff --git a/service/Core/Handlers/SaveRecordsHandler.cs b/service/Core/Handlers/SaveRecordsHandler.cs index 50baf535d..6987d6bc5 100644 --- a/service/Core/Handlers/SaveRecordsHandler.cs +++ b/service/Core/Handlers/SaveRecordsHandler.cs @@ -103,7 +103,7 @@ public SaveRecordsHandler( } /// - public async Task<(bool success, DataPipeline updatedPipeline)> InvokeAsync( + public async Task<(ResultType resultType, DataPipeline updatedPipeline)> InvokeAsync( DataPipeline pipeline, CancellationToken cancellationToken = default) { this._log.LogDebug("Saving memory records, pipeline '{0}/{1}'", pipeline.Index, pipeline.DocumentId); @@ -241,7 +241,7 @@ record = PrepareRecord( this._log.LogWarning("Pipeline '{0}/{1}': step {2}: no records found, cannot save, moving to next pipeline step.", pipeline.Index, pipeline.DocumentId, this.StepName); } - return (true, pipeline); + return (ResultType.Success, pipeline); } private static IEnumerable GetListOfEmbeddingFiles(DataPipeline pipeline) diff --git a/service/Core/Handlers/SummarizationHandler.cs b/service/Core/Handlers/SummarizationHandler.cs index 81a08c65d..bd71163a9 100644 --- a/service/Core/Handlers/SummarizationHandler.cs +++ b/service/Core/Handlers/SummarizationHandler.cs @@ -54,7 +54,7 @@ public SummarizationHandler( } /// - public async Task<(bool success, DataPipeline updatedPipeline)> InvokeAsync( + public async Task<(ResultType resultType, DataPipeline updatedPipeline)> InvokeAsync( DataPipeline pipeline, CancellationToken cancellationToken = default) { this._log.LogDebug("Generating summary, pipeline '{0}/{1}'", pipeline.Index, pipeline.DocumentId); @@ -125,7 +125,7 @@ public SummarizationHandler( } } - return (true, pipeline); + return (ResultType.Success, pipeline); } private async Task<(string summary, bool skip)> SummarizeAsync(string content, IContext context) diff --git a/service/Core/Handlers/SummarizationParallelHandler.cs b/service/Core/Handlers/SummarizationParallelHandler.cs index 19a685e3d..2370ea570 100644 --- a/service/Core/Handlers/SummarizationParallelHandler.cs +++ b/service/Core/Handlers/SummarizationParallelHandler.cs @@ -53,7 +53,7 @@ public SummarizationParallelHandler( } /// - public async Task<(bool success, DataPipeline updatedPipeline)> InvokeAsync( + public async Task<(ResultType resultType, DataPipeline updatedPipeline)> InvokeAsync( DataPipeline pipeline, CancellationToken cancellationToken = default) { this._log.LogDebug("Generating summary, pipeline '{0}/{1}'", pipeline.Index, pipeline.DocumentId); @@ -133,7 +133,7 @@ await Parallel.ForEachAsync(uploadedFile.GeneratedFiles, options, async (generat } } - return (true, pipeline); + return (ResultType.Success, pipeline); } private async Task<(string summary, bool skip)> SummarizeAsync(string content) diff --git a/service/Core/Handlers/TextExtractionHandler.cs b/service/Core/Handlers/TextExtractionHandler.cs index c68d4f32d..6b5d34a89 100644 --- a/service/Core/Handlers/TextExtractionHandler.cs +++ b/service/Core/Handlers/TextExtractionHandler.cs @@ -54,7 +54,7 @@ public TextExtractionHandler( } /// - public async Task<(bool success, DataPipeline updatedPipeline)> InvokeAsync( + public async Task<(ResultType resultType, DataPipeline updatedPipeline)> InvokeAsync( DataPipeline pipeline, CancellationToken cancellationToken = default) { this._log.LogDebug("Extracting text, pipeline '{0}/{1}'", pipeline.Index, pipeline.DocumentId); @@ -135,7 +135,7 @@ public TextExtractionHandler( uploadedFile.MarkProcessedBy(this); } - return (true, pipeline); + return (ResultType.Success, pipeline); } public void Dispose() diff --git a/service/Core/Handlers/TextPartitioningHandler.cs b/service/Core/Handlers/TextPartitioningHandler.cs index 9ac57da4a..905ecea3f 100644 --- a/service/Core/Handlers/TextPartitioningHandler.cs +++ b/service/Core/Handlers/TextPartitioningHandler.cs @@ -67,7 +67,7 @@ public TextPartitioningHandler( } /// - public async Task<(bool success, DataPipeline updatedPipeline)> InvokeAsync( + public async Task<(ResultType resultType, DataPipeline updatedPipeline)> InvokeAsync( DataPipeline pipeline, CancellationToken cancellationToken = default) { this._log.LogDebug("Partitioning text, pipeline '{0}/{1}'", pipeline.Index, pipeline.DocumentId); @@ -75,7 +75,7 @@ public TextPartitioningHandler( if (pipeline.Files.Count == 0) { this._log.LogWarning("Pipeline '{0}/{1}': there are no files to process, moving to next pipeline step.", pipeline.Index, pipeline.DocumentId); - return (true, pipeline); + return (ResultType.Success, pipeline); } var context = pipeline.GetContext(); @@ -197,7 +197,7 @@ public TextPartitioningHandler( } } - return (true, pipeline); + return (ResultType.Success, pipeline); } #pragma warning disable CA2254 // the msg is always used diff --git a/service/Core/Pipeline/BaseOrchestrator.cs b/service/Core/Pipeline/BaseOrchestrator.cs index 84df7ec36..8311c0406 100644 --- a/service/Core/Pipeline/BaseOrchestrator.cs +++ b/service/Core/Pipeline/BaseOrchestrator.cs @@ -123,7 +123,7 @@ public async Task ImportDocumentAsync( } catch (Exception e) { - this.Log.LogError(e, "Pipeline start failed"); + this.Log.LogError(e, "Pipeline start failed."); throw; } } @@ -327,7 +327,7 @@ protected async Task CleanUpAfterCompletionAsync(DataPipeline pipeline, Cancella } catch (Exception e) { - this.Log.LogError(e, "Error while trying to delete the document directory"); + this.Log.LogError(e, "Error while trying to delete the document directory."); } } @@ -339,7 +339,7 @@ protected async Task CleanUpAfterCompletionAsync(DataPipeline pipeline, Cancella } catch (Exception e) { - this.Log.LogError(e, "Error while trying to delete the index directory"); + this.Log.LogError(e, "Error while trying to delete the index directory."); } } #pragma warning restore CA1031 diff --git a/service/Core/Pipeline/DistributedPipelineOrchestrator.cs b/service/Core/Pipeline/DistributedPipelineOrchestrator.cs index 8b35d6d58..cc916dd77 100644 --- a/service/Core/Pipeline/DistributedPipelineOrchestrator.cs +++ b/service/Core/Pipeline/DistributedPipelineOrchestrator.cs @@ -83,12 +83,6 @@ public override async Task AddHandlerAsync( throw new ArgumentException($"There is already a handler for step '{handler.StepName}'"); } - // When returning False a message is put back in the queue and processed again - const bool Retry = false; - - // When returning True a message is removed from the queue and deleted - const bool Complete = true; - // Create a new queue client and start listening for messages this._queues[handler.StepName] = this._queueClientFactory.Build(); this._queues[handler.StepName].OnDequeue(async msg => @@ -99,7 +93,7 @@ public override async Task AddHandlerAsync( if (pipelinePointer == null) { this.Log.LogError("Pipeline pointer deserialization failed, queue `{0}`. Message discarded.", handler.StepName); - return Complete; + return ResultType.NonRetriableError; } DataPipeline? pipeline; @@ -127,18 +121,18 @@ public override async Task AddHandlerAsync( } this.Log.LogError("Pipeline `{0}/{1}` not found, cancelling step `{2}`", pipelinePointer.Index, pipelinePointer.DocumentId, handler.StepName); - return Complete; + return ResultType.NonRetriableError; } catch (InvalidPipelineDataException) { this.Log.LogError("Pipeline `{0}/{1}` state load failed, invalid state, queue `{2}`", pipelinePointer.Index, pipelinePointer.DocumentId, handler.StepName); - return Retry; + return ResultType.RetriableError; } if (pipeline == null) { this.Log.LogError("Pipeline `{0}/{1}` state load failed, the state is null, queue `{2}`", pipelinePointer.Index, pipelinePointer.DocumentId, handler.StepName); - return Retry; + return ResultType.RetriableError; } if (pipelinePointer.ExecutionId != pipeline.ExecutionId) @@ -147,7 +141,7 @@ public override async Task AddHandlerAsync( "Document `{0}/{1}` has been updated without waiting for the previous pipeline execution `{2}` to complete (current execution: `{3}`). " + "Step `{4}` and any consecutive steps from the previous execution have been cancelled.", pipelinePointer.Index, pipelinePointer.DocumentId, pipelinePointer.ExecutionId, pipeline.ExecutionId, handler.StepName); - return Complete; + return ResultType.Success; } var currentStepName = pipeline.RemainingSteps.First(); @@ -207,7 +201,7 @@ public override async Task RunPipelineAsync(DataPipeline pipeline, CancellationT #region private - private async Task RunPipelineStepAsync( + private async Task RunPipelineStepAsync( DataPipeline pipeline, IPipelineStepHandler handler, CancellationToken cancellationToken) @@ -216,31 +210,37 @@ private async Task RunPipelineStepAsync( if (pipeline.Complete) { this.Log.LogInformation("Pipeline '{0}/{1}' complete", pipeline.Index, pipeline.DocumentId); - // Note: returning True, the message is removed from the queue - return true; + return ResultType.Success; } string currentStepName = pipeline.RemainingSteps.First(); // Execute the business logic - exceptions are automatically handled by IQueue - (bool success, DataPipeline updatedPipeline) = await handler.InvokeAsync(pipeline, cancellationToken).ConfigureAwait(false); - if (success) + (ResultType resultType, DataPipeline updatedPipeline) = await handler.InvokeAsync(pipeline, cancellationToken).ConfigureAwait(false); + switch (resultType) { - pipeline = updatedPipeline; - pipeline.LastUpdate = DateTimeOffset.UtcNow; + case ResultType.Success: + pipeline = updatedPipeline; + pipeline.LastUpdate = DateTimeOffset.UtcNow; - this.Log.LogInformation("Handler {0} processed pipeline {1} successfully", currentStepName, pipeline.DocumentId); - pipeline.MoveToNextStep(); - await this.MoveForwardAsync(pipeline, cancellationToken).ConfigureAwait(false); - } - else - { - this.Log.LogError("Handler {0} failed to process pipeline {1}", currentStepName, pipeline.DocumentId); + this.Log.LogInformation("Handler {0} processed pipeline {1} successfully", currentStepName, pipeline.DocumentId); + pipeline.MoveToNextStep(); + await this.MoveForwardAsync(pipeline, cancellationToken).ConfigureAwait(false); + break; + + case ResultType.RetriableError: + this.Log.LogError("Handler {0} failed to process pipeline {1}", currentStepName, pipeline.DocumentId); + break; + + case ResultType.NonRetriableError: + this.Log.LogError("Handler {0} failed to process pipeline {1} due to an unrecoverable error", currentStepName, pipeline.DocumentId); + break; + + default: + throw new ArgumentOutOfRangeException($"Unknown {resultType:G} result type"); } - // Note: returning True, the message is removed from the queue - // Note: returning False, the message is put back in the queue and processed again - return success; + return resultType; } private async Task MoveForwardAsync(DataPipeline pipeline, CancellationToken cancellationToken = default) diff --git a/service/Core/Pipeline/InProcessPipelineOrchestrator.cs b/service/Core/Pipeline/InProcessPipelineOrchestrator.cs index 66ba8c473..0dde21b2e 100644 --- a/service/Core/Pipeline/InProcessPipelineOrchestrator.cs +++ b/service/Core/Pipeline/InProcessPipelineOrchestrator.cs @@ -171,21 +171,30 @@ public override async Task RunPipelineAsync(DataPipeline pipeline, CancellationT } // Run handler - (bool success, DataPipeline updatedPipeline) = await stepHandler + (ResultType resultType, DataPipeline updatedPipeline) = await stepHandler .InvokeAsync(pipeline, this.CancellationTokenSource.Token) .ConfigureAwait(false); - if (success) - { - pipeline = updatedPipeline; - pipeline.LastUpdate = DateTimeOffset.UtcNow; - this.Log.LogInformation("Handler '{0}' processed pipeline '{1}/{2}' successfully", currentStepName, pipeline.Index, pipeline.DocumentId); - pipeline.MoveToNextStep(); - await this.UpdatePipelineStatusAsync(pipeline, cancellationToken).ConfigureAwait(false); - } - else + + switch (resultType) { - this.Log.LogError("Handler '{0}' failed to process pipeline '{1}/{2}'", currentStepName, pipeline.Index, pipeline.DocumentId); - throw new OrchestrationException($"Pipeline error, step {currentStepName} failed"); + case ResultType.Success: + pipeline = updatedPipeline; + pipeline.LastUpdate = DateTimeOffset.UtcNow; + this.Log.LogInformation("Handler '{0}' processed pipeline '{1}/{2}' successfully", currentStepName, pipeline.Index, pipeline.DocumentId); + pipeline.MoveToNextStep(); + await this.UpdatePipelineStatusAsync(pipeline, cancellationToken).ConfigureAwait(false); + break; + + case ResultType.RetriableError: + this.Log.LogError("Handler '{0}' failed to process pipeline '{1}/{2}'", currentStepName, pipeline.Index, pipeline.DocumentId); + throw new OrchestrationException($"Pipeline error, step {currentStepName} failed"); + + case ResultType.NonRetriableError: + this.Log.LogError("Handler '{0}' failed to process pipeline '{1}/{2}' due to an unrecoverable error", currentStepName, pipeline.Index, pipeline.DocumentId); + throw new NonRetriableException($"Unrecoverable pipeline error, step {currentStepName} failed and cannot be retried"); + + default: + throw new ArgumentOutOfRangeException($"Unknown {resultType:G} result type"); } } diff --git a/service/Core/Pipeline/Queue/DevTools/SimpleQueues.cs b/service/Core/Pipeline/Queue/DevTools/SimpleQueues.cs index 97c5aeeae..f7331c0a3 100644 --- a/service/Core/Pipeline/Queue/DevTools/SimpleQueues.cs +++ b/service/Core/Pipeline/Queue/DevTools/SimpleQueues.cs @@ -177,7 +177,7 @@ await this.StoreMessageAsync( /// /// about the logic handling dequeued messages. - public void OnDequeue(Func> processMessageAction) + public void OnDequeue(Func> processMessageAction) { this._log.LogInformation("Queue {0}: subscribing...", this._queueName); this.Received += async (sender, args) => @@ -193,39 +193,55 @@ public void OnDequeue(Func> processMessageAction) this._log.LogInformation("Queue {0}: message {0} received", this._queueName, message.Id); // Process message with the logic provided by the orchestrator - bool success = await processMessageAction.Invoke(message.Content).ConfigureAwait(false); - if (success) + var resultType = await processMessageAction.Invoke(message.Content).ConfigureAwait(false); + switch (resultType) { - this._log.LogTrace("Message '{0}' successfully processed, deleting message", message.Id); - await this.DeleteMessageAsync(message.Id, this._cancellation.Token).ConfigureAwait(false); - } - else - { - message.LastError = "Message handler returned false"; - if (message.DequeueCount == this._maxAttempts) - { - this._log.LogError("Message '{0}' processing failed to process, max attempts reached, moving to dead letter queue. Message content: {1}", message.Id, message.Content); + case ResultType.Success: + this._log.LogTrace("Message '{0}' successfully processed, deleting message", message.Id); + await this.DeleteMessageAsync(message.Id, this._cancellation.Token).ConfigureAwait(false); + break; + + case ResultType.RetriableError: + message.LastError = "Message handler returned false"; + if (message.DequeueCount == this._maxAttempts) + { + this._log.LogError("Message '{0}' processing failed to process, max attempts reached, moving to poison queue. Message content: {1}", message.Id, message.Content); + poison = true; + } + else + { + this._log.LogWarning("Message '{0}' failed to process, putting message back in the queue. Message content: {1}", message.Id, message.Content); + retry = true; + } + + break; + + case ResultType.NonRetriableError: + this._log.LogError("Message '{0}' failed to process due to a non-recoverable error, moving to poison queue", message.Id); poison = true; - } - else - { - this._log.LogWarning("Message '{0}' failed to process, putting message back in the queue. Message content: {1}", message.Id, message.Content); - retry = true; - } + break; + + default: + throw new ArgumentOutOfRangeException($"Unknown {resultType:G} result"); } } + catch (NonRetriableException e) + { + this._log.LogError(e, "Message '{0}' failed to process due to a non-recoverable error, moving to poison queue.", message.Id); + poison = true; + } // Note: must catch all also because using a void event handler catch (Exception e) { message.LastError = $"{e.GetType().FullName}: {e.Message}"; if (message.DequeueCount == this._maxAttempts) { - this._log.LogError(e, "Message '{0}' processing failed with exception, max attempts reached, moving to dead letter queue. Message content: {1}", message.Id, message.Content); + this._log.LogError(e, "Message '{0}' processing failed with exception, max attempts reached, moving to poison queue. Message content: {1}.", message.Id, message.Content); poison = true; } else { - this._log.LogWarning(e, "Message '{0}' processing failed with exception, putting message back in the queue. Message content: {1}", message.Id, message.Content); + this._log.LogWarning(e, "Message '{0}' processing failed with exception, putting message back in the queue. Message content: {1}.", message.Id, message.Content); retry = true; } } @@ -260,8 +276,9 @@ private void PopulateQueue(object? sender, ElapsedEventArgs elapsedEventArgs) await s_lock.WaitAsync(this._cancellation.Token).ConfigureAwait(false); // Loop through all messages on storage - this._log.LogTrace("Queue {0}: polling...", this._queueName); var messagesOnStorage = (await this._fileSystem.GetAllFileNamesAsync(this._queueName, "", this._cancellation.Token).ConfigureAwait(false)).ToList(); + if (messagesOnStorage.Count == 0) { return; } + this._log.LogTrace("Queue {0}: {1} messages on storage, {2} ready to dispatch, max batch size {3}", this._queueName, messagesOnStorage.Count, this._queue.Count, this._config.FetchBatchSize); @@ -313,12 +330,12 @@ private void PopulateQueue(object? sender, ElapsedEventArgs elapsedEventArgs) } catch (DirectoryNotFoundException e) { - this._log.LogError(e, "Directory missing, recreating"); + this._log.LogError(e, "Directory missing, recreating."); await this.CreateDirectoriesAsync(this._cancellation.Token).ConfigureAwait(false); } catch (Exception e) { - this._log.LogError(e, "Queue {0}: Unexpected error while polling", this._queueName); + this._log.LogError(e, "Queue {0}: Unexpected error while polling.", this._queueName); } finally { diff --git a/service/Service.AspNetCore/WebAPIEndpoints.cs b/service/Service.AspNetCore/WebAPIEndpoints.cs index 49211441a..cef2388b8 100644 --- a/service/Service.AspNetCore/WebAPIEndpoints.cs +++ b/service/Service.AspNetCore/WebAPIEndpoints.cs @@ -322,7 +322,8 @@ async Task ( .Produces(StatusCodes.Status400BadRequest) .Produces(StatusCodes.Status401Unauthorized) .Produces(StatusCodes.Status403Forbidden) - .Produces(StatusCodes.Status404NotFound); + .Produces(StatusCodes.Status404NotFound) + .Produces(StatusCodes.Status413PayloadTooLarge); return route; } diff --git a/service/Service/Program.cs b/service/Service/Program.cs index 08a1db6e2..f0db0bed2 100644 --- a/service/Service/Program.cs +++ b/service/Service/Program.cs @@ -2,6 +2,8 @@ using System; using System.Collections.Generic; +using System.Globalization; +using System.IO; using System.Linq; using Microsoft.AspNetCore.Builder; using Microsoft.AspNetCore.Http; @@ -182,8 +184,15 @@ public static void Main(string[] args) Console.WriteLine("* Memory type : " + memoryType); Console.WriteLine("* Pipeline handlers : " + $"{syncHandlersCount} synchronous / {asyncHandlersCount} asynchronous"); Console.WriteLine("* Web service : " + (config.Service.RunWebService ? "Enabled" : "Disabled")); - Console.WriteLine("* Web service auth : " + (config.ServiceAuthorization.Enabled ? "Enabled" : "Disabled")); - Console.WriteLine("* OpenAPI swagger : " + (config.Service.OpenApiEnabled ? "Enabled" : "Disabled")); + + if (config.Service.RunWebService) + { + const double AspnetDefaultMaxUploadSize = 30000000d / 1024 / 1024; + Console.WriteLine("* Web service auth : " + (config.ServiceAuthorization.Enabled ? "Enabled" : "Disabled")); + Console.WriteLine("* Max HTTP req size : " + (config.Service.MaxUploadSizeMb ?? AspnetDefaultMaxUploadSize).ToString("0.#", CultureInfo.CurrentCulture) + " Mb"); + Console.WriteLine("* OpenAPI swagger : " + (config.Service.OpenApiEnabled ? "Enabled" : "Disabled")); + } + Console.WriteLine("* Memory Db : " + app.Services.GetService()?.GetType().FullName); Console.WriteLine("* Document storage : " + app.Services.GetService()?.GetType().FullName); Console.WriteLine("* Embedding generation: " + app.Services.GetService()?.GetType().FullName); @@ -201,7 +210,15 @@ public static void Main(string[] args) config.Service.RunHandlers); // Start web service and handler services - app.Run(); + try + { + app.Run(); + } + catch (IOException e) + { + Console.WriteLine($"I/O error: {e.Message}"); + Environment.Exit(-1); + } } /// From c57657bb0b8196768768fd00ed9ebe5baf0d3028 Mon Sep 17 00:00:00 2001 From: Devis Lucato Date: Sun, 17 Nov 2024 20:25:36 -0800 Subject: [PATCH 02/10] Support non retriable exceptions in AI clients --- .../AzureOpenAITextEmbeddingGenerator.cs | 21 ++++++++++++++++--- .../AzureOpenAI/AzureOpenAITextGenerator.cs | 14 +++++++++++-- .../OpenAI/OpenAITextEmbeddingGenerator.cs | 21 ++++++++++++++++--- .../OpenAI/OpenAI/OpenAITextGenerator.cs | 15 +++++++++++-- .../RabbitMQ/RabbitMQ/RabbitMQPipeline.cs | 2 +- .../Pipeline/Queue/DevTools/SimpleQueues.cs | 1 + 6 files changed, 63 insertions(+), 11 deletions(-) diff --git a/extensions/AzureOpenAI/AzureOpenAI/AzureOpenAITextEmbeddingGenerator.cs b/extensions/AzureOpenAI/AzureOpenAI/AzureOpenAITextEmbeddingGenerator.cs index 4a096c4b7..9a5efbe02 100644 --- a/extensions/AzureOpenAI/AzureOpenAI/AzureOpenAITextEmbeddingGenerator.cs +++ b/extensions/AzureOpenAI/AzureOpenAI/AzureOpenAITextEmbeddingGenerator.cs @@ -12,6 +12,7 @@ using Microsoft.KernelMemory.AI.AzureOpenAI.Internals; using Microsoft.KernelMemory.AI.OpenAI; using Microsoft.KernelMemory.Diagnostics; +using Microsoft.SemanticKernel; using Microsoft.SemanticKernel.AI.Embeddings; using Microsoft.SemanticKernel.Connectors.AzureOpenAI; @@ -121,7 +122,14 @@ public IReadOnlyList GetTokens(string text) public Task GenerateEmbeddingAsync(string text, CancellationToken cancellationToken = default) { this._log.LogTrace("Generating embedding"); - return this._client.GenerateEmbeddingAsync(text, cancellationToken); + try + { + return this._client.GenerateEmbeddingAsync(text, cancellationToken); + } + catch (HttpOperationException e) when (e.StatusCode.HasValue && (int)e.StatusCode >= 400 && (int)e.StatusCode < 500) + { + throw new NonRetriableException(e.Message, e); + } } /// @@ -129,7 +137,14 @@ public async Task GenerateEmbeddingBatchAsync(IEnumerable t { var list = textList.ToList(); this._log.LogTrace("Generating embeddings, batch size '{0}'", list.Count); - IList> embeddings = await this._client.GenerateEmbeddingsAsync(list, cancellationToken: cancellationToken).ConfigureAwait(false); - return embeddings.Select(e => new Embedding(e)).ToArray(); + try + { + IList> embeddings = await this._client.GenerateEmbeddingsAsync(list, cancellationToken: cancellationToken).ConfigureAwait(false); + return embeddings.Select(e => new Embedding(e)).ToArray(); + } + catch (HttpOperationException e) when (e.StatusCode.HasValue && (int)e.StatusCode >= 400 && (int)e.StatusCode < 500) + { + throw new NonRetriableException(e.Message, e); + } } } diff --git a/extensions/AzureOpenAI/AzureOpenAI/AzureOpenAITextGenerator.cs b/extensions/AzureOpenAI/AzureOpenAI/AzureOpenAITextGenerator.cs index cb6126901..3522d3eec 100644 --- a/extensions/AzureOpenAI/AzureOpenAI/AzureOpenAITextGenerator.cs +++ b/extensions/AzureOpenAI/AzureOpenAI/AzureOpenAITextGenerator.cs @@ -5,6 +5,7 @@ using System.Net.Http; using System.Runtime.CompilerServices; using System.Threading; +using System.Threading.Tasks; using Azure.AI.OpenAI; using Microsoft.Extensions.Logging; using Microsoft.KernelMemory.AI.AzureOpenAI.Internals; @@ -140,8 +141,17 @@ public async IAsyncEnumerable GenerateTextAsync( } this._log.LogTrace("Sending chat message generation request"); - IAsyncEnumerable result = this._client.GetStreamingTextContentsAsync(prompt, skOptions, cancellationToken: cancellationToken); - await foreach (StreamingTextContent x in result) + IAsyncEnumerable result; + try + { + result = this._client.GetStreamingTextContentsAsync(prompt, skOptions, cancellationToken: cancellationToken); + } + catch (HttpOperationException e) when (e.StatusCode.HasValue && (int)e.StatusCode >= 400 && (int)e.StatusCode < 500) + { + throw new NonRetriableException(e.Message, e); + } + + await foreach (StreamingTextContent x in result.WithCancellation(cancellationToken)) { if (x.Text == null) { continue; } diff --git a/extensions/OpenAI/OpenAI/OpenAITextEmbeddingGenerator.cs b/extensions/OpenAI/OpenAI/OpenAITextEmbeddingGenerator.cs index 9efe22f42..227b1f0a6 100644 --- a/extensions/OpenAI/OpenAI/OpenAITextEmbeddingGenerator.cs +++ b/extensions/OpenAI/OpenAI/OpenAITextEmbeddingGenerator.cs @@ -9,6 +9,7 @@ using Microsoft.Extensions.Logging; using Microsoft.KernelMemory.AI.OpenAI.Internals; using Microsoft.KernelMemory.Diagnostics; +using Microsoft.SemanticKernel; using Microsoft.SemanticKernel.AI.Embeddings; using Microsoft.SemanticKernel.Embeddings; using OpenAI; @@ -122,7 +123,14 @@ public IReadOnlyList GetTokens(string text) public Task GenerateEmbeddingAsync(string text, CancellationToken cancellationToken = default) { this._log.LogTrace("Generating embedding"); - return this._client.GenerateEmbeddingAsync(text, cancellationToken); + try + { + return this._client.GenerateEmbeddingAsync(text, cancellationToken); + } + catch (HttpOperationException e) when (e.StatusCode.HasValue && (int)e.StatusCode >= 400 && (int)e.StatusCode < 500) + { + throw new NonRetriableException(e.Message, e); + } } /// @@ -130,7 +138,14 @@ public async Task GenerateEmbeddingBatchAsync(IEnumerable t { var list = textList.ToList(); this._log.LogTrace("Generating embeddings, batch size '{0}'", list.Count); - var embeddings = await this._client.GenerateEmbeddingsAsync(list, cancellationToken: cancellationToken).ConfigureAwait(false); - return embeddings.Select(e => new Embedding(e)).ToArray(); + try + { + var embeddings = await this._client.GenerateEmbeddingsAsync(list, cancellationToken: cancellationToken).ConfigureAwait(false); + return embeddings.Select(e => new Embedding(e)).ToArray(); + } + catch (HttpOperationException e) when (e.StatusCode.HasValue && (int)e.StatusCode >= 400 && (int)e.StatusCode < 500) + { + throw new NonRetriableException(e.Message, e); + } } } diff --git a/extensions/OpenAI/OpenAI/OpenAITextGenerator.cs b/extensions/OpenAI/OpenAI/OpenAITextGenerator.cs index 1e9a96c06..3e53bbeaf 100644 --- a/extensions/OpenAI/OpenAI/OpenAITextGenerator.cs +++ b/extensions/OpenAI/OpenAI/OpenAITextGenerator.cs @@ -5,6 +5,7 @@ using System.Net.Http; using System.Runtime.CompilerServices; using System.Threading; +using System.Threading.Tasks; using Microsoft.Extensions.Logging; using Microsoft.KernelMemory.AI.OpenAI.Internals; using Microsoft.KernelMemory.Diagnostics; @@ -139,8 +140,18 @@ public async IAsyncEnumerable GenerateTextAsync( } this._log.LogTrace("Sending chat message generation request"); - IAsyncEnumerable result = this._client.GetStreamingTextContentsAsync(prompt, skOptions, cancellationToken: cancellationToken); - await foreach (StreamingTextContent x in result) + + IAsyncEnumerable result; + try + { + result = this._client.GetStreamingTextContentsAsync(prompt, skOptions, cancellationToken: cancellationToken); + } + catch (HttpOperationException e) when (e.StatusCode.HasValue && (int)e.StatusCode >= 400 && (int)e.StatusCode < 500) + { + throw new NonRetriableException(e.Message, e); + } + + await foreach (StreamingTextContent x in result.WithCancellation(cancellationToken)) { // TODO: try catch // if (x.Metadata?["Usage"] is not null) diff --git a/extensions/RabbitMQ/RabbitMQ/RabbitMQPipeline.cs b/extensions/RabbitMQ/RabbitMQ/RabbitMQPipeline.cs index c660eb350..d31b6683b 100644 --- a/extensions/RabbitMQ/RabbitMQ/RabbitMQPipeline.cs +++ b/extensions/RabbitMQ/RabbitMQ/RabbitMQPipeline.cs @@ -69,7 +69,7 @@ public RabbitMQPipeline(RabbitMQConfig config, ILoggerFactory? loggerFactory = n } /// - /// About posion queue and dead letters, see https://www.rabbitmq.com/docs/dlx + /// About poison queue and dead letters, see https://www.rabbitmq.com/docs/dlx public Task ConnectToQueueAsync(string queueName, QueueOptions options = default, CancellationToken cancellationToken = default) { ArgumentNullExceptionEx.ThrowIfNullOrWhiteSpace(queueName, nameof(queueName), "The queue name is empty"); diff --git a/service/Core/Pipeline/Queue/DevTools/SimpleQueues.cs b/service/Core/Pipeline/Queue/DevTools/SimpleQueues.cs index f7331c0a3..d2123ae84 100644 --- a/service/Core/Pipeline/Queue/DevTools/SimpleQueues.cs +++ b/service/Core/Pipeline/Queue/DevTools/SimpleQueues.cs @@ -227,6 +227,7 @@ public void OnDequeue(Func> processMessageAction) } catch (NonRetriableException e) { + message.LastError = $"{e.GetType().FullName} [{e.InnerException?.GetType().FullName}]: {e.Message}"; this._log.LogError(e, "Message '{0}' failed to process due to a non-recoverable error, moving to poison queue.", message.Id); poison = true; } From cd0d270b1fcd934b84ee2067e5c742bd0208859a Mon Sep 17 00:00:00 2001 From: Devis Lucato Date: Sun, 17 Nov 2024 23:13:39 -0800 Subject: [PATCH 03/10] Refactoring --- .../Anthropic/Client/RawAnthropicClient.cs | 3 ++- .../AzureAIDocIntel/AzureAIDocIntelEngine.cs | 2 +- .../AzureAIDocIntelException.cs | 26 +++++++++++++++++++ .../AzureAISearchMemoryException.cs | 10 ++++--- .../AzureAISearch/Internals/MemoryDbSchema.cs | 16 ++++++------ .../AzureOpenAI/AzureOpenAIException.cs | 26 +++++++++++++++++++ .../AzureOpenAITextEmbeddingGenerator.cs | 4 +-- .../AzureOpenAI/AzureOpenAITextGenerator.cs | 2 +- extensions/AzureQueues/AzureQueuesPipeline.cs | 6 ++--- extensions/OpenAI/OpenAI/OpenAIException.cs | 26 +++++++++++++++++++ .../OpenAI/OpenAITextEmbeddingGenerator.cs | 4 +-- .../OpenAI/OpenAI/OpenAITextGenerator.cs | 2 +- .../RabbitMQ.TestApplication/Program.cs | 2 +- .../RabbitMQ/RabbitMQ/RabbitMQPipeline.cs | 6 ++--- service/Abstractions/KernelMemoryException.cs | 14 +++++++--- .../Pipeline/NonRetriableException.cs | 20 -------------- .../Pipeline/OrchestrationException.cs | 15 ++++++++--- service/Abstractions/Pipeline/ResultType.cs | 4 +-- .../DistributedPipelineOrchestrator.cs | 12 ++++----- .../Pipeline/InProcessPipelineOrchestrator.cs | 8 +++--- .../Pipeline/Queue/DevTools/SimpleQueues.cs | 6 ++--- 21 files changed, 147 insertions(+), 67 deletions(-) create mode 100644 extensions/AzureAIDocIntel/AzureAIDocIntelException.cs create mode 100644 extensions/AzureOpenAI/AzureOpenAI/AzureOpenAIException.cs create mode 100644 extensions/OpenAI/OpenAI/OpenAIException.cs delete mode 100644 service/Abstractions/Pipeline/NonRetriableException.cs diff --git a/extensions/Anthropic/Client/RawAnthropicClient.cs b/extensions/Anthropic/Client/RawAnthropicClient.cs index 7e23e3783..a0b1249b8 100644 --- a/extensions/Anthropic/Client/RawAnthropicClient.cs +++ b/extensions/Anthropic/Client/RawAnthropicClient.cs @@ -64,7 +64,8 @@ internal async IAsyncEnumerable CallClaudeStreamingAsy if (!response.IsSuccessStatusCode) { var responseError = await response.Content.ReadAsStringAsync(cancellationToken).ConfigureAwait(false); - throw new KernelMemoryException($"Failed to send request: {response.StatusCode} - {responseError}"); + var isTransient = (new List { 500, 502, 503, 504 }).Contains((int)response.StatusCode); + throw new KernelMemoryException($"Failed to send request: {response.StatusCode} - {responseError}", isTransient: isTransient); } var responseStream = await response.Content.ReadAsStreamAsync(cancellationToken).ConfigureAwait(false); diff --git a/extensions/AzureAIDocIntel/AzureAIDocIntelEngine.cs b/extensions/AzureAIDocIntel/AzureAIDocIntelEngine.cs index 09920256d..7b5cdb262 100644 --- a/extensions/AzureAIDocIntel/AzureAIDocIntelEngine.cs +++ b/extensions/AzureAIDocIntel/AzureAIDocIntelEngine.cs @@ -70,7 +70,7 @@ public async Task ExtractTextFromImageAsync(Stream imageContent, Cancell } catch (RequestFailedException e) when (e.Status is >= 400 and < 500) { - throw new NonRetriableException(e.Message, e); + throw new AzureAIDocIntelException(e.Message, e, isTransient: false); } } } diff --git a/extensions/AzureAIDocIntel/AzureAIDocIntelException.cs b/extensions/AzureAIDocIntel/AzureAIDocIntelException.cs new file mode 100644 index 000000000..0df68ee16 --- /dev/null +++ b/extensions/AzureAIDocIntel/AzureAIDocIntelException.cs @@ -0,0 +1,26 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; + +namespace Microsoft.KernelMemory.DataFormats.AzureAIDocIntel; + +public class AzureAIDocIntelException : KernelMemoryException +{ + /// + public AzureAIDocIntelException(bool? isTransient = null) + { + this.IsTransient = isTransient; + } + + /// + public AzureAIDocIntelException(string message, bool? isTransient = null) : base(message) + { + this.IsTransient = isTransient; + } + + /// + public AzureAIDocIntelException(string message, Exception? innerException, bool? isTransient = null) : base(message, innerException) + { + this.IsTransient = isTransient; + } +} diff --git a/extensions/AzureAISearch/AzureAISearch/AzureAISearchMemoryException.cs b/extensions/AzureAISearch/AzureAISearch/AzureAISearchMemoryException.cs index f63ccc594..d91955c1b 100644 --- a/extensions/AzureAISearch/AzureAISearch/AzureAISearchMemoryException.cs +++ b/extensions/AzureAISearch/AzureAISearch/AzureAISearchMemoryException.cs @@ -7,17 +7,21 @@ namespace Microsoft.KernelMemory.MemoryDb.AzureAISearch; public class AzureAISearchMemoryException : KernelMemoryException { /// - public AzureAISearchMemoryException() + public AzureAISearchMemoryException(bool? isTransient = null) { + this.IsTransient = isTransient; } /// - public AzureAISearchMemoryException(string? message) : base(message) + public AzureAISearchMemoryException(string message, bool? isTransient = null) : base(message) { + this.IsTransient = isTransient; } /// - public AzureAISearchMemoryException(string? message, Exception? innerException) : base(message, innerException) + public AzureAISearchMemoryException(string message, Exception? innerException, bool? isTransient = null) : base(message, innerException) { + this.IsTransient = isTransient; } } + diff --git a/extensions/AzureAISearch/AzureAISearch/Internals/MemoryDbSchema.cs b/extensions/AzureAISearch/AzureAISearch/Internals/MemoryDbSchema.cs index a5f8522bc..f06245dfe 100644 --- a/extensions/AzureAISearch/AzureAISearch/Internals/MemoryDbSchema.cs +++ b/extensions/AzureAISearch/AzureAISearch/Internals/MemoryDbSchema.cs @@ -13,41 +13,41 @@ public void Validate(bool vectorSizeRequired = false) { if (this.Fields.Count == 0) { - throw new KernelMemoryException("The schema is empty"); + throw new AzureAISearchMemoryException("The schema is empty", isTransient: false); } if (this.Fields.All(x => x.Type != MemoryDbField.FieldType.Vector)) { - throw new KernelMemoryException("The schema doesn't contain a vector field"); + throw new AzureAISearchMemoryException("The schema doesn't contain a vector field", isTransient: false); } int keys = this.Fields.Count(x => x.IsKey); switch (keys) { case 0: - throw new KernelMemoryException("The schema doesn't contain a key field"); + throw new AzureAISearchMemoryException("The schema doesn't contain a key field", isTransient: false); case > 1: - throw new KernelMemoryException("The schema cannot contain more than one key"); + throw new AzureAISearchMemoryException("The schema cannot contain more than one key", isTransient: false); } if (vectorSizeRequired && this.Fields.Any(x => x is { Type: MemoryDbField.FieldType.Vector, VectorSize: 0 })) { - throw new KernelMemoryException("Vector fields must have a size greater than zero defined"); + throw new AzureAISearchMemoryException("Vector fields must have a size greater than zero defined", isTransient: false); } if (this.Fields.Any(x => x is { Type: MemoryDbField.FieldType.Bool, IsKey: true })) { - throw new KernelMemoryException("Boolean fields cannot be used as unique keys"); + throw new AzureAISearchMemoryException("Boolean fields cannot be used as unique keys", isTransient: false); } if (this.Fields.Any(x => x is { Type: MemoryDbField.FieldType.ListOfStrings, IsKey: true })) { - throw new KernelMemoryException("Collection fields cannot be used as unique keys"); + throw new AzureAISearchMemoryException("Collection fields cannot be used as unique keys", isTransient: false); } if (this.Fields.Any(x => x is { Type: MemoryDbField.FieldType.Vector, IsKey: true })) { - throw new KernelMemoryException("Vector fields cannot be used as unique keys"); + throw new AzureAISearchMemoryException("Vector fields cannot be used as unique keys", isTransient: false); } } } diff --git a/extensions/AzureOpenAI/AzureOpenAI/AzureOpenAIException.cs b/extensions/AzureOpenAI/AzureOpenAI/AzureOpenAIException.cs new file mode 100644 index 000000000..1f67ebc01 --- /dev/null +++ b/extensions/AzureOpenAI/AzureOpenAI/AzureOpenAIException.cs @@ -0,0 +1,26 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; + +namespace Microsoft.KernelMemory.AI.AzureOpenAI; + +public class AzureOpenAIException : KernelMemoryException +{ + /// + public AzureOpenAIException(bool? isTransient = null) + { + this.IsTransient = isTransient; + } + + /// + public AzureOpenAIException(string message, bool? isTransient = null) : base(message) + { + this.IsTransient = isTransient; + } + + /// + public AzureOpenAIException(string message, Exception? innerException, bool? isTransient = null) : base(message, innerException) + { + this.IsTransient = isTransient; + } +} diff --git a/extensions/AzureOpenAI/AzureOpenAI/AzureOpenAITextEmbeddingGenerator.cs b/extensions/AzureOpenAI/AzureOpenAI/AzureOpenAITextEmbeddingGenerator.cs index 9a5efbe02..51631ce8f 100644 --- a/extensions/AzureOpenAI/AzureOpenAI/AzureOpenAITextEmbeddingGenerator.cs +++ b/extensions/AzureOpenAI/AzureOpenAI/AzureOpenAITextEmbeddingGenerator.cs @@ -128,7 +128,7 @@ public Task GenerateEmbeddingAsync(string text, CancellationToken can } catch (HttpOperationException e) when (e.StatusCode.HasValue && (int)e.StatusCode >= 400 && (int)e.StatusCode < 500) { - throw new NonRetriableException(e.Message, e); + throw new AzureOpenAIException(e.Message, e, isTransient: false); } } @@ -144,7 +144,7 @@ public async Task GenerateEmbeddingBatchAsync(IEnumerable t } catch (HttpOperationException e) when (e.StatusCode.HasValue && (int)e.StatusCode >= 400 && (int)e.StatusCode < 500) { - throw new NonRetriableException(e.Message, e); + throw new AzureOpenAIException(e.Message, e, isTransient: false); } } } diff --git a/extensions/AzureOpenAI/AzureOpenAI/AzureOpenAITextGenerator.cs b/extensions/AzureOpenAI/AzureOpenAI/AzureOpenAITextGenerator.cs index 3522d3eec..3a38238c2 100644 --- a/extensions/AzureOpenAI/AzureOpenAI/AzureOpenAITextGenerator.cs +++ b/extensions/AzureOpenAI/AzureOpenAI/AzureOpenAITextGenerator.cs @@ -148,7 +148,7 @@ public async IAsyncEnumerable GenerateTextAsync( } catch (HttpOperationException e) when (e.StatusCode.HasValue && (int)e.StatusCode >= 400 && (int)e.StatusCode < 500) { - throw new NonRetriableException(e.Message, e); + throw new AzureOpenAIException(e.Message, e, isTransient: false); } await foreach (StreamingTextContent x in result.WithCancellation(cancellationToken)) diff --git a/extensions/AzureQueues/AzureQueuesPipeline.cs b/extensions/AzureQueues/AzureQueuesPipeline.cs index 93b8a80dc..27e8183c8 100644 --- a/extensions/AzureQueues/AzureQueuesPipeline.cs +++ b/extensions/AzureQueues/AzureQueuesPipeline.cs @@ -202,14 +202,14 @@ public void OnDequeue(Func> processMessageAction) await this.DeleteMessageAsync(message, cancellationToken: default).ConfigureAwait(false); break; - case ResultType.RetriableError: + case ResultType.TransientError: var backoffDelay = TimeSpan.FromSeconds(1 * message.DequeueCount); this._log.LogWarning("Message '{0}' failed to process, putting message back in the queue with a delay of {1} msecs", message.MessageId, backoffDelay.TotalMilliseconds); await this.UnlockMessageAsync(message, backoffDelay, cancellationToken: default).ConfigureAwait(false); break; - case ResultType.NonRetriableError: + case ResultType.UnrecoverableError: this._log.LogError("Message '{0}' failed to process due to a non-recoverable error, moving to poison queue", message.MessageId); await this.MoveMessageToPoisonQueueAsync(message, cancellationToken: default).ConfigureAwait(false); break; @@ -224,7 +224,7 @@ public void OnDequeue(Func> processMessageAction) await this.MoveMessageToPoisonQueueAsync(message, cancellationToken: default).ConfigureAwait(false); } } - catch (NonRetriableException e) + catch (KernelMemoryException e) when (e.IsTransient.HasValue && !e.IsTransient.Value) { this._log.LogError(e, "Message '{0}' failed to process due to a non-recoverable error, moving to poison queue", message.MessageId); await this.MoveMessageToPoisonQueueAsync(message, cancellationToken: default).ConfigureAwait(false); diff --git a/extensions/OpenAI/OpenAI/OpenAIException.cs b/extensions/OpenAI/OpenAI/OpenAIException.cs new file mode 100644 index 000000000..47451d0ce --- /dev/null +++ b/extensions/OpenAI/OpenAI/OpenAIException.cs @@ -0,0 +1,26 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; + +namespace Microsoft.KernelMemory.AI.OpenAI; + +public class OpenAIException : KernelMemoryException +{ + /// + public OpenAIException(bool? isTransient = null) + { + this.IsTransient = isTransient; + } + + /// + public OpenAIException(string message, bool? isTransient = null) : base(message) + { + this.IsTransient = isTransient; + } + + /// + public OpenAIException(string message, Exception? innerException, bool? isTransient = null) : base(message, innerException) + { + this.IsTransient = isTransient; + } +} diff --git a/extensions/OpenAI/OpenAI/OpenAITextEmbeddingGenerator.cs b/extensions/OpenAI/OpenAI/OpenAITextEmbeddingGenerator.cs index 227b1f0a6..0ea58cc55 100644 --- a/extensions/OpenAI/OpenAI/OpenAITextEmbeddingGenerator.cs +++ b/extensions/OpenAI/OpenAI/OpenAITextEmbeddingGenerator.cs @@ -129,7 +129,7 @@ public Task GenerateEmbeddingAsync(string text, CancellationToken can } catch (HttpOperationException e) when (e.StatusCode.HasValue && (int)e.StatusCode >= 400 && (int)e.StatusCode < 500) { - throw new NonRetriableException(e.Message, e); + throw new OpenAIException(e.Message, e, isTransient: false); } } @@ -145,7 +145,7 @@ public async Task GenerateEmbeddingBatchAsync(IEnumerable t } catch (HttpOperationException e) when (e.StatusCode.HasValue && (int)e.StatusCode >= 400 && (int)e.StatusCode < 500) { - throw new NonRetriableException(e.Message, e); + throw new OpenAIException(e.Message, e, isTransient: false); } } } diff --git a/extensions/OpenAI/OpenAI/OpenAITextGenerator.cs b/extensions/OpenAI/OpenAI/OpenAITextGenerator.cs index 3e53bbeaf..caccc43df 100644 --- a/extensions/OpenAI/OpenAI/OpenAITextGenerator.cs +++ b/extensions/OpenAI/OpenAI/OpenAITextGenerator.cs @@ -148,7 +148,7 @@ public async IAsyncEnumerable GenerateTextAsync( } catch (HttpOperationException e) when (e.StatusCode.HasValue && (int)e.StatusCode >= 400 && (int)e.StatusCode < 500) { - throw new NonRetriableException(e.Message, e); + throw new OpenAIException(e.Message, e, isTransient: false); } await foreach (StreamingTextContent x in result.WithCancellation(cancellationToken)) diff --git a/extensions/RabbitMQ/RabbitMQ.TestApplication/Program.cs b/extensions/RabbitMQ/RabbitMQ.TestApplication/Program.cs index 84c4a2e0d..0b110c3bd 100644 --- a/extensions/RabbitMQ/RabbitMQ.TestApplication/Program.cs +++ b/extensions/RabbitMQ/RabbitMQ.TestApplication/Program.cs @@ -39,7 +39,7 @@ public static async Task Main() { Console.WriteLine($"{++counter} Received message: {msg}"); await Task.Delay(0); - return ResultType.RetriableError; + return ResultType.TransientError; }); await pipeline.ConnectToQueueAsync(QueueName, QueueOptions.PubSub); diff --git a/extensions/RabbitMQ/RabbitMQ/RabbitMQPipeline.cs b/extensions/RabbitMQ/RabbitMQ/RabbitMQPipeline.cs index d31b6683b..63af63a39 100644 --- a/extensions/RabbitMQ/RabbitMQ/RabbitMQPipeline.cs +++ b/extensions/RabbitMQ/RabbitMQ/RabbitMQPipeline.cs @@ -201,7 +201,7 @@ public void OnDequeue(Func> processMessageAction) this._channel.BasicAck(args.DeliveryTag, multiple: false); break; - case ResultType.RetriableError: + case ResultType.TransientError: if (attemptNumber < this._maxAttempts) { this._log.LogWarning("Message '{0}' failed to process (attempt {1} of {2}), putting message back in the queue", @@ -220,7 +220,7 @@ public void OnDequeue(Func> processMessageAction) this._channel.BasicNack(args.DeliveryTag, multiple: false, requeue: true); break; - case ResultType.NonRetriableError: + case ResultType.UnrecoverableError: this._log.LogError("Message '{0}' failed to process due to a non-recoverable error, moving to poison queue", args.BasicProperties?.MessageId); this._channel.BasicNack(args.DeliveryTag, multiple: false, requeue: false); break; @@ -229,7 +229,7 @@ public void OnDequeue(Func> processMessageAction) throw new ArgumentOutOfRangeException($"Unknown {resultType:G} result"); } } - catch (NonRetriableException e) + catch (KernelMemoryException e) when (e.IsTransient.HasValue && !e.IsTransient.Value) { this._log.LogError(e, "Message '{0}' failed to process due to a non-recoverable error, moving to poison queue", args.BasicProperties?.MessageId); this._channel.BasicNack(args.DeliveryTag, multiple: false, requeue: false); diff --git a/service/Abstractions/KernelMemoryException.cs b/service/Abstractions/KernelMemoryException.cs index 387106d13..3afcdfe97 100644 --- a/service/Abstractions/KernelMemoryException.cs +++ b/service/Abstractions/KernelMemoryException.cs @@ -9,19 +9,25 @@ namespace Microsoft.KernelMemory; /// public class KernelMemoryException : Exception { + public bool? IsTransient { get; protected init; } = null; + /// /// Initializes a new instance of the class with a default message. /// - public KernelMemoryException() + /// Optional parameter to indicate if the error is temporary and might disappear by retrying. + public KernelMemoryException(bool? isTransient = null) { + this.IsTransient = isTransient; } /// /// Initializes a new instance of the class with its message set to . /// /// A string that describes the error. - public KernelMemoryException(string? message) : base(message) + /// Optional parameter to indicate if the error is temporary and might disappear by retrying. + public KernelMemoryException(string? message, bool? isTransient = null) : base(message) { + this.IsTransient = isTransient; } /// @@ -29,7 +35,9 @@ public KernelMemoryException(string? message) : base(message) /// /// A string that describes the error. /// The exception that is the cause of the current exception. - public KernelMemoryException(string? message, Exception? innerException) : base(message, innerException) + /// Optional parameter to indicate if the error is temporary and might disappear by retrying. + public KernelMemoryException(string? message, Exception? innerException, bool? isTransient = null) : base(message, innerException) { + this.IsTransient = isTransient; } } diff --git a/service/Abstractions/Pipeline/NonRetriableException.cs b/service/Abstractions/Pipeline/NonRetriableException.cs deleted file mode 100644 index 4905f89af..000000000 --- a/service/Abstractions/Pipeline/NonRetriableException.cs +++ /dev/null @@ -1,20 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using System; -using Microsoft.KernelMemory.Pipeline; - -#pragma warning disable IDE0130 // reduce number of "using" statements -// ReSharper disable once CheckNamespace - reduce number of "using" statements -namespace Microsoft.KernelMemory; - -public class NonRetriableException : OrchestrationException -{ - /// - public NonRetriableException() { } - - /// - public NonRetriableException(string message) : base(message) { } - - /// - public NonRetriableException(string message, Exception? innerException) : base(message, innerException) { } -} diff --git a/service/Abstractions/Pipeline/OrchestrationException.cs b/service/Abstractions/Pipeline/OrchestrationException.cs index 6d2e3246c..6670e3cc4 100644 --- a/service/Abstractions/Pipeline/OrchestrationException.cs +++ b/service/Abstractions/Pipeline/OrchestrationException.cs @@ -7,11 +7,20 @@ namespace Microsoft.KernelMemory.Pipeline; public class OrchestrationException : KernelMemoryException { /// - public OrchestrationException() { } + public OrchestrationException(bool? isTransient = null) + { + this.IsTransient = isTransient; + } /// - public OrchestrationException(string message) : base(message) { } + public OrchestrationException(string message, bool? isTransient = null) : base(message) + { + this.IsTransient = isTransient; + } /// - public OrchestrationException(string message, Exception? innerException) : base(message, innerException) { } + public OrchestrationException(string message, Exception? innerException, bool? isTransient = null) : base(message, innerException) + { + this.IsTransient = isTransient; + } } diff --git a/service/Abstractions/Pipeline/ResultType.cs b/service/Abstractions/Pipeline/ResultType.cs index 3207e0578..09064614a 100644 --- a/service/Abstractions/Pipeline/ResultType.cs +++ b/service/Abstractions/Pipeline/ResultType.cs @@ -5,6 +5,6 @@ namespace Microsoft.KernelMemory.Pipeline; public enum ResultType { Success = 0, - RetriableError = 1, - NonRetriableError = 2, + TransientError = 1, + UnrecoverableError = 2, } diff --git a/service/Core/Pipeline/DistributedPipelineOrchestrator.cs b/service/Core/Pipeline/DistributedPipelineOrchestrator.cs index cc916dd77..8c25592e8 100644 --- a/service/Core/Pipeline/DistributedPipelineOrchestrator.cs +++ b/service/Core/Pipeline/DistributedPipelineOrchestrator.cs @@ -93,7 +93,7 @@ public override async Task AddHandlerAsync( if (pipelinePointer == null) { this.Log.LogError("Pipeline pointer deserialization failed, queue `{0}`. Message discarded.", handler.StepName); - return ResultType.NonRetriableError; + return ResultType.UnrecoverableError; } DataPipeline? pipeline; @@ -121,18 +121,18 @@ public override async Task AddHandlerAsync( } this.Log.LogError("Pipeline `{0}/{1}` not found, cancelling step `{2}`", pipelinePointer.Index, pipelinePointer.DocumentId, handler.StepName); - return ResultType.NonRetriableError; + return ResultType.UnrecoverableError; } catch (InvalidPipelineDataException) { this.Log.LogError("Pipeline `{0}/{1}` state load failed, invalid state, queue `{2}`", pipelinePointer.Index, pipelinePointer.DocumentId, handler.StepName); - return ResultType.RetriableError; + return ResultType.TransientError; } if (pipeline == null) { this.Log.LogError("Pipeline `{0}/{1}` state load failed, the state is null, queue `{2}`", pipelinePointer.Index, pipelinePointer.DocumentId, handler.StepName); - return ResultType.RetriableError; + return ResultType.TransientError; } if (pipelinePointer.ExecutionId != pipeline.ExecutionId) @@ -228,11 +228,11 @@ private async Task RunPipelineStepAsync( await this.MoveForwardAsync(pipeline, cancellationToken).ConfigureAwait(false); break; - case ResultType.RetriableError: + case ResultType.TransientError: this.Log.LogError("Handler {0} failed to process pipeline {1}", currentStepName, pipeline.DocumentId); break; - case ResultType.NonRetriableError: + case ResultType.UnrecoverableError: this.Log.LogError("Handler {0} failed to process pipeline {1} due to an unrecoverable error", currentStepName, pipeline.DocumentId); break; diff --git a/service/Core/Pipeline/InProcessPipelineOrchestrator.cs b/service/Core/Pipeline/InProcessPipelineOrchestrator.cs index 0dde21b2e..40df40790 100644 --- a/service/Core/Pipeline/InProcessPipelineOrchestrator.cs +++ b/service/Core/Pipeline/InProcessPipelineOrchestrator.cs @@ -185,13 +185,13 @@ public override async Task RunPipelineAsync(DataPipeline pipeline, CancellationT await this.UpdatePipelineStatusAsync(pipeline, cancellationToken).ConfigureAwait(false); break; - case ResultType.RetriableError: + case ResultType.TransientError: this.Log.LogError("Handler '{0}' failed to process pipeline '{1}/{2}'", currentStepName, pipeline.Index, pipeline.DocumentId); - throw new OrchestrationException($"Pipeline error, step {currentStepName} failed"); + throw new OrchestrationException($"Pipeline error, step {currentStepName} failed", isTransient: true); - case ResultType.NonRetriableError: + case ResultType.UnrecoverableError: this.Log.LogError("Handler '{0}' failed to process pipeline '{1}/{2}' due to an unrecoverable error", currentStepName, pipeline.Index, pipeline.DocumentId); - throw new NonRetriableException($"Unrecoverable pipeline error, step {currentStepName} failed and cannot be retried"); + throw new OrchestrationException($"Unrecoverable pipeline error, step {currentStepName} failed and cannot be retried", isTransient: false); default: throw new ArgumentOutOfRangeException($"Unknown {resultType:G} result type"); diff --git a/service/Core/Pipeline/Queue/DevTools/SimpleQueues.cs b/service/Core/Pipeline/Queue/DevTools/SimpleQueues.cs index d2123ae84..0d6a0156e 100644 --- a/service/Core/Pipeline/Queue/DevTools/SimpleQueues.cs +++ b/service/Core/Pipeline/Queue/DevTools/SimpleQueues.cs @@ -201,7 +201,7 @@ public void OnDequeue(Func> processMessageAction) await this.DeleteMessageAsync(message.Id, this._cancellation.Token).ConfigureAwait(false); break; - case ResultType.RetriableError: + case ResultType.TransientError: message.LastError = "Message handler returned false"; if (message.DequeueCount == this._maxAttempts) { @@ -216,7 +216,7 @@ public void OnDequeue(Func> processMessageAction) break; - case ResultType.NonRetriableError: + case ResultType.UnrecoverableError: this._log.LogError("Message '{0}' failed to process due to a non-recoverable error, moving to poison queue", message.Id); poison = true; break; @@ -225,7 +225,7 @@ public void OnDequeue(Func> processMessageAction) throw new ArgumentOutOfRangeException($"Unknown {resultType:G} result"); } } - catch (NonRetriableException e) + catch (KernelMemoryException e) when (e.IsTransient.HasValue && !e.IsTransient.Value) { message.LastError = $"{e.GetType().FullName} [{e.InnerException?.GetType().FullName}]: {e.Message}"; this._log.LogError(e, "Message '{0}' failed to process due to a non-recoverable error, moving to poison queue.", message.Id); From c879204ff8da55a39617870b6ed91ee9f46bdb63 Mon Sep 17 00:00:00 2001 From: Devis Lucato Date: Mon, 18 Nov 2024 10:44:16 -0800 Subject: [PATCH 04/10] Refactoring --- .../Anthropic/Client/RawAnthropicClient.cs | 4 +- .../AzureAIDocIntel/AzureAIDocIntelEngine.cs | 2 +- .../AzureOpenAITextEmbeddingGenerator.cs | 4 +- .../AzureOpenAI/AzureOpenAITextGenerator.cs | 2 +- extensions/AzureQueues/AzureQueuesPipeline.cs | 2 +- .../OpenAI/OpenAITextEmbeddingGenerator.cs | 4 +- .../OpenAI/OpenAI/OpenAITextGenerator.cs | 2 +- .../RabbitMQ/RabbitMQ/RabbitMQPipeline.cs | 2 +- .../Abstractions/Diagnostics/HttpErrors.cs | 59 +++++++++++++++++++ service/Abstractions/Pipeline/ResultType.cs | 2 +- .../DistributedPipelineOrchestrator.cs | 6 +- .../Pipeline/InProcessPipelineOrchestrator.cs | 2 +- .../Pipeline/Queue/DevTools/SimpleQueues.cs | 2 +- 13 files changed, 76 insertions(+), 17 deletions(-) create mode 100644 service/Abstractions/Diagnostics/HttpErrors.cs diff --git a/extensions/Anthropic/Client/RawAnthropicClient.cs b/extensions/Anthropic/Client/RawAnthropicClient.cs index a0b1249b8..d24945e8c 100644 --- a/extensions/Anthropic/Client/RawAnthropicClient.cs +++ b/extensions/Anthropic/Client/RawAnthropicClient.cs @@ -64,8 +64,8 @@ internal async IAsyncEnumerable CallClaudeStreamingAsy if (!response.IsSuccessStatusCode) { var responseError = await response.Content.ReadAsStringAsync(cancellationToken).ConfigureAwait(false); - var isTransient = (new List { 500, 502, 503, 504 }).Contains((int)response.StatusCode); - throw new KernelMemoryException($"Failed to send request: {response.StatusCode} - {responseError}", isTransient: isTransient); + throw new KernelMemoryException($"Failed to send request: {response.StatusCode} - {responseError}", + isTransient: response.StatusCode.IsTransientError()); } var responseStream = await response.Content.ReadAsStreamAsync(cancellationToken).ConfigureAwait(false); diff --git a/extensions/AzureAIDocIntel/AzureAIDocIntelEngine.cs b/extensions/AzureAIDocIntel/AzureAIDocIntelEngine.cs index 7b5cdb262..0f28ae663 100644 --- a/extensions/AzureAIDocIntel/AzureAIDocIntelEngine.cs +++ b/extensions/AzureAIDocIntel/AzureAIDocIntelEngine.cs @@ -68,7 +68,7 @@ public async Task ExtractTextFromImageAsync(Stream imageContent, Cancell return operationResponse.Value.Content; } - catch (RequestFailedException e) when (e.Status is >= 400 and < 500) + catch (RequestFailedException e) when (HttpErrors.IsFatalError(e.Status)) { throw new AzureAIDocIntelException(e.Message, e, isTransient: false); } diff --git a/extensions/AzureOpenAI/AzureOpenAI/AzureOpenAITextEmbeddingGenerator.cs b/extensions/AzureOpenAI/AzureOpenAI/AzureOpenAITextEmbeddingGenerator.cs index 51631ce8f..f600c7dd2 100644 --- a/extensions/AzureOpenAI/AzureOpenAI/AzureOpenAITextEmbeddingGenerator.cs +++ b/extensions/AzureOpenAI/AzureOpenAI/AzureOpenAITextEmbeddingGenerator.cs @@ -126,7 +126,7 @@ public Task GenerateEmbeddingAsync(string text, CancellationToken can { return this._client.GenerateEmbeddingAsync(text, cancellationToken); } - catch (HttpOperationException e) when (e.StatusCode.HasValue && (int)e.StatusCode >= 400 && (int)e.StatusCode < 500) + catch (HttpOperationException e) when (e.StatusCode.IsFatalError()) { throw new AzureOpenAIException(e.Message, e, isTransient: false); } @@ -142,7 +142,7 @@ public async Task GenerateEmbeddingBatchAsync(IEnumerable t IList> embeddings = await this._client.GenerateEmbeddingsAsync(list, cancellationToken: cancellationToken).ConfigureAwait(false); return embeddings.Select(e => new Embedding(e)).ToArray(); } - catch (HttpOperationException e) when (e.StatusCode.HasValue && (int)e.StatusCode >= 400 && (int)e.StatusCode < 500) + catch (HttpOperationException e) when (e.StatusCode.IsFatalError()) { throw new AzureOpenAIException(e.Message, e, isTransient: false); } diff --git a/extensions/AzureOpenAI/AzureOpenAI/AzureOpenAITextGenerator.cs b/extensions/AzureOpenAI/AzureOpenAI/AzureOpenAITextGenerator.cs index 3a38238c2..47896fd0e 100644 --- a/extensions/AzureOpenAI/AzureOpenAI/AzureOpenAITextGenerator.cs +++ b/extensions/AzureOpenAI/AzureOpenAI/AzureOpenAITextGenerator.cs @@ -146,7 +146,7 @@ public async IAsyncEnumerable GenerateTextAsync( { result = this._client.GetStreamingTextContentsAsync(prompt, skOptions, cancellationToken: cancellationToken); } - catch (HttpOperationException e) when (e.StatusCode.HasValue && (int)e.StatusCode >= 400 && (int)e.StatusCode < 500) + catch (HttpOperationException e) when (e.StatusCode.IsFatalError()) { throw new AzureOpenAIException(e.Message, e, isTransient: false); } diff --git a/extensions/AzureQueues/AzureQueuesPipeline.cs b/extensions/AzureQueues/AzureQueuesPipeline.cs index 27e8183c8..be4d2282c 100644 --- a/extensions/AzureQueues/AzureQueuesPipeline.cs +++ b/extensions/AzureQueues/AzureQueuesPipeline.cs @@ -209,7 +209,7 @@ public void OnDequeue(Func> processMessageAction) await this.UnlockMessageAsync(message, backoffDelay, cancellationToken: default).ConfigureAwait(false); break; - case ResultType.UnrecoverableError: + case ResultType.FatalError: this._log.LogError("Message '{0}' failed to process due to a non-recoverable error, moving to poison queue", message.MessageId); await this.MoveMessageToPoisonQueueAsync(message, cancellationToken: default).ConfigureAwait(false); break; diff --git a/extensions/OpenAI/OpenAI/OpenAITextEmbeddingGenerator.cs b/extensions/OpenAI/OpenAI/OpenAITextEmbeddingGenerator.cs index 0ea58cc55..d36423abb 100644 --- a/extensions/OpenAI/OpenAI/OpenAITextEmbeddingGenerator.cs +++ b/extensions/OpenAI/OpenAI/OpenAITextEmbeddingGenerator.cs @@ -127,7 +127,7 @@ public Task GenerateEmbeddingAsync(string text, CancellationToken can { return this._client.GenerateEmbeddingAsync(text, cancellationToken); } - catch (HttpOperationException e) when (e.StatusCode.HasValue && (int)e.StatusCode >= 400 && (int)e.StatusCode < 500) + catch (HttpOperationException e) when (e.StatusCode.IsFatalError()) { throw new OpenAIException(e.Message, e, isTransient: false); } @@ -143,7 +143,7 @@ public async Task GenerateEmbeddingBatchAsync(IEnumerable t var embeddings = await this._client.GenerateEmbeddingsAsync(list, cancellationToken: cancellationToken).ConfigureAwait(false); return embeddings.Select(e => new Embedding(e)).ToArray(); } - catch (HttpOperationException e) when (e.StatusCode.HasValue && (int)e.StatusCode >= 400 && (int)e.StatusCode < 500) + catch (HttpOperationException e) when (e.StatusCode.IsFatalError()) { throw new OpenAIException(e.Message, e, isTransient: false); } diff --git a/extensions/OpenAI/OpenAI/OpenAITextGenerator.cs b/extensions/OpenAI/OpenAI/OpenAITextGenerator.cs index caccc43df..ef069df88 100644 --- a/extensions/OpenAI/OpenAI/OpenAITextGenerator.cs +++ b/extensions/OpenAI/OpenAI/OpenAITextGenerator.cs @@ -146,7 +146,7 @@ public async IAsyncEnumerable GenerateTextAsync( { result = this._client.GetStreamingTextContentsAsync(prompt, skOptions, cancellationToken: cancellationToken); } - catch (HttpOperationException e) when (e.StatusCode.HasValue && (int)e.StatusCode >= 400 && (int)e.StatusCode < 500) + catch (HttpOperationException e) when (e.StatusCode.IsFatalError()) { throw new OpenAIException(e.Message, e, isTransient: false); } diff --git a/extensions/RabbitMQ/RabbitMQ/RabbitMQPipeline.cs b/extensions/RabbitMQ/RabbitMQ/RabbitMQPipeline.cs index 63af63a39..3168e66a5 100644 --- a/extensions/RabbitMQ/RabbitMQ/RabbitMQPipeline.cs +++ b/extensions/RabbitMQ/RabbitMQ/RabbitMQPipeline.cs @@ -220,7 +220,7 @@ public void OnDequeue(Func> processMessageAction) this._channel.BasicNack(args.DeliveryTag, multiple: false, requeue: true); break; - case ResultType.UnrecoverableError: + case ResultType.FatalError: this._log.LogError("Message '{0}' failed to process due to a non-recoverable error, moving to poison queue", args.BasicProperties?.MessageId); this._channel.BasicNack(args.DeliveryTag, multiple: false, requeue: false); break; diff --git a/service/Abstractions/Diagnostics/HttpErrors.cs b/service/Abstractions/Diagnostics/HttpErrors.cs new file mode 100644 index 000000000..7d52e9718 --- /dev/null +++ b/service/Abstractions/Diagnostics/HttpErrors.cs @@ -0,0 +1,59 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Collections.Generic; +using System.Net; + +namespace Microsoft.KernelMemory.Diagnostics; + +public static class HttpErrors +{ + // Errors that might disappear by retrying + private static readonly HashSet s_transientErrors = + [ + (int)HttpStatusCode.InternalServerError, + (int)HttpStatusCode.BadGateway, + (int)HttpStatusCode.ServiceUnavailable, + (int)HttpStatusCode.GatewayTimeout, + (int)HttpStatusCode.InsufficientStorage + ]; + + public static bool IsTransientError(this HttpStatusCode statusCode) + { + return s_transientErrors.Contains((int)statusCode); + } + + public static bool IsTransientError(this HttpStatusCode? statusCode) + { + return statusCode.HasValue && s_transientErrors.Contains((int)statusCode.Value); + } + + public static bool IsTransientError(int statusCode) + { + return s_transientErrors.Contains(statusCode); + } + + public static bool IsFatalError(this HttpStatusCode statusCode) + { + return IsError(statusCode) && !IsTransientError(statusCode); + } + + public static bool IsFatalError(this HttpStatusCode? statusCode) + { + return statusCode.HasValue && IsError(statusCode) && !IsTransientError(statusCode); + } + + public static bool IsFatalError(int statusCode) + { + return IsError(statusCode) && !IsTransientError(statusCode); + } + + private static bool IsError(this HttpStatusCode? statusCode) + { + return statusCode.HasValue && (int)statusCode.Value >= 400; + } + + private static bool IsError(int statusCode) + { + return statusCode >= 400; + } +} diff --git a/service/Abstractions/Pipeline/ResultType.cs b/service/Abstractions/Pipeline/ResultType.cs index 09064614a..b949fe2ca 100644 --- a/service/Abstractions/Pipeline/ResultType.cs +++ b/service/Abstractions/Pipeline/ResultType.cs @@ -6,5 +6,5 @@ public enum ResultType { Success = 0, TransientError = 1, - UnrecoverableError = 2, + FatalError = 2, } diff --git a/service/Core/Pipeline/DistributedPipelineOrchestrator.cs b/service/Core/Pipeline/DistributedPipelineOrchestrator.cs index 8c25592e8..98214ab5b 100644 --- a/service/Core/Pipeline/DistributedPipelineOrchestrator.cs +++ b/service/Core/Pipeline/DistributedPipelineOrchestrator.cs @@ -93,7 +93,7 @@ public override async Task AddHandlerAsync( if (pipelinePointer == null) { this.Log.LogError("Pipeline pointer deserialization failed, queue `{0}`. Message discarded.", handler.StepName); - return ResultType.UnrecoverableError; + return ResultType.FatalError; } DataPipeline? pipeline; @@ -121,7 +121,7 @@ public override async Task AddHandlerAsync( } this.Log.LogError("Pipeline `{0}/{1}` not found, cancelling step `{2}`", pipelinePointer.Index, pipelinePointer.DocumentId, handler.StepName); - return ResultType.UnrecoverableError; + return ResultType.FatalError; } catch (InvalidPipelineDataException) { @@ -232,7 +232,7 @@ private async Task RunPipelineStepAsync( this.Log.LogError("Handler {0} failed to process pipeline {1}", currentStepName, pipeline.DocumentId); break; - case ResultType.UnrecoverableError: + case ResultType.FatalError: this.Log.LogError("Handler {0} failed to process pipeline {1} due to an unrecoverable error", currentStepName, pipeline.DocumentId); break; diff --git a/service/Core/Pipeline/InProcessPipelineOrchestrator.cs b/service/Core/Pipeline/InProcessPipelineOrchestrator.cs index 40df40790..c8dbc9a95 100644 --- a/service/Core/Pipeline/InProcessPipelineOrchestrator.cs +++ b/service/Core/Pipeline/InProcessPipelineOrchestrator.cs @@ -189,7 +189,7 @@ public override async Task RunPipelineAsync(DataPipeline pipeline, CancellationT this.Log.LogError("Handler '{0}' failed to process pipeline '{1}/{2}'", currentStepName, pipeline.Index, pipeline.DocumentId); throw new OrchestrationException($"Pipeline error, step {currentStepName} failed", isTransient: true); - case ResultType.UnrecoverableError: + case ResultType.FatalError: this.Log.LogError("Handler '{0}' failed to process pipeline '{1}/{2}' due to an unrecoverable error", currentStepName, pipeline.Index, pipeline.DocumentId); throw new OrchestrationException($"Unrecoverable pipeline error, step {currentStepName} failed and cannot be retried", isTransient: false); diff --git a/service/Core/Pipeline/Queue/DevTools/SimpleQueues.cs b/service/Core/Pipeline/Queue/DevTools/SimpleQueues.cs index 0d6a0156e..8fe2ed713 100644 --- a/service/Core/Pipeline/Queue/DevTools/SimpleQueues.cs +++ b/service/Core/Pipeline/Queue/DevTools/SimpleQueues.cs @@ -216,7 +216,7 @@ public void OnDequeue(Func> processMessageAction) break; - case ResultType.UnrecoverableError: + case ResultType.FatalError: this._log.LogError("Message '{0}' failed to process due to a non-recoverable error, moving to poison queue", message.Id); poison = true; break; From 7694ece8fc114258faa50b7215e5fe783db1629a Mon Sep 17 00:00:00 2001 From: Devis Lucato Date: Mon, 18 Nov 2024 10:57:58 -0800 Subject: [PATCH 05/10] refactoring --- extensions/AzureAIDocIntel/AzureAIDocIntelEngine.cs | 4 ++-- .../AzureOpenAI/AzureOpenAITextEmbeddingGenerator.cs | 8 ++++---- .../AzureOpenAI/AzureOpenAI/AzureOpenAITextGenerator.cs | 4 ++-- extensions/OpenAI/OpenAI/OpenAITextEmbeddingGenerator.cs | 8 ++++---- extensions/OpenAI/OpenAI/OpenAITextGenerator.cs | 4 ++-- 5 files changed, 14 insertions(+), 14 deletions(-) diff --git a/extensions/AzureAIDocIntel/AzureAIDocIntelEngine.cs b/extensions/AzureAIDocIntel/AzureAIDocIntelEngine.cs index 0f28ae663..98b6df6cf 100644 --- a/extensions/AzureAIDocIntel/AzureAIDocIntelEngine.cs +++ b/extensions/AzureAIDocIntel/AzureAIDocIntelEngine.cs @@ -68,9 +68,9 @@ public async Task ExtractTextFromImageAsync(Stream imageContent, Cancell return operationResponse.Value.Content; } - catch (RequestFailedException e) when (HttpErrors.IsFatalError(e.Status)) + catch (RequestFailedException e) { - throw new AzureAIDocIntelException(e.Message, e, isTransient: false); + throw new AzureAIDocIntelException(e.Message, e, isTransient: HttpErrors.IsTransientError(e.Status)); } } } diff --git a/extensions/AzureOpenAI/AzureOpenAI/AzureOpenAITextEmbeddingGenerator.cs b/extensions/AzureOpenAI/AzureOpenAI/AzureOpenAITextEmbeddingGenerator.cs index f600c7dd2..7aa954d36 100644 --- a/extensions/AzureOpenAI/AzureOpenAI/AzureOpenAITextEmbeddingGenerator.cs +++ b/extensions/AzureOpenAI/AzureOpenAI/AzureOpenAITextEmbeddingGenerator.cs @@ -126,9 +126,9 @@ public Task GenerateEmbeddingAsync(string text, CancellationToken can { return this._client.GenerateEmbeddingAsync(text, cancellationToken); } - catch (HttpOperationException e) when (e.StatusCode.IsFatalError()) + catch (HttpOperationException e) { - throw new AzureOpenAIException(e.Message, e, isTransient: false); + throw new AzureOpenAIException(e.Message, e, isTransient: e.StatusCode.IsTransientError()); } } @@ -142,9 +142,9 @@ public async Task GenerateEmbeddingBatchAsync(IEnumerable t IList> embeddings = await this._client.GenerateEmbeddingsAsync(list, cancellationToken: cancellationToken).ConfigureAwait(false); return embeddings.Select(e => new Embedding(e)).ToArray(); } - catch (HttpOperationException e) when (e.StatusCode.IsFatalError()) + catch (HttpOperationException e) { - throw new AzureOpenAIException(e.Message, e, isTransient: false); + throw new AzureOpenAIException(e.Message, e, isTransient: e.StatusCode.IsTransientError()); } } } diff --git a/extensions/AzureOpenAI/AzureOpenAI/AzureOpenAITextGenerator.cs b/extensions/AzureOpenAI/AzureOpenAI/AzureOpenAITextGenerator.cs index 47896fd0e..0661bed3e 100644 --- a/extensions/AzureOpenAI/AzureOpenAI/AzureOpenAITextGenerator.cs +++ b/extensions/AzureOpenAI/AzureOpenAI/AzureOpenAITextGenerator.cs @@ -146,9 +146,9 @@ public async IAsyncEnumerable GenerateTextAsync( { result = this._client.GetStreamingTextContentsAsync(prompt, skOptions, cancellationToken: cancellationToken); } - catch (HttpOperationException e) when (e.StatusCode.IsFatalError()) + catch (HttpOperationException e) { - throw new AzureOpenAIException(e.Message, e, isTransient: false); + throw new AzureOpenAIException(e.Message, e, isTransient: e.StatusCode.IsTransientError()); } await foreach (StreamingTextContent x in result.WithCancellation(cancellationToken)) diff --git a/extensions/OpenAI/OpenAI/OpenAITextEmbeddingGenerator.cs b/extensions/OpenAI/OpenAI/OpenAITextEmbeddingGenerator.cs index d36423abb..d9582d70c 100644 --- a/extensions/OpenAI/OpenAI/OpenAITextEmbeddingGenerator.cs +++ b/extensions/OpenAI/OpenAI/OpenAITextEmbeddingGenerator.cs @@ -127,9 +127,9 @@ public Task GenerateEmbeddingAsync(string text, CancellationToken can { return this._client.GenerateEmbeddingAsync(text, cancellationToken); } - catch (HttpOperationException e) when (e.StatusCode.IsFatalError()) + catch (HttpOperationException e) { - throw new OpenAIException(e.Message, e, isTransient: false); + throw new OpenAIException(e.Message, e, isTransient: e.StatusCode.IsTransientError()); } } @@ -143,9 +143,9 @@ public async Task GenerateEmbeddingBatchAsync(IEnumerable t var embeddings = await this._client.GenerateEmbeddingsAsync(list, cancellationToken: cancellationToken).ConfigureAwait(false); return embeddings.Select(e => new Embedding(e)).ToArray(); } - catch (HttpOperationException e) when (e.StatusCode.IsFatalError()) + catch (HttpOperationException e) { - throw new OpenAIException(e.Message, e, isTransient: false); + throw new OpenAIException(e.Message, e, isTransient: e.StatusCode.IsTransientError()); } } } diff --git a/extensions/OpenAI/OpenAI/OpenAITextGenerator.cs b/extensions/OpenAI/OpenAI/OpenAITextGenerator.cs index ef069df88..8faa3e947 100644 --- a/extensions/OpenAI/OpenAI/OpenAITextGenerator.cs +++ b/extensions/OpenAI/OpenAI/OpenAITextGenerator.cs @@ -146,9 +146,9 @@ public async IAsyncEnumerable GenerateTextAsync( { result = this._client.GetStreamingTextContentsAsync(prompt, skOptions, cancellationToken: cancellationToken); } - catch (HttpOperationException e) when (e.StatusCode.IsFatalError()) + catch (HttpOperationException e) { - throw new OpenAIException(e.Message, e, isTransient: false); + throw new OpenAIException(e.Message, e, isTransient: e.StatusCode.IsTransientError()); } await foreach (StreamingTextContent x in result.WithCancellation(cancellationToken)) From ee6215bd0c4cd5131f570e637212765a4310f855 Mon Sep 17 00:00:00 2001 From: Devis Lucato Date: Mon, 18 Nov 2024 11:05:55 -0800 Subject: [PATCH 06/10] Refactoring --- .../Program.cs | 4 +-- .../MyHandler.cs | 4 +-- extensions/AzureQueues/AzureQueuesPipeline.cs | 14 +++++----- .../RabbitMQ.TestApplication/Program.cs | 2 +- .../RabbitMQ/RabbitMQ/RabbitMQPipeline.cs | 14 +++++----- .../Pipeline/IPipelineStepHandler.cs | 2 +- service/Abstractions/Pipeline/Queue/IQueue.cs | 2 +- .../Pipeline/{ResultType.cs => ReturnType.cs} | 2 +- .../Core/Handlers/DeleteDocumentHandler.cs | 4 +-- .../Handlers/DeleteGeneratedFilesHandler.cs | 4 +-- service/Core/Handlers/DeleteIndexHandler.cs | 4 +-- .../Handlers/GenerateEmbeddingsHandler.cs | 6 ++-- .../GenerateEmbeddingsParallelHandler.cs | 6 ++-- service/Core/Handlers/SaveRecordsHandler.cs | 4 +-- service/Core/Handlers/SummarizationHandler.cs | 4 +-- .../Handlers/SummarizationParallelHandler.cs | 4 +-- .../Core/Handlers/TextExtractionHandler.cs | 4 +-- .../Core/Handlers/TextPartitioningHandler.cs | 6 ++-- .../DistributedPipelineOrchestrator.cs | 28 +++++++++---------- .../Pipeline/InProcessPipelineOrchestrator.cs | 12 ++++---- .../Pipeline/Queue/DevTools/SimpleQueues.cs | 14 +++++----- 21 files changed, 72 insertions(+), 72 deletions(-) rename service/Abstractions/Pipeline/{ResultType.cs => ReturnType.cs} (87%) diff --git a/examples/201-dotnet-serverless-custom-handler/Program.cs b/examples/201-dotnet-serverless-custom-handler/Program.cs index cd35e3aed..1ca961bab 100644 --- a/examples/201-dotnet-serverless-custom-handler/Program.cs +++ b/examples/201-dotnet-serverless-custom-handler/Program.cs @@ -47,7 +47,7 @@ public MyHandler( public string StepName { get; } /// - public async Task<(ResultType resultType, DataPipeline updatedPipeline)> InvokeAsync( + public async Task<(ReturnType returnType, DataPipeline updatedPipeline)> InvokeAsync( DataPipeline pipeline, CancellationToken cancellationToken = default) { /* ... your custom ... @@ -64,6 +64,6 @@ public MyHandler( // Remove this - here only to avoid build errors await Task.Delay(0, cancellationToken).ConfigureAwait(false); - return (ResultType.Success, pipeline); + return (ReturnType.Success, pipeline); } } diff --git a/examples/202-dotnet-custom-handler-as-a-service/MyHandler.cs b/examples/202-dotnet-custom-handler-as-a-service/MyHandler.cs index e3c7fd89d..8f0dff26d 100644 --- a/examples/202-dotnet-custom-handler-as-a-service/MyHandler.cs +++ b/examples/202-dotnet-custom-handler-as-a-service/MyHandler.cs @@ -38,7 +38,7 @@ public Task StopAsync(CancellationToken cancellationToken = default) } /// - public async Task<(ResultType resultType, DataPipeline updatedPipeline)> InvokeAsync(DataPipeline pipeline, CancellationToken cancellationToken = default) + public async Task<(ReturnType returnType, DataPipeline updatedPipeline)> InvokeAsync(DataPipeline pipeline, CancellationToken cancellationToken = default) { /* ... your custom ... * ... handler ... @@ -49,6 +49,6 @@ public Task StopAsync(CancellationToken cancellationToken = default) // Remove this - here only to avoid build errors await Task.Delay(0, cancellationToken).ConfigureAwait(false); - return (ResultType.Success, pipeline); + return (ReturnType.Success, pipeline); } } diff --git a/extensions/AzureQueues/AzureQueuesPipeline.cs b/extensions/AzureQueues/AzureQueuesPipeline.cs index be4d2282c..2ce06ebd8 100644 --- a/extensions/AzureQueues/AzureQueuesPipeline.cs +++ b/extensions/AzureQueues/AzureQueuesPipeline.cs @@ -181,7 +181,7 @@ public async Task EnqueueAsync(string message, CancellationToken cancellationTok } /// - public void OnDequeue(Func> processMessageAction) + public void OnDequeue(Func> processMessageAction) { this.Received += async (object sender, MessageEventArgs args) => { @@ -192,30 +192,30 @@ public void OnDequeue(Func> processMessageAction) try { - ResultType resultType = await processMessageAction.Invoke(message.MessageText).ConfigureAwait(false); + ReturnType returnType = await processMessageAction.Invoke(message.MessageText).ConfigureAwait(false); if (message.DequeueCount <= this._config.MaxRetriesBeforePoisonQueue) { - switch (resultType) + switch (returnType) { - case ResultType.Success: + case ReturnType.Success: this._log.LogTrace("Message '{0}' successfully processed, deleting message", message.MessageId); await this.DeleteMessageAsync(message, cancellationToken: default).ConfigureAwait(false); break; - case ResultType.TransientError: + case ReturnType.TransientError: var backoffDelay = TimeSpan.FromSeconds(1 * message.DequeueCount); this._log.LogWarning("Message '{0}' failed to process, putting message back in the queue with a delay of {1} msecs", message.MessageId, backoffDelay.TotalMilliseconds); await this.UnlockMessageAsync(message, backoffDelay, cancellationToken: default).ConfigureAwait(false); break; - case ResultType.FatalError: + case ReturnType.FatalError: this._log.LogError("Message '{0}' failed to process due to a non-recoverable error, moving to poison queue", message.MessageId); await this.MoveMessageToPoisonQueueAsync(message, cancellationToken: default).ConfigureAwait(false); break; default: - throw new ArgumentOutOfRangeException($"Unknown {resultType:G} result"); + throw new ArgumentOutOfRangeException($"Unknown {returnType:G} result"); } } else diff --git a/extensions/RabbitMQ/RabbitMQ.TestApplication/Program.cs b/extensions/RabbitMQ/RabbitMQ.TestApplication/Program.cs index 0b110c3bd..ac6556a67 100644 --- a/extensions/RabbitMQ/RabbitMQ.TestApplication/Program.cs +++ b/extensions/RabbitMQ/RabbitMQ.TestApplication/Program.cs @@ -39,7 +39,7 @@ public static async Task Main() { Console.WriteLine($"{++counter} Received message: {msg}"); await Task.Delay(0); - return ResultType.TransientError; + return ReturnType.TransientError; }); await pipeline.ConnectToQueueAsync(QueueName, QueueOptions.PubSub); diff --git a/extensions/RabbitMQ/RabbitMQ/RabbitMQPipeline.cs b/extensions/RabbitMQ/RabbitMQ/RabbitMQPipeline.cs index 3168e66a5..acc08c337 100644 --- a/extensions/RabbitMQ/RabbitMQ/RabbitMQPipeline.cs +++ b/extensions/RabbitMQ/RabbitMQ/RabbitMQPipeline.cs @@ -174,7 +174,7 @@ public Task EnqueueAsync(string message, CancellationToken cancellationToken = d } /// - public void OnDequeue(Func> processMessageAction) + public void OnDequeue(Func> processMessageAction) { this._consumer.Received += async (object sender, BasicDeliverEventArgs args) => { @@ -193,15 +193,15 @@ public void OnDequeue(Func> processMessageAction) byte[] body = args.Body.ToArray(); string message = Encoding.UTF8.GetString(body); - var resultType = await processMessageAction.Invoke(message).ConfigureAwait(false); - switch (resultType) + var returnType = await processMessageAction.Invoke(message).ConfigureAwait(false); + switch (returnType) { - case ResultType.Success: + case ReturnType.Success: this._log.LogTrace("Message '{0}' successfully processed, deleting message", args.BasicProperties?.MessageId); this._channel.BasicAck(args.DeliveryTag, multiple: false); break; - case ResultType.TransientError: + case ReturnType.TransientError: if (attemptNumber < this._maxAttempts) { this._log.LogWarning("Message '{0}' failed to process (attempt {1} of {2}), putting message back in the queue", @@ -220,13 +220,13 @@ public void OnDequeue(Func> processMessageAction) this._channel.BasicNack(args.DeliveryTag, multiple: false, requeue: true); break; - case ResultType.FatalError: + case ReturnType.FatalError: this._log.LogError("Message '{0}' failed to process due to a non-recoverable error, moving to poison queue", args.BasicProperties?.MessageId); this._channel.BasicNack(args.DeliveryTag, multiple: false, requeue: false); break; default: - throw new ArgumentOutOfRangeException($"Unknown {resultType:G} result"); + throw new ArgumentOutOfRangeException($"Unknown {returnType:G} result"); } } catch (KernelMemoryException e) when (e.IsTransient.HasValue && !e.IsTransient.Value) diff --git a/service/Abstractions/Pipeline/IPipelineStepHandler.cs b/service/Abstractions/Pipeline/IPipelineStepHandler.cs index f785c4323..32930c9b9 100644 --- a/service/Abstractions/Pipeline/IPipelineStepHandler.cs +++ b/service/Abstractions/Pipeline/IPipelineStepHandler.cs @@ -20,5 +20,5 @@ public interface IPipelineStepHandler /// Pipeline status /// Async task cancellation token /// Whether the pipeline step has been processed successfully, and the new pipeline status to use moving forward - Task<(ResultType resultType, DataPipeline updatedPipeline)> InvokeAsync(DataPipeline pipeline, CancellationToken cancellationToken = default); + Task<(ReturnType returnType, DataPipeline updatedPipeline)> InvokeAsync(DataPipeline pipeline, CancellationToken cancellationToken = default); } diff --git a/service/Abstractions/Pipeline/Queue/IQueue.cs b/service/Abstractions/Pipeline/Queue/IQueue.cs index ed8f05824..0a3ccad96 100644 --- a/service/Abstractions/Pipeline/Queue/IQueue.cs +++ b/service/Abstractions/Pipeline/Queue/IQueue.cs @@ -28,5 +28,5 @@ public interface IQueue : IDisposable /// Define the logic to execute when a new message is in the queue. /// /// Async action to execute - void OnDequeue(Func> processMessageAction); + void OnDequeue(Func> processMessageAction); } diff --git a/service/Abstractions/Pipeline/ResultType.cs b/service/Abstractions/Pipeline/ReturnType.cs similarity index 87% rename from service/Abstractions/Pipeline/ResultType.cs rename to service/Abstractions/Pipeline/ReturnType.cs index b949fe2ca..d751070f5 100644 --- a/service/Abstractions/Pipeline/ResultType.cs +++ b/service/Abstractions/Pipeline/ReturnType.cs @@ -2,7 +2,7 @@ namespace Microsoft.KernelMemory.Pipeline; -public enum ResultType +public enum ReturnType { Success = 0, TransientError = 1, diff --git a/service/Core/Handlers/DeleteDocumentHandler.cs b/service/Core/Handlers/DeleteDocumentHandler.cs index 8e438511a..d869ef981 100644 --- a/service/Core/Handlers/DeleteDocumentHandler.cs +++ b/service/Core/Handlers/DeleteDocumentHandler.cs @@ -34,7 +34,7 @@ public DeleteDocumentHandler( } /// - public async Task<(ResultType resultType, DataPipeline updatedPipeline)> InvokeAsync( + public async Task<(ReturnType returnType, DataPipeline updatedPipeline)> InvokeAsync( DataPipeline pipeline, CancellationToken cancellationToken = default) { this._log.LogDebug("Deleting document, pipeline '{0}/{1}'", pipeline.Index, pipeline.DocumentId); @@ -60,6 +60,6 @@ await this._documentStorage.EmptyDocumentDirectoryAsync( documentId: pipeline.DocumentId, cancellationToken).ConfigureAwait(false); - return (ResultType.Success, pipeline); + return (ReturnType.Success, pipeline); } } diff --git a/service/Core/Handlers/DeleteGeneratedFilesHandler.cs b/service/Core/Handlers/DeleteGeneratedFilesHandler.cs index 5792aad17..24c5763fd 100644 --- a/service/Core/Handlers/DeleteGeneratedFilesHandler.cs +++ b/service/Core/Handlers/DeleteGeneratedFilesHandler.cs @@ -29,7 +29,7 @@ public DeleteGeneratedFilesHandler( } /// - public async Task<(ResultType resultType, DataPipeline updatedPipeline)> InvokeAsync( + public async Task<(ReturnType returnType, DataPipeline updatedPipeline)> InvokeAsync( DataPipeline pipeline, CancellationToken cancellationToken = default) { this._log.LogDebug("Deleting generated files, pipeline '{0}/{1}'", pipeline.Index, pipeline.DocumentId); @@ -40,6 +40,6 @@ await this._documentStorage.EmptyDocumentDirectoryAsync( documentId: pipeline.DocumentId, cancellationToken).ConfigureAwait(false); - return (ResultType.Success, pipeline); + return (ReturnType.Success, pipeline); } } diff --git a/service/Core/Handlers/DeleteIndexHandler.cs b/service/Core/Handlers/DeleteIndexHandler.cs index a18dbab60..4deccf31a 100644 --- a/service/Core/Handlers/DeleteIndexHandler.cs +++ b/service/Core/Handlers/DeleteIndexHandler.cs @@ -34,7 +34,7 @@ public DeleteIndexHandler( } /// - public async Task<(ResultType resultType, DataPipeline updatedPipeline)> InvokeAsync( + public async Task<(ReturnType returnType, DataPipeline updatedPipeline)> InvokeAsync( DataPipeline pipeline, CancellationToken cancellationToken = default) { this._log.LogDebug("Deleting index, pipeline '{0}/{1}'", pipeline.Index, pipeline.DocumentId); @@ -50,6 +50,6 @@ await this._documentStorage.DeleteIndexDirectoryAsync( index: pipeline.Index, cancellationToken).ConfigureAwait(false); - return (ResultType.Success, pipeline); + return (ReturnType.Success, pipeline); } } diff --git a/service/Core/Handlers/GenerateEmbeddingsHandler.cs b/service/Core/Handlers/GenerateEmbeddingsHandler.cs index 0bdb1a970..9e41cbdc3 100644 --- a/service/Core/Handlers/GenerateEmbeddingsHandler.cs +++ b/service/Core/Handlers/GenerateEmbeddingsHandler.cs @@ -58,13 +58,13 @@ public GenerateEmbeddingsHandler( } /// - public async Task<(ResultType resultType, DataPipeline updatedPipeline)> InvokeAsync( + public async Task<(ReturnType returnType, DataPipeline updatedPipeline)> InvokeAsync( DataPipeline pipeline, CancellationToken cancellationToken = default) { if (!this._embeddingGenerationEnabled) { this._log.LogTrace("Embedding generation is disabled, skipping - pipeline '{0}/{1}'", pipeline.Index, pipeline.DocumentId); - return (ResultType.Success, pipeline); + return (ReturnType.Success, pipeline); } foreach (ITextEmbeddingGenerator generator in this._embeddingGenerators) @@ -83,7 +83,7 @@ public GenerateEmbeddingsHandler( } } - return (ResultType.Success, pipeline); + return (ReturnType.Success, pipeline); } protected override IPipelineStepHandler ActualInstance => this; diff --git a/service/Core/Handlers/GenerateEmbeddingsParallelHandler.cs b/service/Core/Handlers/GenerateEmbeddingsParallelHandler.cs index 8150a126a..2483f414b 100644 --- a/service/Core/Handlers/GenerateEmbeddingsParallelHandler.cs +++ b/service/Core/Handlers/GenerateEmbeddingsParallelHandler.cs @@ -58,13 +58,13 @@ public GenerateEmbeddingsParallelHandler( } /// - public async Task<(ResultType resultType, DataPipeline updatedPipeline)> InvokeAsync( + public async Task<(ReturnType returnType, DataPipeline updatedPipeline)> InvokeAsync( DataPipeline pipeline, CancellationToken cancellationToken = default) { if (!this._embeddingGenerationEnabled) { this._log.LogTrace("Embedding generation is disabled, skipping - pipeline '{0}/{1}'", pipeline.Index, pipeline.DocumentId); - return (ResultType.Success, pipeline); + return (ReturnType.Success, pipeline); } foreach (ITextEmbeddingGenerator generator in this._embeddingGenerators) @@ -83,7 +83,7 @@ public GenerateEmbeddingsParallelHandler( } } - return (ResultType.Success, pipeline); + return (ReturnType.Success, pipeline); } protected override IPipelineStepHandler ActualInstance => this; diff --git a/service/Core/Handlers/SaveRecordsHandler.cs b/service/Core/Handlers/SaveRecordsHandler.cs index 6987d6bc5..cceeee66d 100644 --- a/service/Core/Handlers/SaveRecordsHandler.cs +++ b/service/Core/Handlers/SaveRecordsHandler.cs @@ -103,7 +103,7 @@ public SaveRecordsHandler( } /// - public async Task<(ResultType resultType, DataPipeline updatedPipeline)> InvokeAsync( + public async Task<(ReturnType returnType, DataPipeline updatedPipeline)> InvokeAsync( DataPipeline pipeline, CancellationToken cancellationToken = default) { this._log.LogDebug("Saving memory records, pipeline '{0}/{1}'", pipeline.Index, pipeline.DocumentId); @@ -241,7 +241,7 @@ record = PrepareRecord( this._log.LogWarning("Pipeline '{0}/{1}': step {2}: no records found, cannot save, moving to next pipeline step.", pipeline.Index, pipeline.DocumentId, this.StepName); } - return (ResultType.Success, pipeline); + return (ReturnType.Success, pipeline); } private static IEnumerable GetListOfEmbeddingFiles(DataPipeline pipeline) diff --git a/service/Core/Handlers/SummarizationHandler.cs b/service/Core/Handlers/SummarizationHandler.cs index bd71163a9..a4ee36f9e 100644 --- a/service/Core/Handlers/SummarizationHandler.cs +++ b/service/Core/Handlers/SummarizationHandler.cs @@ -54,7 +54,7 @@ public SummarizationHandler( } /// - public async Task<(ResultType resultType, DataPipeline updatedPipeline)> InvokeAsync( + public async Task<(ReturnType returnType, DataPipeline updatedPipeline)> InvokeAsync( DataPipeline pipeline, CancellationToken cancellationToken = default) { this._log.LogDebug("Generating summary, pipeline '{0}/{1}'", pipeline.Index, pipeline.DocumentId); @@ -125,7 +125,7 @@ public SummarizationHandler( } } - return (ResultType.Success, pipeline); + return (ReturnType.Success, pipeline); } private async Task<(string summary, bool skip)> SummarizeAsync(string content, IContext context) diff --git a/service/Core/Handlers/SummarizationParallelHandler.cs b/service/Core/Handlers/SummarizationParallelHandler.cs index 2370ea570..44e309270 100644 --- a/service/Core/Handlers/SummarizationParallelHandler.cs +++ b/service/Core/Handlers/SummarizationParallelHandler.cs @@ -53,7 +53,7 @@ public SummarizationParallelHandler( } /// - public async Task<(ResultType resultType, DataPipeline updatedPipeline)> InvokeAsync( + public async Task<(ReturnType returnType, DataPipeline updatedPipeline)> InvokeAsync( DataPipeline pipeline, CancellationToken cancellationToken = default) { this._log.LogDebug("Generating summary, pipeline '{0}/{1}'", pipeline.Index, pipeline.DocumentId); @@ -133,7 +133,7 @@ await Parallel.ForEachAsync(uploadedFile.GeneratedFiles, options, async (generat } } - return (ResultType.Success, pipeline); + return (ReturnType.Success, pipeline); } private async Task<(string summary, bool skip)> SummarizeAsync(string content) diff --git a/service/Core/Handlers/TextExtractionHandler.cs b/service/Core/Handlers/TextExtractionHandler.cs index 6b5d34a89..05279ea2b 100644 --- a/service/Core/Handlers/TextExtractionHandler.cs +++ b/service/Core/Handlers/TextExtractionHandler.cs @@ -54,7 +54,7 @@ public TextExtractionHandler( } /// - public async Task<(ResultType resultType, DataPipeline updatedPipeline)> InvokeAsync( + public async Task<(ReturnType returnType, DataPipeline updatedPipeline)> InvokeAsync( DataPipeline pipeline, CancellationToken cancellationToken = default) { this._log.LogDebug("Extracting text, pipeline '{0}/{1}'", pipeline.Index, pipeline.DocumentId); @@ -135,7 +135,7 @@ public TextExtractionHandler( uploadedFile.MarkProcessedBy(this); } - return (ResultType.Success, pipeline); + return (ReturnType.Success, pipeline); } public void Dispose() diff --git a/service/Core/Handlers/TextPartitioningHandler.cs b/service/Core/Handlers/TextPartitioningHandler.cs index 905ecea3f..960c4b47b 100644 --- a/service/Core/Handlers/TextPartitioningHandler.cs +++ b/service/Core/Handlers/TextPartitioningHandler.cs @@ -67,7 +67,7 @@ public TextPartitioningHandler( } /// - public async Task<(ResultType resultType, DataPipeline updatedPipeline)> InvokeAsync( + public async Task<(ReturnType returnType, DataPipeline updatedPipeline)> InvokeAsync( DataPipeline pipeline, CancellationToken cancellationToken = default) { this._log.LogDebug("Partitioning text, pipeline '{0}/{1}'", pipeline.Index, pipeline.DocumentId); @@ -75,7 +75,7 @@ public TextPartitioningHandler( if (pipeline.Files.Count == 0) { this._log.LogWarning("Pipeline '{0}/{1}': there are no files to process, moving to next pipeline step.", pipeline.Index, pipeline.DocumentId); - return (ResultType.Success, pipeline); + return (ReturnType.Success, pipeline); } var context = pipeline.GetContext(); @@ -197,7 +197,7 @@ public TextPartitioningHandler( } } - return (ResultType.Success, pipeline); + return (ReturnType.Success, pipeline); } #pragma warning disable CA2254 // the msg is always used diff --git a/service/Core/Pipeline/DistributedPipelineOrchestrator.cs b/service/Core/Pipeline/DistributedPipelineOrchestrator.cs index 98214ab5b..2d2e462ff 100644 --- a/service/Core/Pipeline/DistributedPipelineOrchestrator.cs +++ b/service/Core/Pipeline/DistributedPipelineOrchestrator.cs @@ -93,7 +93,7 @@ public override async Task AddHandlerAsync( if (pipelinePointer == null) { this.Log.LogError("Pipeline pointer deserialization failed, queue `{0}`. Message discarded.", handler.StepName); - return ResultType.FatalError; + return ReturnType.FatalError; } DataPipeline? pipeline; @@ -121,18 +121,18 @@ public override async Task AddHandlerAsync( } this.Log.LogError("Pipeline `{0}/{1}` not found, cancelling step `{2}`", pipelinePointer.Index, pipelinePointer.DocumentId, handler.StepName); - return ResultType.FatalError; + return ReturnType.FatalError; } catch (InvalidPipelineDataException) { this.Log.LogError("Pipeline `{0}/{1}` state load failed, invalid state, queue `{2}`", pipelinePointer.Index, pipelinePointer.DocumentId, handler.StepName); - return ResultType.TransientError; + return ReturnType.TransientError; } if (pipeline == null) { this.Log.LogError("Pipeline `{0}/{1}` state load failed, the state is null, queue `{2}`", pipelinePointer.Index, pipelinePointer.DocumentId, handler.StepName); - return ResultType.TransientError; + return ReturnType.TransientError; } if (pipelinePointer.ExecutionId != pipeline.ExecutionId) @@ -141,7 +141,7 @@ public override async Task AddHandlerAsync( "Document `{0}/{1}` has been updated without waiting for the previous pipeline execution `{2}` to complete (current execution: `{3}`). " + "Step `{4}` and any consecutive steps from the previous execution have been cancelled.", pipelinePointer.Index, pipelinePointer.DocumentId, pipelinePointer.ExecutionId, pipeline.ExecutionId, handler.StepName); - return ResultType.Success; + return ReturnType.Success; } var currentStepName = pipeline.RemainingSteps.First(); @@ -201,7 +201,7 @@ public override async Task RunPipelineAsync(DataPipeline pipeline, CancellationT #region private - private async Task RunPipelineStepAsync( + private async Task RunPipelineStepAsync( DataPipeline pipeline, IPipelineStepHandler handler, CancellationToken cancellationToken) @@ -210,16 +210,16 @@ private async Task RunPipelineStepAsync( if (pipeline.Complete) { this.Log.LogInformation("Pipeline '{0}/{1}' complete", pipeline.Index, pipeline.DocumentId); - return ResultType.Success; + return ReturnType.Success; } string currentStepName = pipeline.RemainingSteps.First(); // Execute the business logic - exceptions are automatically handled by IQueue - (ResultType resultType, DataPipeline updatedPipeline) = await handler.InvokeAsync(pipeline, cancellationToken).ConfigureAwait(false); - switch (resultType) + (ReturnType returnType, DataPipeline updatedPipeline) = await handler.InvokeAsync(pipeline, cancellationToken).ConfigureAwait(false); + switch (returnType) { - case ResultType.Success: + case ReturnType.Success: pipeline = updatedPipeline; pipeline.LastUpdate = DateTimeOffset.UtcNow; @@ -228,19 +228,19 @@ private async Task RunPipelineStepAsync( await this.MoveForwardAsync(pipeline, cancellationToken).ConfigureAwait(false); break; - case ResultType.TransientError: + case ReturnType.TransientError: this.Log.LogError("Handler {0} failed to process pipeline {1}", currentStepName, pipeline.DocumentId); break; - case ResultType.FatalError: + case ReturnType.FatalError: this.Log.LogError("Handler {0} failed to process pipeline {1} due to an unrecoverable error", currentStepName, pipeline.DocumentId); break; default: - throw new ArgumentOutOfRangeException($"Unknown {resultType:G} result type"); + throw new ArgumentOutOfRangeException($"Unknown {returnType:G} return type"); } - return resultType; + return returnType; } private async Task MoveForwardAsync(DataPipeline pipeline, CancellationToken cancellationToken = default) diff --git a/service/Core/Pipeline/InProcessPipelineOrchestrator.cs b/service/Core/Pipeline/InProcessPipelineOrchestrator.cs index c8dbc9a95..07fd18371 100644 --- a/service/Core/Pipeline/InProcessPipelineOrchestrator.cs +++ b/service/Core/Pipeline/InProcessPipelineOrchestrator.cs @@ -171,13 +171,13 @@ public override async Task RunPipelineAsync(DataPipeline pipeline, CancellationT } // Run handler - (ResultType resultType, DataPipeline updatedPipeline) = await stepHandler + (ReturnType returnType, DataPipeline updatedPipeline) = await stepHandler .InvokeAsync(pipeline, this.CancellationTokenSource.Token) .ConfigureAwait(false); - switch (resultType) + switch (returnType) { - case ResultType.Success: + case ReturnType.Success: pipeline = updatedPipeline; pipeline.LastUpdate = DateTimeOffset.UtcNow; this.Log.LogInformation("Handler '{0}' processed pipeline '{1}/{2}' successfully", currentStepName, pipeline.Index, pipeline.DocumentId); @@ -185,16 +185,16 @@ public override async Task RunPipelineAsync(DataPipeline pipeline, CancellationT await this.UpdatePipelineStatusAsync(pipeline, cancellationToken).ConfigureAwait(false); break; - case ResultType.TransientError: + case ReturnType.TransientError: this.Log.LogError("Handler '{0}' failed to process pipeline '{1}/{2}'", currentStepName, pipeline.Index, pipeline.DocumentId); throw new OrchestrationException($"Pipeline error, step {currentStepName} failed", isTransient: true); - case ResultType.FatalError: + case ReturnType.FatalError: this.Log.LogError("Handler '{0}' failed to process pipeline '{1}/{2}' due to an unrecoverable error", currentStepName, pipeline.Index, pipeline.DocumentId); throw new OrchestrationException($"Unrecoverable pipeline error, step {currentStepName} failed and cannot be retried", isTransient: false); default: - throw new ArgumentOutOfRangeException($"Unknown {resultType:G} result type"); + throw new ArgumentOutOfRangeException($"Unknown {returnType:G} return type"); } } diff --git a/service/Core/Pipeline/Queue/DevTools/SimpleQueues.cs b/service/Core/Pipeline/Queue/DevTools/SimpleQueues.cs index 8fe2ed713..9ce65508d 100644 --- a/service/Core/Pipeline/Queue/DevTools/SimpleQueues.cs +++ b/service/Core/Pipeline/Queue/DevTools/SimpleQueues.cs @@ -177,7 +177,7 @@ await this.StoreMessageAsync( /// /// about the logic handling dequeued messages. - public void OnDequeue(Func> processMessageAction) + public void OnDequeue(Func> processMessageAction) { this._log.LogInformation("Queue {0}: subscribing...", this._queueName); this.Received += async (sender, args) => @@ -193,15 +193,15 @@ public void OnDequeue(Func> processMessageAction) this._log.LogInformation("Queue {0}: message {0} received", this._queueName, message.Id); // Process message with the logic provided by the orchestrator - var resultType = await processMessageAction.Invoke(message.Content).ConfigureAwait(false); - switch (resultType) + var returnType = await processMessageAction.Invoke(message.Content).ConfigureAwait(false); + switch (returnType) { - case ResultType.Success: + case ReturnType.Success: this._log.LogTrace("Message '{0}' successfully processed, deleting message", message.Id); await this.DeleteMessageAsync(message.Id, this._cancellation.Token).ConfigureAwait(false); break; - case ResultType.TransientError: + case ReturnType.TransientError: message.LastError = "Message handler returned false"; if (message.DequeueCount == this._maxAttempts) { @@ -216,13 +216,13 @@ public void OnDequeue(Func> processMessageAction) break; - case ResultType.FatalError: + case ReturnType.FatalError: this._log.LogError("Message '{0}' failed to process due to a non-recoverable error, moving to poison queue", message.Id); poison = true; break; default: - throw new ArgumentOutOfRangeException($"Unknown {resultType:G} result"); + throw new ArgumentOutOfRangeException($"Unknown {returnType:G} result"); } } catch (KernelMemoryException e) when (e.IsTransient.HasValue && !e.IsTransient.Value) From 39432f23bc0cc9cf3df702645780d75d18a8c4d2 Mon Sep 17 00:00:00 2001 From: Devis Lucato Date: Mon, 18 Nov 2024 12:08:23 -0800 Subject: [PATCH 07/10] Revisit list of transient errors --- .../Abstractions/Diagnostics/HttpErrors.cs | 14 ++- .../Diagnostics/HttpErrorsTests.cs | 92 +++++++++++++++++++ 2 files changed, 101 insertions(+), 5 deletions(-) create mode 100644 service/tests/Abstractions.UnitTests/Diagnostics/HttpErrorsTests.cs diff --git a/service/Abstractions/Diagnostics/HttpErrors.cs b/service/Abstractions/Diagnostics/HttpErrors.cs index 7d52e9718..62a188e10 100644 --- a/service/Abstractions/Diagnostics/HttpErrors.cs +++ b/service/Abstractions/Diagnostics/HttpErrors.cs @@ -10,11 +10,15 @@ public static class HttpErrors // Errors that might disappear by retrying private static readonly HashSet s_transientErrors = [ - (int)HttpStatusCode.InternalServerError, - (int)HttpStatusCode.BadGateway, - (int)HttpStatusCode.ServiceUnavailable, - (int)HttpStatusCode.GatewayTimeout, - (int)HttpStatusCode.InsufficientStorage + (int)HttpStatusCode.RequestTimeout, // 408 + (int)HttpStatusCode.PreconditionFailed, // 412 + (int)HttpStatusCode.Locked, // 423 + (int)HttpStatusCode.TooManyRequests, // 429 + (int)HttpStatusCode.InternalServerError, // 500 + (int)HttpStatusCode.BadGateway, // 502 + (int)HttpStatusCode.ServiceUnavailable, // 503 + (int)HttpStatusCode.GatewayTimeout, // 504 + (int)HttpStatusCode.InsufficientStorage // 507 ]; public static bool IsTransientError(this HttpStatusCode statusCode) diff --git a/service/tests/Abstractions.UnitTests/Diagnostics/HttpErrorsTests.cs b/service/tests/Abstractions.UnitTests/Diagnostics/HttpErrorsTests.cs new file mode 100644 index 000000000..1017568c0 --- /dev/null +++ b/service/tests/Abstractions.UnitTests/Diagnostics/HttpErrorsTests.cs @@ -0,0 +1,92 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Net; +using Microsoft.KernelMemory.Diagnostics; + +namespace Microsoft.KM.Abstractions.UnitTests.Diagnostics; + +public sealed class HttpErrorsTests +{ + [Fact] + public void ItRecognizesErrorsFromNulls() + { + HttpStatusCode? statusCode = null; + + Assert.False(statusCode.IsTransientError()); + Assert.False(statusCode.IsFatalError()); + } + + [Theory] + [InlineData(HttpStatusCode.Continue)] // 100 + [InlineData(HttpStatusCode.SwitchingProtocols)] // 101 + [InlineData(HttpStatusCode.Processing)] // 102 + [InlineData(HttpStatusCode.EarlyHints)] // 103 + [InlineData(HttpStatusCode.OK)] // 200 + [InlineData(HttpStatusCode.Created)] // 201 + [InlineData(HttpStatusCode.Accepted)] // 202 + [InlineData(HttpStatusCode.NonAuthoritativeInformation)] // 203 + [InlineData(HttpStatusCode.NoContent)] // 204 + [InlineData(HttpStatusCode.ResetContent)] // 205 + [InlineData(HttpStatusCode.Ambiguous)] // 300 + [InlineData(HttpStatusCode.Moved)] // 301 + [InlineData(HttpStatusCode.Found)] // 302 + [InlineData(HttpStatusCode.RedirectMethod)] // 303 + [InlineData(HttpStatusCode.NotModified)] // 304 + [InlineData(HttpStatusCode.UseProxy)] // 305 + [InlineData(HttpStatusCode.Unused)] // 306 + [InlineData(HttpStatusCode.RedirectKeepVerb)] // 307 + [InlineData(HttpStatusCode.PermanentRedirect)] // 308 + public void ItRecognizesErrors(HttpStatusCode statusCode) + { + Assert.False(statusCode.IsTransientError()); + Assert.False(statusCode.IsFatalError()); + } + + [Theory] + [InlineData(HttpStatusCode.RequestTimeout)] // 408 + [InlineData(HttpStatusCode.PreconditionFailed)] // 412 + [InlineData(HttpStatusCode.Locked)] // 423 + [InlineData(HttpStatusCode.TooManyRequests)] // 429 + [InlineData(HttpStatusCode.InternalServerError)] // 500 + [InlineData(HttpStatusCode.BadGateway)] // 502 + [InlineData(HttpStatusCode.ServiceUnavailable)] // 503 + [InlineData(HttpStatusCode.GatewayTimeout)] // 504 + [InlineData(HttpStatusCode.InsufficientStorage)] // 507 + public void ItRecognizesTransientErrors(HttpStatusCode statusCode) + { + Assert.True(statusCode.IsTransientError()); + Assert.False(statusCode.IsFatalError()); + } + + [Theory] + [InlineData(HttpStatusCode.BadRequest)] // 400 + [InlineData(HttpStatusCode.Unauthorized)] // 401 + [InlineData(HttpStatusCode.PaymentRequired)] // 402 + [InlineData(HttpStatusCode.Forbidden)] // 403 + [InlineData(HttpStatusCode.NotFound)] // 404 + [InlineData(HttpStatusCode.MethodNotAllowed)] // 405 + [InlineData(HttpStatusCode.NotAcceptable)] // 406 + [InlineData(HttpStatusCode.ProxyAuthenticationRequired)] // 407 + [InlineData(HttpStatusCode.Conflict)] // 409 + [InlineData(HttpStatusCode.Gone)] // 410 + [InlineData(HttpStatusCode.LengthRequired)] // 411 + [InlineData(HttpStatusCode.RequestEntityTooLarge)] // 413 + [InlineData(HttpStatusCode.RequestUriTooLong)] // 414 + [InlineData(HttpStatusCode.UnsupportedMediaType)] // 415 + [InlineData(HttpStatusCode.RequestedRangeNotSatisfiable)] // 416 + [InlineData(HttpStatusCode.ExpectationFailed)] // 417 + [InlineData(HttpStatusCode.UnprocessableContent)] // 422 + [InlineData(HttpStatusCode.UpgradeRequired)] // 426 + [InlineData(HttpStatusCode.RequestHeaderFieldsTooLarge)] // 431 + [InlineData(HttpStatusCode.UnavailableForLegalReasons)] // 451 + [InlineData(HttpStatusCode.NotImplemented)] // 501 + [InlineData(HttpStatusCode.HttpVersionNotSupported)] // 505 + [InlineData(HttpStatusCode.LoopDetected)] // 508 + [InlineData(HttpStatusCode.NotExtended)] // 510 + [InlineData(HttpStatusCode.NetworkAuthenticationRequired)] // 511 + public void ItRecognizesFatalErrors(HttpStatusCode statusCode) + { + Assert.False(statusCode.IsTransientError()); + Assert.True(statusCode.IsFatalError()); + } +} From c054276a0be982ea49b6b9a65e0939824aff6d75 Mon Sep 17 00:00:00 2001 From: Devis Lucato Date: Mon, 18 Nov 2024 12:13:30 -0800 Subject: [PATCH 08/10] More tests --- .../Diagnostics/HttpErrorsTests.cs | 83 +++++++++++++++++++ 1 file changed, 83 insertions(+) diff --git a/service/tests/Abstractions.UnitTests/Diagnostics/HttpErrorsTests.cs b/service/tests/Abstractions.UnitTests/Diagnostics/HttpErrorsTests.cs index 1017568c0..579a2ce51 100644 --- a/service/tests/Abstractions.UnitTests/Diagnostics/HttpErrorsTests.cs +++ b/service/tests/Abstractions.UnitTests/Diagnostics/HttpErrorsTests.cs @@ -37,6 +37,35 @@ public void ItRecognizesErrorsFromNulls() [InlineData(HttpStatusCode.RedirectKeepVerb)] // 307 [InlineData(HttpStatusCode.PermanentRedirect)] // 308 public void ItRecognizesErrors(HttpStatusCode statusCode) + { + Assert.False(statusCode.IsTransientError()); + Assert.False(HttpErrors.IsTransientError((int)statusCode)); + + Assert.False(statusCode.IsFatalError()); + Assert.False(HttpErrors.IsFatalError((int)statusCode)); + } + + [Theory] + [InlineData(HttpStatusCode.Continue)] // 100 + [InlineData(HttpStatusCode.SwitchingProtocols)] // 101 + [InlineData(HttpStatusCode.Processing)] // 102 + [InlineData(HttpStatusCode.EarlyHints)] // 103 + [InlineData(HttpStatusCode.OK)] // 200 + [InlineData(HttpStatusCode.Created)] // 201 + [InlineData(HttpStatusCode.Accepted)] // 202 + [InlineData(HttpStatusCode.NonAuthoritativeInformation)] // 203 + [InlineData(HttpStatusCode.NoContent)] // 204 + [InlineData(HttpStatusCode.ResetContent)] // 205 + [InlineData(HttpStatusCode.Ambiguous)] // 300 + [InlineData(HttpStatusCode.Moved)] // 301 + [InlineData(HttpStatusCode.Found)] // 302 + [InlineData(HttpStatusCode.RedirectMethod)] // 303 + [InlineData(HttpStatusCode.NotModified)] // 304 + [InlineData(HttpStatusCode.UseProxy)] // 305 + [InlineData(HttpStatusCode.Unused)] // 306 + [InlineData(HttpStatusCode.RedirectKeepVerb)] // 307 + [InlineData(HttpStatusCode.PermanentRedirect)] // 308 + public void ItRecognizesErrors(HttpStatusCode? statusCode) { Assert.False(statusCode.IsTransientError()); Assert.False(statusCode.IsFatalError()); @@ -53,6 +82,25 @@ public void ItRecognizesErrors(HttpStatusCode statusCode) [InlineData(HttpStatusCode.GatewayTimeout)] // 504 [InlineData(HttpStatusCode.InsufficientStorage)] // 507 public void ItRecognizesTransientErrors(HttpStatusCode statusCode) + { + Assert.True(statusCode.IsTransientError()); + Assert.True(HttpErrors.IsTransientError((int)statusCode)); + + Assert.False(statusCode.IsFatalError()); + Assert.False(HttpErrors.IsFatalError((int)statusCode)); + } + + [Theory] + [InlineData(HttpStatusCode.RequestTimeout)] // 408 + [InlineData(HttpStatusCode.PreconditionFailed)] // 412 + [InlineData(HttpStatusCode.Locked)] // 423 + [InlineData(HttpStatusCode.TooManyRequests)] // 429 + [InlineData(HttpStatusCode.InternalServerError)] // 500 + [InlineData(HttpStatusCode.BadGateway)] // 502 + [InlineData(HttpStatusCode.ServiceUnavailable)] // 503 + [InlineData(HttpStatusCode.GatewayTimeout)] // 504 + [InlineData(HttpStatusCode.InsufficientStorage)] // 507 + public void ItRecognizesTransientErrors(HttpStatusCode? statusCode) { Assert.True(statusCode.IsTransientError()); Assert.False(statusCode.IsFatalError()); @@ -85,6 +133,41 @@ public void ItRecognizesTransientErrors(HttpStatusCode statusCode) [InlineData(HttpStatusCode.NotExtended)] // 510 [InlineData(HttpStatusCode.NetworkAuthenticationRequired)] // 511 public void ItRecognizesFatalErrors(HttpStatusCode statusCode) + { + Assert.False(statusCode.IsTransientError()); + Assert.False(HttpErrors.IsTransientError((int)statusCode)); + + Assert.True(statusCode.IsFatalError()); + Assert.True(HttpErrors.IsFatalError((int)statusCode)); + } + + [Theory] + [InlineData(HttpStatusCode.BadRequest)] // 400 + [InlineData(HttpStatusCode.Unauthorized)] // 401 + [InlineData(HttpStatusCode.PaymentRequired)] // 402 + [InlineData(HttpStatusCode.Forbidden)] // 403 + [InlineData(HttpStatusCode.NotFound)] // 404 + [InlineData(HttpStatusCode.MethodNotAllowed)] // 405 + [InlineData(HttpStatusCode.NotAcceptable)] // 406 + [InlineData(HttpStatusCode.ProxyAuthenticationRequired)] // 407 + [InlineData(HttpStatusCode.Conflict)] // 409 + [InlineData(HttpStatusCode.Gone)] // 410 + [InlineData(HttpStatusCode.LengthRequired)] // 411 + [InlineData(HttpStatusCode.RequestEntityTooLarge)] // 413 + [InlineData(HttpStatusCode.RequestUriTooLong)] // 414 + [InlineData(HttpStatusCode.UnsupportedMediaType)] // 415 + [InlineData(HttpStatusCode.RequestedRangeNotSatisfiable)] // 416 + [InlineData(HttpStatusCode.ExpectationFailed)] // 417 + [InlineData(HttpStatusCode.UnprocessableContent)] // 422 + [InlineData(HttpStatusCode.UpgradeRequired)] // 426 + [InlineData(HttpStatusCode.RequestHeaderFieldsTooLarge)] // 431 + [InlineData(HttpStatusCode.UnavailableForLegalReasons)] // 451 + [InlineData(HttpStatusCode.NotImplemented)] // 501 + [InlineData(HttpStatusCode.HttpVersionNotSupported)] // 505 + [InlineData(HttpStatusCode.LoopDetected)] // 508 + [InlineData(HttpStatusCode.NotExtended)] // 510 + [InlineData(HttpStatusCode.NetworkAuthenticationRequired)] // 511 + public void ItRecognizesFatalErrors(HttpStatusCode? statusCode) { Assert.False(statusCode.IsTransientError()); Assert.True(statusCode.IsFatalError()); From 1ada50de86373480b218c6bf6f372eb0a02d436b Mon Sep 17 00:00:00 2001 From: Devis Lucato Date: Mon, 18 Nov 2024 12:36:04 -0800 Subject: [PATCH 09/10] Stop retrying unsupported file types --- .../Pipeline/MimeTypeException.cs | 26 +++++++++++++++++++ service/Abstractions/Pipeline/MimeTypes.cs | 2 +- service/Core/Pipeline/BaseOrchestrator.cs | 2 +- 3 files changed, 28 insertions(+), 2 deletions(-) create mode 100644 service/Abstractions/Pipeline/MimeTypeException.cs diff --git a/service/Abstractions/Pipeline/MimeTypeException.cs b/service/Abstractions/Pipeline/MimeTypeException.cs new file mode 100644 index 000000000..3a305790d --- /dev/null +++ b/service/Abstractions/Pipeline/MimeTypeException.cs @@ -0,0 +1,26 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; + +namespace Microsoft.KernelMemory.Pipeline; + +public class MimeTypeException : KernelMemoryException +{ + /// + public MimeTypeException(bool? isTransient = null) + { + this.IsTransient = isTransient; + } + + /// + public MimeTypeException(string message, bool? isTransient = null) : base(message) + { + this.IsTransient = isTransient; + } + + /// + public MimeTypeException(string message, Exception? innerException, bool? isTransient = null) : base(message, innerException) + { + this.IsTransient = isTransient; + } +} diff --git a/service/Abstractions/Pipeline/MimeTypes.cs b/service/Abstractions/Pipeline/MimeTypes.cs index 6efdc4aa8..956eeae39 100644 --- a/service/Abstractions/Pipeline/MimeTypes.cs +++ b/service/Abstractions/Pipeline/MimeTypes.cs @@ -221,7 +221,7 @@ public string GetFileType(string filename) return mimeType; } - throw new NotSupportedException($"File type not supported: {filename}"); + throw new MimeTypeException($"File type not supported: {filename}", isTransient: false); } public bool TryGetFileType(string filename, out string? mimeType) diff --git a/service/Core/Pipeline/BaseOrchestrator.cs b/service/Core/Pipeline/BaseOrchestrator.cs index 8311c0406..c613139b6 100644 --- a/service/Core/Pipeline/BaseOrchestrator.cs +++ b/service/Core/Pipeline/BaseOrchestrator.cs @@ -477,7 +477,7 @@ private async Task UploadFormFilesAsync(DataPipeline pipeline, CancellationToken { mimeType = this._mimeTypeDetection.GetFileType(file.FileName); } - catch (NotSupportedException) + catch (MimeTypeException) { this.Log.LogWarning("File type not supported, the ingestion pipeline might skip it"); } From e96e4aa7fa6e76089282ac06a549e7d4742de954 Mon Sep 17 00:00:00 2001 From: Devis Lucato Date: Mon, 18 Nov 2024 14:01:04 -0800 Subject: [PATCH 10/10] Tag tests --- service/Core/MemoryStorage/DevTools/SimpleVectorDb.cs | 2 ++ .../Abstractions.UnitTests/Diagnostics/HttpErrorsTests.cs | 7 +++++++ 2 files changed, 9 insertions(+) diff --git a/service/Core/MemoryStorage/DevTools/SimpleVectorDb.cs b/service/Core/MemoryStorage/DevTools/SimpleVectorDb.cs index 0d421dfcb..38ee33e06 100644 --- a/service/Core/MemoryStorage/DevTools/SimpleVectorDb.cs +++ b/service/Core/MemoryStorage/DevTools/SimpleVectorDb.cs @@ -113,6 +113,8 @@ public async Task UpsertAsync(string index, MemoryRecord record, Cancell records[r.Id] = r; } + this._log.LogDebug("{VectorCount} vectors loaded for similarity check", records.Count); + // Calculate all the distances from the given vector // Note: this is a brute force search, very slow, not meant for production use cases var similarity = new Dictionary(); diff --git a/service/tests/Abstractions.UnitTests/Diagnostics/HttpErrorsTests.cs b/service/tests/Abstractions.UnitTests/Diagnostics/HttpErrorsTests.cs index 579a2ce51..269330274 100644 --- a/service/tests/Abstractions.UnitTests/Diagnostics/HttpErrorsTests.cs +++ b/service/tests/Abstractions.UnitTests/Diagnostics/HttpErrorsTests.cs @@ -8,6 +8,7 @@ namespace Microsoft.KM.Abstractions.UnitTests.Diagnostics; public sealed class HttpErrorsTests { [Fact] + [Trait("Category", "UnitTest")] public void ItRecognizesErrorsFromNulls() { HttpStatusCode? statusCode = null; @@ -17,6 +18,7 @@ public void ItRecognizesErrorsFromNulls() } [Theory] + [Trait("Category", "UnitTest")] [InlineData(HttpStatusCode.Continue)] // 100 [InlineData(HttpStatusCode.SwitchingProtocols)] // 101 [InlineData(HttpStatusCode.Processing)] // 102 @@ -46,6 +48,7 @@ public void ItRecognizesErrors(HttpStatusCode statusCode) } [Theory] + [Trait("Category", "UnitTest")] [InlineData(HttpStatusCode.Continue)] // 100 [InlineData(HttpStatusCode.SwitchingProtocols)] // 101 [InlineData(HttpStatusCode.Processing)] // 102 @@ -72,6 +75,7 @@ public void ItRecognizesErrors(HttpStatusCode? statusCode) } [Theory] + [Trait("Category", "UnitTest")] [InlineData(HttpStatusCode.RequestTimeout)] // 408 [InlineData(HttpStatusCode.PreconditionFailed)] // 412 [InlineData(HttpStatusCode.Locked)] // 423 @@ -91,6 +95,7 @@ public void ItRecognizesTransientErrors(HttpStatusCode statusCode) } [Theory] + [Trait("Category", "UnitTest")] [InlineData(HttpStatusCode.RequestTimeout)] // 408 [InlineData(HttpStatusCode.PreconditionFailed)] // 412 [InlineData(HttpStatusCode.Locked)] // 423 @@ -107,6 +112,7 @@ public void ItRecognizesTransientErrors(HttpStatusCode? statusCode) } [Theory] + [Trait("Category", "UnitTest")] [InlineData(HttpStatusCode.BadRequest)] // 400 [InlineData(HttpStatusCode.Unauthorized)] // 401 [InlineData(HttpStatusCode.PaymentRequired)] // 402 @@ -142,6 +148,7 @@ public void ItRecognizesFatalErrors(HttpStatusCode statusCode) } [Theory] + [Trait("Category", "UnitTest")] [InlineData(HttpStatusCode.BadRequest)] // 400 [InlineData(HttpStatusCode.Unauthorized)] // 401 [InlineData(HttpStatusCode.PaymentRequired)] // 402