Skip to content

Commit 2ea6c92

Browse files
sort template workflows by required vram (#6285)
## Summary Resolves #6281 by implementing the stubbed out vram sorting. Previously was waiting for it to be added to the templates data and it now has (https://github.com/Comfy-Org/workflow_templates/blob/main/templates/index.json) ┆Issue is synchronized with this [Notion page](https://www.notion.so/PR-6285-sort-template-workflows-by-required-vram-2976d73d36508164a8f9fab438f53b21) by [Unito](https://www.unito.io)
1 parent 23e0d26 commit 2ea6c92

File tree

3 files changed

+261
-3
lines changed

3 files changed

+261
-3
lines changed

src/composables/useTemplateFiltering.ts

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,17 @@ export function useTemplateFiltering(
128128
})
129129
})
130130

131+
const getVramMetric = (template: TemplateInfo) => {
132+
if (
133+
typeof template.vram === 'number' &&
134+
Number.isFinite(template.vram) &&
135+
template.vram > 0
136+
) {
137+
return template.vram
138+
}
139+
return Number.POSITIVE_INFINITY
140+
}
141+
131142
const sortedTemplates = computed(() => {
132143
const templates = [...filteredByLicenses.value]
133144

@@ -145,9 +156,21 @@ export function useTemplateFiltering(
145156
return dateB.getTime() - dateA.getTime()
146157
})
147158
case 'vram-low-to-high':
148-
// TODO: Implement VRAM sorting when VRAM data is available
149-
// For now, keep original order
150-
return templates
159+
return templates.sort((a, b) => {
160+
const vramA = getVramMetric(a)
161+
const vramB = getVramMetric(b)
162+
163+
if (vramA === vramB) {
164+
const nameA = a.title || a.name || ''
165+
const nameB = b.title || b.name || ''
166+
return nameA.localeCompare(nameB)
167+
}
168+
169+
if (vramA === Number.POSITIVE_INFINITY) return 1
170+
if (vramB === Number.POSITIVE_INFINITY) return -1
171+
172+
return vramA - vramB
173+
})
151174
case 'model-size-low-to-high':
152175
return templates.sort((a: any, b: any) => {
153176
const sizeA =

src/platform/workflow/templates/types/template.ts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@ export interface TemplateInfo {
1818
date?: string
1919
useCase?: string
2020
license?: string
21+
/**
22+
* Estimated VRAM requirement in bytes.
23+
*/
24+
vram?: number
2125
size?: number
2226
}
2327

Lines changed: 231 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,231 @@
1+
import { afterEach, describe, expect, it, vi } from 'vitest'
2+
import { nextTick, ref } from 'vue'
3+
4+
import { useTemplateFiltering } from '@/composables/useTemplateFiltering'
5+
import type { TemplateInfo } from '@/platform/workflow/templates/types/template'
6+
7+
describe('useTemplateFiltering', () => {
8+
afterEach(() => {
9+
vi.useRealTimers()
10+
})
11+
12+
it('sorts templates by VRAM from low to high and pushes missing values last', () => {
13+
const gb = (value: number) => value * 1024 ** 3
14+
15+
const templates = ref<TemplateInfo[]>([
16+
{
17+
name: 'missing-vram',
18+
description: 'no vram value',
19+
mediaType: 'image',
20+
mediaSubtype: 'png'
21+
},
22+
{
23+
name: 'highest-vram',
24+
description: 'high usage',
25+
mediaType: 'image',
26+
mediaSubtype: 'png',
27+
vram: gb(12)
28+
},
29+
{
30+
name: 'mid-vram',
31+
description: 'medium usage',
32+
mediaType: 'image',
33+
mediaSubtype: 'png',
34+
vram: gb(7.5)
35+
},
36+
{
37+
name: 'low-vram',
38+
description: 'low usage',
39+
mediaType: 'image',
40+
mediaSubtype: 'png',
41+
vram: gb(5)
42+
},
43+
{
44+
name: 'zero-vram',
45+
description: 'unknown usage',
46+
mediaType: 'image',
47+
mediaSubtype: 'png',
48+
vram: 0
49+
}
50+
])
51+
52+
const { sortBy, filteredTemplates } = useTemplateFiltering(templates)
53+
54+
sortBy.value = 'vram-low-to-high'
55+
56+
expect(filteredTemplates.value.map((template) => template.name)).toEqual([
57+
'low-vram',
58+
'mid-vram',
59+
'highest-vram',
60+
'missing-vram',
61+
'zero-vram'
62+
])
63+
})
64+
65+
it('filters by search text, models, tags, and license with debounce handling', async () => {
66+
vi.useFakeTimers()
67+
68+
const templates = ref<TemplateInfo[]>([
69+
{
70+
name: 'api-template',
71+
description: 'Enterprise API workflow for video',
72+
mediaType: 'image',
73+
mediaSubtype: 'png',
74+
tags: ['API', 'Video'],
75+
models: ['Flux'],
76+
date: '2024-06-01',
77+
vram: 15 * 1024 ** 3
78+
},
79+
{
80+
name: 'portrait-flow',
81+
description: 'Portrait template tuned for SDXL',
82+
mediaType: 'image',
83+
mediaSubtype: 'png',
84+
tags: ['Portrait'],
85+
models: ['SDXL'],
86+
date: '2024-05-15',
87+
vram: 10 * 1024 ** 3
88+
},
89+
{
90+
name: 'landscape-lite',
91+
description: 'Lightweight landscape generator',
92+
mediaType: 'image',
93+
mediaSubtype: 'png',
94+
tags: ['Landscape'],
95+
models: ['SDXL', 'Flux'],
96+
date: '2024-04-20'
97+
}
98+
])
99+
100+
const {
101+
searchQuery,
102+
selectedModels,
103+
selectedUseCases,
104+
selectedLicenses,
105+
filteredTemplates,
106+
availableModels,
107+
availableUseCases,
108+
availableLicenses,
109+
filteredCount,
110+
totalCount,
111+
removeUseCaseFilter,
112+
resetFilters
113+
} = useTemplateFiltering(templates)
114+
115+
expect(totalCount.value).toBe(3)
116+
expect(availableModels.value).toEqual(['Flux', 'SDXL'])
117+
expect(availableUseCases.value).toEqual([
118+
'API',
119+
'Landscape',
120+
'Portrait',
121+
'Video'
122+
])
123+
expect(availableLicenses.value).toEqual([
124+
'Open Source',
125+
'Closed Source (API Nodes)'
126+
])
127+
128+
searchQuery.value = 'enterprise'
129+
await nextTick()
130+
await vi.runOnlyPendingTimersAsync()
131+
await nextTick()
132+
expect(filteredTemplates.value.map((template) => template.name)).toEqual([
133+
'api-template'
134+
])
135+
136+
selectedLicenses.value = ['Closed Source (API Nodes)']
137+
await nextTick()
138+
expect(filteredTemplates.value.map((template) => template.name)).toEqual([
139+
'api-template'
140+
])
141+
142+
selectedModels.value = ['Flux']
143+
await nextTick()
144+
expect(filteredTemplates.value.map((template) => template.name)).toEqual([
145+
'api-template'
146+
])
147+
148+
selectedUseCases.value = ['Video']
149+
await nextTick()
150+
expect(filteredTemplates.value.map((template) => template.name)).toEqual([
151+
'api-template'
152+
])
153+
expect(filteredCount.value).toBe(1)
154+
155+
removeUseCaseFilter('Video')
156+
await nextTick()
157+
expect(selectedUseCases.value).toHaveLength(0)
158+
159+
resetFilters()
160+
await nextTick()
161+
await vi.runOnlyPendingTimersAsync()
162+
await nextTick()
163+
expect(filteredTemplates.value.map((template) => template.name)).toEqual([
164+
'api-template',
165+
'portrait-flow',
166+
'landscape-lite'
167+
])
168+
})
169+
170+
it('supports alphabetical, newest, and size-based sorting options', async () => {
171+
const templates = ref<TemplateInfo[]>([
172+
{
173+
name: 'zeta-extended',
174+
description: 'older template',
175+
mediaType: 'image',
176+
mediaSubtype: 'png',
177+
date: '2024-01-01',
178+
size: 300
179+
},
180+
{
181+
name: 'alpha-starter',
182+
description: 'new template',
183+
mediaType: 'image',
184+
mediaSubtype: 'png',
185+
date: '2024-07-01',
186+
size: 100
187+
},
188+
{
189+
name: 'beta-pro',
190+
description: 'mid template',
191+
mediaType: 'image',
192+
mediaSubtype: 'png',
193+
date: '2024-05-01',
194+
size: 200
195+
}
196+
])
197+
198+
const { sortBy, filteredTemplates } = useTemplateFiltering(templates)
199+
200+
// default is 'newest'
201+
expect(filteredTemplates.value.map((template) => template.name)).toEqual([
202+
'alpha-starter',
203+
'beta-pro',
204+
'zeta-extended'
205+
])
206+
207+
sortBy.value = 'alphabetical'
208+
await nextTick()
209+
expect(filteredTemplates.value.map((template) => template.name)).toEqual([
210+
'alpha-starter',
211+
'beta-pro',
212+
'zeta-extended'
213+
])
214+
215+
sortBy.value = 'model-size-low-to-high'
216+
await nextTick()
217+
expect(filteredTemplates.value.map((template) => template.name)).toEqual([
218+
'alpha-starter',
219+
'beta-pro',
220+
'zeta-extended'
221+
])
222+
223+
sortBy.value = 'default'
224+
await nextTick()
225+
expect(filteredTemplates.value.map((template) => template.name)).toEqual([
226+
'zeta-extended',
227+
'alpha-starter',
228+
'beta-pro'
229+
])
230+
})
231+
})

0 commit comments

Comments
 (0)