Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,15 @@ package software.aws.toolkits.jetbrains.services.amazonq.clients

import com.intellij.testFramework.RuleChain
import com.intellij.testFramework.replaceService
import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.test.runTest
import org.assertj.core.api.Assertions.assertThat
import org.junit.Before
import org.junit.Rule
import org.junit.Test
import org.mockito.kotlin.any
import org.mockito.kotlin.argumentCaptor
import org.mockito.kotlin.doAnswer
import org.mockito.kotlin.doReturn
import org.mockito.kotlin.mock
import org.mockito.kotlin.stub
Expand All @@ -20,6 +23,7 @@ import software.amazon.awssdk.services.codewhispererstreaming.CodeWhispererStrea
import software.amazon.awssdk.services.codewhispererstreaming.model.ExportIntent
import software.amazon.awssdk.services.codewhispererstreaming.model.ExportResultArchiveRequest
import software.amazon.awssdk.services.codewhispererstreaming.model.ExportResultArchiveResponseHandler
import software.amazon.awssdk.services.codewhispererstreaming.model.ValidationException
import software.amazon.awssdk.services.ssooidc.SsoOidcClient
import software.aws.toolkits.core.TokenConnectionSettings
import software.aws.toolkits.core.utils.test.aString
Expand Down Expand Up @@ -81,4 +85,156 @@ class AmazonQStreamingClientTest : AmazonQTestBase() {
verify(streamingBearerClient).exportResultArchive(requestCaptor.capture(), handlerCaptor.capture())
}
}

@Test
fun `verify retry on ValidationException`(): Unit = runBlocking {
var attemptCount = 0
streamingBearerClient = mockClientManagerRule.create<CodeWhispererStreamingAsyncClient>().stub {
on {
exportResultArchive(any<ExportResultArchiveRequest>(), any<ExportResultArchiveResponseHandler>())
} doAnswer {
attemptCount++
if (attemptCount <= 2) {
CompletableFuture<Void>().apply {
completeExceptionally(VALIDATION_EXCEPTION)
}
} else {
CompletableFuture.completedFuture(mock())
}
}
}

amazonQStreamingClient.exportResultArchive("test-id", ExportIntent.TRANSFORMATION, null, {}, {})

assertThat(attemptCount).isEqualTo(3)
}

@Test
fun `verify retry gives up after max attempts`(): Unit = runBlocking {
var attemptCount = 0
streamingBearerClient = mockClientManagerRule.create<CodeWhispererStreamingAsyncClient>().stub {
on {
exportResultArchive(any<ExportResultArchiveRequest>(), any<ExportResultArchiveResponseHandler>())
} doAnswer {
attemptCount++
CompletableFuture<Void>().apply {
completeExceptionally(VALIDATION_EXCEPTION)
}
}
}

val thrown = catchCoroutineException {
amazonQStreamingClient.exportResultArchive("test-id", ExportIntent.TRANSFORMATION, null, {}, {})
}

assertThat(attemptCount).isEqualTo(3)
assertThat(thrown)
.isInstanceOf(ValidationException::class.java)
.hasMessage("Resource validation failed")
}

@Test
fun `verify no retry on non-retryable exception`(): Unit = runBlocking {
var attemptCount = 0

streamingBearerClient = mockClientManagerRule.create<CodeWhispererStreamingAsyncClient>().stub {
on {
exportResultArchive(any<ExportResultArchiveRequest>(), any<ExportResultArchiveResponseHandler>())
} doAnswer {
attemptCount++
CompletableFuture<Void>().apply {
completeExceptionally(IllegalArgumentException("Non-retryable error"))
}
}
}

val thrown = catchCoroutineException {
amazonQStreamingClient.exportResultArchive("test-id", ExportIntent.TRANSFORMATION, null, {}, {})
}

assertThat(attemptCount).isEqualTo(1)
assertThat(thrown)
.isInstanceOf(IllegalArgumentException::class.java)
.hasMessage("Non-retryable error")
}

@Test
fun `verify backoff timing between retries`(): Unit = runBlocking {
var lastAttemptTime = 0L
var minBackoffObserved = Long.MAX_VALUE
var maxBackoffObserved = 0L

streamingBearerClient = mockClientManagerRule.create<CodeWhispererStreamingAsyncClient>().stub {
on {
exportResultArchive(any<ExportResultArchiveRequest>(), any<ExportResultArchiveResponseHandler>())
} doAnswer {
val currentTime = System.currentTimeMillis()
if (lastAttemptTime > 0) {
val backoffTime = currentTime - lastAttemptTime
minBackoffObserved = minOf(minBackoffObserved, backoffTime)
maxBackoffObserved = maxOf(maxBackoffObserved, backoffTime)
}
lastAttemptTime = currentTime

CompletableFuture<Void>().apply {
completeExceptionally(VALIDATION_EXCEPTION)
}
}
}

val thrown = catchCoroutineException {
amazonQStreamingClient.exportResultArchive("test-id", ExportIntent.TRANSFORMATION, null, {}, {})
}

assertThat(thrown)
.isInstanceOf(ValidationException::class.java)
.hasMessage("Resource validation failed")
assertThat(minBackoffObserved).isGreaterThanOrEqualTo(100)
assertThat(maxBackoffObserved).isLessThanOrEqualTo(10000)
}

