@@ -24,6 +24,7 @@ import software.amazon.awssdk.services.codewhispererruntime.model.TargetCode
2424import software.amazon.awssdk.services.codewhispererruntime.model.TestGenerationJobStatus
2525import software.amazon.awssdk.services.codewhispererstreaming.model.ExportContext
2626import software.amazon.awssdk.services.codewhispererstreaming.model.ExportIntent
27+ import software.aws.toolkits.core.utils.Waiters.waitUntil
2728import software.aws.toolkits.core.utils.debug
2829import software.aws.toolkits.core.utils.error
2930import software.aws.toolkits.core.utils.getLogger
@@ -58,6 +59,7 @@ import java.io.ByteArrayOutputStream
5859import java.io.File
5960import java.io.IOException
6061import java.nio.file.Paths
62+ import java.time.Duration
6163import java.time.Instant
6264import java.util.concurrent.atomic.AtomicBoolean
6365import java.util.zip.ZipInputStream
@@ -109,29 +111,38 @@ class CodeWhispererUTGChatManager(val project: Project, private val cs: Coroutin
109111
110112 // 2nd API call: StartTestGeneration
111113 val startTestGenerationResponse = try {
112- startTestGeneration(
113- uploadId = createUploadUrlResponse.uploadId(),
114- targetCode = listOf (
115- TargetCode .builder()
116- .relativeTargetPath(codeTestResponseContext.currentFileRelativePath.toString())
117- .targetLineRangeList(
118- if (selectionRange != null ) {
119- listOf (
120- selectionRange
114+ var response: StartTestGenerationResponse ? = null
115+
116+ waitUntil(
117+ succeedOn = { response?.sdkHttpResponse()?.statusCode() == 200 },
118+ maxDuration = Duration .ofSeconds(1 ), // 1 second timeout
119+ ) {
120+ try {
121+ response = startTestGeneration(
122+ uploadId = createUploadUrlResponse.uploadId(),
123+ targetCode = listOf (
124+ TargetCode .builder()
125+ .relativeTargetPath(codeTestResponseContext.currentFileRelativePath.toString())
126+ .targetLineRangeList(
127+ if (selectionRange != null ) {
128+ listOf (selectionRange)
129+ } else {
130+ emptyList()
131+ }
121132 )
122- } else {
123- emptyList()
124- }
125- )
126- .build()
127- ),
128- userInput = prompt
129- )
130- } catch (e: Exception ) {
131- val statusCode = when {
132- e is SdkServiceException -> e.statusCode()
133- else -> 400
133+ .build()
134+ ),
135+ userInput = prompt
136+ )
137+ delay(200 )
138+ response?.testGenerationJob() != null
139+ } catch (e: Exception ) {
140+ throw e
141+ }
134142 }
143+
144+ response ? : throw RuntimeException (" Failed to start test generation" )
145+ } catch (e: Exception ) {
135146 LOG .error(e) { " Unexpected error while creating test generation job" }
136147 val errorMessage = getTelemetryErrorMessage(e, CodeWhispererConstants .FeatureName .TEST_GENERATION )
137148 throw CodeTestException (
0 commit comments