Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 24 additions & 11 deletions src/config.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { Computed, Dict, Schema, Session, Time } from 'koishi'
import { Size } from './utils'
import { ALLOWED_TYPES, Size } from './utils'

const options: Computed.Options = {
userFields: ['authority'],
Expand Down Expand Up @@ -183,10 +183,14 @@ export interface PromptConfig {
translator?: boolean
lowerCase?: boolean
maxWords?: Computed<number>
allowedInputImageTypes?: Computed<string[]>
transformPromptSyntax?: Computed<boolean>
}

export const PromptConfig: Schema<PromptConfig> = Schema.object({
basePrompt: Schema.computed(Schema.string().role('textarea'), options).description('默认附加的标签。').default('best quality, amazing quality, very aesthetic, absurdres'),
basePrompt: Schema.computed(Schema.string().role('textarea'), options)
.description('默认附加的标签。')
.default('best quality, amazing quality, very aesthetic, absurdres'),
negativePrompt: Schema.computed(Schema.string().role('textarea'), options).description('默认附加的反向标签。').default(ucPreset),
forbidden: Schema.computed(Schema.string().role('textarea'), options).description('违禁词列表。请求中的违禁词将会被自动删除。').default(''),
defaultPromptSw: Schema.boolean().description('是否启用默认标签。').default(false),
Expand All @@ -199,6 +203,13 @@ export const PromptConfig: Schema<PromptConfig> = Schema.object({
latinOnly: Schema.computed(Schema.boolean(), options).description('是否只接受英文输入。').default(false),
lowerCase: Schema.boolean().description('是否将输入的标签转换为小写。').default(true),
maxWords: Schema.computed(Schema.natural(), options).description('允许的最大单词数量。').default(0),
allowedInputImageTypes: Schema.computed(
Schema.array(Schema.string()).role('table').default(ALLOWED_TYPES),
options,
)
.description('允许从聊天平台获取的图片的文件类型,设为空列表表示忽略类型。')
.default(ALLOWED_TYPES),
transformPromptSyntax: Schema.computed(Schema.boolean(), options).description('是否自动转换输入标签的括号语法。').default(false),
}).description('输入设置')

interface FeatureConfig {
Expand Down Expand Up @@ -485,7 +496,7 @@ export function parseInput(session: Session, input: string, config: Config, over
return [
null,
[session.resolve(config.basePrompt), session.resolve(config.defaultPrompt)].join(','),
session.resolve(config.negativePrompt)
session.resolve(config.negativePrompt),
]
}

Expand All @@ -497,14 +508,16 @@ export function parseInput(session: Session, input: string, config: Config, over
.replace(/《/g, '<')
.replace(/》/g, '>')

if (config.type === 'sd-webui') {
input = input
.split('\\{').map(s => s.replace(/\{/g, '(')).join('\\{')
.split('\\}').map(s => s.replace(/\}/g, ')')).join('\\}')
} else {
input = input
.split('\\(').map(s => s.replace(/\(/g, '{')).join('\\(')
.split('\\)').map(s => s.replace(/\)/g, '}')).join('\\)')
if (session.resolve(config.transformPromptSyntax)) {
if (config.type === 'sd-webui') {
input = input
.split('\\{').map(s => s.replace(/\{/g, '(')).join('\\{')
.split('\\}').map(s => s.replace(/\}/g, ')')).join('\\}')
} else {
input = input
.split('\\(').map(s => s.replace(/\(/g, '{')).join('\\(')
.split('\\)').map(s => s.replace(/\)/g, '}')).join('\\)')
}
}

input = input
Expand Down
6 changes: 3 additions & 3 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ export function apply(ctx: Context, config: Config) {
}
} else {
input = haveInput ? h('', h.transform(h.parse(input), {
image(attrs) {
img(attrs) {
throw new SessionError('commands.novelai.messages.invalid-content')
},
})).toString(true) : input
Expand Down Expand Up @@ -236,7 +236,7 @@ export function apply(ctx: Context, config: Config) {

if (imgUrl) {
try {
image = await download(ctx, imgUrl)
image = await download(ctx, imgUrl, { allowedTypes: session.resolve(config.allowedInputImageTypes) })
} catch (err) {
if (err instanceof NetworkError) {
return session.text(err.message, err.params)
Expand Down Expand Up @@ -731,7 +731,7 @@ export function apply(ctx: Context, config: Config) {
if (!imgUrl) return session.text('.expect-image')
let image: ImageData
try {
image = await download(ctx, imgUrl)
image = await download(ctx, imgUrl, { allowedTypes: session.resolve(config.allowedInputImageTypes) })
} catch (err) {
if (err instanceof NetworkError) {
return session.text(err.message, err.params)
Expand Down
40 changes: 22 additions & 18 deletions src/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,31 +29,35 @@ export function getImageSize(buffer: ArrayBuffer): Size {
return pick(image, ['width', 'height'])
}

const MAX_OUTPUT_SIZE = 1048576
const MAX_CONTENT_SIZE = 10485760
const ALLOWED_TYPES = ['image/jpeg', 'image/png']

export async function download(ctx: Context, url: string, headers = {}): Promise<ImageData> {
export const MAX_OUTPUT_SIZE = 1048576
export const MAX_CONTENT_SIZE = 10485760
export const ALLOWED_TYPES = ['image/jpeg', 'image/png']

export async function download(
ctx: Context,
url: string,
{ headers, allowedTypes }: { headers?: Dict; allowedTypes?: string[] | null },
): Promise<ImageData> {
allowedTypes = allowedTypes === undefined ? ALLOWED_TYPES : allowedTypes

let mime: string
let buffer: ArrayBuffer
if (url.startsWith('data:') || url.startsWith('file:')) {
const { mime, data } = await ctx.http.file(url)
if (!ALLOWED_TYPES.includes(mime)) {
throw new NetworkError('.unsupported-file-type')
}
const base64 = arrayBufferToBase64(data)
return { buffer: data, base64, dataUrl: `data:${mime};base64,${base64}` }
({ mime, data: buffer } = await ctx.http.file(url))
} else {
const image = await ctx.http(url, { responseType: 'arraybuffer', headers })
if (+image.headers.get('content-length') > MAX_CONTENT_SIZE) {
throw new NetworkError('.file-too-large')
}
const mimetype = image.headers.get('content-type')
if (!ALLOWED_TYPES.includes(mimetype)) {
throw new NetworkError('.unsupported-file-type')
}
const buffer = image.data
const base64 = arrayBufferToBase64(buffer)
return { buffer, base64, dataUrl: `data:${mimetype};base64,${base64}` }
mime = image.headers.get('content-type')
buffer = image.data
}

if (allowedTypes && allowedTypes.length > 0 && !allowedTypes.includes(mime)) {
throw new NetworkError('.unsupported-file-type')
}
const base64 = arrayBufferToBase64(buffer)
return { buffer, base64, dataUrl: `data:${mime};base64,${base64}` }
}

export async function calcAccessKey(email: string, password: string) {
Expand Down