diff --git a/plugins/amazonq/chat/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/amazonq/clients/AmazonQStreamingClientTest.kt b/plugins/amazonq/chat/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/amazonq/clients/AmazonQStreamingClientTest.kt index bd1c65d257f..3727a46aa7f 100644 --- a/plugins/amazonq/chat/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/amazonq/clients/AmazonQStreamingClientTest.kt +++ b/plugins/amazonq/chat/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/amazonq/clients/AmazonQStreamingClientTest.kt @@ -5,12 +5,15 @@ package software.aws.toolkits.jetbrains.services.amazonq.clients import com.intellij.testFramework.RuleChain import com.intellij.testFramework.replaceService +import kotlinx.coroutines.runBlocking import kotlinx.coroutines.test.runTest +import org.assertj.core.api.Assertions.assertThat import org.junit.Before import org.junit.Rule import org.junit.Test import org.mockito.kotlin.any import org.mockito.kotlin.argumentCaptor +import org.mockito.kotlin.doAnswer import org.mockito.kotlin.doReturn import org.mockito.kotlin.mock import org.mockito.kotlin.stub @@ -20,6 +23,7 @@ import software.amazon.awssdk.services.codewhispererstreaming.CodeWhispererStrea import software.amazon.awssdk.services.codewhispererstreaming.model.ExportIntent import software.amazon.awssdk.services.codewhispererstreaming.model.ExportResultArchiveRequest import software.amazon.awssdk.services.codewhispererstreaming.model.ExportResultArchiveResponseHandler +import software.amazon.awssdk.services.codewhispererstreaming.model.ValidationException import software.amazon.awssdk.services.ssooidc.SsoOidcClient import software.aws.toolkits.core.TokenConnectionSettings import software.aws.toolkits.core.utils.test.aString @@ -81,4 +85,156 @@ class AmazonQStreamingClientTest : AmazonQTestBase() { verify(streamingBearerClient).exportResultArchive(requestCaptor.capture(), handlerCaptor.capture()) } } + + @Test + fun `verify retry on ValidationException`(): Unit = runBlocking { + var attemptCount = 0 + streamingBearerClient = mockClientManagerRule.create().stub { + on { + exportResultArchive(any(), any()) + } doAnswer { + attemptCount++ + if (attemptCount <= 2) { + CompletableFuture.runAsync { + throw VALIDATION_EXCEPTION + } + } else { + CompletableFuture.completedFuture(mock()) + } + } + } + + amazonQStreamingClient.exportResultArchive("test-id", ExportIntent.TRANSFORMATION, null, {}, {}) + + assertThat(attemptCount).isEqualTo(3) + } + + @Test + fun `verify retry gives up after max attempts`(): Unit = runBlocking { + var attemptCount = 0 + streamingBearerClient = mockClientManagerRule.create().stub { + on { + exportResultArchive(any(), any()) + } doAnswer { + attemptCount++ + CompletableFuture.runAsync { + throw VALIDATION_EXCEPTION + } + } + } + + val thrown = catchCoroutineException { + amazonQStreamingClient.exportResultArchive("test-id", ExportIntent.TRANSFORMATION, null, {}, {}) + } + + assertThat(attemptCount).isEqualTo(3) + assertThat(thrown) + .isInstanceOf(ValidationException::class.java) + .hasMessage("Resource validation failed") + } + + @Test + fun `verify no retry on non-retryable exception`(): Unit = runBlocking { + var attemptCount = 0 + + streamingBearerClient = mockClientManagerRule.create().stub { + on { + exportResultArchive(any(), any()) + } doAnswer { + attemptCount++ + CompletableFuture.runAsync { + throw IllegalArgumentException("Non-retryable error") + } + } + } + + val thrown = catchCoroutineException { + amazonQStreamingClient.exportResultArchive("test-id", ExportIntent.TRANSFORMATION, null, {}, {}) + } + + assertThat(attemptCount).isEqualTo(1) + assertThat(thrown) + .isInstanceOf(IllegalArgumentException::class.java) + .hasMessage("Non-retryable error") + } + + @Test + fun `verify backoff timing between retries`(): Unit = runBlocking { + var lastAttemptTime = 0L + var minBackoffObserved = Long.MAX_VALUE + var maxBackoffObserved = 0L + + streamingBearerClient = mockClientManagerRule.create().stub { + on { + exportResultArchive(any(), any()) + } doAnswer { + val currentTime = System.currentTimeMillis() + if (lastAttemptTime > 0) { + val backoffTime = currentTime - lastAttemptTime + minBackoffObserved = minOf(minBackoffObserved, backoffTime) + maxBackoffObserved = maxOf(maxBackoffObserved, backoffTime) + } + lastAttemptTime = currentTime + + CompletableFuture.runAsync { + throw VALIDATION_EXCEPTION + } + } + } + + val thrown = catchCoroutineException { + amazonQStreamingClient.exportResultArchive("test-id", ExportIntent.TRANSFORMATION, null, {}, {}) + } + + assertThat(thrown) + .isInstanceOf(ValidationException::class.java) + .hasMessage("Resource validation failed") + assertThat(minBackoffObserved).isGreaterThanOrEqualTo(100) + assertThat(maxBackoffObserved).isLessThanOrEqualTo(10000) + } + + @Test + fun `verify onError callback is called with final exception`(): Unit = runBlocking { + var errorCaught: Exception? = null + + streamingBearerClient = mockClientManagerRule.create().stub { + on { + exportResultArchive(any(), any()) + } doAnswer { + CompletableFuture.runAsync { + throw VALIDATION_EXCEPTION + } + } + } + + val thrown = catchCoroutineException { + amazonQStreamingClient.exportResultArchive( + "test-id", + ExportIntent.TRANSFORMATION, + null, + { errorCaught = it }, + {} + ) + } + + assertThat(thrown) + .isInstanceOf(ValidationException::class.java) + .hasMessage("Resource validation failed") + assertThat(errorCaught).isEqualTo(VALIDATION_EXCEPTION) + } + + private suspend fun catchCoroutineException(block: suspend () -> Unit): Throwable { + try { + block() + error("Expected exception was not thrown") + } catch (e: Throwable) { + return e + } + } + + companion object { + private val VALIDATION_EXCEPTION = ValidationException.builder() + .message("Resource validation failed") + .build() + } } diff --git a/plugins/amazonq/codewhisperer/jetbrains-community/src/software/aws/toolkits/jetbrains/services/codewhisperer/util/CodeWhispererZipUploadManager.kt b/plugins/amazonq/codewhisperer/jetbrains-community/src/software/aws/toolkits/jetbrains/services/codewhisperer/util/CodeWhispererZipUploadManager.kt index c259f06973a..edee7644674 100644 --- a/plugins/amazonq/codewhisperer/jetbrains-community/src/software/aws/toolkits/jetbrains/services/codewhisperer/util/CodeWhispererZipUploadManager.kt +++ b/plugins/amazonq/codewhisperer/jetbrains-community/src/software/aws/toolkits/jetbrains/services/codewhisperer/util/CodeWhispererZipUploadManager.kt @@ -20,6 +20,7 @@ import software.amazon.awssdk.utils.IoUtils import software.aws.toolkits.core.utils.debug import software.aws.toolkits.core.utils.getLogger import software.aws.toolkits.jetbrains.core.AwsClientManager +import software.aws.toolkits.jetbrains.services.amazonq.RetryableOperation import software.aws.toolkits.jetbrains.services.codewhisperer.codescan.CodeWhispererCodeScanSession.Companion.APPLICATION_ZIP import software.aws.toolkits.jetbrains.services.codewhisperer.codescan.CodeWhispererCodeScanSession.Companion.AWS_KMS import software.aws.toolkits.jetbrains.services.codewhisperer.codescan.CodeWhispererCodeScanSession.Companion.CONTENT_MD5 @@ -208,41 +209,3 @@ fun getTelemetryErrorMessage(e: Exception, featureUseCase: CodeWhispererConstant else -> message("testgen.message.failed") } } - -class RetryableOperation { - private var attempts = 0 - private var currentDelay = INITIAL_DELAY - private var lastException: Exception? = null - - fun execute( - operation: () -> T, - isRetryable: (Exception) -> Boolean, - errorHandler: (Exception, Int) -> Nothing, - ): T { - while (attempts < MAX_RETRY_ATTEMPTS) { - try { - return operation() - } catch (e: Exception) { - lastException = e - - attempts++ - if (attempts < MAX_RETRY_ATTEMPTS && isRetryable(e)) { - Thread.sleep(currentDelay) - currentDelay = (currentDelay * 2).coerceAtMost(MAX_BACKOFF) - continue - } - - errorHandler(e, attempts) - } - } - - // This line should never be reached due to errorHandler throwing exception - throw RuntimeException("Unexpected state after $attempts attempts") - } - - companion object { - private const val INITIAL_DELAY = 100L // milliseconds - private const val MAX_BACKOFF = 10000L // milliseconds - private const val MAX_RETRY_ATTEMPTS = 3 - } -} diff --git a/plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/RetryableOperation.kt b/plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/RetryableOperation.kt new file mode 100644 index 00000000000..049ca957d08 --- /dev/null +++ b/plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/RetryableOperation.kt @@ -0,0 +1,53 @@ +// Copyright 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package software.aws.toolkits.jetbrains.services.amazonq + +import kotlinx.coroutines.delay +import kotlinx.coroutines.runBlocking +import software.amazon.awssdk.core.exception.RetryableException +import kotlin.random.Random + +class RetryableOperation { + private var attempts = 0 + private var currentDelay = INITIAL_DELAY + + private fun getJitteredDelay(): Long { + currentDelay = (currentDelay * 2).coerceAtMost(MAX_BACKOFF) + return (currentDelay * (0.5 + Random.nextDouble(0.5))).toLong() + } + + fun execute( + operation: () -> T, + isRetryable: (Exception) -> Boolean = { it is RetryableException }, + errorHandler: ((Exception, Int) -> Nothing), + ): T = runBlocking { + executeSuspend(operation, isRetryable, errorHandler) + } + + suspend fun executeSuspend( + operation: suspend () -> T, + isRetryable: (Exception) -> Boolean = { it is RetryableException }, + errorHandler: (suspend (Exception, Int) -> Nothing), + ): T { + while (attempts < MAX_RETRY_ATTEMPTS) { + try { + return operation() + } catch (e: Exception) { + attempts++ + if (attempts >= MAX_RETRY_ATTEMPTS || !isRetryable(e)) { + errorHandler.invoke(e, attempts) + } + delay(getJitteredDelay()) + } + } + + throw RuntimeException("Unexpected state after $attempts attempts") + } + + companion object { + private const val INITIAL_DELAY = 100L + private const val MAX_BACKOFF = 10000L + private const val MAX_RETRY_ATTEMPTS = 3 + } +} diff --git a/plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/clients/AmazonQStreamingClient.kt b/plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/clients/AmazonQStreamingClient.kt index 07967cc89fd..43c6e85c8b1 100644 --- a/plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/clients/AmazonQStreamingClient.kt +++ b/plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/clients/AmazonQStreamingClient.kt @@ -7,16 +7,21 @@ import com.intellij.openapi.components.Service import com.intellij.openapi.components.service import com.intellij.openapi.project.Project import kotlinx.coroutines.future.await +import software.amazon.awssdk.core.exception.SdkException import software.amazon.awssdk.services.codewhispererstreaming.CodeWhispererStreamingAsyncClient import software.amazon.awssdk.services.codewhispererstreaming.model.ExportContext import software.amazon.awssdk.services.codewhispererstreaming.model.ExportIntent import software.amazon.awssdk.services.codewhispererstreaming.model.ExportResultArchiveResponseHandler +import software.amazon.awssdk.services.codewhispererstreaming.model.ThrottlingException +import software.amazon.awssdk.services.codewhispererstreaming.model.ValidationException import software.aws.toolkits.core.utils.getLogger import software.aws.toolkits.core.utils.warn import software.aws.toolkits.jetbrains.core.awsClient import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnectionManager import software.aws.toolkits.jetbrains.core.credentials.pinning.QConnection +import software.aws.toolkits.jetbrains.services.amazonq.RetryableOperation import java.time.Instant +import java.util.concurrent.TimeoutException import java.util.concurrent.atomic.AtomicReference @Service(Service.Level.PROJECT) @@ -54,30 +59,45 @@ class AmazonQStreamingClient(private val project: Project) { val checksum = AtomicReference("") try { - val result = streamingBearerClient().exportResultArchive( - { - it.exportId(exportId) - it.exportIntent(exportIntent) - it.exportContext(exportContext) + RetryableOperation().executeSuspend( + operation = { + val result = streamingBearerClient().exportResultArchive( + { + it.exportId(exportId) + it.exportIntent(exportIntent) + it.exportContext(exportContext) + }, + ExportResultArchiveResponseHandler.builder().subscriber( + ExportResultArchiveResponseHandler.Visitor.builder() + .onBinaryMetadataEvent { + checksum.set(it.contentChecksum()) + }.onBinaryPayloadEvent { + val payloadBytes = it.bytes().asByteArray() + byteBufferList.add(payloadBytes) + }.onDefault { + LOG.warn { "Received unknown payload stream: $it" } + } + .build() + ) + .build() + ) + result.await() }, - ExportResultArchiveResponseHandler.builder().subscriber( - ExportResultArchiveResponseHandler.Visitor.builder() - .onBinaryMetadataEvent { - checksum.set(it.contentChecksum()) - }.onBinaryPayloadEvent { - val payloadBytes = it.bytes().asByteArray() - byteBufferList.add(payloadBytes) - }.onDefault { - LOG.warn { "Received unknown payload stream: $it" } - } - .build() - ) - .build() + isRetryable = { e -> + when (e) { + is ValidationException, + is ThrottlingException, + is SdkException, + is TimeoutException, + -> true + else -> false + } + }, + errorHandler = { e, attempts -> + onError(e) + throw e + } ) - result.await() - } catch (e: Exception) { - onError(e) - throw e } finally { onStreamingFinished(startTime) }