Skip to content

Commit 4de4279

Browse files
Merge pull request #106 from alichherawalla/fix/llm-context-use-after-free
fix: prevent SIGSEGV by draining active completion before context release
2 parents e41b173 + 3604e36 commit 4de4279

File tree

2 files changed

+63
-94
lines changed

2 files changed

+63
-94
lines changed

__tests__/unit/services/llm.test.ts

Lines changed: 0 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -949,37 +949,6 @@ describe('LLMService', () => {
949949
});
950950
});
951951

952-
// ========================================================================
953-
// getImageUris
954-
// ========================================================================
955-
describe('getImageUris', () => {
956-
it('extracts image URIs from messages', () => {
957-
const messages = [{
958-
id: 'msg-1',
959-
role: 'user' as const,
960-
content: 'Look',
961-
timestamp: Date.now(),
962-
attachments: [
963-
{ type: 'image' as const, uri: '/img1.jpg', name: 'img1.jpg' },
964-
{ type: 'audio' as const, uri: '/voice.wav', name: 'voice.wav' },
965-
{ type: 'image' as const, uri: '/img2.jpg', name: 'img2.jpg' },
966-
],
967-
}];
968-
const uris = (llmService as any).getImageUris(messages);
969-
970-
expect(uris).toHaveLength(2);
971-
expect(uris).toContain('/img1.jpg');
972-
expect(uris).toContain('/img2.jpg');
973-
});
974-
975-
it('returns empty array when no attachments', () => {
976-
const messages = [createUserMessage('Hello')];
977-
const uris = (llmService as any).getImageUris(messages);
978-
979-
expect(uris).toEqual([]);
980-
});
981-
});
982-
983952
// ========================================================================
984953
// context window tokenize fallback
985954
// ========================================================================

src/services/llm.ts

Lines changed: 63 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ import {
1212
buildCompletionParams, createThinkInjector, getMaxContextForDevice, getGpuLayersForDevice, BYTES_PER_GB,
1313
} from './llmHelpers';
1414
import { hardwareService } from './hardware';
15-
import { formatLlamaMessages, extractImageUris, buildOAIMessages } from './llmMessages';
15+
import { formatLlamaMessages, buildOAIMessages } from './llmMessages';
1616
import { generateWithToolsImpl } from './llmToolGeneration';
1717
import type { ToolCall } from './tools/types';
1818