@Test
fun `verify onError callback is called with final exception`(): Unit = runBlocking {
var errorCaught: Exception? = null

streamingBearerClient = mockClientManagerRule.create<CodeWhispererStreamingAsyncClient>().stub {
on {
exportResultArchive(any<ExportResultArchiveRequest>(), any<ExportResultArchiveResponseHandler>())
} doAnswer {
CompletableFuture<Void>().apply {
completeExceptionally(VALIDATION_EXCEPTION)
}
}
}

val thrown = catchCoroutineException {
amazonQStreamingClient.exportResultArchive(
"test-id",
ExportIntent.TRANSFORMATION,
null,
{ errorCaught = it },
{}
)
}

assertThat(thrown)
.isInstanceOf(ValidationException::class.java)
.hasMessage("Resource validation failed")
assertThat(errorCaught).isEqualTo(VALIDATION_EXCEPTION)
}

private suspend fun catchCoroutineException(block: suspend () -> Unit): Throwable {
try {
block()
error("Expected exception was not thrown")
} catch (e: Throwable) {
return e
}
}

companion object {
private val VALIDATION_EXCEPTION = ValidationException.builder()
.message("Resource validation failed")
.build()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,26 @@ package software.aws.toolkits.jetbrains.services.amazonq.clients
import com.intellij.openapi.components.Service
import com.intellij.openapi.components.service
import com.intellij.openapi.project.Project
import kotlinx.coroutines.delay
import kotlinx.coroutines.future.await
import software.amazon.awssdk.core.exception.RetryableException
import software.amazon.awssdk.core.exception.SdkException
import software.amazon.awssdk.services.codewhispererstreaming.CodeWhispererStreamingAsyncClient
import software.amazon.awssdk.services.codewhispererstreaming.model.ExportContext
import software.amazon.awssdk.services.codewhispererstreaming.model.ExportIntent
import software.amazon.awssdk.services.codewhispererstreaming.model.ExportResultArchiveResponseHandler
import software.amazon.awssdk.services.codewhispererstreaming.model.ThrottlingException
import software.amazon.awssdk.services.codewhispererstreaming.model.ValidationException
import software.aws.toolkits.core.utils.getLogger
import software.aws.toolkits.core.utils.warn
import software.aws.toolkits.jetbrains.core.awsClient
import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnectionManager
import software.aws.toolkits.jetbrains.core.credentials.pinning.QConnection
import java.time.Instant
import java.util.concurrent.TimeoutException
import java.util.concurrent.atomic.AtomicReference
import javax.naming.ServiceUnavailableException
import kotlin.random.Random

@Service(Service.Level.PROJECT)
class AmazonQStreamingClient(private val project: Project) {
Expand Down Expand Up @@ -54,27 +62,42 @@ class AmazonQStreamingClient(private val project: Project) {
val checksum = AtomicReference("")

try {
val result = streamingBearerClient().exportResultArchive(
{
it.exportId(exportId)
it.exportIntent(exportIntent)
it.exportContext(exportContext)
withRetry(
block = {
val result = streamingBearerClient().exportResultArchive(
{
it.exportId(exportId)
it.exportIntent(exportIntent)
it.exportContext(exportContext)
},
ExportResultArchiveResponseHandler.builder().subscriber(
ExportResultArchiveResponseHandler.Visitor.builder()
.onBinaryMetadataEvent {
checksum.set(it.contentChecksum())
}.onBinaryPayloadEvent {
val payloadBytes = it.bytes().asByteArray()
byteBufferList.add(payloadBytes)
}.onDefault {
LOG.warn { "Received unknown payload stream: $it" }
}
.build()
)
.build()
)
result.await()
},
ExportResultArchiveResponseHandler.builder().subscriber(
ExportResultArchiveResponseHandler.Visitor.builder()
.onBinaryMetadataEvent {
checksum.set(it.contentChecksum())
}.onBinaryPayloadEvent {
val payloadBytes = it.bytes().asByteArray()
byteBufferList.add(payloadBytes)
}.onDefault {
LOG.warn { "Received unknown payload stream: $it" }
}
.build()
)
.build()
isRetryable = { e ->
when (e) {
is ValidationException,
is ThrottlingException,
is ServiceUnavailableException,
is SdkException,
is TimeoutException,
-> true
else -> false
}
}
)
result.await()
} catch (e: Exception) {
onError(e)
throw e
Expand All @@ -85,8 +108,43 @@ class AmazonQStreamingClient(private val project: Project) {
return byteBufferList
}

/**
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: If possible we can move this to a shared place.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Felt it was alright to keep it here in AmazonQStreamingClient class. Assuming that any other streaming APIs will be added here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

@ashishrp-aws ashishrp-aws Feb 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this streaming API is a suspend function. tried using a common function for both but way to many conflicts

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can have a thin wrapper for the non-suspending use case instead of duplicating the logic

* Helper function to implement retry logic with exponential backoff and jitter
*
* @param block The suspend function to execute with retry logic
* @param isRetryable A function that determines if an exception should trigger a retry
* @return The result of the block execution
*/
private suspend fun <T> withRetry(
block: suspend () -> T,
isRetryable: (Exception) -> Boolean = { it is RetryableException },
): T {
var currentDelay = INITIAL_DELAY
var attempt = 0

while (true) {
try {
return block()
} catch (e: Exception) {
attempt++
if (attempt >= MAX_RETRY_ATTEMPTS || !isRetryable(e)) {
throw e
}

// Calculate delay with exponential backoff and jitter
currentDelay = (currentDelay * 2).coerceAtMost(MAX_BACKOFF)
val jitteredDelay = currentDelay * (0.5 + Random.nextDouble(0.5))

delay(jitteredDelay.toLong())
}
}
}

companion object {
private val LOG = getLogger<AmazonQStreamingClient>()
private const val INITIAL_DELAY = 100L // milliseconds
private const val MAX_BACKOFF = 10000L // milliseconds
private const val MAX_RETRY_ATTEMPTS = 3

fun getInstance(project: Project) = project.service<AmazonQStreamingClient>()
}
Expand Down
Loading