Skip to content

Commit d84328c

Browse files
authored
feat: support nai nai-diffusion-4-curated-preview model (#268)
1 parent 3fdf021 commit d84328c

File tree

3 files changed

+97
-29
lines changed

3 files changed

+97
-29
lines changed

src/config.ts

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ export const modelMap = {
1010
nai: 'nai-diffusion',
1111
furry: 'nai-diffusion-furry',
1212
'nai-v3': 'nai-diffusion-3',
13+
'nai-v4-curated-preview': 'nai-diffusion-4-curated-preview',
1314
} as const
1415

1516
export const orientMap = {
@@ -35,6 +36,7 @@ export const orients = Object.keys(orientMap) as Orient[]
3536

3637
export namespace scheduler {
3738
export const nai = ['native', 'karras', 'exponential', 'polyexponential'] as const
39+
export const nai4 = ['karras', 'exponential', 'polyexponential'] as const
3840
export const sd = ['Automatic', 'Uniform', 'Karras', 'Exponential', 'Polyexponential', 'SGM Uniform'] as const
3941
export const horde = ['karras'] as const
4042
export const comfyUI = ['normal', 'karras', 'exponential', 'sgm_uniform', 'simple', 'ddim_uniform'] as const
@@ -58,6 +60,17 @@ export namespace sampler {
5860
'ddim_v3': 'DDIM V3',
5961
}
6062

63+
export const nai4 = {
64+
// recommended
65+
'k_euler': 'Euler',
66+
'k_euler_a': 'Euler ancestral',
67+
'k_dpmpp_2s_ancestral': 'DPM++ 2S ancestral',
68+
'k_dpmpp_2m_sde': 'DPM++ 2M SDE',
69+
// other
70+
'k_dpmpp_2m': 'DPM++ 2M',
71+
'k_dpmpp_sde': 'DPM++ SDE',
72+
}
73+
6174
// samplers in stable-diffusion-webui
6275
// auto-generated by `build/fetch-sd-samplers.js`
6376
export const sd = require('../data/sd-samplers.json') as Dict<string>
@@ -387,6 +400,11 @@ export const Config = Schema.intersect([
387400
smeaDyn: Schema.boolean().description('默认启用 SMEA 采样器的 DYN 变体。'),
388401
scheduler: Schema.union(scheduler.nai).description('默认的调度器。').default('native'),
389402
}),
403+
Schema.object({
404+
model: Schema.const('nai-v4-curated-preview'),
405+
sampler: sampler.createSchema(sampler.nai4),
406+
scheduler: Schema.union(scheduler.nai4).description('默认的调度器。').default('karras'),
407+
}),
390408
Schema.object({ sampler: sampler.createSchema(sampler.nai) }),
391409
]),
392410
Schema.object({ decrisper: Schema.boolean().description('默认启用 decrisper') }),

src/index.ts

Lines changed: 46 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import { Computed, Context, Dict, h, Logger, omit, Quester, Session, SessionError, trimSlash } from 'koishi'
22
import { Config, modelMap, models, orientMap, parseInput, sampler, upscalers, scheduler } from './config'
3-
import { ImageData, StableDiffusionWebUI } from './types'
3+
import { ImageData, NovelAI, StableDiffusionWebUI } from './types'
44
import { closestMultiple, download, forceDataPrefix, getImageSize, login, NetworkError, project, resizeInput, Size } from './utils'
55
import { } from '@koishijs/translator'
66
import { } from '@koishijs/plugin-help'
@@ -13,9 +13,7 @@ export * from './config'
1313
export const reactive = true
1414
export const name = 'novelai'
1515

16-
const logger = new Logger('novelai')
17-
18-
function handleError(session: Session, err: Error) {
16+
function handleError({ logger }: Context, session: Session, err: Error) {
1917
if (Quester.Error.is(err)) {
2018
if (err.response?.status === 402) {
2119
return session.text('.unauthorized')
@@ -201,7 +199,7 @@ export function apply(ctx: Context, config: Config) {
201199
try {
202200
input = await ctx.translator.translate({ input, target: 'en' })
203201
} catch (err) {
204-
logger.warn(err)
202+
ctx.logger.warn(err)
205203
}
206204
}
207205

@@ -215,7 +213,7 @@ export function apply(ctx: Context, config: Config) {
215213
if (err instanceof NetworkError) {
216214
return session.text(err.message, err.params)
217215
}
218-
logger.error(err)
216+
ctx.logger.error(err)
219217
return session.text('.unknown-error')
220218
}
221219

@@ -243,7 +241,7 @@ export function apply(ctx: Context, config: Config) {
243241
if (err instanceof NetworkError) {
244242
return session.text(err.message, err.params)
245243
}
246-
logger.error(err)
244+
ctx.logger.error(err)
247245
return session.text('.download-error')
248246
}
249247

@@ -335,29 +333,49 @@ export function apply(ctx: Context, config: Config) {
335333
delete parameters.uc
336334
}
337335
parameters.dynamic_thresholding = options.decrisper ?? config.decrisper
338-
if (model === 'nai-diffusion-3') {
336+
if (model === 'nai-diffusion-3' || model === 'nai-diffusion-4-curated-preview') {
339337
parameters.legacy = false
340-
parameters.legacy_v3_extend = false
341-
parameters.sm_dyn = options.smeaDyn ?? config.smeaDyn
342-
parameters.sm = (options.smea ?? config.smea) || parameters.sm_dyn
343338
parameters.noise_schedule = options.scheduler ?? config.scheduler
344-
if (['k_euler_ancestral', 'k_dpmpp_2s_ancestral'].includes(parameters.sampler)
345-
&& parameters.noise_schedule === 'karras') {
346-
parameters.noise_schedule = 'native'
347-
}
348-
if (parameters.sampler === 'ddim_v3') {
349-
parameters.sm = false
350-
parameters.sm_dyn = false
351-
delete parameters.noise_schedule
352-
}
353339
// Max scale for nai-v3 is 10, but not 20.
354340
// If the given value is greater than 10,
355341
// we can assume it is configured with an older version (max 20)
356342
if (parameters.scale > 10) {
357343
parameters.scale = parameters.scale / 2
358344
}
345+
if (model === 'nai-diffusion-3') {
346+
parameters.legacy_v3_extend = false
347+
parameters.sm_dyn = options.smeaDyn ?? config.smeaDyn
348+
parameters.sm = (options.smea ?? config.smea) || parameters.sm_dyn
349+
if (['k_euler_ancestral', 'k_dpmpp_2s_ancestral'].includes(parameters.sampler)
350+
&& parameters.noise_schedule === 'karras') {
351+
parameters.noise_schedule = 'native'
352+
}
353+
if (parameters.sampler === 'ddim_v3') {
354+
parameters.sm = false
355+
parameters.sm_dyn = false
356+
delete parameters.noise_schedule
357+
}
358+
}
359+
if (model === 'nai-diffusion-4-curated-preview') {
360+
parameters.use_coords = false // unknown
361+
parameters.characterPrompts = [] satisfies NovelAI.V4CharacterPrompt[]
362+
parameters.v4_prompt = {
363+
caption: {
364+
base_caption: prompt,
365+
char_captions: [],
366+
},
367+
use_coords: parameters.use_coords,
368+
use_order: true,
369+
} satisfies NovelAI.V4PromptPositive
370+
parameters.v4_negative_prompt = {
371+
caption: {
372+
base_caption: parameters.negative_prompt,
373+
char_captions: [],
374+
},
375+
} satisfies NovelAI.V4Prompt
376+
}
359377
}
360-
return { model, input: prompt, parameters: omit(parameters, ['prompt']) }
378+
return { model, input: prompt, action: 'generate', parameters: omit(parameters, ['prompt']) }
361379
}
362380
case 'sd-webui': {
363381
return {
@@ -417,7 +435,7 @@ export function apply(ctx: Context, config: Config) {
417435
? resolve(ctx.baseDir, config.workflowImage2Image)
418436
: resolve(__dirname, '../data/default-comfyui-i2i-wf.json')
419437
const workflow = image ? workflowImage2Image : workflowText2Image
420-
logger.debug('workflow:', workflow)
438+
ctx.logger.debug('workflow:', workflow)
421439
const prompt = JSON.parse(await readFile(workflow, 'utf8'))
422440

423441
// have to upload image to the comfyui server first
@@ -480,7 +498,7 @@ export function apply(ctx: Context, config: Config) {
480498
break
481499
}
482500
}
483-
logger.debug('prompt:', prompt)
501+
ctx.logger.debug('prompt:', prompt)
484502
return { prompt }
485503
}
486504
}
@@ -520,7 +538,7 @@ export function apply(ctx: Context, config: Config) {
520538
try {
521539
finalPrompt = (JSON.parse(data.info)).prompt
522540
} catch (err) {
523-
logger.warn(err)
541+
ctx.logger.warn(err)
524542
}
525543
}
526544
return forceDataPrefix(data.images[0])
@@ -598,8 +616,7 @@ export function apply(ctx: Context, config: Config) {
598616
continue
599617
}
600618
}
601-
602-
return await session.send(handleError(session, err))
619+
return await session.send(handleError(ctx, session, err))
603620
}
604621
}
605622

@@ -639,7 +656,7 @@ export function apply(ctx: Context, config: Config) {
639656
return result
640657
}
641658

642-
logger.debug(`${session.uid}: ${finalPrompt}`)
659+
ctx.logger.debug(`${session.uid}: ${finalPrompt}`)
643660
const messageIds = await session.send(getContent())
644661
if (messageIds.length && config.recallTimeout) {
645662
ctx.setTimeout(() => {
@@ -708,7 +725,7 @@ export function apply(ctx: Context, config: Config) {
708725
if (err instanceof NetworkError) {
709726
return session.text(err.message, err.params)
710727
}
711-
logger.error(err)
728+
ctx.logger.error(err)
712729
return session.text('.download-error')
713730
}
714731

@@ -737,7 +754,7 @@ export function apply(ctx: Context, config: Config) {
737754
})
738755
return h.image(forceDataPrefix(data.image))
739756
} catch (e) {
740-
logger.warn(e)
757+
ctx.logger.warn(e)
741758
return session.text('.unknown-error')
742759
}
743760
})

src/types.ts

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,39 @@ export interface ImageData {
4444
dataUrl: string
4545
}
4646

47+
export namespace NovelAI {
48+
/** 0.5, 0.5 means make ai choose */
49+
export interface V4CharacterPromptCenter {
50+
x: number
51+
y: number
52+
}
53+
54+
export interface V4CharacterPrompt {
55+
prompt: string
56+
uc: string
57+
center: V4CharacterPromptCenter
58+
}
59+
60+
export interface V4CharCaption {
61+
char_caption: string
62+
centers: V4CharacterPromptCenter[]
63+
}
64+
65+
export interface V4PromptCaption {
66+
base_caption: string
67+
char_captions: V4CharCaption[]
68+
}
69+
70+
export interface V4Prompt {
71+
caption: V4PromptCaption
72+
}
73+
74+
export interface V4PromptPositive extends V4Prompt {
75+
use_coords: boolean
76+
use_order: boolean
77+
}
78+
}
79+
4780
export namespace StableDiffusionWebUI {
4881
export interface Request {
4982
prompt: string

0 commit comments

Comments
 (0)