Skip to content
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,26 @@ package software.aws.toolkits.jetbrains.services.amazonq.clients
import com.intellij.openapi.components.Service
import com.intellij.openapi.components.service
import com.intellij.openapi.project.Project
import kotlinx.coroutines.delay
import kotlinx.coroutines.future.await
import software.amazon.awssdk.core.exception.RetryableException
import software.amazon.awssdk.core.exception.SdkException
import software.amazon.awssdk.services.codewhispererstreaming.CodeWhispererStreamingAsyncClient
import software.amazon.awssdk.services.codewhispererstreaming.model.ExportContext
import software.amazon.awssdk.services.codewhispererstreaming.model.ExportIntent
import software.amazon.awssdk.services.codewhispererstreaming.model.ExportResultArchiveResponseHandler
import software.amazon.awssdk.services.codewhispererstreaming.model.ThrottlingException
import software.amazon.awssdk.services.codewhispererstreaming.model.ValidationException
import software.aws.toolkits.core.utils.getLogger
import software.aws.toolkits.core.utils.warn
import software.aws.toolkits.jetbrains.core.awsClient
import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnectionManager
import software.aws.toolkits.jetbrains.core.credentials.pinning.QConnection
import java.time.Instant
import java.util.concurrent.TimeoutException
import java.util.concurrent.atomic.AtomicReference
import javax.naming.ServiceUnavailableException
import kotlin.random.Random

@Service(Service.Level.PROJECT)
class AmazonQStreamingClient(private val project: Project) {
Expand Down Expand Up @@ -54,27 +62,42 @@ class AmazonQStreamingClient(private val project: Project) {
val checksum = AtomicReference("")

try {
val result = streamingBearerClient().exportResultArchive(
{
it.exportId(exportId)
it.exportIntent(exportIntent)
it.exportContext(exportContext)
withRetry(
block = {
val result = streamingBearerClient().exportResultArchive(
{
it.exportId(exportId)
it.exportIntent(exportIntent)
it.exportContext(exportContext)
},
ExportResultArchiveResponseHandler.builder().subscriber(
ExportResultArchiveResponseHandler.Visitor.builder()
.onBinaryMetadataEvent {
checksum.set(it.contentChecksum())
}.onBinaryPayloadEvent {
val payloadBytes = it.bytes().asByteArray()
byteBufferList.add(payloadBytes)
}.onDefault {
LOG.warn { "Received unknown payload stream: $it" }
}
.build()
)
.build()
)
result.await()
},
ExportResultArchiveResponseHandler.builder().subscriber(
ExportResultArchiveResponseHandler.Visitor.builder()
.onBinaryMetadataEvent {
checksum.set(it.contentChecksum())
}.onBinaryPayloadEvent {
val payloadBytes = it.bytes().asByteArray()
byteBufferList.add(payloadBytes)
}.onDefault {
LOG.warn { "Received unknown payload stream: $it" }
}
.build()
)
.build()
isRetryable = { e ->
when (e) {
is ValidationException,
is ThrottlingException,
is ServiceUnavailableException,
is SdkException,
is TimeoutException,
-> true
else -> false
}
}
)
result.await()
} catch (e: Exception) {
onError(e)
throw e
Expand All @@ -85,8 +108,43 @@ class AmazonQStreamingClient(private val project: Project) {
return byteBufferList
}

/**
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: If possible we can move this to a shared place.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Felt it was alright to keep it here in AmazonQStreamingClient class. Assuming that any other streaming APIs will be added here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

@ashishrp-aws ashishrp-aws Feb 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this streaming API is a suspend function. tried using a common function for both but way to many conflicts

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can have a thin wrapper for the non-suspending use case instead of duplicating the logic

* Helper function to implement retry logic with exponential backoff and jitter
*
* @param block The suspend function to execute with retry logic
* @param isRetryable A function that determines if an exception should trigger a retry
* @return The result of the block execution
*/
private suspend fun <T> withRetry(
block: suspend () -> T,
isRetryable: (Exception) -> Boolean = { it is RetryableException },
): T {
var currentDelay = INITIAL_DELAY
var attempt = 0

while (true) {
try {
return block()
} catch (e: Exception) {
attempt++
if (attempt >= MAX_RETRY_ATTEMPTS || !isRetryable(e)) {
throw e
}

// Calculate delay with exponential backoff and jitter
currentDelay = (currentDelay * 2).coerceAtMost(MAX_BACKOFF)
val jitteredDelay = currentDelay * (0.5 + Random.nextDouble(0.5))

delay(jitteredDelay.toLong())
}
}
}

companion object {
private val LOG = getLogger<AmazonQStreamingClient>()
private const val INITIAL_DELAY = 100L // milliseconds
private const val MAX_BACKOFF = 10000L // milliseconds
private const val MAX_RETRY_ATTEMPTS = 3

fun getInstance(project: Project) = project.service<AmazonQStreamingClient>()
}
Expand Down
Loading