Skip to content

Commit 4833547

Browse files
authored
Chores: Updates for Asset Services (#5872)
## Changes 1. Updates schema to match new API 2. Adds additional relevant models to the registry ┆Issue is synchronized with this [Notion page](https://www.notion.so/PR-5872-Chores-Updates-for-Asset-Services-27f6d73d36508117b89fd473f1a7090d) by [Unito](https://www.unito.io)
1 parent 0d3d258 commit 4833547

File tree

3 files changed

+75
-16
lines changed

3 files changed

+75
-16
lines changed

src/platform/assets/schemas/assetSchema.ts

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,16 @@ import { z } from 'zod'
44
const zAsset = z.object({
55
id: z.string(),
66
name: z.string(),
7-
asset_hash: z.string().nullable(),
7+
asset_hash: z.string().optional(),
88
size: z.number(),
9-
mime_type: z.string().nullable(),
10-
tags: z.array(z.string()),
9+
mime_type: z.string().optional(),
10+
tags: z.array(z.string()).optional().default([]),
11+
preview_id: z.string().nullable().optional(),
1112
preview_url: z.string().optional(),
1213
created_at: z.string(),
1314
updated_at: z.string().optional(),
14-
last_access_time: z.string(),
15-
user_metadata: z.record(z.unknown()).optional(), // API allows arbitrary key-value pairs
16-
preview_id: z.string().nullable().optional()
15+
last_access_time: z.string().optional(),
16+
user_metadata: z.record(z.unknown()).optional() // API allows arbitrary key-value pairs
1717
})
1818

1919
const zAssetResponse = z.object({

src/stores/modelToNodeStore.ts

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ export const useModelToNodeStore = defineStore('modelToNode', () => {
2929
return new Set(
3030
Object.values(modelToNodeMap.value)
3131
.flat()
32+
.filter((provider) => !!provider.nodeDef)
3233
.map((provider) => provider.nodeDef.name)
3334
)
3435
})
@@ -38,6 +39,8 @@ export const useModelToNodeStore = defineStore('modelToNode', () => {
3839
const lookup: Record<string, string> = {}
3940
for (const [category, providers] of Object.entries(modelToNodeMap.value)) {
4041
for (const provider of providers) {
42+
// Extension nodes may not be installed
43+
if (!provider.nodeDef) continue
4144
// Only store the first category for each node type (matches current assetService behavior)
4245
if (!lookup[provider.nodeDef.name]) {
4346
lookup[provider.nodeDef.name] = category
@@ -98,6 +101,7 @@ export const useModelToNodeStore = defineStore('modelToNode', () => {
98101
nodeProvider: ModelNodeProvider
99102
) {
100103
registerDefaults()
104+
if (!nodeProvider.nodeDef) return
101105
if (!modelToNodeMap.value[modelType]) {
102106
modelToNodeMap.value[modelType] = []
103107
}
@@ -131,10 +135,24 @@ export const useModelToNodeStore = defineStore('modelToNode', () => {
131135
quickRegister('loras', 'LoraLoaderModelOnly', 'lora_name')
132136
quickRegister('vae', 'VAELoader', 'vae_name')
133137
quickRegister('controlnet', 'ControlNetLoader', 'control_net_name')
134-
quickRegister('unet', 'UNETLoader', 'unet_name')
138+
quickRegister('diffusion_models', 'UNETLoader', 'unet_name')
135139
quickRegister('upscale_models', 'UpscaleModelLoader', 'model_name')
136-
quickRegister('style_models', 'StyleModelLoader', 'style_model')
140+
quickRegister('style_models', 'StyleModelLoader', 'style_model_name')
137141
quickRegister('gligen', 'GLIGENLoader', 'gligen_name')
142+
quickRegister('clip_vision', 'CLIPVisionLoader', 'clip_name')
143+
quickRegister('text_encoders', 'CLIPLoader', 'clip_name')
144+
quickRegister('audio_encoders', 'AudioEncoderLoader', 'audio_encoder_name')
145+
quickRegister('model_patches', 'ModelPatchLoader', 'name')
146+
quickRegister(
147+
'animatediff_models',
148+
'ADE_LoadAnimateDiffModel',
149+
'model_name'
150+
)
151+
quickRegister(
152+
'animatediff_motion_lora',
153+
'ADE_AnimateDiffLoRALoader',
154+
'name'
155+
)
138156
}
139157

140158
return {

tests-ui/tests/store/modelToNodeStore.test.ts

Lines changed: 49 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,16 @@ const EXPECTED_DEFAULT_TYPES = [
1313
'loras',
1414
'vae',
1515
'controlnet',
16-
'unet',
16+
'diffusion_models',
1717
'upscale_models',
1818
'style_models',
19-
'gligen'
19+
'gligen',
20+
'clip_vision',
21+
'text_encoders',
22+
'audio_encoders',
23+
'model_patches',
24+
'animatediff_models',
25+
'animatediff_motion_lora'
2026
] as const
2127

2228
type NodeDefStoreType = ReturnType<typeof useNodeDefStore>
@@ -48,7 +54,13 @@ const MOCK_NODE_NAMES = [
4854
'UNETLoader',
4955
'UpscaleModelLoader',
5056
'StyleModelLoader',
51-
'GLIGENLoader'
57+
'GLIGENLoader',
58+
'CLIPVisionLoader',
59+
'CLIPLoader',
60+
'AudioEncoderLoader',
61+
'ModelPatchLoader',
62+
'ADE_LoadAnimateDiffModel',
63+
'ADE_AnimateDiffLoRALoader'
5264
] as const
5365

5466
const mockNodeDefsByName = Object.fromEntries(
@@ -84,7 +96,7 @@ describe('useModelToNodeStore', () => {
8496
const modelToNodeStore = useModelToNodeStore()
8597
modelToNodeStore.registerDefaults()
8698
expect(Object.keys(modelToNodeStore.modelToNodeMap)).toEqual(
87-
expect.arrayContaining(['checkpoints', 'unet'])
99+
expect.arrayContaining(['checkpoints', 'diffusion_models'])
88100
)
89101
})
90102
})
@@ -153,9 +165,10 @@ describe('useModelToNodeStore', () => {
153165
const modelToNodeStore = useModelToNodeStore()
154166
modelToNodeStore.registerDefaults()
155167

156-
const unetProviders = modelToNodeStore.getAllNodeProviders('unet')
157-
expect(unetProviders).toHaveLength(1)
158-
expect(unetProviders[0].nodeDef.name).toBe('UNETLoader')
168+
const diffusionModelProviders =
169+
modelToNodeStore.getAllNodeProviders('diffusion_models')
170+
expect(diffusionModelProviders).toHaveLength(1)
171+
expect(diffusionModelProviders[0].nodeDef.name).toBe('UNETLoader')
159172
})
160173

161174
it('should return empty array for unregistered model type', () => {
@@ -173,6 +186,22 @@ describe('useModelToNodeStore', () => {
173186
})
174187

175188
describe('registerNodeProvider', () => {
189+
it('should not register provider when nodeDef is undefined', () => {
190+
const modelToNodeStore = useModelToNodeStore()
191+
const providerWithoutNodeDef = new ModelNodeProvider(
192+
undefined as any,
193+
'custom_key'
194+
)
195+
196+
modelToNodeStore.registerNodeProvider(
197+
'custom_type',
198+
providerWithoutNodeDef
199+
)
200+
201+
const retrieved = modelToNodeStore.getNodeProvider('custom_type')
202+
expect(retrieved).toBeUndefined()
203+
})
204+
176205
it('should register provider directly', () => {
177206
const modelToNodeStore = useModelToNodeStore()
178207
const nodeDefStore = useNodeDefStore()
@@ -250,8 +279,20 @@ describe('useModelToNodeStore', () => {
250279
}).not.toThrow()
251280

252281
const provider = modelToNodeStore.getNodeProvider('test_type')
253-
// Optional chaining needed since getNodeProvider() can return undefined
254282
expect(provider?.nodeDef).toBeUndefined()
283+
284+
expect(() => modelToNodeStore.getRegisteredNodeTypes()).not.toThrow()
285+
expect(() =>
286+
modelToNodeStore.getCategoryForNodeType('NonExistentLoader')
287+
).not.toThrow()
288+
289+
// Non-existent nodes are filtered out from registered types
290+
const types = modelToNodeStore.getRegisteredNodeTypes()
291+
expect(types.has('NonExistentLoader')).toBe(false)
292+
293+
expect(
294+
modelToNodeStore.getCategoryForNodeType('NonExistentLoader')
295+
).toBeUndefined()
255296
})
256297

257298
it('should allow multiple node classes for same model type', () => {

0 commit comments

Comments
 (0)