11import { Computed , Context , Dict , h , Logger , omit , Quester , Session , SessionError , trimSlash } from 'koishi'
22import { Config , modelMap , models , orientMap , parseInput , sampler , upscalers , scheduler } from './config'
3- import { ImageData , StableDiffusionWebUI } from './types'
3+ import { ImageData , NovelAI , StableDiffusionWebUI } from './types'
44import { closestMultiple , download , forceDataPrefix , getImageSize , login , NetworkError , project , resizeInput , Size } from './utils'
55import { } from '@koishijs/translator'
66import { } from '@koishijs/plugin-help'
@@ -13,9 +13,7 @@ export * from './config'
1313export const reactive = true
1414export const name = 'novelai'
1515
16- const logger = new Logger ( 'novelai' )
17-
18- function handleError ( session : Session , err : Error ) {
16+ function handleError ( { logger } : Context , session : Session , err : Error ) {
1917 if ( Quester . Error . is ( err ) ) {
2018 if ( err . response ?. status === 402 ) {
2119 return session . text ( '.unauthorized' )
@@ -201,7 +199,7 @@ export function apply(ctx: Context, config: Config) {
201199 try {
202200 input = await ctx . translator . translate ( { input, target : 'en' } )
203201 } catch ( err ) {
204- logger . warn ( err )
202+ ctx . logger . warn ( err )
205203 }
206204 }
207205
@@ -215,7 +213,7 @@ export function apply(ctx: Context, config: Config) {
215213 if ( err instanceof NetworkError ) {
216214 return session . text ( err . message , err . params )
217215 }
218- logger . error ( err )
216+ ctx . logger . error ( err )
219217 return session . text ( '.unknown-error' )
220218 }
221219
@@ -243,7 +241,7 @@ export function apply(ctx: Context, config: Config) {
243241 if ( err instanceof NetworkError ) {
244242 return session . text ( err . message , err . params )
245243 }
246- logger . error ( err )
244+ ctx . logger . error ( err )
247245 return session . text ( '.download-error' )
248246 }
249247
@@ -335,29 +333,49 @@ export function apply(ctx: Context, config: Config) {
335333 delete parameters . uc
336334 }
337335 parameters . dynamic_thresholding = options . decrisper ?? config . decrisper
338- if ( model === 'nai-diffusion-3' ) {
336+ if ( model === 'nai-diffusion-3' || model === 'nai-diffusion-4-curated-preview' ) {
339337 parameters . legacy = false
340- parameters . legacy_v3_extend = false
341- parameters . sm_dyn = options . smeaDyn ?? config . smeaDyn
342- parameters . sm = ( options . smea ?? config . smea ) || parameters . sm_dyn
343338 parameters . noise_schedule = options . scheduler ?? config . scheduler
344- if ( [ 'k_euler_ancestral' , 'k_dpmpp_2s_ancestral' ] . includes ( parameters . sampler )
345- && parameters . noise_schedule === 'karras' ) {
346- parameters . noise_schedule = 'native'
347- }
348- if ( parameters . sampler === 'ddim_v3' ) {
349- parameters . sm = false
350- parameters . sm_dyn = false
351- delete parameters . noise_schedule
352- }
353339 // Max scale for nai-v3 is 10, but not 20.
354340 // If the given value is greater than 10,
355341 // we can assume it is configured with an older version (max 20)
356342 if ( parameters . scale > 10 ) {
357343 parameters . scale = parameters . scale / 2
358344 }
345+ if ( model === 'nai-diffusion-3' ) {
346+ parameters . legacy_v3_extend = false
347+ parameters . sm_dyn = options . smeaDyn ?? config . smeaDyn
348+ parameters . sm = ( options . smea ?? config . smea ) || parameters . sm_dyn
349+ if ( [ 'k_euler_ancestral' , 'k_dpmpp_2s_ancestral' ] . includes ( parameters . sampler )
350+ && parameters . noise_schedule === 'karras' ) {
351+ parameters . noise_schedule = 'native'
352+ }
353+ if ( parameters . sampler === 'ddim_v3' ) {
354+ parameters . sm = false
355+ parameters . sm_dyn = false
356+ delete parameters . noise_schedule
357+ }
358+ }
359+ if ( model === 'nai-diffusion-4-curated-preview' ) {
360+ parameters . use_coords = false // unknown
361+ parameters . characterPrompts = [ ] satisfies NovelAI . V4CharacterPrompt [ ]
362+ parameters . v4_prompt = {
363+ caption : {
364+ base_caption : prompt ,
365+ char_captions : [ ] ,
366+ } ,
367+ use_coords : parameters . use_coords ,
368+ use_order : true ,
369+ } satisfies NovelAI . V4PromptPositive
370+ parameters . v4_negative_prompt = {
371+ caption : {
372+ base_caption : parameters . negative_prompt ,
373+ char_captions : [ ] ,
374+ } ,
375+ } satisfies NovelAI . V4Prompt
376+ }
359377 }
360- return { model, input : prompt , parameters : omit ( parameters , [ 'prompt' ] ) }
378+ return { model, input : prompt , action : 'generate' , parameters : omit ( parameters , [ 'prompt' ] ) }
361379 }
362380 case 'sd-webui' : {
363381 return {
@@ -417,7 +435,7 @@ export function apply(ctx: Context, config: Config) {
417435 ? resolve ( ctx . baseDir , config . workflowImage2Image )
418436 : resolve ( __dirname , '../data/default-comfyui-i2i-wf.json' )
419437 const workflow = image ? workflowImage2Image : workflowText2Image
420- logger . debug ( 'workflow:' , workflow )
438+ ctx . logger . debug ( 'workflow:' , workflow )
421439 const prompt = JSON . parse ( await readFile ( workflow , 'utf8' ) )
422440
423441 // have to upload image to the comfyui server first
@@ -480,7 +498,7 @@ export function apply(ctx: Context, config: Config) {
480498 break
481499 }
482500 }
483- logger . debug ( 'prompt:' , prompt )
501+ ctx . logger . debug ( 'prompt:' , prompt )
484502 return { prompt }
485503 }
486504 }
@@ -520,7 +538,7 @@ export function apply(ctx: Context, config: Config) {
520538 try {
521539 finalPrompt = ( JSON . parse ( data . info ) ) . prompt
522540 } catch ( err ) {
523- logger . warn ( err )
541+ ctx . logger . warn ( err )
524542 }
525543 }
526544 return forceDataPrefix ( data . images [ 0 ] )
@@ -598,8 +616,7 @@ export function apply(ctx: Context, config: Config) {
598616 continue
599617 }
600618 }
601-
602- return await session . send ( handleError ( session , err ) )
619+ return await session . send ( handleError ( ctx , session , err ) )
603620 }
604621 }
605622
@@ -639,7 +656,7 @@ export function apply(ctx: Context, config: Config) {
639656 return result
640657 }
641658
642- logger . debug ( `${ session . uid } : ${ finalPrompt } ` )
659+ ctx . logger . debug ( `${ session . uid } : ${ finalPrompt } ` )
643660 const messageIds = await session . send ( getContent ( ) )
644661 if ( messageIds . length && config . recallTimeout ) {
645662 ctx . setTimeout ( ( ) => {
@@ -708,7 +725,7 @@ export function apply(ctx: Context, config: Config) {
708725 if ( err instanceof NetworkError ) {
709726 return session . text ( err . message , err . params )
710727 }
711- logger . error ( err )
728+ ctx . logger . error ( err )
712729 return session . text ( '.download-error' )
713730 }
714731
@@ -737,7 +754,7 @@ export function apply(ctx: Context, config: Config) {
737754 } )
738755 return h . image ( forceDataPrefix ( data . image ) )
739756 } catch ( e ) {
740- logger . warn ( e )
757+ ctx . logger . warn ( e )
741758 return session . text ( '.unknown-error' )
742759 }
743760 } )
0 commit comments