Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Wisp/Models/Local/SpriteChat.swift
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ final class SpriteChat {
var messagesData: Data?
var streamEventUUIDsData: Data?
var draftInputText: String?
var draftAttachmentPaths: [String]?
var isClosed: Bool
var spriteCreatedAt: Date?
var firstMessagePreview: String?
Expand Down
8 changes: 8 additions & 0 deletions Wisp/ViewModels/ChatViewModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,12 @@ final class ChatViewModel {
inputText = draft
}

if attachedFiles.isEmpty, let paths = chat.draftAttachmentPaths, !paths.isEmpty {
attachedFiles = paths.map { path in
AttachedFile(name: (path as NSString).lastPathComponent, path: path)
}
}

if let context = chat.forkContext, !context.isEmpty {
let notice = ChatMessage(role: .system, content: [.text("Forked from a previous chat")])
messages.insert(notice, at: 0)
Expand All @@ -203,6 +209,8 @@ final class ChatViewModel {
func saveDraft(modelContext: ModelContext) {
guard let chat = fetchChat(modelContext: modelContext) else { return }
chat.draftInputText = inputText.isEmpty ? nil : inputText
let paths = attachedFiles.map(\.path)
chat.draftAttachmentPaths = paths.isEmpty ? nil : paths
try? modelContext.save()
}

Expand Down
3 changes: 3 additions & 0 deletions Wisp/Views/SpriteDetail/Chat/ChatView.swift
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,9 @@ struct ChatView: View {
.onChange(of: viewModel.inputText) {
viewModel.saveDraft(modelContext: modelContext)
}
.onChange(of: viewModel.attachedFiles.count) {
viewModel.saveDraft(modelContext: modelContext)
}
.onDisappear {
viewModel.saveDraft(modelContext: modelContext)
}
Expand Down
60 changes: 60 additions & 0 deletions WispTests/ChatViewModelTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -703,4 +703,64 @@ struct ChatViewModelTests {
#expect(vm.inputText == "long prompt I want to come back to")
#expect(vm.stashedDraft == nil)
}

// MARK: - Draft attachment persistence

@Test func saveDraft_persistsAttachmentPaths() throws {
let ctx = try makeModelContext()
let (vm, chat) = makeChatViewModel(modelContext: ctx)

vm.attachedFiles = [
AttachedFile(name: "main.py", path: "/home/sprite/project/main.py"),
AttachedFile(name: "README.md", path: "/home/sprite/project/README.md"),
]
vm.saveDraft(modelContext: ctx)

#expect(chat.draftAttachmentPaths == [
"/home/sprite/project/main.py",
"/home/sprite/project/README.md",
])
}

@Test func saveDraft_clearsAttachmentPathsWhenEmpty() throws {
let ctx = try makeModelContext()
let (vm, chat) = makeChatViewModel(modelContext: ctx)

chat.draftAttachmentPaths = ["/home/sprite/project/old.py"]
vm.attachedFiles = []
vm.saveDraft(modelContext: ctx)

#expect(chat.draftAttachmentPaths == nil)
}

@Test func loadSession_restoresDraftAttachments() throws {
let ctx = try makeModelContext()
let (vm, chat) = makeChatViewModel(modelContext: ctx)

chat.draftAttachmentPaths = [
"/home/sprite/project/main.py",
"/home/sprite/project/photo.png",
]

vm.loadSession(apiClient: SpritesAPIClient(), modelContext: ctx)

#expect(vm.attachedFiles.count == 2)
#expect(vm.attachedFiles[0].name == "main.py")
#expect(vm.attachedFiles[0].path == "/home/sprite/project/main.py")
#expect(vm.attachedFiles[1].name == "photo.png")
#expect(vm.attachedFiles[1].path == "/home/sprite/project/photo.png")
}

@Test func loadSession_doesNotOverwriteExistingAttachments() throws {
let ctx = try makeModelContext()
let (vm, chat) = makeChatViewModel(modelContext: ctx)

chat.draftAttachmentPaths = ["/home/sprite/project/persisted.py"]
vm.attachedFiles = [AttachedFile(name: "live.py", path: "/home/sprite/project/live.py")]

vm.loadSession(apiClient: SpritesAPIClient(), modelContext: ctx)

#expect(vm.attachedFiles.count == 1)
#expect(vm.attachedFiles[0].name == "live.py")
}
}
Loading