Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import software.aws.toolkits.core.utils.getLogger
import software.aws.toolkits.core.utils.info
import software.aws.toolkits.core.utils.warn
import software.aws.toolkits.jetbrains.core.coroutines.EDT
import software.aws.toolkits.jetbrains.services.amazonq.CHAT_INPLICIT_PROJECT_CONTEXT_TIMEOUT
import software.aws.toolkits.jetbrains.services.amazonq.CodeWhispererFeatureConfigService
import software.aws.toolkits.jetbrains.services.amazonq.apps.AmazonQAppInitContext
import software.aws.toolkits.jetbrains.services.amazonq.auth.AuthController
Expand Down Expand Up @@ -137,15 +138,15 @@ class ChatController private constructor(
shouldUseWorkspaceContext = true
prompt = prompt.replace("@workspace", "")
val projectContextController = ProjectContextController.getInstance(context.project)
queryResult = projectContextController.query(prompt)
queryResult = projectContextController.query(prompt, timeout = null)
if (!projectContextController.getProjectContextIndexComplete()) shouldAddIndexInProgressMessage = true
logger.info { "project context relevant document count: ${queryResult.size}" }
} else {
sendOpenSettingsMessage(message.tabId)
}
} else if (CodeWhispererSettings.getInstance().isProjectContextEnabled() && isDataCollectionGroup) {
val projectContextController = ProjectContextController.getInstance(context.project)
queryResult = projectContextController.query(prompt)
queryResult = projectContextController.query(prompt, timeout = CHAT_INPLICIT_PROJECT_CONTEXT_TIMEOUT)
}

handleChat(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -214,8 +214,10 @@ class ProjectContextProviderTest {
}

@Test
fun `query should send correct encrypted request to lsp`() {
sut.query("foo")
fun `query should send correct encrypted request to lsp`() = runTest {
sut = ProjectContextProvider(project, encoderServer, this)
val r = sut.query("foo", null)
advanceUntilIdle()

val request = QueryChatRequest("foo")
val requestJson = mapper.writeValueAsString(request)
Expand Down Expand Up @@ -269,13 +271,15 @@ class ProjectContextProviderTest {
)

assertThrows<Exception> {
sut.query("foo")
sut.query("foo", null)
}
}

@Test
fun `query chat should return deserialized relevantDocument`() = runTest {
val r = sut.query("foo")
sut = ProjectContextProvider(project, encoderServer, this)
val r = sut.query("foo", null)
advanceUntilIdle()
assertThat(r).hasSize(2)
assertThat(r[0]).isEqualTo(
RelevantDocument(
Expand Down Expand Up @@ -377,6 +381,29 @@ class ProjectContextProviderTest {
}
}

@Test
fun `queryChat should throw if time elapsed is greather than 500ms`() = runTest {
assertThrows<TimeoutCancellationException> {
sut = ProjectContextProvider(project, encoderServer, this)
stubFor(
any(urlPathEqualTo("/query")).willReturn(
aResponse()
.withStatus(200)
.withResponseBody(
Body(validQueryChatResponse)
)
.withFixedDelay(501)
)
)

withContext(getCoroutineBgContext()) {
sut.query("foo", timeout = 500L)
}

advanceUntilIdle()
}
}

@Test
fun `test index payload is encrypted`() = runTest {
whenever(encoderServer.port).thenReturn(3000)
Expand All @@ -390,12 +417,9 @@ class ProjectContextProviderTest {

@Test
fun `test query payload is encrypted`() = runTest {
whenever(encoderServer.port).thenReturn(3000)
try {
sut.query("what does this project do")
} catch (e: ConnectException) {
// no-op
}
sut = ProjectContextProvider(project, encoderServer, this)
sut.query("what does this project do", null)
advanceUntilIdle()
verify(encoderServer, times(1)).encrypt(any())
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

package software.aws.toolkits.jetbrains.services.amazonq

import kotlin.time.Duration.Companion.minutes

const val APPLICATION_ZIP = "application/zip"
const val SERVER_SIDE_ENCRYPTION = "x-amz-server-side-encryption"
const val AWS_KMS = "aws:kms"
Expand Down Expand Up @@ -39,3 +41,7 @@ const val CODE_TRANSFORM_PREREQUISITES =
const val FEATURE_EVALUATION_PRODUCT_NAME = "CodeWhisperer"

const val SUPPLEMENTAL_CONTEXT_TIMEOUT = 100L

val CHAT_EXPLICIT_PROJECT_CONTEXT_TIMEOUT = 5.minutes.inWholeMilliseconds

const val CHAT_INPLICIT_PROJECT_CONTEXT_TIMEOUT = 500L // 500ms
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ class ProjectContextController(private val project: Project, private val cs: Cor

fun getProjectContextIndexComplete() = projectContextProvider.isIndexComplete.get()

fun query(prompt: String): List<RelevantDocument> {
suspend fun query(prompt: String, timeout: Long?): List<RelevantDocument> {
try {
return projectContextProvider.query(prompt)
return projectContextProvider.query(prompt, timeout)
} catch (e: Exception) {
logger.warn { "error while querying for project context $e.message" }
return emptyList()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.withTimeout
import software.aws.toolkits.core.utils.debug
import software.aws.toolkits.core.utils.error
import software.aws.toolkits.core.utils.getLogger
import software.aws.toolkits.core.utils.info
import software.aws.toolkits.core.utils.warn
import software.aws.toolkits.jetbrains.services.amazonq.CHAT_EXPLICIT_PROJECT_CONTEXT_TIMEOUT
import software.aws.toolkits.jetbrains.services.amazonq.FeatureDevSessionContext
import software.aws.toolkits.jetbrains.services.amazonq.SUPPLEMENTAL_CONTEXT_TIMEOUT
import software.aws.toolkits.jetbrains.services.cwc.controller.chat.telemetry.getStartUrl
Expand Down Expand Up @@ -161,17 +161,15 @@ class ProjectContextProvider(val project: Project, private val encoderServer: En
}

// TODO: rename queryChat
fun query(prompt: String): List<RelevantDocument> {
val encrypted = encryptRequest(QueryChatRequest(prompt))
val response = sendMsgToLsp(LspMessage.QueryChat, encrypted)
// TODO: timeout window not decided
suspend fun query(prompt: String, timeout: Long?): List<RelevantDocument> = withTimeout(timeout ?: CHAT_EXPLICIT_PROJECT_CONTEXT_TIMEOUT) {
cs.async {
val encrypted = encryptRequest(QueryChatRequest(prompt))
val response = sendMsgToLsp(LspMessage.QueryChat, encrypted)

return try {
val parsedResponse = mapper.readValue<List<Chunk>>(response.responseBody)
queryResultToRelevantDocuments(parsedResponse)
} catch (e: Exception) {
logger.error { "error parsing query response ${e.message}" }
throw e
}
}.await()
}

suspend fun queryInline(query: String, filePath: String): List<InlineBm25Chunk> = withTimeout(SUPPLEMENTAL_CONTEXT_TIMEOUT) {
Expand Down
Loading