diff --git a/plugins/amazonq/chat/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonqCodeTest/CodeWhispererUTGChatManager.kt b/plugins/amazonq/chat/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonqCodeTest/CodeWhispererUTGChatManager.kt index e1419c4c107..04c1e4dc6db 100644 --- a/plugins/amazonq/chat/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonqCodeTest/CodeWhispererUTGChatManager.kt +++ b/plugins/amazonq/chat/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonqCodeTest/CodeWhispererUTGChatManager.kt @@ -16,7 +16,6 @@ import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.Job import kotlinx.coroutines.delay import kotlinx.coroutines.launch -import software.amazon.awssdk.core.exception.SdkServiceException import software.amazon.awssdk.services.codewhispererruntime.model.GetTestGenerationResponse import software.amazon.awssdk.services.codewhispererruntime.model.Range import software.amazon.awssdk.services.codewhispererruntime.model.StartTestGenerationResponse @@ -24,6 +23,7 @@ import software.amazon.awssdk.services.codewhispererruntime.model.TargetCode import software.amazon.awssdk.services.codewhispererruntime.model.TestGenerationJobStatus import software.amazon.awssdk.services.codewhispererstreaming.model.ExportContext import software.amazon.awssdk.services.codewhispererstreaming.model.ExportIntent +import software.aws.toolkits.core.utils.Waiters.waitUntil import software.aws.toolkits.core.utils.debug import software.aws.toolkits.core.utils.error import software.aws.toolkits.core.utils.getLogger @@ -58,6 +58,7 @@ import java.io.ByteArrayOutputStream import java.io.File import java.io.IOException import java.nio.file.Paths +import java.time.Duration import java.time.Instant import java.util.concurrent.atomic.AtomicBoolean import java.util.zip.ZipInputStream @@ -109,29 +110,38 @@ class CodeWhispererUTGChatManager(val project: Project, private val cs: Coroutin // 2nd API call: StartTestGeneration val startTestGenerationResponse = try { - startTestGeneration( - uploadId = createUploadUrlResponse.uploadId(), - targetCode = listOf( - TargetCode.builder() - .relativeTargetPath(codeTestResponseContext.currentFileRelativePath.toString()) - .targetLineRangeList( - if (selectionRange != null) { - listOf( - selectionRange + var response: StartTestGenerationResponse? = null + + waitUntil( + succeedOn = { response?.sdkHttpResponse()?.statusCode() == 200 }, + maxDuration = Duration.ofSeconds(1), // 1 second timeout + ) { + try { + response = startTestGeneration( + uploadId = createUploadUrlResponse.uploadId(), + targetCode = listOf( + TargetCode.builder() + .relativeTargetPath(codeTestResponseContext.currentFileRelativePath.toString()) + .targetLineRangeList( + if (selectionRange != null) { + listOf(selectionRange) + } else { + emptyList() + } ) - } else { - emptyList() - } - ) - .build() - ), - userInput = prompt - ) - } catch (e: Exception) { - val statusCode = when { - e is SdkServiceException -> e.statusCode() - else -> 400 + .build() + ), + userInput = prompt + ) + delay(200) + response?.testGenerationJob() != null + } catch (e: Exception) { + throw e + } } + + response ?: throw RuntimeException("Failed to start test generation") + } catch (e: Exception) { LOG.error(e) { "Unexpected error while creating test generation job" } val errorMessage = getTelemetryErrorMessage(e, CodeWhispererConstants.FeatureName.TEST_GENERATION) throw CodeTestException( diff --git a/plugins/amazonq/codewhisperer/jetbrains-community/src/software/aws/toolkits/jetbrains/services/codewhisperer/util/CodeWhispererZipUploadManager.kt b/plugins/amazonq/codewhisperer/jetbrains-community/src/software/aws/toolkits/jetbrains/services/codewhisperer/util/CodeWhispererZipUploadManager.kt index b9ad79b12ca..c259f06973a 100644 --- a/plugins/amazonq/codewhisperer/jetbrains-community/src/software/aws/toolkits/jetbrains/services/codewhisperer/util/CodeWhispererZipUploadManager.kt +++ b/plugins/amazonq/codewhisperer/jetbrains-community/src/software/aws/toolkits/jetbrains/services/codewhisperer/util/CodeWhispererZipUploadManager.kt @@ -12,6 +12,8 @@ import software.amazon.awssdk.services.codewhispererruntime.model.CodeAnalysisUp import software.amazon.awssdk.services.codewhispererruntime.model.CodeFixUploadContext import software.amazon.awssdk.services.codewhispererruntime.model.CreateUploadUrlRequest import software.amazon.awssdk.services.codewhispererruntime.model.CreateUploadUrlResponse +import software.amazon.awssdk.services.codewhispererruntime.model.InternalServerException +import software.amazon.awssdk.services.codewhispererruntime.model.ThrottlingException import software.amazon.awssdk.services.codewhispererruntime.model.UploadContext import software.amazon.awssdk.services.codewhispererruntime.model.UploadIntent import software.amazon.awssdk.utils.IoUtils @@ -82,40 +84,50 @@ class CodeWhispererZipUploadManager(private val project: Project) { requestHeaders: Map?, featureUseCase: CodeWhispererConstants.FeatureName, ) { - try { - val uploadIdJson = """{"uploadId":"$uploadId"}""" - HttpRequests.put(url, "application/zip").userAgent(AwsClientManager.getUserAgent()).tuner { - if (requestHeaders.isNullOrEmpty()) { - it.setRequestProperty(CONTENT_MD5, md5) - it.setRequestProperty(CONTENT_TYPE, APPLICATION_ZIP) - it.setRequestProperty(SERVER_SIDE_ENCRYPTION, AWS_KMS) - if (kmsArn?.isNotEmpty() == true) { - it.setRequestProperty(SERVER_SIDE_ENCRYPTION_AWS_KMS_KEY_ID, kmsArn) - } - it.setRequestProperty(SERVER_SIDE_ENCRYPTION_CONTEXT, Base64.getEncoder().encodeToString(uploadIdJson.toByteArray())) - } else { - requestHeaders.forEach { entry -> - it.setRequestProperty(entry.key, entry.value) + RetryableOperation().execute( + operation = { + val uploadIdJson = """{"uploadId":"$uploadId"}""" + HttpRequests.put(url, "application/zip").userAgent(AwsClientManager.getUserAgent()).tuner { + if (requestHeaders.isNullOrEmpty()) { + it.setRequestProperty(CONTENT_MD5, md5) + it.setRequestProperty(CONTENT_TYPE, APPLICATION_ZIP) + it.setRequestProperty(SERVER_SIDE_ENCRYPTION, AWS_KMS) + if (kmsArn?.isNotEmpty() == true) { + it.setRequestProperty(SERVER_SIDE_ENCRYPTION_AWS_KMS_KEY_ID, kmsArn) + } + it.setRequestProperty(SERVER_SIDE_ENCRYPTION_CONTEXT, Base64.getEncoder().encodeToString(uploadIdJson.toByteArray())) + } else { + requestHeaders.forEach { entry -> + it.setRequestProperty(entry.key, entry.value) + } } + }.connect { + val connection = it.connection as HttpURLConnection + connection.setFixedLengthStreamingMode(fileToUpload.length()) + IoUtils.copy(fileToUpload.inputStream(), connection.outputStream) + } + }, + isRetryable = { e -> + when (e) { + is IOException -> true + else -> false + } + }, + errorHandler = { e, attempts -> + val errorMessage = getTelemetryErrorMessage(e, featureUseCase) + when (featureUseCase) { + CodeWhispererConstants.FeatureName.CODE_REVIEW -> + codeScanServerException("CreateUploadUrlException: $errorMessage") + CodeWhispererConstants.FeatureName.TEST_GENERATION -> + throw CodeTestException( + "UploadTestArtifactToS3Error: $errorMessage", + "UploadTestArtifactToS3Error", + message("testgen.error.generic_technical_error_message") + ) + else -> throw RuntimeException("$errorMessage (after $attempts attempts)") } - }.connect { - val connection = it.connection as HttpURLConnection - connection.setFixedLengthStreamingMode(fileToUpload.length()) - IoUtils.copy(fileToUpload.inputStream(), connection.outputStream) - } - } catch (e: Exception) { - LOG.debug { "$featureUseCase: Artifact failed to upload in the S3 bucket: ${e.message}" } - val errorMessage = getTelemetryErrorMessage(e, featureUseCase) - when (featureUseCase) { - CodeWhispererConstants.FeatureName.CODE_REVIEW -> codeScanServerException("CreateUploadUrlException: $errorMessage") - CodeWhispererConstants.FeatureName.TEST_GENERATION -> throw CodeTestException( - "UploadTestArtifactToS3Error: $errorMessage", - "UploadTestArtifactToS3Error", - message("testgen.error.generic_technical_error_message") - ) - else -> throw RuntimeException(errorMessage) // Adding else for safety check } - } + ) } fun createUploadUrl( @@ -124,35 +136,44 @@ class CodeWhispererZipUploadManager(private val project: Project) { uploadTaskType: CodeWhispererConstants.UploadTaskType, taskName: String, featureUseCase: CodeWhispererConstants.FeatureName, - ): CreateUploadUrlResponse = try { - CodeWhispererClientAdaptor.getInstance(project).createUploadUrl( - CreateUploadUrlRequest.builder() - .contentMd5(md5Content) - .artifactType(artifactType) - .uploadIntent(getUploadIntent(uploadTaskType)) - .uploadContext( - // For UTG we don't need uploadContext but sending else case as UploadContext - if (uploadTaskType == CodeWhispererConstants.UploadTaskType.CODE_FIX) { - UploadContext.fromCodeFixUploadContext(CodeFixUploadContext.builder().codeFixName(taskName).build()) - } else { - UploadContext.fromCodeAnalysisUploadContext(CodeAnalysisUploadContext.builder().codeScanName(taskName).build()) - } - ) - .build() - ) - } catch (e: Exception) { - LOG.debug { "$featureUseCase: Create Upload URL failed: ${e.message}" } - val errorMessage = getTelemetryErrorMessage(e, featureUseCase) - when (featureUseCase) { - CodeWhispererConstants.FeatureName.CODE_REVIEW -> codeScanServerException("CreateUploadUrlException: $errorMessage") - CodeWhispererConstants.FeatureName.TEST_GENERATION -> throw CodeTestException( - "CreateUploadUrlError: $errorMessage", - "CreateUploadUrlError", - message("testgen.error.generic_technical_error_message") + ): CreateUploadUrlResponse = RetryableOperation().execute( + operation = { + CodeWhispererClientAdaptor.getInstance(project).createUploadUrl( + CreateUploadUrlRequest.builder() + .contentMd5(md5Content) + .artifactType(artifactType) + .uploadIntent(getUploadIntent(uploadTaskType)) + .uploadContext( + // For UTG we don't need uploadContext but sending else case as UploadContext + if (uploadTaskType == CodeWhispererConstants.UploadTaskType.CODE_FIX) { + UploadContext.fromCodeFixUploadContext(CodeFixUploadContext.builder().codeFixName(taskName).build()) + } else { + UploadContext.fromCodeAnalysisUploadContext(CodeAnalysisUploadContext.builder().codeScanName(taskName).build()) + } + ) + .build() ) - else -> throw RuntimeException(errorMessage) // Adding else for safety check + }, + isRetryable = { e -> + e is ThrottlingException || e is InternalServerException + }, + errorHandler = { e, attempts -> + val errorMessage = getTelemetryErrorMessage(e, featureUseCase) + when (featureUseCase) { + CodeWhispererConstants.FeatureName.CODE_REVIEW -> + codeScanServerException("CreateUploadUrlException after $attempts attempts: $errorMessage") + + CodeWhispererConstants.FeatureName.TEST_GENERATION -> + throw CodeTestException( + "CreateUploadUrlError after $attempts attempts: $errorMessage", + "CreateUploadUrlError", + message("testgen.error.generic_technical_error_message") + ) + + else -> throw RuntimeException("$errorMessage (after $attempts attempts)") + } } - } + ) private fun getUploadIntent(uploadTaskType: CodeWhispererConstants.UploadTaskType): UploadIntent = when (uploadTaskType) { CodeWhispererConstants.UploadTaskType.SCAN_FILE -> UploadIntent.AUTOMATIC_FILE_SECURITY_SCAN @@ -187,3 +208,41 @@ fun getTelemetryErrorMessage(e: Exception, featureUseCase: CodeWhispererConstant else -> message("testgen.message.failed") } } + +class RetryableOperation { + private var attempts = 0 + private var currentDelay = INITIAL_DELAY + private var lastException: Exception? = null + + fun execute( + operation: () -> T, + isRetryable: (Exception) -> Boolean, + errorHandler: (Exception, Int) -> Nothing, + ): T { + while (attempts < MAX_RETRY_ATTEMPTS) { + try { + return operation() + } catch (e: Exception) { + lastException = e + + attempts++ + if (attempts < MAX_RETRY_ATTEMPTS && isRetryable(e)) { + Thread.sleep(currentDelay) + currentDelay = (currentDelay * 2).coerceAtMost(MAX_BACKOFF) + continue + } + + errorHandler(e, attempts) + } + } + + // This line should never be reached due to errorHandler throwing exception + throw RuntimeException("Unexpected state after $attempts attempts") + } + + companion object { + private const val INITIAL_DELAY = 100L // milliseconds + private const val MAX_BACKOFF = 10000L // milliseconds + private const val MAX_RETRY_ATTEMPTS = 3 + } +}