diff --git a/src/config.ts b/src/config.ts index be66299..80d622ce 100644 --- a/src/config.ts +++ b/src/config.ts @@ -1,5 +1,5 @@ import { Computed, Dict, Schema, Session, Time } from 'koishi' -import { Size } from './utils' +import { ALLOWED_TYPES, Size } from './utils' const options: Computed.Options = { userFields: ['authority'], @@ -183,10 +183,14 @@ export interface PromptConfig { translator?: boolean lowerCase?: boolean maxWords?: Computed + allowedInputImageTypes?: Computed + transformPromptSyntax?: Computed } export const PromptConfig: Schema = Schema.object({ - basePrompt: Schema.computed(Schema.string().role('textarea'), options).description('默认附加的标签。').default('best quality, amazing quality, very aesthetic, absurdres'), + basePrompt: Schema.computed(Schema.string().role('textarea'), options) + .description('默认附加的标签。') + .default('best quality, amazing quality, very aesthetic, absurdres'), negativePrompt: Schema.computed(Schema.string().role('textarea'), options).description('默认附加的反向标签。').default(ucPreset), forbidden: Schema.computed(Schema.string().role('textarea'), options).description('违禁词列表。请求中的违禁词将会被自动删除。').default(''), defaultPromptSw: Schema.boolean().description('是否启用默认标签。').default(false), @@ -199,6 +203,13 @@ export const PromptConfig: Schema = Schema.object({ latinOnly: Schema.computed(Schema.boolean(), options).description('是否只接受英文输入。').default(false), lowerCase: Schema.boolean().description('是否将输入的标签转换为小写。').default(true), maxWords: Schema.computed(Schema.natural(), options).description('允许的最大单词数量。').default(0), + allowedInputImageTypes: Schema.computed( + Schema.array(Schema.string()).role('table').default(ALLOWED_TYPES), + options, + ) + .description('允许从聊天平台获取的图片的文件类型,设为空列表表示忽略类型。') + .default(ALLOWED_TYPES), + transformPromptSyntax: Schema.computed(Schema.boolean(), options).description('是否自动转换输入标签的括号语法。').default(false), }).description('输入设置') interface FeatureConfig { @@ -485,7 +496,7 @@ export function parseInput(session: Session, input: string, config: Config, over return [ null, [session.resolve(config.basePrompt), session.resolve(config.defaultPrompt)].join(','), - session.resolve(config.negativePrompt) + session.resolve(config.negativePrompt), ] } @@ -497,14 +508,16 @@ export function parseInput(session: Session, input: string, config: Config, over .replace(/《/g, '<') .replace(/》/g, '>') - if (config.type === 'sd-webui') { - input = input - .split('\\{').map(s => s.replace(/\{/g, '(')).join('\\{') - .split('\\}').map(s => s.replace(/\}/g, ')')).join('\\}') - } else { - input = input - .split('\\(').map(s => s.replace(/\(/g, '{')).join('\\(') - .split('\\)').map(s => s.replace(/\)/g, '}')).join('\\)') + if (session.resolve(config.transformPromptSyntax)) { + if (config.type === 'sd-webui') { + input = input + .split('\\{').map(s => s.replace(/\{/g, '(')).join('\\{') + .split('\\}').map(s => s.replace(/\}/g, ')')).join('\\}') + } else { + input = input + .split('\\(').map(s => s.replace(/\(/g, '{')).join('\\(') + .split('\\)').map(s => s.replace(/\)/g, '}')).join('\\)') + } } input = input diff --git a/src/index.ts b/src/index.ts index 75b8a8b..e626c8e 100644 --- a/src/index.ts +++ b/src/index.ts @@ -180,7 +180,7 @@ export function apply(ctx: Context, config: Config) { } } else { input = haveInput ? h('', h.transform(h.parse(input), { - image(attrs) { + img(attrs) { throw new SessionError('commands.novelai.messages.invalid-content') }, })).toString(true) : input @@ -236,7 +236,7 @@ export function apply(ctx: Context, config: Config) { if (imgUrl) { try { - image = await download(ctx, imgUrl) + image = await download(ctx, imgUrl, { allowedTypes: session.resolve(config.allowedInputImageTypes) }) } catch (err) { if (err instanceof NetworkError) { return session.text(err.message, err.params) @@ -731,7 +731,7 @@ export function apply(ctx: Context, config: Config) { if (!imgUrl) return session.text('.expect-image') let image: ImageData try { - image = await download(ctx, imgUrl) + image = await download(ctx, imgUrl, { allowedTypes: session.resolve(config.allowedInputImageTypes) }) } catch (err) { if (err instanceof NetworkError) { return session.text(err.message, err.params) diff --git a/src/utils.ts b/src/utils.ts index 7c97433..fd2ce62 100644 --- a/src/utils.ts +++ b/src/utils.ts @@ -29,31 +29,35 @@ export function getImageSize(buffer: ArrayBuffer): Size { return pick(image, ['width', 'height']) } -const MAX_OUTPUT_SIZE = 1048576 -const MAX_CONTENT_SIZE = 10485760 -const ALLOWED_TYPES = ['image/jpeg', 'image/png'] - -export async function download(ctx: Context, url: string, headers = {}): Promise { +export const MAX_OUTPUT_SIZE = 1048576 +export const MAX_CONTENT_SIZE = 10485760 +export const ALLOWED_TYPES = ['image/jpeg', 'image/png'] + +export async function download( + ctx: Context, + url: string, + { headers, allowedTypes }: { headers?: Dict; allowedTypes?: string[] | null }, +): Promise { + allowedTypes = allowedTypes === undefined ? ALLOWED_TYPES : allowedTypes + + let mime: string + let buffer: ArrayBuffer if (url.startsWith('data:') || url.startsWith('file:')) { - const { mime, data } = await ctx.http.file(url) - if (!ALLOWED_TYPES.includes(mime)) { - throw new NetworkError('.unsupported-file-type') - } - const base64 = arrayBufferToBase64(data) - return { buffer: data, base64, dataUrl: `data:${mime};base64,${base64}` } + ({ mime, data: buffer } = await ctx.http.file(url)) } else { const image = await ctx.http(url, { responseType: 'arraybuffer', headers }) if (+image.headers.get('content-length') > MAX_CONTENT_SIZE) { throw new NetworkError('.file-too-large') } - const mimetype = image.headers.get('content-type') - if (!ALLOWED_TYPES.includes(mimetype)) { - throw new NetworkError('.unsupported-file-type') - } - const buffer = image.data - const base64 = arrayBufferToBase64(buffer) - return { buffer, base64, dataUrl: `data:${mimetype};base64,${base64}` } + mime = image.headers.get('content-type') + buffer = image.data + } + + if (allowedTypes && allowedTypes.length > 0 && !allowedTypes.includes(mime)) { + throw new NetworkError('.unsupported-file-type') } + const base64 = arrayBufferToBase64(buffer) + return { buffer, base64, dataUrl: `data:${mime};base64,${base64}` } } export async function calcAccessKey(email: string, password: string) {