From d5b796263dd160d4f6fe0ccc0ad8c464c97ec01f Mon Sep 17 00:00:00 2001
From: System233
Date: Wed, 26 Feb 2025 02:44:57 +0800
Subject: [PATCH 1/4] Add a combobox component with auto-complete functionality
---
webview-ui/src/__mocks__/lucide-react.ts | 6 +
.../src/components/ui/combobox-primitive.tsx | 522 ++++++++++++++++++
webview-ui/src/components/ui/combobox.tsx | 177 ++++++
webview-ui/src/components/ui/input-base.tsx | 157 ++++++
4 files changed, 862 insertions(+)
create mode 100644 webview-ui/src/__mocks__/lucide-react.ts
create mode 100644 webview-ui/src/components/ui/combobox-primitive.tsx
create mode 100644 webview-ui/src/components/ui/combobox.tsx
create mode 100644 webview-ui/src/components/ui/input-base.tsx
diff --git a/webview-ui/src/__mocks__/lucide-react.ts b/webview-ui/src/__mocks__/lucide-react.ts
new file mode 100644
index 00000000000..d85cd25d6a7
--- /dev/null
+++ b/webview-ui/src/__mocks__/lucide-react.ts
@@ -0,0 +1,6 @@
+import React from "react"
+
+export const Check = () => React.createElement("div")
+export const ChevronsUpDown = () => React.createElement("div")
+export const Loader = () => React.createElement("div")
+export const X = () => React.createElement("div")
diff --git a/webview-ui/src/components/ui/combobox-primitive.tsx b/webview-ui/src/components/ui/combobox-primitive.tsx
new file mode 100644
index 00000000000..13bad87abac
--- /dev/null
+++ b/webview-ui/src/components/ui/combobox-primitive.tsx
@@ -0,0 +1,522 @@
+/* eslint-disable react/jsx-pascal-case */
+"use client"
+
+import * as React from "react"
+import { composeEventHandlers } from "@radix-ui/primitive"
+import { useComposedRefs } from "@radix-ui/react-compose-refs"
+import * as PopoverPrimitive from "@radix-ui/react-popover"
+import { Primitive } from "@radix-ui/react-primitive"
+import * as RovingFocusGroupPrimitive from "@radix-ui/react-roving-focus"
+import { useControllableState } from "@radix-ui/react-use-controllable-state"
+import { Command as CommandPrimitive } from "cmdk"
+
+export type ComboboxContextProps = {
+ inputValue: string
+ onInputValueChange: (inputValue: string, reason: "inputChange" | "itemSelect" | "clearClick") => void
+ onInputBlur?: (e: React.FocusEvent) => void
+ open: boolean
+ onOpenChange: (open: boolean) => void
+ currentTabStopId: string | null
+ onCurrentTabStopIdChange: (currentTabStopId: string | null) => void
+ inputRef: React.RefObject
+ tagGroupRef: React.RefObject>
+ disabled?: boolean
+ required?: boolean
+} & (
+ | Required>
+ | Required>
+)
+
+const ComboboxContext = React.createContext({
+ type: "single",
+ value: "",
+ onValueChange: () => {},
+ inputValue: "",
+ onInputValueChange: () => {},
+ onInputBlur: () => {},
+ open: false,
+ onOpenChange: () => {},
+ currentTabStopId: null,
+ onCurrentTabStopIdChange: () => {},
+ inputRef: { current: null },
+ tagGroupRef: { current: null },
+ disabled: false,
+ required: false,
+})
+
+export const useComboboxContext = () => React.useContext(ComboboxContext)
+
+export type ComboboxType = "single" | "multiple"
+
+export interface ComboboxBaseProps
+ extends React.ComponentProps,
+ Omit, "value" | "defaultValue" | "onValueChange"> {
+ type?: ComboboxType | undefined
+ inputValue?: string
+ defaultInputValue?: string
+ onInputValueChange?: (inputValue: string, reason: "inputChange" | "itemSelect" | "clearClick") => void
+ onInputBlur?: (e: React.FocusEvent) => void
+ disabled?: boolean
+ required?: boolean
+}
+
+export type ComboboxValue = T extends "single"
+ ? string
+ : T extends "multiple"
+ ? string[]
+ : never
+
+export interface ComboboxSingleProps {
+ type: "single"
+ value?: string
+ defaultValue?: string
+ onValueChange?: (value: string) => void
+}
+
+export interface ComboboxMultipleProps {
+ type: "multiple"
+ value?: string[]
+ defaultValue?: string[]
+ onValueChange?: (value: string[]) => void
+}
+
+export type ComboboxProps = ComboboxBaseProps & (ComboboxSingleProps | ComboboxMultipleProps)
+
+export const Combobox = React.forwardRef(
+ (
+ {
+ type = "single" as T,
+ open: openProp,
+ onOpenChange,
+ defaultOpen,
+ modal,
+ children,
+ value: valueProp,
+ defaultValue,
+ onValueChange,
+ inputValue: inputValueProp,
+ defaultInputValue,
+ onInputValueChange,
+ onInputBlur,
+ disabled,
+ required,
+ ...props
+ }: ComboboxProps,
+ ref: React.ForwardedRef>,
+ ) => {
+ const [value = type === "multiple" ? [] : "", setValue] = useControllableState>({
+ prop: valueProp as ComboboxValue,
+ defaultProp: defaultValue as ComboboxValue,
+ onChange: onValueChange as (value: ComboboxValue) => void,
+ })
+ const [inputValue = "", setInputValue] = useControllableState({
+ prop: inputValueProp,
+ defaultProp: defaultInputValue,
+ })
+ const [open = false, setOpen] = useControllableState({
+ prop: openProp,
+ defaultProp: defaultOpen,
+ onChange: onOpenChange,
+ })
+ const [currentTabStopId, setCurrentTabStopId] = React.useState(null)
+ const inputRef = React.useRef(null)
+ const tagGroupRef = React.useRef>(null)
+
+ const handleInputValueChange: ComboboxContextProps["onInputValueChange"] = React.useCallback(
+ (inputValue, reason) => {
+ setInputValue(inputValue)
+ onInputValueChange?.(inputValue, reason)
+ },
+ [setInputValue, onInputValueChange],
+ )
+
+ return (
+
+
+
+ {children}
+ {!open && }
+
+
+
+ )
+ },
+)
+Combobox.displayName = "Combobox"
+
+export const ComboboxTagGroup = React.forwardRef<
+ React.ElementRef,
+ React.ComponentPropsWithoutRef
+>((props, ref) => {
+ const { currentTabStopId, onCurrentTabStopIdChange, tagGroupRef, type } = useComboboxContext()
+
+ if (type !== "multiple") {
+ throw new Error(' should only be used when type is "multiple"')
+ }
+
+ const composedRefs = useComposedRefs(ref, tagGroupRef)
+
+ return (
+ onCurrentTabStopIdChange(null)}
+ {...props}
+ />
+ )
+})
+ComboboxTagGroup.displayName = "ComboboxTagGroup"
+
+export interface ComboboxTagGroupItemProps
+ extends React.ComponentPropsWithoutRef {
+ value: string
+ disabled?: boolean
+}
+
+const ComboboxTagGroupItemContext = React.createContext>({
+ value: "",
+ disabled: false,
+})
+
+const useComboboxTagGroupItemContext = () => React.useContext(ComboboxTagGroupItemContext)
+
+export const ComboboxTagGroupItem = React.forwardRef<
+ React.ElementRef,
+ ComboboxTagGroupItemProps
+>(({ onClick, onKeyDown, value: valueProp, disabled, ...props }, ref) => {
+ const { value, onValueChange, inputRef, currentTabStopId, type } = useComboboxContext()
+
+ if (type !== "multiple") {
+ throw new Error(' should only be used when type is "multiple"')
+ }
+
+ const lastItemValue = value.at(-1)
+
+ return (
+
+ {
+ if (event.key === "Escape") {
+ inputRef.current?.focus()
+ }
+ if (event.key === "ArrowUp" || event.key === "ArrowDown") {
+ event.preventDefault()
+ inputRef.current?.focus()
+ }
+ if (event.key === "ArrowRight" && currentTabStopId === lastItemValue) {
+ inputRef.current?.focus()
+ }
+ if (event.key === "Backspace" || event.key === "Delete") {
+ onValueChange(value.filter((v) => v !== currentTabStopId))
+ inputRef.current?.focus()
+ }
+ })}
+ onClick={composeEventHandlers(onClick, () => disabled && inputRef.current?.focus())}
+ tabStopId={valueProp}
+ focusable={!disabled}
+ data-disabled={disabled}
+ active={valueProp === lastItemValue}
+ {...props}
+ />
+
+ )
+})
+ComboboxTagGroupItem.displayName = "ComboboxTagGroupItem"
+
+export const ComboboxTagGroupItemRemove = React.forwardRef<
+ React.ElementRef,
+ React.ComponentPropsWithoutRef
+>(({ onClick, ...props }, ref) => {
+ const { value, onValueChange, type } = useComboboxContext()
+
+ if (type !== "multiple") {
+ throw new Error(' should only be used when type is "multiple"')
+ }
+
+ const { value: valueProp, disabled } = useComboboxTagGroupItemContext()
+
+ return (
+ onValueChange(value.filter((v) => v !== valueProp)))}
+ {...props}
+ />
+ )
+})
+ComboboxTagGroupItemRemove.displayName = "ComboboxTagGroupItemRemove"
+
+export const ComboboxInput = React.forwardRef<
+ React.ElementRef,
+ Omit, "value" | "onValueChange">
+>(({ onKeyDown, onMouseDown, onFocus, onBlur, ...props }, ref) => {
+ const {
+ type,
+ inputValue,
+ onInputValueChange,
+ onInputBlur,
+ open,
+ onOpenChange,
+ value,
+ onValueChange,
+ inputRef,
+ disabled,
+ required,
+ tagGroupRef,
+ } = useComboboxContext()
+
+ const composedRefs = useComposedRefs(ref, inputRef)
+
+ return (
+ {
+ if (!open) {
+ onOpenChange(true)
+ }
+ // Schedule input value change to the next tick.
+ setTimeout(() => onInputValueChange(search, "inputChange"))
+ if (!search && type === "single") {
+ onValueChange("")
+ }
+ }}
+ onKeyDown={composeEventHandlers(onKeyDown, (event) => {
+ if (event.key === "ArrowUp" || event.key === "ArrowDown") {
+ if (!open) {
+ event.preventDefault()
+ onOpenChange(true)
+ }
+ }
+ if (type !== "multiple") {
+ return
+ }
+ if (event.key === "ArrowLeft" && !inputValue && value.length) {
+ tagGroupRef.current?.focus()
+ }
+ if (event.key === "Backspace" && !inputValue) {
+ onValueChange(value.slice(0, -1))
+ }
+ })}
+ onMouseDown={composeEventHandlers(onMouseDown, () => onOpenChange(!!inputValue || !open))}
+ onFocus={composeEventHandlers(onFocus, () => onOpenChange(true))}
+ onBlur={composeEventHandlers(onBlur, (event) => {
+ if (!event.relatedTarget?.hasAttribute("cmdk-list")) {
+ onInputBlur?.(event)
+ }
+ })}
+ {...props}
+ />
+ )
+})
+ComboboxInput.displayName = "ComboboxInput"
+
+export const ComboboxClear = React.forwardRef<
+ React.ElementRef,
+ React.ComponentPropsWithoutRef
+>(({ onClick, ...props }, ref) => {
+ const { value, onValueChange, inputValue, onInputValueChange, type } = useComboboxContext()
+
+ const isValueEmpty = type === "single" ? !value : !value.length
+
+ return (
+ {
+ if (type === "single") {
+ onValueChange("")
+ } else {
+ onValueChange([])
+ }
+ onInputValueChange("", "clearClick")
+ })}
+ {...props}
+ />
+ )
+})
+ComboboxClear.displayName = "ComboboxClear"
+
+export const ComboboxTrigger = PopoverPrimitive.Trigger
+
+export const ComboboxAnchor = PopoverPrimitive.Anchor
+
+export const ComboboxPortal = PopoverPrimitive.Portal
+
+export const ComboboxContent = React.forwardRef<
+ React.ElementRef,
+ React.ComponentPropsWithoutRef
+>(({ children, onOpenAutoFocus, onInteractOutside, ...props }, ref) => (
+ event.preventDefault())}
+ onCloseAutoFocus={composeEventHandlers(onOpenAutoFocus, (event) => event.preventDefault())}
+ onInteractOutside={composeEventHandlers(onInteractOutside, (event) => {
+ if (event.target instanceof Element && event.target.hasAttribute("cmdk-input")) {
+ event.preventDefault()
+ }
+ })}
+ {...props}>
+ {children}
+
+))
+ComboboxContent.displayName = "ComboboxContent"
+
+export const ComboboxEmpty = CommandPrimitive.Empty
+
+export const ComboboxLoading = CommandPrimitive.Loading
+
+export interface ComboboxItemProps extends Omit, "value"> {
+ value: string
+}
+
+const ComboboxItemContext = React.createContext({ isSelected: false })
+
+const useComboboxItemContext = () => React.useContext(ComboboxItemContext)
+
+const findComboboxItemText = (children: React.ReactNode) => {
+ let text = ""
+
+ React.Children.forEach(children, (child) => {
+ if (text) {
+ return
+ }
+
+ if (React.isValidElement<{ children: React.ReactNode }>(child)) {
+ if (child.type === ComboboxItemText) {
+ text = child.props.children as string
+ } else {
+ text = findComboboxItemText(child.props.children)
+ }
+ }
+ })
+
+ return text
+}
+
+export const ComboboxItem = React.forwardRef, ComboboxItemProps>(
+ ({ value: valueProp, children, onMouseDown, ...props }, ref) => {
+ const { type, value, onValueChange, onInputValueChange, onOpenChange } = useComboboxContext()
+
+ const inputValue = React.useMemo(() => findComboboxItemText(children), [children])
+
+ const isSelected = type === "single" ? value === valueProp : value.includes(valueProp)
+
+ return (
+
+ event.preventDefault())}
+ onSelect={() => {
+ if (type === "multiple") {
+ onValueChange(
+ value.includes(valueProp)
+ ? value.filter((v) => v !== valueProp)
+ : [...value, valueProp],
+ )
+ onInputValueChange("", "itemSelect")
+ } else {
+ onValueChange(valueProp)
+ onInputValueChange(inputValue, "itemSelect")
+ // Schedule open change to the next tick.
+ setTimeout(() => onOpenChange(false))
+ }
+ }}
+ value={inputValue}
+ {...props}>
+ {children}
+
+
+ )
+ },
+)
+ComboboxItem.displayName = "ComboboxItem"
+
+export const ComboboxItemIndicator = React.forwardRef<
+ React.ElementRef,
+ React.ComponentPropsWithoutRef
+>((props, ref) => {
+ const { isSelected } = useComboboxItemContext()
+
+ if (!isSelected) {
+ return null
+ }
+
+ return
+})
+ComboboxItemIndicator.displayName = "ComboboxItemIndicator"
+
+export interface ComboboxItemTextProps extends React.ComponentPropsWithoutRef {
+ children: string
+}
+
+export const ComboboxItemText = (props: ComboboxItemTextProps) =>
+ComboboxItemText.displayName = "ComboboxItemText"
+
+export const ComboboxGroup = CommandPrimitive.Group
+
+export const ComboboxSeparator = CommandPrimitive.Separator
+
+const Root = Combobox
+const TagGroup = ComboboxTagGroup
+const TagGroupItem = ComboboxTagGroupItem
+const TagGroupItemRemove = ComboboxTagGroupItemRemove
+const Input = ComboboxInput
+const Clear = ComboboxClear
+const Trigger = ComboboxTrigger
+const Anchor = ComboboxAnchor
+const Portal = ComboboxPortal
+const Content = ComboboxContent
+const Empty = ComboboxEmpty
+const Loading = ComboboxLoading
+const Item = ComboboxItem
+const ItemIndicator = ComboboxItemIndicator
+const ItemText = ComboboxItemText
+const Group = ComboboxGroup
+const Separator = ComboboxSeparator
+
+export {
+ Root,
+ TagGroup,
+ TagGroupItem,
+ TagGroupItemRemove,
+ Input,
+ Clear,
+ Trigger,
+ Anchor,
+ Portal,
+ Content,
+ Empty,
+ Loading,
+ Item,
+ ItemIndicator,
+ ItemText,
+ Group,
+ Separator,
+}
diff --git a/webview-ui/src/components/ui/combobox.tsx b/webview-ui/src/components/ui/combobox.tsx
new file mode 100644
index 00000000000..24b2f7be1f3
--- /dev/null
+++ b/webview-ui/src/components/ui/combobox.tsx
@@ -0,0 +1,177 @@
+"use client"
+
+import * as React from "react"
+import { Slottable } from "@radix-ui/react-slot"
+import { cva } from "class-variance-authority"
+import { Check, ChevronsUpDown, Loader, X } from "lucide-react"
+
+import { cn } from "@/lib/utils"
+import * as ComboboxPrimitive from "@/components/ui/combobox-primitive"
+import { badgeVariants } from "@/components/ui/badge"
+// import * as ComboboxPrimitive from "@/registry/default/ui/combobox-primitive"
+import {
+ InputBase,
+ InputBaseAdornmentButton,
+ InputBaseControl,
+ InputBaseFlexWrapper,
+ InputBaseInput,
+} from "@/components/ui/input-base"
+
+export const Combobox = ComboboxPrimitive.Root
+
+const ComboboxInputBase = React.forwardRef<
+ React.ElementRef,
+ React.ComponentPropsWithoutRef
+>(({ children, ...props }, ref) => (
+
+
+ {children}
+
+
+
+
+
+
+
+
+
+
+
+
+))
+ComboboxInputBase.displayName = "ComboboxInputBase"
+
+export const ComboboxInput = React.forwardRef<
+ React.ElementRef,
+ React.ComponentPropsWithoutRef
+>((props, ref) => (
+
+
+
+
+
+
+
+))
+ComboboxInput.displayName = "ComboboxInput"
+
+export const ComboboxTagsInput = React.forwardRef<
+ React.ElementRef,
+ React.ComponentPropsWithoutRef
+>(({ children, ...props }, ref) => (
+
+
+
+ {children}
+
+
+
+
+
+
+
+
+))
+ComboboxTagsInput.displayName = "ComboboxTagsInput"
+
+export const ComboboxTag = React.forwardRef<
+ React.ElementRef,
+ React.ComponentPropsWithoutRef
+>(({ children, className, ...props }, ref) => (
+
+ {children}
+
+
+ Remove
+
+
+))
+ComboboxTag.displayName = "ComboboxTag"
+
+export const ComboboxContent = React.forwardRef<
+ React.ElementRef,
+ React.ComponentPropsWithoutRef
+>(({ className, align = "start", alignOffset = 0, ...props }, ref) => (
+
+
+
+))
+ComboboxContent.displayName = "ComboboxContent"
+
+export const ComboboxEmpty = React.forwardRef<
+ React.ElementRef,
+ React.ComponentPropsWithoutRef
+>(({ className, ...props }, ref) => (
+
+))
+ComboboxEmpty.displayName = "ComboboxEmpty"
+
+export const ComboboxLoading = React.forwardRef<
+ React.ElementRef,
+ React.ComponentPropsWithoutRef
+>(({ className, ...props }, ref) => (
+
+
+
+))
+ComboboxLoading.displayName = "ComboboxLoading"
+
+export const ComboboxGroup = React.forwardRef<
+ React.ElementRef,
+ React.ComponentPropsWithoutRef
+>(({ className, ...props }, ref) => (
+
+))
+ComboboxGroup.displayName = "ComboboxGroup"
+
+const ComboboxSeparator = React.forwardRef<
+ React.ElementRef,
+ React.ComponentPropsWithoutRef
+>(({ className, ...props }, ref) => (
+
+))
+ComboboxSeparator.displayName = "ComboboxSeparator"
+
+export const comboboxItemStyle = cva(
+ "relative flex w-full cursor-pointer select-none items-center rounded-sm px-2 py-1.5 text-sm outline-none data-[disabled=true]:pointer-events-none data-[selected=true]:bg-accent data-[selected=true]:text-vscode-dropdown-foreground data-[disabled=true]:opacity-50",
+)
+
+export const ComboboxItem = React.forwardRef<
+ React.ElementRef,
+ Omit, "children"> &
+ Pick, "children">
+>(({ className, children, ...props }, ref) => (
+
+ {children}
+
+
+
+
+))
+ComboboxItem.displayName = "ComboboxItem"
diff --git a/webview-ui/src/components/ui/input-base.tsx b/webview-ui/src/components/ui/input-base.tsx
new file mode 100644
index 00000000000..9dbda6eb138
--- /dev/null
+++ b/webview-ui/src/components/ui/input-base.tsx
@@ -0,0 +1,157 @@
+/* eslint-disable react/jsx-no-comment-textnodes */
+/* eslint-disable react/jsx-pascal-case */
+"use client"
+
+import * as React from "react"
+import { composeEventHandlers } from "@radix-ui/primitive"
+import { composeRefs } from "@radix-ui/react-compose-refs"
+import { Primitive } from "@radix-ui/react-primitive"
+import { Slot } from "@radix-ui/react-slot"
+
+import { cn } from "@/lib/utils"
+import { Button } from "./button"
+
+export type InputBaseContextProps = Pick & {
+ controlRef: React.RefObject
+ onFocusedChange: (focused: boolean) => void
+}
+
+const InputBaseContext = React.createContext({
+ autoFocus: false,
+ controlRef: { current: null },
+ disabled: false,
+ onFocusedChange: () => {},
+})
+
+const useInputBaseContext = () => React.useContext(InputBaseContext)
+
+export interface InputBaseProps extends React.ComponentPropsWithoutRef {
+ autoFocus?: boolean
+ disabled?: boolean
+}
+
+export const InputBase = React.forwardRef, InputBaseProps>(
+ ({ autoFocus, disabled, className, onClick, ...props }, ref) => {
+ // eslint-disable-next-line @typescript-eslint/no-unused-vars
+ const [focused, setFocused] = React.useState(false)
+
+ const controlRef = React.useRef(null)
+
+ return (
+
+ {
+ // Based on MUI's implementation.
+ // https://github.com/mui/material-ui/blob/master/packages/mui-material/src/InputBase/InputBase.js#L458~L460
+ if (controlRef.current && event.currentTarget === event.target) {
+ controlRef.current.focus()
+ }
+ })}
+ className={cn(
+ "flex w-full text-vscode-input-foreground border border-vscode-dropdown-border bg-vscode-input-background rounded-xs px-3 py-0.5 text-base transition-colors file:border-0 file:bg-transparent file:text-sm file:font-medium file:text-foreground placeholder:text-muted-foreground focus:outline-0 focus-visible:outline-none focus-visible:border-vscode-focusBorder disabled:cursor-not-allowed disabled:opacity-50",
+ disabled && "cursor-not-allowed opacity-50",
+ className,
+ )}
+ {...props}
+ />
+
+ )
+ },
+)
+InputBase.displayName = "InputBase"
+
+export const InputBaseFlexWrapper = React.forwardRef<
+ React.ElementRef,
+ React.ComponentPropsWithoutRef
+>(({ className, ...props }, ref) => (
+
+))
+InputBaseFlexWrapper.displayName = "InputBaseFlexWrapper"
+
+export const InputBaseControl = React.forwardRef<
+ React.ElementRef,
+ React.ComponentPropsWithoutRef
+>(({ onFocus, onBlur, ...props }, ref) => {
+ const { controlRef, autoFocus, disabled, onFocusedChange } = useInputBaseContext()
+
+ return (
+ onFocusedChange(true))}
+ onBlur={composeEventHandlers(onBlur, () => onFocusedChange(false))}
+ {...{ disabled }}
+ {...props}
+ />
+ )
+})
+InputBaseControl.displayName = "InputBaseControl"
+
+export interface InputBaseAdornmentProps extends React.ComponentPropsWithoutRef<"div"> {
+ asChild?: boolean
+ disablePointerEvents?: boolean
+}
+
+export const InputBaseAdornment = React.forwardRef, InputBaseAdornmentProps>(
+ ({ className, disablePointerEvents, asChild, children, ...props }, ref) => {
+ const Comp = asChild ? Slot : typeof children === "string" ? "p" : "div"
+
+ const isAction = React.isValidElement(children) && children.type === InputBaseAdornmentButton
+
+ return (
+
+ {children}
+
+ )
+ },
+)
+InputBaseAdornment.displayName = "InputBaseAdornment"
+
+export const InputBaseAdornmentButton = React.forwardRef<
+ React.ElementRef,
+ React.ComponentPropsWithoutRef
+>(({ type = "button", variant = "ghost", size = "icon", disabled: disabledProp, className, ...props }, ref) => {
+ const { disabled } = useInputBaseContext()
+
+ return (
+
+ )
+})
+InputBaseAdornmentButton.displayName = "InputBaseAdornmentButton"
+
+export const InputBaseInput = React.forwardRef<
+ React.ElementRef,
+ React.ComponentPropsWithoutRef
+>(({ className, ...props }, ref) => (
+
+))
+InputBaseInput.displayName = "InputBaseInput"
From f065f039be6af037480314937fbef91b96f77045 Mon Sep 17 00:00:00 2001
From: System233
Date: Wed, 26 Feb 2025 06:01:33 +0800
Subject: [PATCH 2/4] Fixed the issue that Model ID cannot be saved
---
.../src/components/settings/ApiOptions.tsx | 194 ++++++++++++---
.../components/settings/GlamaModelPicker.tsx | 15 --
.../src/components/settings/ModelPicker.tsx | 227 +++++-------------
.../components/settings/OpenAiModelPicker.tsx | 27 ---
.../settings/OpenRouterModelPicker.tsx | 15 --
.../settings/RequestyModelPicker.tsx | 22 --
.../src/components/settings/SettingsView.tsx | 6 +-
.../settings/UnboundModelPicker.tsx | 15 --
.../settings/__tests__/ModelPicker.test.tsx | 59 +++--
webview-ui/src/utils/validate.ts | 13 +-
10 files changed, 257 insertions(+), 336 deletions(-)
delete mode 100644 webview-ui/src/components/settings/GlamaModelPicker.tsx
delete mode 100644 webview-ui/src/components/settings/OpenAiModelPicker.tsx
delete mode 100644 webview-ui/src/components/settings/OpenRouterModelPicker.tsx
delete mode 100644 webview-ui/src/components/settings/RequestyModelPicker.tsx
delete mode 100644 webview-ui/src/components/settings/UnboundModelPicker.tsx
diff --git a/webview-ui/src/components/settings/ApiOptions.tsx b/webview-ui/src/components/settings/ApiOptions.tsx
index 8c2f382db6f..594cd2fd5fb 100644
--- a/webview-ui/src/components/settings/ApiOptions.tsx
+++ b/webview-ui/src/components/settings/ApiOptions.tsx
@@ -38,18 +38,14 @@ import { ExtensionMessage } from "../../../../src/shared/ExtensionMessage"
import { vscode } from "../../utils/vscode"
import VSCodeButtonLink from "../common/VSCodeButtonLink"
-import { OpenRouterModelPicker } from "./OpenRouterModelPicker"
-import OpenAiModelPicker from "./OpenAiModelPicker"
-import { GlamaModelPicker } from "./GlamaModelPicker"
-import { UnboundModelPicker } from "./UnboundModelPicker"
import { ModelInfoView } from "./ModelInfoView"
import { DROPDOWN_Z_INDEX } from "./styles"
-import { RequestyModelPicker } from "./RequestyModelPicker"
+import { ModelPicker } from "./ModelPicker"
import { TemperatureControl } from "./TemperatureControl"
interface ApiOptionsProps {
uriScheme: string | undefined
- apiConfiguration: ApiConfiguration | undefined
+ apiConfiguration: ApiConfiguration
setApiConfigurationField: (field: K, value: ApiConfiguration[K]) => void
apiErrorMessage?: string
modelIdErrorMessage?: string
@@ -67,6 +63,20 @@ const ApiOptions = ({
const [ollamaModels, setOllamaModels] = useState([])
const [lmStudioModels, setLmStudioModels] = useState([])
const [vsCodeLmModels, setVsCodeLmModels] = useState([])
+ const [openRouterModels, setOpenRouterModels] = useState>({
+ [openRouterDefaultModelId]: openRouterDefaultModelInfo,
+ })
+ const [glamaModels, setGlamaModels] = useState>({
+ [glamaDefaultModelId]: glamaDefaultModelInfo,
+ })
+ const [unboundModels, setUnboundModels] = useState>({
+ [unboundDefaultModelId]: unboundDefaultModelInfo,
+ })
+ const [requestyModels, setRequestyModels] = useState>({
+ [requestyDefaultModelId]: requestyDefaultModelInfo,
+ })
+ const [openAiModels, setOpenAiModels] = useState | null>(null)
+
const [anthropicBaseUrlSelected, setAnthropicBaseUrlSelected] = useState(!!apiConfiguration?.anthropicBaseUrl)
const [azureApiVersionSelected, setAzureApiVersionSelected] = useState(!!apiConfiguration?.azureApiVersion)
const [openRouterBaseUrlSelected, setOpenRouterBaseUrlSelected] = useState(!!apiConfiguration?.openRouterBaseUrl)
@@ -104,24 +114,93 @@ const ApiOptions = ({
vscode.postMessage({ type: "requestLmStudioModels", text: apiConfiguration?.lmStudioBaseUrl })
} else if (selectedProvider === "vscode-lm") {
vscode.postMessage({ type: "requestVsCodeLmModels" })
+ } else if (selectedProvider === "openai") {
+ vscode.postMessage({
+ type: "refreshOpenAiModels",
+ values: {
+ baseUrl: apiConfiguration?.openAiBaseUrl,
+ apiKey: apiConfiguration?.openAiApiKey,
+ },
+ })
+ } else if (selectedProvider === "openrouter") {
+ vscode.postMessage({ type: "refreshOpenRouterModels", values: {} })
+ } else if (selectedProvider === "glama") {
+ vscode.postMessage({ type: "refreshGlamaModels", values: {} })
+ } else if (selectedProvider === "requesty") {
+ vscode.postMessage({
+ type: "refreshRequestyModels",
+ values: {
+ apiKey: apiConfiguration?.requestyApiKey,
+ },
+ })
}
},
250,
- [selectedProvider, apiConfiguration?.ollamaBaseUrl, apiConfiguration?.lmStudioBaseUrl],
+ [
+ selectedProvider,
+ apiConfiguration?.ollamaBaseUrl,
+ apiConfiguration?.lmStudioBaseUrl,
+ apiConfiguration?.openAiBaseUrl,
+ apiConfiguration?.openAiApiKey,
+ apiConfiguration?.requestyApiKey,
+ ],
)
const handleMessage = useCallback((event: MessageEvent) => {
const message: ExtensionMessage = event.data
-
- if (message.type === "ollamaModels" && Array.isArray(message.ollamaModels)) {
- const newModels = message.ollamaModels
- setOllamaModels(newModels)
- } else if (message.type === "lmStudioModels" && Array.isArray(message.lmStudioModels)) {
- const newModels = message.lmStudioModels
- setLmStudioModels(newModels)
- } else if (message.type === "vsCodeLmModels" && Array.isArray(message.vsCodeLmModels)) {
- const newModels = message.vsCodeLmModels
- setVsCodeLmModels(newModels)
+ switch (message.type) {
+ case "ollamaModels":
+ {
+ const newModels = message.ollamaModels ?? []
+ setOllamaModels(newModels)
+ }
+ break
+ case "lmStudioModels":
+ {
+ const newModels = message.lmStudioModels ?? []
+ setLmStudioModels(newModels)
+ }
+ break
+ case "vsCodeLmModels":
+ {
+ const newModels = message.vsCodeLmModels ?? []
+ setVsCodeLmModels(newModels)
+ }
+ break
+ case "glamaModels": {
+ const updatedModels = message.glamaModels ?? {}
+ setGlamaModels({
+ [glamaDefaultModelId]: glamaDefaultModelInfo, // in case the extension sent a model list without the default model
+ ...updatedModels,
+ })
+ break
+ }
+ case "openRouterModels": {
+ const updatedModels = message.openRouterModels ?? {}
+ setOpenRouterModels({
+ [openRouterDefaultModelId]: openRouterDefaultModelInfo, // in case the extension sent a model list without the default model
+ ...updatedModels,
+ })
+ break
+ }
+ case "openAiModels": {
+ const updatedModels = message.openAiModels ?? []
+ setOpenAiModels(Object.fromEntries(updatedModels.map((item) => [item, openAiModelInfoSaneDefaults])))
+ break
+ }
+ case "unboundModels": {
+ const updatedModels = message.unboundModels ?? {}
+ setUnboundModels(updatedModels)
+ break
+ }
+ case "requestyModels": {
+ const updatedModels = message.requestyModels ?? {}
+ setRequestyModels({
+ [requestyDefaultModelId]: requestyDefaultModelInfo, // in case the extension sent a model list without the default model
+ ...updatedModels,
+ })
+ break
+ }
}
}, [])
@@ -616,7 +695,17 @@ const ApiOptions = ({
placeholder="Enter API Key...">
API Key
-
+
{
+ onInput={handleInputChange("openAiCustomModelInfo", (e) => {
const value = parseInt((e.target as HTMLInputElement).value)
return {
...(apiConfiguration?.openAiCustomModelInfo ||
@@ -751,7 +840,7 @@ const ApiOptions = ({
})(),
}}
title="Total number of tokens (input + output) the model can process in a single request"
- onChange={handleInputChange("openAiCustomModelInfo", (e) => {
+ onInput={handleInputChange("openAiCustomModelInfo", (e) => {
const value = (e.target as HTMLInputElement).value
const parsed = parseInt(value)
return {
@@ -897,7 +986,7 @@ const ApiOptions = ({
: "var(--vscode-errorForeground)"
})(),
}}
- onChange={handleInputChange("openAiCustomModelInfo", (e) => {
+ onInput={handleInputChange("openAiCustomModelInfo", (e) => {
const value = (e.target as HTMLInputElement).value
const parsed = parseInt(value)
return {
@@ -942,7 +1031,7 @@ const ApiOptions = ({
: "var(--vscode-errorForeground)"
})(),
}}
- onChange={handleInputChange("openAiCustomModelInfo", (e) => {
+ onInput={handleInputChange("openAiCustomModelInfo", (e) => {
const value = (e.target as HTMLInputElement).value
const parsed = parseInt(value)
return {
@@ -1011,6 +1100,7 @@ const ApiOptions = ({
placeholder={"e.g. meta-llama-3.1-8b-instruct"}>
Model ID
+
{lmStudioModels.length > 0 && (
This key is stored locally and only used to make API requests from this extension.
-
+
)}
@@ -1236,9 +1337,49 @@ const ApiOptions = ({
)}
- {selectedProvider === "glama" && }
- {selectedProvider === "openrouter" && }
- {selectedProvider === "requesty" && }
+ {selectedProvider === "glama" && (
+
+ )}
+
+ {selectedProvider === "openrouter" && (
+
+ )}
+ {selectedProvider === "requesty" && (
+
+ )}
{selectedProvider !== "glama" &&
selectedProvider !== "openrouter" &&
@@ -1260,7 +1401,6 @@ const ApiOptions = ({
{selectedProvider === "deepseek" && createDropdown(deepSeekModels)}
{selectedProvider === "mistral" && createDropdown(mistralModels)}
-
(
-
-)
diff --git a/webview-ui/src/components/settings/ModelPicker.tsx b/webview-ui/src/components/settings/ModelPicker.tsx
index b21b37ef0f4..8fd6d82daa7 100644
--- a/webview-ui/src/components/settings/ModelPicker.tsx
+++ b/webview-ui/src/components/settings/ModelPicker.tsx
@@ -1,185 +1,90 @@
import { VSCodeLink } from "@vscode/webview-ui-toolkit/react"
-import debounce from "debounce"
-import { useMemo, useState, useCallback, useEffect, useRef } from "react"
-import { useMount } from "react-use"
-import { CaretSortIcon, CheckIcon } from "@radix-ui/react-icons"
+import { useMemo, useState, useCallback, useEffect } from "react"
-import { cn } from "@/lib/utils"
-import {
- Button,
- Command,
- CommandEmpty,
- CommandGroup,
- CommandInput,
- CommandItem,
- CommandList,
- Popover,
- PopoverContent,
- PopoverTrigger,
-} from "@/components/ui"
-
-import { useExtensionState } from "../../context/ExtensionStateContext"
-import { vscode } from "../../utils/vscode"
import { normalizeApiConfiguration } from "./ApiOptions"
import { ModelInfoView } from "./ModelInfoView"
-
-type ModelProvider = "glama" | "openRouter" | "unbound" | "requesty" | "openAi"
-
-type ModelKeys = `${T}Models`
-type ConfigKeys = `${T}ModelId`
-type InfoKeys = `${T}ModelInfo`
-type RefreshMessageType = `refresh${Capitalize}Models`
-
-interface ModelPickerProps {
- defaultModelId: string
- modelsKey: ModelKeys
- configKey: ConfigKeys
- infoKey: InfoKeys
- refreshMessageType: RefreshMessageType
- refreshValues?: Record
+import { ApiConfiguration, ModelInfo } from "../../../../src/shared/api"
+import { Combobox, ComboboxContent, ComboboxEmpty, ComboboxInput, ComboboxItem } from "../ui/combobox"
+
+type ExtractType = NonNullable<
+ { [K in keyof ApiConfiguration]: Required[K] extends T ? K : never }[keyof ApiConfiguration]
+>
+
+type ModelIdKeys = NonNullable<
+ { [K in keyof ApiConfiguration]: K extends `${string}ModelId` ? K : never }[keyof ApiConfiguration]
+>
+declare module "react" {
+ interface CSSProperties {
+ // Allow CSS variables
+ [key: `--${string}`]: string | number
+ }
+}
+interface ModelPickerProps {
+ defaultModelId?: string
+ models: Record | null
+ modelIdKey: ModelIdKeys
+ modelInfoKey: ExtractType
serviceName: string
serviceUrl: string
recommendedModel: string
- allowCustomModel?: boolean
+ apiConfiguration: ApiConfiguration
+ setApiConfigurationField: (field: K, value: ApiConfiguration[K]) => void
+ defaultModelInfo?: ModelInfo
}
export const ModelPicker = ({
defaultModelId,
- modelsKey,
- configKey,
- infoKey,
- refreshMessageType,
- refreshValues,
+ models,
+ modelIdKey,
+ modelInfoKey,
serviceName,
serviceUrl,
recommendedModel,
- allowCustomModel = false,
+ apiConfiguration,
+ setApiConfigurationField,
+ defaultModelInfo,
}: ModelPickerProps) => {
- const [customModelId, setCustomModelId] = useState("")
- const [isCustomModel, setIsCustomModel] = useState(false)
- const [open, setOpen] = useState(false)
- const [value, setValue] = useState(defaultModelId)
const [isDescriptionExpanded, setIsDescriptionExpanded] = useState(false)
- const prevRefreshValuesRef = useRef | undefined>()
-
- const { apiConfiguration, [modelsKey]: models, onUpdateApiConfig, setApiConfiguration } = useExtensionState()
- const modelIds = useMemo(
- () => (Array.isArray(models) ? models : Object.keys(models)).sort((a, b) => a.localeCompare(b)),
- [models],
- )
+ const modelIds = useMemo(() => Object.keys(models ?? {}).sort((a, b) => a.localeCompare(b)), [models])
const { selectedModelId, selectedModelInfo } = useMemo(
() => normalizeApiConfiguration(apiConfiguration),
[apiConfiguration],
)
-
- const onSelectCustomModel = useCallback(
- (modelId: string) => {
- setCustomModelId(modelId)
- const modelInfo = { id: modelId }
- const apiConfig = { ...apiConfiguration, [configKey]: modelId, [infoKey]: modelInfo }
- setApiConfiguration(apiConfig)
- onUpdateApiConfig(apiConfig)
- setValue(modelId)
- setOpen(false)
- setIsCustomModel(false)
- },
- [apiConfiguration, configKey, infoKey, onUpdateApiConfig, setApiConfiguration],
- )
-
const onSelect = useCallback(
(modelId: string) => {
- const modelInfo = Array.isArray(models)
- ? { id: modelId } // For OpenAI models which are just strings
- : models[modelId] // For other models that have full info objects
- const apiConfig = { ...apiConfiguration, [configKey]: modelId, [infoKey]: modelInfo }
- setApiConfiguration(apiConfig)
- onUpdateApiConfig(apiConfig)
- setValue(modelId)
- setOpen(false)
+ const modelInfo = models?.[modelId]
+ setApiConfigurationField(modelIdKey, modelId)
+ setApiConfigurationField(modelInfoKey, modelInfo ?? defaultModelInfo)
},
- [apiConfiguration, configKey, infoKey, models, onUpdateApiConfig, setApiConfiguration],
+ [modelIdKey, modelInfoKey, models, setApiConfigurationField, defaultModelInfo],
)
-
- const debouncedRefreshModels = useMemo(() => {
- return debounce(() => {
- const message = refreshValues
- ? { type: refreshMessageType, values: refreshValues }
- : { type: refreshMessageType }
- vscode.postMessage(message)
- }, 100)
- }, [refreshMessageType, refreshValues])
-
- useMount(() => {
- debouncedRefreshModels()
- return () => debouncedRefreshModels.clear()
- })
-
useEffect(() => {
- if (!refreshValues) {
- prevRefreshValuesRef.current = undefined
- return
- }
-
- // Check if all values in refreshValues are truthy
- if (Object.values(refreshValues).some((value) => !value)) {
- prevRefreshValuesRef.current = undefined
- return
- }
-
- // Compare with previous values
- const prevValues = prevRefreshValuesRef.current
- if (prevValues && JSON.stringify(prevValues) === JSON.stringify(refreshValues)) {
- return
+ if (apiConfiguration[modelIdKey] == null && defaultModelId) {
+ onSelect(defaultModelId)
}
-
- prevRefreshValuesRef.current = refreshValues
- debouncedRefreshModels()
- }, [debouncedRefreshModels, refreshValues])
-
- useEffect(() => setValue(selectedModelId), [selectedModelId])
+ }, [apiConfiguration, defaultModelId, modelIdKey, onSelect])
return (
<>
Model
-
-
-
-
-
-
-
-
- No model found.
-
- {modelIds.map((model) => (
-
- {model}
-
-
- ))}
-
- {allowCustomModel && (
-
- {
- setIsCustomModel(true)
- setOpen(false)
- }}>
- + Add custom model
-
-
- )}
-
-
-
-
+
+
+
+ No model found.
+ {modelIds.map((model) => (
+
+ {model}
+
+ ))}
+
+
+
{selectedModelId && selectedModelInfo && (
onSelect(recommendedModel)}>{recommendedModel}.
You can also try searching "free" for no-cost options currently available.
- {allowCustomModel && isCustomModel && (
-
-
-
Add Custom Model
-
setCustomModelId(e.target.value)}
- />
-
-
-
-
-
-
- )}
>
)
}
diff --git a/webview-ui/src/components/settings/OpenAiModelPicker.tsx b/webview-ui/src/components/settings/OpenAiModelPicker.tsx
deleted file mode 100644
index 040da1d4210..00000000000
--- a/webview-ui/src/components/settings/OpenAiModelPicker.tsx
+++ /dev/null
@@ -1,27 +0,0 @@
-import React from "react"
-import { useExtensionState } from "../../context/ExtensionStateContext"
-import { ModelPicker } from "./ModelPicker"
-
-const OpenAiModelPicker: React.FC = () => {
- const { apiConfiguration } = useExtensionState()
-
- return (
-
- )
-}
-
-export default OpenAiModelPicker
diff --git a/webview-ui/src/components/settings/OpenRouterModelPicker.tsx b/webview-ui/src/components/settings/OpenRouterModelPicker.tsx
deleted file mode 100644
index c773478e542..00000000000
--- a/webview-ui/src/components/settings/OpenRouterModelPicker.tsx
+++ /dev/null
@@ -1,15 +0,0 @@
-import { ModelPicker } from "./ModelPicker"
-import { openRouterDefaultModelId } from "../../../../src/shared/api"
-
-export const OpenRouterModelPicker = () => (
-
-)
diff --git a/webview-ui/src/components/settings/RequestyModelPicker.tsx b/webview-ui/src/components/settings/RequestyModelPicker.tsx
deleted file mode 100644
index c65067068aa..00000000000
--- a/webview-ui/src/components/settings/RequestyModelPicker.tsx
+++ /dev/null
@@ -1,22 +0,0 @@
-import { ModelPicker } from "./ModelPicker"
-import { requestyDefaultModelId } from "../../../../src/shared/api"
-import { useExtensionState } from "@/context/ExtensionStateContext"
-
-export const RequestyModelPicker = () => {
- const { apiConfiguration } = useExtensionState()
- return (
-
- )
-}
diff --git a/webview-ui/src/components/settings/SettingsView.tsx b/webview-ui/src/components/settings/SettingsView.tsx
index 761e8565214..75ba11107c4 100644
--- a/webview-ui/src/components/settings/SettingsView.tsx
+++ b/webview-ui/src/components/settings/SettingsView.tsx
@@ -1,4 +1,4 @@
-import { forwardRef, memo, useCallback, useEffect, useImperativeHandle, useRef, useState } from "react"
+import { forwardRef, memo, useCallback, useEffect, useImperativeHandle, useMemo, useRef, useState } from "react"
import { VSCodeButton, VSCodeCheckbox, VSCodeLink, VSCodeTextField } from "@vscode/webview-ui-toolkit/react"
import { Dropdown, type DropdownOption } from "vscrui"
@@ -45,7 +45,6 @@ const SettingsView = forwardRef(({ onDone },
// TODO: Reduce WebviewMessage/ExtensionState complexity
const { currentApiConfigName } = extensionState
const {
- apiConfiguration,
alwaysAllowReadOnly,
allowedCommands,
alwaysAllowBrowser,
@@ -69,6 +68,9 @@ const SettingsView = forwardRef(({ onDone },
terminalOutputLineLimit,
writeDelayMs,
} = cachedState
+
+ //Make sure apiConfiguration is initialized and managed by SettingsView
+ const apiConfiguration = useMemo(() => cachedState.apiConfiguration ?? {}, [cachedState.apiConfiguration])
useEffect(() => {
// Update only when currentApiConfigName is changed
diff --git a/webview-ui/src/components/settings/UnboundModelPicker.tsx b/webview-ui/src/components/settings/UnboundModelPicker.tsx
deleted file mode 100644
index 4901884f1e6..00000000000
--- a/webview-ui/src/components/settings/UnboundModelPicker.tsx
+++ /dev/null
@@ -1,15 +0,0 @@
-import { ModelPicker } from "./ModelPicker"
-import { unboundDefaultModelId } from "../../../../src/shared/api"
-
-export const UnboundModelPicker = () => (
-
-)
diff --git a/webview-ui/src/components/settings/__tests__/ModelPicker.test.tsx b/webview-ui/src/components/settings/__tests__/ModelPicker.test.tsx
index 4e7c67c1872..49d60c55c48 100644
--- a/webview-ui/src/components/settings/__tests__/ModelPicker.test.tsx
+++ b/webview-ui/src/components/settings/__tests__/ModelPicker.test.tsx
@@ -3,7 +3,6 @@
import { screen, fireEvent, render } from "@testing-library/react"
import { act } from "react"
import { ModelPicker } from "../ModelPicker"
-import { useExtensionState } from "../../../context/ExtensionStateContext"
jest.mock("../../../context/ExtensionStateContext", () => ({
useExtensionState: jest.fn(),
@@ -20,36 +19,40 @@ global.ResizeObserver = MockResizeObserver
Element.prototype.scrollIntoView = jest.fn()
describe("ModelPicker", () => {
- const mockOnUpdateApiConfig = jest.fn()
- const mockSetApiConfiguration = jest.fn()
-
+ const mockSetApiConfigurationField = jest.fn()
+ const modelInfo = {
+ maxTokens: 8192,
+ contextWindow: 200_000,
+ supportsImages: true,
+ supportsComputerUse: true,
+ supportsPromptCache: true,
+ inputPrice: 3.0,
+ outputPrice: 15.0,
+ cacheWritesPrice: 3.75,
+ cacheReadsPrice: 0.3,
+ }
+ const mockModels = {
+ model1: { name: "Model 1", description: "Test model 1", ...modelInfo },
+ model2: { name: "Model 2", description: "Test model 2", ...modelInfo },
+ }
const defaultProps = {
+ apiConfiguration: {},
defaultModelId: "model1",
- modelsKey: "glamaModels" as const,
- configKey: "glamaModelId" as const,
- infoKey: "glamaModelInfo" as const,
- refreshMessageType: "refreshGlamaModels" as const,
+ defaultModelInfo: modelInfo,
+ modelIdKey: "glamaModelId" as const,
+ modelInfoKey: "glamaModelInfo" as const,
serviceName: "Test Service",
serviceUrl: "https://test.service",
recommendedModel: "recommended-model",
- }
-
- const mockModels = {
- model1: { name: "Model 1", description: "Test model 1" },
- model2: { name: "Model 2", description: "Test model 2" },
+ models: mockModels,
+ setApiConfigurationField: mockSetApiConfigurationField,
}
beforeEach(() => {
jest.clearAllMocks()
- ;(useExtensionState as jest.Mock).mockReturnValue({
- apiConfiguration: {},
- setApiConfiguration: mockSetApiConfiguration,
- glamaModels: mockModels,
- onUpdateApiConfig: mockOnUpdateApiConfig,
- })
})
- it("calls onUpdateApiConfig when a model is selected", async () => {
+ it("calls setApiConfigurationField when a model is selected", async () => {
await act(async () => {
render()
})
@@ -67,20 +70,12 @@ describe("ModelPicker", () => {
await act(async () => {
// Find and click the model item by its value.
- const modelItem = screen.getByRole("option", { name: "model2" })
- fireEvent.click(modelItem)
+ const modelItem = screen.getByTestId("model-input")
+ fireEvent.input(modelItem, { target: { value: "model2" } })
})
// Verify the API config was updated.
- expect(mockSetApiConfiguration).toHaveBeenCalledWith({
- glamaModelId: "model2",
- glamaModelInfo: mockModels["model2"],
- })
-
- // Verify onUpdateApiConfig was called with the new config.
- expect(mockOnUpdateApiConfig).toHaveBeenCalledWith({
- glamaModelId: "model2",
- glamaModelInfo: mockModels["model2"],
- })
+ expect(mockSetApiConfigurationField).toHaveBeenCalledWith(defaultProps.modelIdKey, "model2")
+ expect(mockSetApiConfigurationField).toHaveBeenCalledWith(defaultProps.modelInfoKey, mockModels.model2)
})
})
diff --git a/webview-ui/src/utils/validate.ts b/webview-ui/src/utils/validate.ts
index 19b13e2c6c2..97c702637c4 100644
--- a/webview-ui/src/utils/validate.ts
+++ b/webview-ui/src/utils/validate.ts
@@ -1,9 +1,4 @@
-import {
- ApiConfiguration,
- glamaDefaultModelId,
- openRouterDefaultModelId,
- unboundDefaultModelId,
-} from "../../../src/shared/api"
+import { ApiConfiguration } from "../../../src/shared/api"
import { ModelInfo } from "../../../src/shared/api"
export function validateApiConfiguration(apiConfiguration?: ApiConfiguration): string | undefined {
if (apiConfiguration) {
@@ -86,7 +81,7 @@ export function validateModelId(
if (apiConfiguration) {
switch (apiConfiguration.apiProvider) {
case "glama":
- const glamaModelId = apiConfiguration.glamaModelId || glamaDefaultModelId // in case the user hasn't changed the model id, it will be undefined by default
+ const glamaModelId = apiConfiguration.glamaModelId
if (!glamaModelId) {
return "You must provide a model ID."
}
@@ -96,7 +91,7 @@ export function validateModelId(
}
break
case "openrouter":
- const modelId = apiConfiguration.openRouterModelId || openRouterDefaultModelId // in case the user hasn't changed the model id, it will be undefined by default
+ const modelId = apiConfiguration.openRouterModelId
if (!modelId) {
return "You must provide a model ID."
}
@@ -106,7 +101,7 @@ export function validateModelId(
}
break
case "unbound":
- const unboundModelId = apiConfiguration.unboundModelId || unboundDefaultModelId
+ const unboundModelId = apiConfiguration.unboundModelId
if (!unboundModelId) {
return "You must provide a model ID."
}
From 1a3b8700ba2c8f93f31cdfed7aa0fadb38d584d0 Mon Sep 17 00:00:00 2001
From: System233
Date: Wed, 26 Feb 2025 06:42:29 +0800
Subject: [PATCH 3/4] Improved error message feedback in settings panel
---
webview-ui/src/__mocks__/vscrui.ts | 3 +
.../components/settings/ApiErrorMessage.tsx | 16 +++++
.../src/components/settings/ApiOptions.tsx | 69 ++++++++-----------
.../src/components/settings/ModelPicker.tsx | 41 ++++++++---
.../src/components/settings/SettingsView.tsx | 53 ++++----------
.../settings/__tests__/ApiOptions.test.tsx | 4 ++
.../src/components/welcome/WelcomeView.tsx | 2 +
7 files changed, 98 insertions(+), 90 deletions(-)
create mode 100644 webview-ui/src/components/settings/ApiErrorMessage.tsx
diff --git a/webview-ui/src/__mocks__/vscrui.ts b/webview-ui/src/__mocks__/vscrui.ts
index 76760ba5cce..9b4a20f4d6b 100644
--- a/webview-ui/src/__mocks__/vscrui.ts
+++ b/webview-ui/src/__mocks__/vscrui.ts
@@ -8,6 +8,9 @@ export const Dropdown = ({ children, value, onChange }: any) =>
export const Pane = ({ children }: any) => React.createElement("div", { "data-testid": "mock-pane" }, children)
+export const Button = ({ children, ...props }: any) =>
+ React.createElement("div", { "data-testid": "mock-button", ...props }, children)
+
export type DropdownOption = {
label: string
value: string
diff --git a/webview-ui/src/components/settings/ApiErrorMessage.tsx b/webview-ui/src/components/settings/ApiErrorMessage.tsx
new file mode 100644
index 00000000000..4b419957b6c
--- /dev/null
+++ b/webview-ui/src/components/settings/ApiErrorMessage.tsx
@@ -0,0 +1,16 @@
+import React from "react"
+
+interface ApiErrorMessageProps {
+ errorMessage: string | undefined
+ children?: React.ReactNode
+}
+const ApiErrorMessage = ({ errorMessage, children }: ApiErrorMessageProps) => {
+ return (
+
+
+ {errorMessage}
+ {children}
+
+ )
+}
+export default ApiErrorMessage
diff --git a/webview-ui/src/components/settings/ApiOptions.tsx b/webview-ui/src/components/settings/ApiOptions.tsx
index 594cd2fd5fb..f0c2b0e45fa 100644
--- a/webview-ui/src/components/settings/ApiOptions.tsx
+++ b/webview-ui/src/components/settings/ApiOptions.tsx
@@ -1,4 +1,4 @@
-import { memo, useCallback, useMemo, useState } from "react"
+import React, { memo, useCallback, useEffect, useMemo, useState } from "react"
import { useDebounce, useEvent } from "react-use"
import { Checkbox, Dropdown, Pane, type DropdownOption } from "vscrui"
import { VSCodeLink, VSCodeRadio, VSCodeRadioGroup, VSCodeTextField } from "@vscode/webview-ui-toolkit/react"
@@ -42,23 +42,25 @@ import { ModelInfoView } from "./ModelInfoView"
import { DROPDOWN_Z_INDEX } from "./styles"
import { ModelPicker } from "./ModelPicker"
import { TemperatureControl } from "./TemperatureControl"
+import { validateApiConfiguration, validateModelId } from "@/utils/validate"
+import ApiErrorMessage from "./ApiErrorMessage"
interface ApiOptionsProps {
uriScheme: string | undefined
apiConfiguration: ApiConfiguration
setApiConfigurationField: (field: K, value: ApiConfiguration[K]) => void
- apiErrorMessage?: string
- modelIdErrorMessage?: string
fromWelcomeView?: boolean
+ errorMessage: string | undefined
+ setErrorMessage: React.Dispatch>
}
const ApiOptions = ({
uriScheme,
apiConfiguration,
setApiConfigurationField,
- apiErrorMessage,
- modelIdErrorMessage,
fromWelcomeView,
+ errorMessage,
+ setErrorMessage,
}: ApiOptionsProps) => {
const [ollamaModels, setOllamaModels] = useState([])
const [lmStudioModels, setLmStudioModels] = useState([])
@@ -146,6 +148,13 @@ const ApiOptions = ({
],
)
+ useEffect(() => {
+ const apiValidationResult =
+ validateApiConfiguration(apiConfiguration) ||
+ validateModelId(apiConfiguration, glamaModels, openRouterModels, unboundModels)
+ setErrorMessage(apiValidationResult)
+ }, [apiConfiguration, glamaModels, openRouterModels, setErrorMessage, unboundModels])
+
const handleMessage = useCallback((event: MessageEvent) => {
const message: ExtensionMessage = event.data
switch (message.type) {
@@ -626,6 +635,7 @@ const ApiOptions = ({
]}
/>
+ {errorMessage && }
{/* end Model Info Configuration */}
-
-
-
- (Note: Roo Code uses complex prompts and works best
- with Claude models. Less capable models may not work as expected.)
-
-
)}
@@ -1100,6 +1099,7 @@ const ApiOptions = ({
placeholder={"e.g. meta-llama-3.1-8b-instruct"}>
Model ID
+ {errorMessage && }
{lmStudioModels.length > 0 && (
Model ID
+ {errorMessage && (
+
+
+ {errorMessage}
+
+ )}
{ollamaModels.length > 0 && (
)}
- {apiErrorMessage && (
-
-
- {apiErrorMessage}
-
- )}
-
{selectedProvider === "glama" && (
)}
@@ -1364,6 +1360,7 @@ const ApiOptions = ({
serviceName="OpenRouter"
serviceUrl="https://openrouter.ai/models"
recommendedModel="anthropic/claude-3.7-sonnet"
+ errorMessage={errorMessage}
/>
)}
{selectedProvider === "requesty" && (
@@ -1378,6 +1375,7 @@ const ApiOptions = ({
serviceName="Requesty"
serviceUrl="https://requesty.ai"
recommendedModel="anthropic/claude-3-7-sonnet-latest"
+ errorMessage={errorMessage}
/>
)}
@@ -1401,6 +1399,7 @@ const ApiOptions = ({
{selectedProvider === "deepseek" && createDropdown(deepSeekModels)}
{selectedProvider === "mistral" && createDropdown(mistralModels)}
+ {errorMessage && }
)}
-
- {modelIdErrorMessage && (
-
-
- {modelIdErrorMessage}
-
- )}
)
}
diff --git a/webview-ui/src/components/settings/ModelPicker.tsx b/webview-ui/src/components/settings/ModelPicker.tsx
index 8fd6d82daa7..fd62bfb97b6 100644
--- a/webview-ui/src/components/settings/ModelPicker.tsx
+++ b/webview-ui/src/components/settings/ModelPicker.tsx
@@ -5,6 +5,7 @@ import { normalizeApiConfiguration } from "./ApiOptions"
import { ModelInfoView } from "./ModelInfoView"
import { ApiConfiguration, ModelInfo } from "../../../../src/shared/api"
import { Combobox, ComboboxContent, ComboboxEmpty, ComboboxInput, ComboboxItem } from "../ui/combobox"
+import ApiErrorMessage from "./ApiErrorMessage"
type ExtractType = NonNullable<
{ [K in keyof ApiConfiguration]: Required[K] extends T ? K : never }[keyof ApiConfiguration]
@@ -30,6 +31,7 @@ interface ModelPickerProps {
apiConfiguration: ApiConfiguration
setApiConfigurationField: (field: K, value: ApiConfiguration[K]) => void
defaultModelInfo?: ModelInfo
+ errorMessage?: string
}
export const ModelPicker = ({
@@ -43,6 +45,7 @@ export const ModelPicker = ({
apiConfiguration,
setApiConfigurationField,
defaultModelInfo,
+ errorMessage,
}: ModelPickerProps) => {
const [isDescriptionExpanded, setIsDescriptionExpanded] = useState(false)
@@ -69,11 +72,16 @@ export const ModelPicker = ({
return (
<>
Model
-
+
No model found.
@@ -85,13 +93,30 @@ export const ModelPicker = ({
- {selectedModelId && selectedModelInfo && (
-
+ {errorMessage ? (
+
+
+
+ Note: Roo Code uses complex prompts and works best
+ with Claude models. Less capable models may not work as expected.
+
+
+
+ ) : (
+ selectedModelId &&
+ selectedModelInfo && (
+
+ )
)}
The extension automatically fetches the latest list of models available on{" "}
diff --git a/webview-ui/src/components/settings/SettingsView.tsx b/webview-ui/src/components/settings/SettingsView.tsx
index 75ba11107c4..ee032c3ee06 100644
--- a/webview-ui/src/components/settings/SettingsView.tsx
+++ b/webview-ui/src/components/settings/SettingsView.tsx
@@ -1,6 +1,6 @@
import { forwardRef, memo, useCallback, useEffect, useImperativeHandle, useMemo, useRef, useState } from "react"
import { VSCodeButton, VSCodeCheckbox, VSCodeLink, VSCodeTextField } from "@vscode/webview-ui-toolkit/react"
-import { Dropdown, type DropdownOption } from "vscrui"
+import { Button, Dropdown, type DropdownOption } from "vscrui"
import {
AlertDialog,
@@ -14,7 +14,6 @@ import {
} from "@/components/ui"
import { vscode } from "../../utils/vscode"
-import { validateApiConfiguration, validateModelId } from "../../utils/validate"
import { ExtensionStateContextType, useExtensionState } from "../../context/ExtensionStateContext"
import { EXPERIMENT_IDS, experimentConfigsMap, ExperimentId } from "../../../../src/shared/experiments"
import { ApiConfiguration } from "../../../../src/shared/api"
@@ -33,14 +32,13 @@ export interface SettingsViewRef {
const SettingsView = forwardRef(({ onDone }, ref) => {
const extensionState = useExtensionState()
- const [apiErrorMessage, setApiErrorMessage] = useState(undefined)
- const [modelIdErrorMessage, setModelIdErrorMessage] = useState(undefined)
const [commandInput, setCommandInput] = useState("")
const [isDiscardDialogShow, setDiscardDialogShow] = useState(false)
const [cachedState, setCachedState] = useState(extensionState)
const [isChangeDetected, setChangeDetected] = useState(false)
const prevApiConfigName = useRef(extensionState.currentApiConfigName)
const confirmDialogHandler = useRef<() => void>()
+ const [errorMessage, setErrorMessage] = useState(undefined)
// TODO: Reduce WebviewMessage/ExtensionState complexity
const { currentApiConfigName } = extensionState
@@ -135,20 +133,9 @@ const SettingsView = forwardRef(({ onDone },
}
})
}, [])
-
+ const isSettingValid = !errorMessage
const handleSubmit = () => {
- const apiValidationResult = validateApiConfiguration(apiConfiguration)
-
- const modelIdValidationResult = validateModelId(
- apiConfiguration,
- extensionState.glamaModels,
- extensionState.openRouterModels,
- )
-
- setApiErrorMessage(apiValidationResult)
- setModelIdErrorMessage(modelIdValidationResult)
-
- if (!apiValidationResult && !modelIdValidationResult) {
+ if (isSettingValid) {
vscode.postMessage({ type: "alwaysAllowReadOnly", bool: alwaysAllowReadOnly })
vscode.postMessage({ type: "alwaysAllowWrite", bool: alwaysAllowWrite })
vscode.postMessage({ type: "alwaysAllowExecute", bool: alwaysAllowExecute })
@@ -177,23 +164,6 @@ const SettingsView = forwardRef(({ onDone },
}
}
- useEffect(() => {
- setApiErrorMessage(undefined)
- setModelIdErrorMessage(undefined)
- }, [apiConfiguration])
-
- // Initial validation on mount
- useEffect(() => {
- const apiValidationResult = validateApiConfiguration(apiConfiguration)
- const modelIdValidationResult = validateModelId(
- apiConfiguration,
- extensionState.glamaModels,
- extensionState.openRouterModels,
- )
- setApiErrorMessage(apiValidationResult)
- setModelIdErrorMessage(modelIdValidationResult)
- }, [apiConfiguration, extensionState.glamaModels, extensionState.openRouterModels])
-
const checkUnsaveChanges = useCallback(
(then: () => void) => {
if (isChangeDetected) {
@@ -287,13 +257,14 @@ const SettingsView = forwardRef(({ onDone },
justifyContent: "space-between",
gap: "6px",
}}>
-
+ disabled={!isChangeDetected || !isSettingValid}>
Save
-
+
(({ onDone },
uriScheme={extensionState.uriScheme}
apiConfiguration={apiConfiguration}
setApiConfigurationField={setApiConfigurationField}
- apiErrorMessage={apiErrorMessage}
- modelIdErrorMessage={modelIdErrorMessage}
+ errorMessage={errorMessage}
+ setErrorMessage={setErrorMessage}
/>
diff --git a/webview-ui/src/components/settings/__tests__/ApiOptions.test.tsx b/webview-ui/src/components/settings/__tests__/ApiOptions.test.tsx
index 8f2d0dff893..73394bae104 100644
--- a/webview-ui/src/components/settings/__tests__/ApiOptions.test.tsx
+++ b/webview-ui/src/components/settings/__tests__/ApiOptions.test.tsx
@@ -51,6 +51,8 @@ describe("ApiOptions", () => {
render(
{}}
uriScheme={undefined}
apiConfiguration={{}}
setApiConfigurationField={() => {}}
@@ -69,4 +71,6 @@ describe("ApiOptions", () => {
renderApiOptions({ fromWelcomeView: true })
expect(screen.queryByTestId("temperature-control")).not.toBeInTheDocument()
})
+
+ //TODO: More test cases needed
})
diff --git a/webview-ui/src/components/welcome/WelcomeView.tsx b/webview-ui/src/components/welcome/WelcomeView.tsx
index 858d2622f39..5d880efc0b9 100644
--- a/webview-ui/src/components/welcome/WelcomeView.tsx
+++ b/webview-ui/src/components/welcome/WelcomeView.tsx
@@ -42,6 +42,8 @@ const WelcomeView = () => {
apiConfiguration={apiConfiguration || {}}
uriScheme={uriScheme}
setApiConfigurationField={(field, value) => setApiConfiguration({ [field]: value })}
+ errorMessage={errorMessage}
+ setErrorMessage={setErrorMessage}
/>
From 48975003afe593c975476acfd348ceb6110d7ced Mon Sep 17 00:00:00 2001
From: System233
Date: Wed, 26 Feb 2025 06:50:11 +0800
Subject: [PATCH 4/4] Remove ModelInfo related exports from
ExtensionStateContext
---
.../src/context/ExtensionStateContext.tsx | 84 +------------------
1 file changed, 1 insertion(+), 83 deletions(-)
diff --git a/webview-ui/src/context/ExtensionStateContext.tsx b/webview-ui/src/context/ExtensionStateContext.tsx
index 3dca8d5f51c..c2c4d181e4a 100644
--- a/webview-ui/src/context/ExtensionStateContext.tsx
+++ b/webview-ui/src/context/ExtensionStateContext.tsx
@@ -1,18 +1,7 @@
import React, { createContext, useCallback, useContext, useEffect, useState } from "react"
import { useEvent } from "react-use"
import { ApiConfigMeta, ExtensionMessage, ExtensionState } from "../../../src/shared/ExtensionMessage"
-import {
- ApiConfiguration,
- ModelInfo,
- glamaDefaultModelId,
- glamaDefaultModelInfo,
- openRouterDefaultModelId,
- openRouterDefaultModelInfo,
- unboundDefaultModelId,
- unboundDefaultModelInfo,
- requestyDefaultModelId,
- requestyDefaultModelInfo,
-} from "../../../src/shared/api"
+import { ApiConfiguration } from "../../../src/shared/api"
import { vscode } from "../utils/vscode"
import { convertTextMateToHljs } from "../utils/textMateToHljs"
import { findLastIndex } from "../../../src/shared/array"
@@ -26,11 +15,6 @@ export interface ExtensionStateContextType extends ExtensionState {
didHydrateState: boolean
showWelcome: boolean
theme: any
- glamaModels: Record
- requestyModels: Record
- openRouterModels: Record
- unboundModels: Record
- openAiModels: string[]
mcpServers: McpServer[]
currentCheckpoint?: string
filePaths: string[]
@@ -70,7 +54,6 @@ export interface ExtensionStateContextType extends ExtensionState {
setRateLimitSeconds: (value: number) => void
setCurrentApiConfigName: (value: string) => void
setListApiConfigMeta: (value: ApiConfigMeta[]) => void
- onUpdateApiConfig: (apiConfig: ApiConfiguration) => void
mode: Mode
setMode: (value: Mode) => void
setCustomModePrompts: (value: CustomModePrompts) => void
@@ -124,21 +107,8 @@ export const ExtensionStateContextProvider: React.FC<{ children: React.ReactNode
const [showWelcome, setShowWelcome] = useState(false)
const [theme, setTheme] = useState(undefined)
const [filePaths, setFilePaths] = useState([])
- const [glamaModels, setGlamaModels] = useState>({
- [glamaDefaultModelId]: glamaDefaultModelInfo,
- })
const [openedTabs, setOpenedTabs] = useState>([])
- const [openRouterModels, setOpenRouterModels] = useState>({
- [openRouterDefaultModelId]: openRouterDefaultModelInfo,
- })
- const [unboundModels, setUnboundModels] = useState>({
- [unboundDefaultModelId]: unboundDefaultModelInfo,
- })
- const [requestyModels, setRequestyModels] = useState>({
- [requestyDefaultModelId]: requestyDefaultModelInfo,
- })
- const [openAiModels, setOpenAiModels] = useState([])
const [mcpServers, setMcpServers] = useState([])
const [currentCheckpoint, setCurrentCheckpoint] = useState()
@@ -146,18 +116,6 @@ export const ExtensionStateContextProvider: React.FC<{ children: React.ReactNode
(value: ApiConfigMeta[]) => setState((prevState) => ({ ...prevState, listApiConfigMeta: value })),
[],
)
-
- const onUpdateApiConfig = useCallback((apiConfig: ApiConfiguration) => {
- setState((currentState) => {
- vscode.postMessage({
- type: "upsertApiConfiguration",
- text: currentState.currentApiConfigName,
- apiConfiguration: { ...currentState.apiConfiguration, ...apiConfig },
- })
- return currentState // No state update needed
- })
- }, [])
-
const handleMessage = useCallback(
(event: MessageEvent) => {
const message: ExtensionMessage = event.data
@@ -202,40 +160,6 @@ export const ExtensionStateContextProvider: React.FC<{ children: React.ReactNode
})
break
}
- case "glamaModels": {
- const updatedModels = message.glamaModels ?? {}
- setGlamaModels({
- [glamaDefaultModelId]: glamaDefaultModelInfo, // in case the extension sent a model list without the default model
- ...updatedModels,
- })
- break
- }
- case "openRouterModels": {
- const updatedModels = message.openRouterModels ?? {}
- setOpenRouterModels({
- [openRouterDefaultModelId]: openRouterDefaultModelInfo, // in case the extension sent a model list without the default model
- ...updatedModels,
- })
- break
- }
- case "openAiModels": {
- const updatedModels = message.openAiModels ?? []
- setOpenAiModels(updatedModels)
- break
- }
- case "unboundModels": {
- const updatedModels = message.unboundModels ?? {}
- setUnboundModels(updatedModels)
- break
- }
- case "requestyModels": {
- const updatedModels = message.requestyModels ?? {}
- setRequestyModels({
- [requestyDefaultModelId]: requestyDefaultModelInfo, // in case the extension sent a model list without the default model
- ...updatedModels,
- })
- break
- }
case "mcpServers": {
setMcpServers(message.mcpServers ?? [])
break
@@ -264,11 +188,6 @@ export const ExtensionStateContextProvider: React.FC<{ children: React.ReactNode
didHydrateState,
showWelcome,
theme,
- glamaModels,
- requestyModels,
- openRouterModels,
- openAiModels,
- unboundModels,
mcpServers,
currentCheckpoint,
filePaths,
@@ -316,7 +235,6 @@ export const ExtensionStateContextProvider: React.FC<{ children: React.ReactNode
setRateLimitSeconds: (value) => setState((prevState) => ({ ...prevState, rateLimitSeconds: value })),
setCurrentApiConfigName: (value) => setState((prevState) => ({ ...prevState, currentApiConfigName: value })),
setListApiConfigMeta,
- onUpdateApiConfig,
setMode: (value: Mode) => setState((prevState) => ({ ...prevState, mode: value })),
setCustomModePrompts: (value) => setState((prevState) => ({ ...prevState, customModePrompts: value })),
setCustomSupportPrompts: (value) => setState((prevState) => ({ ...prevState, customSupportPrompts: value })),