Skip to content

Commit 9d28004

Browse files
authored
feat(retry): ✨ Add model-switch retry for task/page flows (#63)
* feat(retry): add model switch for task/page retry flows * fix(renderer): 🐛 resolve typecheck errors in retry model flow * fix(retry): 🐛 harden model parsing and review tests
1 parent b8bc729 commit 9d28004

32 files changed

+1597
-146
lines changed

package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
"migrate:reset": "prisma migrate reset --schema=./src/core/infrastructure/db/schema.prisma",
3333
"generate": "prisma generate --schema=./src/core/infrastructure/db/schema.prisma",
3434
"logo": "electron-icon-builder --input=./src/renderer/assets/logo.png --output=./public",
35-
"test": "vitest",
35+
"test": "vitest run",
3636
"test:unit": "vitest run --config vitest.config.ts",
3737
"test:renderer": "vitest run --config vitest.config.renderer.ts",
3838
"test:watch": "vitest watch",

src/core/application/workers/ConverterWorker.ts

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,14 @@ export class ConverterWorker extends WorkerBase {
370370
// Step 2: Check task is not cancelled
371371
const task = await tx.task.findUnique({
372372
where: { id: page.task },
373-
select: { status: true, pages: true, completed_count: true, failed_count: true },
373+
select: {
374+
status: true,
375+
pages: true,
376+
completed_count: true,
377+
failed_count: true,
378+
provider: true,
379+
model: true,
380+
},
374381
});
375382

376383
if (!task) {
@@ -393,6 +400,8 @@ export class ConverterWorker extends WorkerBase {
393400
completed_at: new Date(),
394401
worker_id: null, // Release worker
395402
error: null,
403+
provider: task.provider,
404+
model: task.model,
396405
},
397406
});
398407

@@ -475,7 +484,14 @@ export class ConverterWorker extends WorkerBase {
475484
// Step 2: Check task is not cancelled
476485
const task = await tx.task.findUnique({
477486
where: { id: page.task },
478-
select: { status: true, pages: true, completed_count: true, failed_count: true },
487+
select: {
488+
status: true,
489+
pages: true,
490+
completed_count: true,
491+
failed_count: true,
492+
provider: true,
493+
model: true,
494+
},
479495
});
480496

481497
if (!task) {
@@ -494,6 +510,8 @@ export class ConverterWorker extends WorkerBase {
494510
error: errorMessage,
495511
completed_at: new Date(),
496512
worker_id: null, // Release worker
513+
provider: task.provider,
514+
model: task.model,
497515
},
498516
});
499517

src/core/application/workers/__tests__/ConverterWorker.test.ts

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -585,21 +585,27 @@ describe('ConverterWorker', () => {
585585
};
586586

587587
it('should update page status to COMPLETED', async () => {
588+
let pageUpdateData: any;
589+
588590
vi.mocked(prisma.$transaction).mockImplementation(async (callback: any) => {
589591
const tx = {
590592
taskDetail: {
591593
findUnique: vi.fn().mockResolvedValue({
592594
worker_id: worker.getWorkerId(),
593595
status: PageStatus.PROCESSING,
594596
}),
595-
update: vi.fn(),
597+
update: vi.fn().mockImplementation((params: any) => {
598+
pageUpdateData = params.data;
599+
}),
596600
},
597601
task: {
598602
findUnique: vi.fn().mockResolvedValue({
599603
status: TaskStatus.PROCESSING,
600604
pages: 10,
601605
completed_count: 5,
602606
failed_count: 0,
607+
provider: 1,
608+
model: 'gpt-4o',
603609
}),
604610
update: vi.fn().mockResolvedValue({
605611
completed_count: 6,
@@ -618,6 +624,8 @@ describe('ConverterWorker', () => {
618624
await (worker as any).completePageSuccess(mockPage, mockResult);
619625

620626
expect(prisma.$transaction).toHaveBeenCalled();
627+
expect(pageUpdateData.provider).toBe(1);
628+
expect(pageUpdateData.model).toBe('gpt-4o');
621629
});
622630

623631
it('should skip if page already completed (idempotency)', async () => {
@@ -823,6 +831,8 @@ describe('ConverterWorker', () => {
823831
pages: 10,
824832
completed_count: 5,
825833
failed_count: 1,
834+
provider: 9,
835+
model: 'claude-3-7-sonnet',
826836
}),
827837
update: vi.fn().mockResolvedValue({ failed_count: 2 }),
828838
},
@@ -839,6 +849,8 @@ describe('ConverterWorker', () => {
839849

840850
expect(pageUpdateData.status).toBe(PageStatus.FAILED);
841851
expect(pageUpdateData.error).toBe('Test error');
852+
expect(pageUpdateData.provider).toBe(9);
853+
expect(pageUpdateData.model).toBe('claude-3-7-sonnet');
842854
});
843855

844856
it('should set task to FAILED when all pages failed', async () => {

src/core/infrastructure/services/CloudService.ts

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import type {
1212
CloudCancelTaskResponse,
1313
CloudRetryPageResponse,
1414
CloudApiPagination,
15+
CloudModelTier,
1516
PaymentCheckoutApiResponse,
1617
PaymentCheckoutStatusApiResponse,
1718
PaymentHistoryApiItem,
@@ -433,15 +434,23 @@ class CloudService {
433434
/**
434435
* Retry an entire task (creates a new task)
435436
*/
436-
public async retryTask(id: string): Promise<{
437+
public async retryTask(id: string, model?: CloudModelTier): Promise<{
437438
success: boolean;
438439
data?: CreateTaskResponse;
439440
error?: string;
440441
}> {
441442
try {
442-
const res = await authManager.fetchWithAuth(`${API_BASE_URL}/api/v1/tasks/${encodeURIComponent(id)}/retry`, {
443-
method: 'POST',
444-
});
443+
const hasModelOverride = typeof model === 'string' && model.length > 0;
444+
const res = await authManager.fetchWithAuth(
445+
`${API_BASE_URL}/api/v1/tasks/${encodeURIComponent(id)}/retry`,
446+
hasModelOverride
447+
? {
448+
method: 'POST',
449+
headers: { 'Content-Type': 'application/json' },
450+
body: JSON.stringify({ model }),
451+
}
452+
: { method: 'POST' }
453+
);
445454

446455
if (!res.ok) {
447456
const errorBody = await res.json().catch(() => null);
@@ -469,15 +478,22 @@ class CloudService {
469478
/**
470479
* Retry a single page
471480
*/
472-
public async retryPage(taskId: string, pageNumber: number): Promise<{
481+
public async retryPage(taskId: string, pageNumber: number, model?: CloudModelTier): Promise<{
473482
success: boolean;
474483
data?: CloudRetryPageResponse;
475484
error?: string;
476485
}> {
477486
try {
487+
const hasModelOverride = typeof model === 'string' && model.length > 0;
478488
const res = await authManager.fetchWithAuth(
479489
`${API_BASE_URL}/api/v1/tasks/${encodeURIComponent(taskId)}/pages/${encodeURIComponent(String(pageNumber))}/retry`,
480-
{ method: 'POST' },
490+
hasModelOverride
491+
? {
492+
method: 'POST',
493+
headers: { 'Content-Type': 'application/json' },
494+
body: JSON.stringify({ model }),
495+
}
496+
: { method: 'POST' },
481497
);
482498

483499
if (!res.ok) {

src/core/infrastructure/services/__tests__/CloudService.test.ts

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,37 @@ describe('CloudService', () => {
205205
})
206206
})
207207

208+
it('retryTask/retryPage send model override payload when model is provided', async () => {
209+
const cloudService = (await import('../CloudService.js')).default
210+
211+
mockAuthManager.fetchWithAuth
212+
.mockResolvedValueOnce(makeJsonResponse(200, { success: true, data: { task_id: 'task-2', events_url: '/events' } }))
213+
.mockResolvedValueOnce(makeJsonResponse(200, { success: true, data: { task_id: 'task-1', page: 3, status: 'queued' } }))
214+
215+
await cloudService.retryTask('task-1', 'pro')
216+
await cloudService.retryPage('task-1', 3, 'ultra')
217+
218+
const retryTaskCall = mockAuthManager.fetchWithAuth.mock.calls[0]
219+
expect(retryTaskCall[0]).toContain('/api/v1/tasks/task-1/retry')
220+
expect(retryTaskCall[1]).toEqual(
221+
expect.objectContaining({
222+
method: 'POST',
223+
headers: { 'Content-Type': 'application/json' },
224+
body: JSON.stringify({ model: 'pro' }),
225+
})
226+
)
227+
228+
const retryPageCall = mockAuthManager.fetchWithAuth.mock.calls[1]
229+
expect(retryPageCall[0]).toContain('/api/v1/tasks/task-1/pages/3/retry')
230+
expect(retryPageCall[1]).toEqual(
231+
expect.objectContaining({
232+
method: 'POST',
233+
headers: { 'Content-Type': 'application/json' },
234+
body: JSON.stringify({ model: 'ultra' }),
235+
})
236+
)
237+
})
238+
208239
it('cancelTask/retryTask/retryPage/deleteTask return API errors when non-OK', async () => {
209240
const cloudService = (await import('../CloudService.js')).default
210241
mockAuthManager.fetchWithAuth

src/main/ipc/__tests__/handlers.test.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ vi.mock('../../../shared/ipc/channels.js', () => ({
139139
GET_ALL: 'task:getAll',
140140
GET_BY_ID: 'task:getById',
141141
UPDATE: 'task:update',
142+
RETRY: 'task:retry',
142143
DELETE: 'task:delete',
143144
HAS_RUNNING: 'task:hasRunningTasks',
144145
},

src/main/ipc/handlers/__tests__/cloud.handler.test.ts

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,10 +170,21 @@ describe('cloud.handler', () => {
170170

171171
expect(mockCloudService.getTasks).toHaveBeenCalledWith(2, 20)
172172
expect(mockCloudService.getTaskById).toHaveBeenCalledWith('t1')
173-
expect(mockCloudService.retryTask).toHaveBeenCalledWith('t1')
173+
expect(mockCloudService.retryTask).toHaveBeenCalledWith('t1', undefined)
174174
expect(mockCloudService.deleteTask).toHaveBeenCalledWith('t1')
175175
})
176176

177+
it('passes model override for retryTask and retryPage', async () => {
178+
mockCloudService.retryTask.mockResolvedValueOnce({ success: true, data: { task_id: 'new' } })
179+
mockCloudService.retryPage.mockResolvedValueOnce({ success: true, data: { task_id: 't1', page: 1, status: 1 } })
180+
181+
await handlers.get('cloud:retryTask')!({}, { id: 't1', model: 'pro' })
182+
await handlers.get('cloud:retryPage')!({}, { taskId: 't1', pageNumber: 1, model: 'ultra' })
183+
184+
expect(mockCloudService.retryTask).toHaveBeenCalledWith('t1', 'pro')
185+
expect(mockCloudService.retryPage).toHaveBeenCalledWith('t1', 1, 'ultra')
186+
})
187+
177188
describe('cloud:downloadPdf', () => {
178189
it('returns service error when download fails', async () => {
179190
mockCloudService.downloadPdf.mockResolvedValueOnce({ success: false, error: 'bad' })

0 commit comments

Comments
 (0)