Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 26 additions & 2 deletions webview-ui/src/components/settings/ModelPicker.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ export const ModelPicker = ({
const [isDescriptionExpanded, setIsDescriptionExpanded] = useState(false)
const isInitialized = useRef(false)
const searchInputRef = useRef<HTMLInputElement>(null)
const selectTimeoutRef = useRef<NodeJS.Timeout | null>(null)
const closeTimeoutRef = useRef<NodeJS.Timeout | null>(null)

const modelIds = useMemo(() => {
const filteredModels = filterModels(models, apiConfiguration.apiProvider, organizationAllowList)
Expand All @@ -79,8 +81,13 @@ export const ModelPicker = ({
setOpen(false)
setApiConfigurationField(modelIdKey, modelId)

// Clear any existing timeout
if (selectTimeoutRef.current) {
clearTimeout(selectTimeoutRef.current)
}

// Delay to ensure the popover is closed before setting the search value.
setTimeout(() => setSearchValue(modelId), 100)
selectTimeoutRef.current = setTimeout(() => setSearchValue(modelId), 100)
},
[modelIdKey, setApiConfigurationField],
)
Expand All @@ -91,8 +98,13 @@ export const ModelPicker = ({

// Abandon the current search if the popover is closed.
if (!open) {
// Clear any existing timeout
if (closeTimeoutRef.current) {
clearTimeout(closeTimeoutRef.current)
}

// Delay to ensure the popover is closed before setting the search value.
setTimeout(() => setSearchValue(selectedModelId), 100)
closeTimeoutRef.current = setTimeout(() => setSearchValue(selectedModelId), 100)
}
},
[selectedModelId],
Expand All @@ -112,6 +124,18 @@ export const ModelPicker = ({
isInitialized.current = true
}, [modelIds, setApiConfigurationField, modelIdKey, selectedModelId, defaultModelId])

// Cleanup timeouts on unmount to prevent test flakiness
useEffect(() => {
return () => {
if (selectTimeoutRef.current) {
clearTimeout(selectTimeoutRef.current)
}
if (closeTimeoutRef.current) {
clearTimeout(closeTimeoutRef.current)
}
}
}, [])

return (
<>
<div>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import { screen, fireEvent, render } from "@testing-library/react"
import { act } from "react"
import { QueryClient, QueryClientProvider } from "@tanstack/react-query"
import { vi } from "vitest"

import { ModelInfo } from "@roo-code/types"

Expand Down Expand Up @@ -58,6 +59,13 @@ describe("ModelPicker", () => {

beforeEach(() => {
vi.clearAllMocks()
vi.useFakeTimers()
})

afterEach(() => {
// Clear any pending timers to prevent test flakiness
vi.clearAllTimers()
vi.useRealTimers()
})

it("calls setApiConfigurationField when a model is selected", async () => {
Expand All @@ -71,7 +79,7 @@ describe("ModelPicker", () => {

// Wait for popover to open and animations to complete.
await act(async () => {
await new Promise((resolve) => setTimeout(resolve, 100))
vi.advanceTimersByTime(100)
})

await act(async () => {
Expand All @@ -87,6 +95,11 @@ describe("ModelPicker", () => {
fireEvent.click(modelItem)
})

// Advance timers to trigger the setTimeout in onSelect
await act(async () => {
vi.advanceTimersByTime(100)
})

// Verify the API config was updated.
expect(mockSetApiConfigurationField).toHaveBeenCalledWith(defaultProps.modelIdKey, "model2")
})
Expand All @@ -102,7 +115,7 @@ describe("ModelPicker", () => {

// Wait for popover to open and animations to complete.
await act(async () => {
await new Promise((resolve) => setTimeout(resolve, 100))
vi.advanceTimersByTime(100)
})

const customModelId = "custom-model-id"
Expand All @@ -115,7 +128,7 @@ describe("ModelPicker", () => {

// Wait for the UI to update
await act(async () => {
await new Promise((resolve) => setTimeout(resolve, 100))
vi.advanceTimersByTime(100)
})

// Find and click the "Use custom" option
Expand All @@ -125,6 +138,11 @@ describe("ModelPicker", () => {
fireEvent.click(customOption)
})

// Advance timers to trigger the setTimeout in onSelect
await act(async () => {
vi.advanceTimersByTime(100)
})

// Verify the API config was updated with the custom model ID
expect(mockSetApiConfigurationField).toHaveBeenCalledWith(defaultProps.modelIdKey, customModelId)
})
Expand Down
Loading