@@ -25,6 +25,7 @@ class LLMService {
2525
private context: LlamaContext | null = null;
2626
private currentModelPath: string | null = null;
2727
private isGenerating: boolean = false;
28+
private activeCompletionPromise: Promise<void> | null = null;
2829
private multimodalSupport: MultimodalSupport | null = null;
2930
private multimodalInitialized: boolean = false;
3031
private performanceStats: LLMPerformanceStats = {
@@ -113,19 +114,12 @@ class LLMService {
113114
async initializeMultimodal(mmProjPath: string): Promise<boolean> {
114115
if (!this.context) { logger.warn('[LLM] initializeMultimodal: no context'); return false; }
115116
try {
116-
const stat = await RNFS.stat(mmProjPath);
117-
const sizeMB = (Number(stat.size) / (1024 * 1024)).toFixed(1);
118-
logger.log(`[LLM] mmproj file size: ${sizeMB} MB`);
119-
if (Number(stat.size) < 100 * 1024 * 1024) {
120-
console.warn(`[LLM] WARNING: mmproj file seems too small (${sizeMB} MB) - may be incomplete download!`);
121-
}
122-
} catch (statErr) {
123-
console.error('[LLM] Failed to stat mmproj file:', statErr);
124-
}
117+
const sizeMB = Number((await RNFS.stat(mmProjPath)).size) / (1024 * 1024);
118+
logger.log(`[LLM] mmproj file size: ${sizeMB.toFixed(1)} MB`);
119+
if (sizeMB < 100) console.warn(`[LLM] WARNING: mmproj file seems too small (${sizeMB.toFixed(1)} MB)`);
120+
} catch (statErr) { console.error('[LLM] Failed to stat mmproj file:', statErr); }
125121
const devInfo = useAppStore.getState().deviceInfo;
126-
const mem = devInfo?.totalMemory ?? 0;
127-
const useGpuForClip = Platform.OS === 'ios' && !devInfo?.isEmulator && mem > 4 * BYTES_PER_GB;
128-
logger.log('[LLM] Calling initMultimodal, use_gpu:', useGpuForClip);
122+
const useGpuForClip = Platform.OS === 'ios' && !devInfo?.isEmulator && (devInfo?.totalMemory ?? 0) > 4 * BYTES_PER_GB;
129123
const { initialized, support } = await initMultimodal(this.context, mmProjPath, useGpuForClip);
130124
this.multimodalInitialized = initialized;
131125
this.multimodalSupport = support;
@@ -134,8 +128,7 @@ class LLMService {
134128

135129
async checkMultimodalSupport(): Promise<MultimodalSupport> {
136130
if (!this.context) { this.multimodalSupport = { vision: false, audio: false }; return this.multimodalSupport; }
137-
this.multimodalSupport = await checkContextMultimodal(this.context);
138-
return this.multimodalSupport;
131+
this.multimodalSupport = await checkContextMultimodal(this.context); return this.multimodalSupport;
139132
}
140133
getMultimodalSupport(): MultimodalSupport | null { return this.multimodalSupport; }
141134
supportsVision(): boolean { return this.multimodalSupport?.vision || false; }
@@ -144,33 +137,29 @@ class LLMService {
144137
private detectToolCallingSupport(): void {
145138
if (!this.context) { this.toolCallingSupported = false; return; }
146139
try {
147-
const model = (this.context as any)?.model;
148-
const jinja = model?.chatTemplates?.jinja;
149-
logger.log('[LLM] Chat template jinja:', JSON.stringify(jinja, null, 2));
150-
logger.log('[LLM] Chat template keys:', model?.chatTemplates ? Object.keys(model.chatTemplates) : 'none');
140+
const jinja = (this.context as any)?.model?.chatTemplates?.jinja;
151141
this.toolCallingSupported = !!(jinja?.defaultCaps?.toolCalls || jinja?.toolUse || jinja?.toolUseCaps?.toolCalls);
152142
logger.log('[LLM] Tool calling supported:', this.toolCallingSupported);
153-
} catch (e) {
154-
logger.warn('[LLM] Error detecting tool calling support:', e);
155-
this.toolCallingSupported = false;
156-
}
143+
} catch (e) { logger.warn('[LLM] Error detecting tool calling support:', e); this.toolCallingSupported = false; }
157144
}
158-
159145
private detectThinkingSupport(): void {
160146
if (!this.context) { this.thinkingSupported = false; return; }
161147
try {
162-
const model = (this.context as any)?.model;
163-
const metadata = model?.metadata || {};
164-
const template = metadata['tokenizer.chat_template'] || '';
148+
const template = (this.context as any)?.model?.metadata?.['tokenizer.chat_template'] || '';
165149
this.thinkingSupported = typeof template === 'string' && template.includes('<think>');
166-
logger.log('[LLM] Thinking supported:', this.thinkingSupported);
167-
} catch (_e) {
168-
this.thinkingSupported = false;
169-
}
150+
} catch (_e) { this.thinkingSupported = false; }
170151
}
171152

172153
async unloadModel(): Promise<void> {
173154
if (this.context) {
155+
if (this.isGenerating) {
156+
try { await this.context.stopCompletion(); } catch (e) { logger.log('[LLM] Stop during unload:', e); }
157+
this.isGenerating = false;
158+
}
159+
if (this.activeCompletionPromise !== null) {
160+
await this.activeCompletionPromise;
161+
this.activeCompletionPromise = null;
162+
}
174163
await this.context.release();
175164
useAppStore.getState().setModelMaxContext(null);
176165
Object.assign(this, {
@@ -191,7 +180,8 @@ class LLMService {
191180
if (!this.context) throw new Error('No model loaded');
192181
if (this.isGenerating) throw new Error('Generation already in progress');
193182
this.isGenerating = true;
194-
try {
183+
const ctx = this.context;
184+
const completionWork = (async () => {
195185
const managed = await this.manageContextWindow(messages);
196186
const hasImages = managed.some(m => m.attachments?.some(a => a.type === 'image'));
197187
const useMultimodal = hasImages && this.multimodalInitialized;
@@ -208,7 +198,7 @@ class LLMService {
208198
let firstReceived = false;
209199
const thinkStream = this.thinkingSupported && onStream
210200
? createThinkInjector(t => onStream(t)) : null;
211-
const completionResult = await this.context.completion({
201+
const completionResult = await ctx.completion({
212202
messages: oaiMessages,
213203
...buildCompletionParams(settings),
214204
}, (data) => {
@@ -219,30 +209,39 @@ class LLMService {
219209
if (thinkStream) { thinkStream(data.token); } else { onStream?.(data.token); }
220210
});
221211
this.performanceStats = recordGenerationStats(startTime, firstTokenMs, tokenCount);
222-
this.isGenerating = false;
223212
if (completionResult?.context_full) {
224213
logger.log('[LLM] Context full detected — signalling for compaction');
225214
throw new Error('Context is full');
226215
}
227216
onComplete?.(fullResponse);
228217
return fullResponse;
229-
} catch (error) {
218+
})();
219+
this.activeCompletionPromise = completionWork.then(() => {}, () => {});
220+
try {
221+
return await completionWork;
222+
} finally {
230223
this.isGenerating = false;
231-
throw error;
224+
this.activeCompletionPromise = null;
232225
}
233226
}
234227

235228
async generateResponseWithTools(
236229
messages: Message[],
237230
options: { tools: any[]; onStream?: StreamCallback; onComplete?: CompleteCallback },
238231
): Promise<{ fullResponse: string; toolCalls: ToolCall[] }> {
239-
return generateWithToolsImpl({
232+
const work = generateWithToolsImpl({
240233
context: this.context, isGenerating: this.isGenerating,
241234
manageContextWindow: (msgs, extra?) => this.manageContextWindow(msgs, extra),
242235
convertToOAIMessages: (msgs) => this.convertToOAIMessages(msgs),
243236
setPerformanceStats: (s) => { this.performanceStats = s; },
244237
setIsGenerating: (v) => { this.isGenerating = v; },
245238
}, messages, options);
239+
this.activeCompletionPromise = work.then(() => {}, () => {});
240+
try {
241+
return await work;
242+
} finally {
243+
this.activeCompletionPromise = null;
244+
}
246245
}
247246

248247
/** No-op pass-through — lets llama.rn's native ctx_shift handle overflow for KV cache reuse. */
@@ -253,61 +252,63 @@ class LLMService {
253252
/** Generate a completion with a hard token cap (used for summarization, not user-facing). */
254253
async generateWithMaxTokens(messages: Message[], maxTokens: number): Promise<string> {
255254
if (!this.context) throw new Error('No model loaded');
255+
if (this.isGenerating) throw new Error('Generation already in progress');
256+
this.isGenerating = true;
256257
const oaiMessages = this.convertToOAIMessages(messages);
257258
const { settings } = useAppStore.getState();
258259
let fullResponse = '';
259-
await this.context.completion(
260+
const completionWork = this.context.completion(
260261
{ messages: oaiMessages, ...buildCompletionParams(settings), n_predict: maxTokens },
261-
(data) => { if (data.token) fullResponse += data.token; },
262+
(data) => { if (this.isGenerating && data.token) fullResponse += data.token; },
262263
);
263-
return fullResponse.trim();
264+
this.activeCompletionPromise = completionWork.then(() => {}, () => {});
265+
try {
266+
await completionWork;
267+
return fullResponse.trim();
268+
} finally {
269+
this.isGenerating = false;
270+
this.activeCompletionPromise = null;
271+
}
264272
}
265273
async stopGeneration(): Promise<void> {
266274
if (this.context) { try { await this.context.stopCompletion(); } catch (e) { logger.log('[LLM] Stop error:', e); } }
267275
this.isGenerating = false;
276+
if (this.activeCompletionPromise !== null) {
277+
await this.activeCompletionPromise;
278+
this.activeCompletionPromise = null;
279+
}
268280
}
269281
async clearKVCache(clearData: boolean = false): Promise<void> {
270282
if (!this.context || this.isGenerating) return;
271283
try { await (this.context as any).clearCache(clearData); } catch (e) { logger.log('[LLM] Clear cache error:', e); }
272284
}
273-
getEstimatedMemoryUsage(): { contextMemoryMB: number; totalEstimatedMB: number } {
274-
if (!this.context) return { contextMemoryMB: 0, totalEstimatedMB: 0 };
275-
const contextMemoryMB = (this.currentSettings.contextLength || 2048) * 0.5;
285+
getEstimatedMemoryUsage() {
286+
const contextMemoryMB = this.context ? (this.currentSettings.contextLength || 2048) * 0.5 : 0;
276287
return { contextMemoryMB, totalEstimatedMB: contextMemoryMB };
277288
}
278-
getGpuInfo(): { gpu: boolean; gpuBackend: string; gpuLayers: number; reasonNoGPU: string } {
279-
let backend = 'CPU';
280-
if (this.gpuEnabled) {
281-
if (Platform.OS === 'ios') backend = 'Metal';
282-
else if (this.gpuDevices.length > 0) backend = this.gpuDevices.join(', ');
283-
else backend = 'OpenCL';
284-
}
289+
getGpuInfo() {
290+
const backend = !this.gpuEnabled ? 'CPU' : Platform.OS === 'ios' ? 'Metal'
291+
: this.gpuDevices.length > 0 ? this.gpuDevices.join(', ') : 'OpenCL';
285292
return { gpu: this.gpuEnabled, gpuBackend: backend, gpuLayers: this.activeGpuLayers, reasonNoGPU: this.gpuReason };
286293
}
287294
isCurrentlyGenerating(): boolean { return this.isGenerating; }
288295
private formatMessages(messages: Message[]): string { return formatLlamaMessages(messages, this.supportsVision()); }
289-
private getImageUris(messages: Message[]): string[] { return extractImageUris(messages); }
290296
private convertToOAIMessages(messages: Message[]): RNLlamaOAICompatibleMessage[] { return buildOAIMessages(messages); }
291-
292-
async getModelInfo(): Promise<{ contextLength: number; vocabSize: number } | null> {
293-
return this.context ? { contextLength: APP_CONFIG.maxContextLength, vocabSize: 0 } : null;
294-
}
295-
async tokenize(text: string): Promise<number[]> {
297+
async getModelInfo() { return this.context ? { contextLength: APP_CONFIG.maxContextLength, vocabSize: 0 } : null; }
298+
async tokenize(text: string) {
296299
if (!this.context) throw new Error('No model loaded');
297300
return (await this.context.tokenize(text)).tokens || [];
298301
}
299-
async getTokenCount(text: string): Promise<number> {
302+
async getTokenCount(text: string) {
300303
if (!this.context) throw new Error('No model loaded');
301304
return (await this.context.tokenize(text)).tokens?.length || 0;
302305
}
303-
async estimateContextUsage(messages: Message[]): Promise<{ tokenCount: number; percentUsed: number; willFit: boolean }> {
304-
const prompt = this.formatMessages(messages);
305-
const tokenCount = await this.getTokenCount(prompt);
306+
async estimateContextUsage(messages: Message[]) {
307+
const tokenCount = await this.getTokenCount(this.formatMessages(messages));
306308
const ctxLen = this.currentSettings.contextLength || APP_CONFIG.maxContextLength;
307309
return { tokenCount, percentUsed: (tokenCount / ctxLen) * 100, willFit: tokenCount < ctxLen * 0.9 };
308310
}
309311
getFormattedPrompt(messages: Message[]): string { return this.formatMessages(messages); }
310-
311312
async getContextDebugInfo(messages: Message[]) {
312313
const managed = await this.manageContextWindow(messages);
313314
const fmt = this.formatMessages(managed);
@@ -329,8 +330,7 @@ class LLMService {
329330
async reloadWithSettings(modelPath: string, settings: LLMPerformanceSettings): Promise<void> {
330331
this.updatePerformanceSettings(settings);
331332
if (this.context) await this.unloadModel();
332-
const { settings: appS } = useAppStore.getState();
333-
const { baseParams, nGpuLayers } = buildModelParams(modelPath, { ...appS, ...settings });
333+
const { baseParams, nGpuLayers } = buildModelParams(modelPath, { ...useAppStore.getState().settings, ...settings });
334334
logger.log(`[LLM] Reloading with threads=${settings.nThreads}, batch=${settings.nBatch}, ctx=${settings.contextLength}`);
335335
try {
336336
const { context, gpuAttemptFailed } = await initContextWithFallback(baseParams, settings.contextLength, nGpuLayers);

0 commit comments

Comments
 (0)