Skip to content

Commit 541b54e

Browse files
authored
Merge pull request #1368 from adamwlarson/speculative_decoding
Adding support for Speculative Decoding for LMStudio Local Models
2 parents fd60c94 + 1c8f9ed commit 541b54e

File tree

5 files changed

+111
-5
lines changed

5 files changed

+111
-5
lines changed

src/api/providers/lmstudio.ts

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,24 @@ export class LmStudioHandler extends BaseProvider implements SingleCompletionHan
3030
]
3131

3232
try {
33-
const stream = await this.client.chat.completions.create({
33+
// Create params object with optional draft model
34+
const params: any = {
3435
model: this.getModel().id,
3536
messages: openAiMessages,
3637
temperature: this.options.modelTemperature ?? LMSTUDIO_DEFAULT_TEMPERATURE,
3738
stream: true,
38-
})
39-
for await (const chunk of stream) {
39+
}
40+
41+
// Add draft model if speculative decoding is enabled and a draft model is specified
42+
if (this.options.lmStudioSpeculativeDecodingEnabled && this.options.lmStudioDraftModelId) {
43+
params.draft_model = this.options.lmStudioDraftModelId
44+
}
45+
46+
const results = await this.client.chat.completions.create(params)
47+
48+
// Stream handling
49+
// @ts-ignore
50+
for await (const chunk of results) {
4051
const delta = chunk.choices[0]?.delta
4152
if (delta?.content) {
4253
yield {
@@ -62,12 +73,20 @@ export class LmStudioHandler extends BaseProvider implements SingleCompletionHan
6273

6374
async completePrompt(prompt: string): Promise<string> {
6475
try {
65-
const response = await this.client.chat.completions.create({
76+
// Create params object with optional draft model
77+
const params: any = {
6678
model: this.getModel().id,
6779
messages: [{ role: "user", content: prompt }],
6880
temperature: this.options.modelTemperature ?? LMSTUDIO_DEFAULT_TEMPERATURE,
6981
stream: false,
70-
})
82+
}
83+
84+
// Add draft model if speculative decoding is enabled and a draft model is specified
85+
if (this.options.lmStudioSpeculativeDecodingEnabled && this.options.lmStudioDraftModelId) {
86+
params.draft_model = this.options.lmStudioDraftModelId
87+
}
88+
89+
const response = await this.client.chat.completions.create(params)
7190
return response.choices[0]?.message.content || ""
7291
} catch (error) {
7392
throw new Error(

src/core/webview/ClineProvider.ts

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1676,6 +1676,8 @@ export class ClineProvider implements vscode.WebviewViewProvider {
16761676
modelTemperature,
16771677
modelMaxTokens,
16781678
modelMaxThinkingTokens,
1679+
lmStudioDraftModelId,
1680+
lmStudioSpeculativeDecodingEnabled,
16791681
} = apiConfiguration
16801682
await Promise.all([
16811683
this.updateGlobalState("apiProvider", apiProvider),
@@ -1725,6 +1727,8 @@ export class ClineProvider implements vscode.WebviewViewProvider {
17251727
this.updateGlobalState("modelTemperature", modelTemperature),
17261728
this.updateGlobalState("modelMaxTokens", modelMaxTokens),
17271729
this.updateGlobalState("anthropicThinking", modelMaxThinkingTokens),
1730+
this.updateGlobalState("lmStudioDraftModelId", lmStudioDraftModelId),
1731+
this.updateGlobalState("lmStudioSpeculativeDecodingEnabled", lmStudioSpeculativeDecodingEnabled),
17281732
])
17291733
if (this.cline) {
17301734
this.cline.api = buildApiHandler(apiConfiguration)
@@ -2221,6 +2225,8 @@ export class ClineProvider implements vscode.WebviewViewProvider {
22212225
modelMaxThinkingTokens,
22222226
maxOpenTabsContext,
22232227
browserToolEnabled,
2228+
lmStudioSpeculativeDecodingEnabled,
2229+
lmStudioDraftModelId,
22242230
] = await Promise.all([
22252231
this.getGlobalState("apiProvider") as Promise<ApiProvider | undefined>,
22262232
this.getGlobalState("apiModelId") as Promise<string | undefined>,
@@ -2306,6 +2312,8 @@ export class ClineProvider implements vscode.WebviewViewProvider {
23062312
this.getGlobalState("anthropicThinking") as Promise<number | undefined>,
23072313
this.getGlobalState("maxOpenTabsContext") as Promise<number | undefined>,
23082314
this.getGlobalState("browserToolEnabled") as Promise<boolean | undefined>,
2315+
this.getGlobalState("lmStudioSpeculativeDecodingEnabled") as Promise<boolean | undefined>,
2316+
this.getGlobalState("lmStudioDraftModelId") as Promise<string | undefined>,
23092317
])
23102318

23112319
let apiProvider: ApiProvider
@@ -2371,6 +2379,8 @@ export class ClineProvider implements vscode.WebviewViewProvider {
23712379
modelTemperature,
23722380
modelMaxTokens,
23732381
modelMaxThinkingTokens,
2382+
lmStudioSpeculativeDecodingEnabled,
2383+
lmStudioDraftModelId,
23742384
},
23752385
lastShownAnnouncementId,
23762386
customInstructions,

src/shared/api.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ export interface ApiHandlerOptions {
4949
ollamaBaseUrl?: string
5050
lmStudioModelId?: string
5151
lmStudioBaseUrl?: string
52+
lmStudioDraftModelId?: string
53+
lmStudioSpeculativeDecodingEnabled?: boolean
5254
geminiApiKey?: string
5355
openAiNativeApiKey?: string
5456
mistralApiKey?: string

src/shared/globalState.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ export type GlobalStateKey =
4141
| "ollamaBaseUrl"
4242
| "lmStudioModelId"
4343
| "lmStudioBaseUrl"
44+
| "lmStudioDraftModelId"
45+
| "lmStudioSpeculativeDecodingEnabled"
4446
| "anthropicBaseUrl"
4547
| "azureApiVersion"
4648
| "openAiStreamingEnabled"

webview-ui/src/components/settings/ApiOptions.tsx

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1107,6 +1107,79 @@ const ApiOptions = ({
11071107
))}
11081108
</VSCodeRadioGroup>
11091109
)}
1110+
<div style={{ display: "flex", alignItems: "center", marginTop: "16px", marginBottom: "8px" }}>
1111+
<Checkbox
1112+
checked={apiConfiguration?.lmStudioSpeculativeDecodingEnabled === true}
1113+
onChange={(checked) => {
1114+
// Explicitly set the boolean value using direct method
1115+
setApiConfigurationField("lmStudioSpeculativeDecodingEnabled", checked)
1116+
}}>
1117+
Enable Speculative Decoding
1118+
</Checkbox>
1119+
</div>
1120+
{apiConfiguration?.lmStudioSpeculativeDecodingEnabled && (
1121+
<>
1122+
<VSCodeTextField
1123+
value={apiConfiguration?.lmStudioDraftModelId || ""}
1124+
style={{ width: "100%" }}
1125+
onInput={handleInputChange("lmStudioDraftModelId")}
1126+
placeholder={"e.g. lmstudio-community/llama-3.2-1b-instruct"}>
1127+
<span className="font-medium">Draft Model ID</span>
1128+
</VSCodeTextField>
1129+
<div
1130+
style={{
1131+
fontSize: "11px",
1132+
color: "var(--vscode-descriptionForeground)",
1133+
marginTop: 4,
1134+
display: "flex",
1135+
alignItems: "center",
1136+
gap: 4,
1137+
}}>
1138+
<i className="codicon codicon-info" style={{ fontSize: "12px" }}></i>
1139+
<span>
1140+
Draft model must be from the same model family for speculative decoding to work
1141+
correctly.
1142+
</span>
1143+
</div>
1144+
{lmStudioModels.length > 0 && (
1145+
<>
1146+
<div style={{ marginTop: "8px" }}>
1147+
<span className="font-medium">Select Draft Model</span>
1148+
</div>
1149+
<VSCodeRadioGroup
1150+
value={
1151+
lmStudioModels.includes(apiConfiguration?.lmStudioDraftModelId || "")
1152+
? apiConfiguration?.lmStudioDraftModelId
1153+
: ""
1154+
}
1155+
onChange={handleInputChange("lmStudioDraftModelId")}>
1156+
{lmStudioModels.map((model) => (
1157+
<VSCodeRadio key={`draft-${model}`} value={model}>
1158+
{model}
1159+
</VSCodeRadio>
1160+
))}
1161+
</VSCodeRadioGroup>
1162+
{lmStudioModels.length === 0 && (
1163+
<div
1164+
style={{
1165+
fontSize: "12px",
1166+
marginTop: "8px",
1167+
padding: "6px",
1168+
backgroundColor: "var(--vscode-inputValidation-infoBackground)",
1169+
border: "1px solid var(--vscode-inputValidation-infoBorder)",
1170+
borderRadius: "3px",
1171+
color: "var(--vscode-inputValidation-infoForeground)",
1172+
}}>
1173+
<i className="codicon codicon-info" style={{ marginRight: "5px" }}></i>
1174+
No draft models found. Please ensure LM Studio is running with Server Mode
1175+
enabled.
1176+
</div>
1177+
)}
1178+
</>
1179+
)}
1180+
</>
1181+
)}
1182+
11101183
<p
11111184
style={{
11121185
fontSize: "12px",

0 commit comments

Comments
 (0)