Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type" : "bugfix",
"description" : "Amazon Q: Fix data isolation between tabs to prevent interference when using /doc in multiple tabs"
}
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,8 @@ class DocController(
private val authController: AuthController = AuthController(),
) : InboundAppMessagesHandler {
val messenger = context.messagesFromAppToUi
var mode: Mode = Mode.CREATE
val toolWindow = ToolWindowManager.getInstance(context.project).getToolWindow(AmazonQToolWindowFactory.WINDOW_ID)
var docGenerationTask = DocGenerationTask()
private val docGenerationTasks = DocGenerationTasks()

override suspend fun processPromptChatMessage(message: IncomingDocMessage.ChatPrompt) {
handleChat(
Expand All @@ -148,7 +147,7 @@ class DocController(
}

override suspend fun processTabRemovedMessage(message: IncomingDocMessage.TabRemoved) {
docGenerationTask.reset()
docGenerationTasks.deleteTask(message.tabId)
chatSessionStorage.deleteSession(message.tabId)
}

Expand All @@ -160,6 +159,7 @@ class DocController(

override suspend fun processFollowupClickedMessage(message: IncomingDocMessage.FollowupClicked) {
val session = getSessionInfo(message.tabId)
val docGenerationTask = docGenerationTasks.getTask(message.tabId)

session.preloader(message.followUp.pillText, messenger) // also stores message in session history

Expand All @@ -173,7 +173,7 @@ class DocController(
FollowUpTypes.CLOSE_SESSION -> closeSession(message.tabId)
FollowUpTypes.CREATE_DOCUMENTATION -> {
docGenerationTask.interactionType = DocInteractionType.GENERATE_README
mode = Mode.CREATE
docGenerationTask.mode = Mode.CREATE
promptForDocTarget(message.tabId)
}

Expand All @@ -183,19 +183,19 @@ class DocController(
}

FollowUpTypes.CANCEL_FOLDER_SELECTION -> {
docGenerationTask.folderLevel = DocFolderLevel.ENTIRE_WORKSPACE
docGenerationTask.reset()
newTask(message.tabId)
}

FollowUpTypes.PROCEED_FOLDER_SELECTION -> if (mode == Mode.EDIT) makeChanges(message.tabId) else onDocsGeneration(message)
FollowUpTypes.PROCEED_FOLDER_SELECTION -> if (docGenerationTask.mode == Mode.EDIT) makeChanges(message.tabId) else onDocsGeneration(message)
FollowUpTypes.ACCEPT_CHANGES -> {
docGenerationTask.userDecision = DocUserDecision.ACCEPT
sendDocAcceptanceTelemetry(message.tabId)
acceptChanges(message)
}

FollowUpTypes.MAKE_CHANGES -> {
mode = Mode.EDIT
docGenerationTask.mode = Mode.EDIT
makeChanges(message.tabId)
}

Expand All @@ -206,12 +206,12 @@ class DocController(
}

FollowUpTypes.SYNCHRONIZE_DOCUMENTATION -> {
mode = Mode.SYNC
docGenerationTask.mode = Mode.SYNC
promptForDocTarget(message.tabId)
}

FollowUpTypes.EDIT_DOCUMENTATION -> {
mode = Mode.EDIT
docGenerationTask.mode = Mode.EDIT
docGenerationTask.interactionType = DocInteractionType.EDIT_README
promptForDocTarget(message.tabId)
}
Expand Down Expand Up @@ -241,7 +241,6 @@ class DocController(
session.sessionState.token?.cancel()
}

docGenerationTask.reset()
newTask(message.tabId)
}

Expand Down Expand Up @@ -307,13 +306,14 @@ class DocController(

private suspend fun promptForDocTarget(tabId: String) {
val session = getSessionInfo(tabId)
val docGenerationTask = docGenerationTasks.getTask(tabId)

val currentSourceFolder = session.context.selectedSourceFolder

try {
messenger.sendFolderConfirmationMessage(
tabId = tabId,
message = if (mode == Mode.CREATE) message("amazonqDoc.prompt.create.confirmation") else message("amazonqDoc.prompt.update"),
message = if (docGenerationTask.mode == Mode.CREATE) message("amazonqDoc.prompt.create.confirmation") else message("amazonqDoc.prompt.update"),
folderPath = currentSourceFolder.name,
followUps = listOf(
FollowUp(
Expand Down Expand Up @@ -452,6 +452,9 @@ class DocController(
var session: DocSession? = null
try {
session = getSessionInfo(tabId)
val docGenerationTask = docGenerationTasks.getTask(tabId)
docGenerationTask.mode = Mode.NONE

logger.debug { "$FEATURE_NAME: Session created with id: ${session.tabID}" }

val credentialState = authController.getAuthNeededStates(context.project).amazonQ
Expand Down Expand Up @@ -528,7 +531,7 @@ class DocController(
}

private suspend fun newTask(tabId: String) {
docGenerationTask = DocGenerationTask()
docGenerationTasks.deleteTask(tabId)
chatSessionStorage.deleteSession(tabId)

messenger.sendAnswer(
Expand Down Expand Up @@ -577,7 +580,7 @@ class DocController(
)

messenger.sendChatInputEnabledMessage(tabId = tabId, enabled = false)
docGenerationTask.reset()
docGenerationTasks.deleteTask(tabId)
}

private suspend fun provideFeedbackAndRegenerateCode(tabId: String) {
Expand Down Expand Up @@ -728,6 +731,7 @@ class DocController(
message: String,
) {
var session: DocSession? = null
val docGenerationTask = docGenerationTasks.getTask(tabId)
try {
logger.debug { "$FEATURE_NAME: Processing message: $message" }
session = getSessionInfo(tabId)
Expand All @@ -746,7 +750,7 @@ class DocController(

when (session.sessionState.phase) {
SessionStatePhase.CODEGEN -> {
onCodeGeneration(session, message, tabId, mode)
onCodeGeneration(session, message, tabId, docGenerationTask.mode)
}

else -> null
Expand All @@ -756,7 +760,7 @@ class DocController(
is PrepareDocGenerationState -> state.filePaths
else -> emptyList()
}
sendDocGenerationTelemetry(filePaths, session)
sendDocGenerationTelemetry(filePaths, session, docGenerationTask)
broadcastQEvent(QFeatureEvent.INVOCATION)

if (filePaths.isNotEmpty()) {
Expand All @@ -767,7 +771,7 @@ class DocController(
} catch (err: Exception) {
// For non edit mode lock the chat input until they explicitly click one of the follow-ups
var isEnableChatInput = false
if (err is DocException && Mode.EDIT == mode) {
if (err is DocException && docGenerationTask.mode == Mode.EDIT) {
isEnableChatInput = err.remainingIterations != null && err.remainingIterations > 0
}

Expand All @@ -779,15 +783,16 @@ class DocController(
messenger.sendUpdatePromptProgress(tabId = followUpMessage.tabId, inProgress(progress = 10, message("amazonqDoc.progress_message.scanning")))

val session = getSessionInfo(followUpMessage.tabId)
val docGenerationTask = docGenerationTasks.getTask(followUpMessage.tabId)

messenger.sendAnswer(
message = docGenerationProgressMessage(DocGenerationStep.UPLOAD_TO_S3, this.mode),
message = docGenerationProgressMessage(DocGenerationStep.UPLOAD_TO_S3, docGenerationTask.mode),
messageType = DocMessageType.AnswerPart,
tabId = followUpMessage.tabId,
)

try {
val sessionMessage: String = when (mode) {
val sessionMessage: String = when (docGenerationTask.mode) {
Mode.CREATE -> message("amazonqDoc.session.create")
else -> message("amazonqDoc.session.sync")
}
Expand Down Expand Up @@ -821,10 +826,10 @@ class DocController(
return
}

sendDocGenerationTelemetry(filePaths, session)
sendDocGenerationTelemetry(filePaths, session, docGenerationTask)

messenger.sendAnswer(
message = docGenerationProgressMessage(DocGenerationStep.COMPLETE, mode),
message = docGenerationProgressMessage(DocGenerationStep.COMPLETE, docGenerationTask.mode),
messageType = DocMessageType.AnswerPart,
tabId = followUpMessage.tabId,
)
Expand Down Expand Up @@ -907,7 +912,6 @@ class DocController(

private suspend fun retryRequests(tabId: String) {
var session: DocSession? = null
docGenerationTask = DocGenerationTask()
try {
messenger.sendAsyncEventProgress(
tabId = tabId,
Expand Down Expand Up @@ -954,6 +958,7 @@ class DocController(
val session = getSessionInfo(tabId)
val currentSourceFolder = session.context.selectedSourceFolder
val projectRoot = session.context.projectRoot
val docGenerationTask = docGenerationTasks.getTask(tabId)

withContext(EDT) {
messenger.sendAnswer(
Expand Down Expand Up @@ -1017,7 +1022,7 @@ class DocController(
}
}

private fun sendDocGenerationTelemetry(filePaths: List<NewFileZipInfo>, session: DocSession) {
private fun sendDocGenerationTelemetry(filePaths: List<NewFileZipInfo>, session: DocSession, docGenerationTask: DocGenerationTask) {
docGenerationTask.conversationId = session.conversationId
val (totalGeneratedChars, totalGeneratedLines, totalGeneratedFiles) = session.countedGeneratedContent(filePaths, docGenerationTask.interactionType)
docGenerationTask.numberOfGeneratedChars = totalGeneratedChars
Expand All @@ -1030,6 +1035,7 @@ class DocController(

private fun sendDocAcceptanceTelemetry(tabId: String) {
val session = getSessionInfo(tabId)
val docGenerationTask = docGenerationTasks.getTask(tabId)
var filePaths: List<NewFileZipInfo> = emptyList()

when (val state = session.sessionState) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ suspend fun DocController.onCodeGeneration(session: DocSession, message: String,
messenger.sendAsyncEventProgress(tabId, inProgress = true)
messenger.sendUpdatePromptProgress(tabId, inProgress(progress = 10, message("amazonqDoc.progress_message.scanning")))
messenger.sendAnswer(
message = docGenerationProgressMessage(DocGenerationStep.UPLOAD_TO_S3, this.mode),
message = docGenerationProgressMessage(DocGenerationStep.UPLOAD_TO_S3, mode),
messageType = DocMessageType.AnswerPart,
tabId = tabId,
)
Expand Down Expand Up @@ -108,7 +108,7 @@ suspend fun DocController.onCodeGeneration(session: DocSession, message: String,
messenger.sendAnswer(
tabId = tabId,
messageType = DocMessageType.Answer,
message = if (this.mode === Mode.CREATE) {
message = if (mode === Mode.CREATE) {
message("amazonqDoc.answer.readmeCreated")
} else {
"${message("amazonqDoc.answer.readmeUpdated")} ${message("amazonqDoc.answer.codeResult")}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,19 @@ import software.amazon.awssdk.services.codewhispererruntime.model.DocV2Generatio
import software.aws.toolkits.core.utils.debug
import software.aws.toolkits.core.utils.getLogger

class DocGenerationTasks {
private val tasks: MutableMap<String, DocGenerationTask> = mutableMapOf()

fun getTask(tabId: String): DocGenerationTask = tasks.getOrPut(tabId) { DocGenerationTask() }

fun deleteTask(tabId: String) {
tasks.remove(tabId)
}
}

class DocGenerationTask {
var mode: Mode = Mode.NONE

// Telemetry fields
var conversationId: String? = null
var numberOfAddedChars: Int? = null
Expand All @@ -22,8 +34,8 @@ class DocGenerationTask {
var numberOfGeneratedFiles: Int? = null
var userDecision: DocUserDecision? = null
var interactionType: DocInteractionType? = null
var numberOfNavigations = 0
var folderLevel: DocFolderLevel? = DocFolderLevel.ENTIRE_WORKSPACE
var numberOfNavigations: Int = 0
var folderLevel: DocFolderLevel = DocFolderLevel.ENTIRE_WORKSPACE
fun docGenerationEventBase(): DocV2GenerationEvent {
val undefinedProps = this::class.java.declaredFields
.filter { it.get(this) == null }
Expand Down
Loading