Skip to content

Commit ddb85ca

Browse files
joshistoastlstein
andauthored
fix(prompts): 🐛 prompt attention behaviors, add tests (#8683)
* fix(prompts): 🐛 prompt attention adjust elevation edge cases, added tests * refactor(prompts): ♻️ create attention edit helper for prompt boxes * feat(prompts): ✨ apply attention keybinds to negative prompt * feat(prompts): 🚀 reconsider behaviors, simplify code * fix(prompts): 🐛 keybind attention update not tracked by undo/redo * feat(prompts): ✨ overhaul prompt attention behavior * fix(prompts): 🩹 remove unused type * fix(prompts): 🩹 remove unused `Token` type --------- Co-authored-by: Lincoln Stein <[email protected]>
1 parent ac245cb commit ddb85ca

File tree

9 files changed

+729
-654
lines changed

9 files changed

+729
-654
lines changed

invokeai/frontend/web/src/common/util/promptAST.test.ts

Lines changed: 73 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -7,72 +7,76 @@ describe('promptAST', () => {
77
it('should tokenize basic text', () => {
88
const tokens = tokenize('a cat');
99
expect(tokens).toEqual([
10-
{ type: 'word', value: 'a' },
11-
{ type: 'whitespace', value: ' ' },
12-
{ type: 'word', value: 'cat' },
10+
{ type: 'word', value: 'a', start: 0, end: 1 },
11+
{ type: 'whitespace', value: ' ', start: 1, end: 2 },
12+
{ type: 'word', value: 'cat', start: 2, end: 5 },
1313
]);
1414
});
1515

1616
it('should tokenize groups with parentheses', () => {
1717
const tokens = tokenize('(a cat)');
1818
expect(tokens).toEqual([
19-
{ type: 'lparen' },
20-
{ type: 'word', value: 'a' },
21-
{ type: 'whitespace', value: ' ' },
22-
{ type: 'word', value: 'cat' },
23-
{ type: 'rparen' },
19+
{ type: 'lparen', start: 0, end: 1 },
20+
{ type: 'word', value: 'a', start: 1, end: 2 },
21+
{ type: 'whitespace', value: ' ', start: 2, end: 3 },
22+
{ type: 'word', value: 'cat', start: 3, end: 6 },
23+
{ type: 'rparen', start: 6, end: 7 },
2424
]);
2525
});
2626

2727
it('should tokenize escaped parentheses', () => {
2828
const tokens = tokenize('\\(medium\\)');
2929
expect(tokens).toEqual([
30-
{ type: 'escaped_paren', value: '(' },
31-
{ type: 'word', value: 'medium' },
32-
{ type: 'escaped_paren', value: ')' },
30+
{ type: 'escaped_paren', value: '(', start: 0, end: 2 },
31+
{ type: 'word', value: 'medium', start: 2, end: 8 },
32+
{ type: 'escaped_paren', value: ')', start: 8, end: 10 },
3333
]);
3434
});
3535

3636
it('should tokenize mixed escaped and unescaped parentheses', () => {
3737
const tokens = tokenize('colored pencil \\(medium\\) (enhanced)');
3838
expect(tokens).toEqual([
39-
{ type: 'word', value: 'colored' },
40-
{ type: 'whitespace', value: ' ' },
41-
{ type: 'word', value: 'pencil' },
42-
{ type: 'whitespace', value: ' ' },
43-
{ type: 'escaped_paren', value: '(' },
44-
{ type: 'word', value: 'medium' },
45-
{ type: 'escaped_paren', value: ')' },
46-
{ type: 'whitespace', value: ' ' },
47-
{ type: 'lparen' },
48-
{ type: 'word', value: 'enhanced' },
49-
{ type: 'rparen' },
39+
{ type: 'word', value: 'colored', start: 0, end: 7 },
40+
{ type: 'whitespace', value: ' ', start: 7, end: 8 },
41+
{ type: 'word', value: 'pencil', start: 8, end: 14 },
42+
{ type: 'whitespace', value: ' ', start: 14, end: 15 },
43+
{ type: 'escaped_paren', value: '(', start: 15, end: 17 },
44+
{ type: 'word', value: 'medium', start: 17, end: 23 },
45+
{ type: 'escaped_paren', value: ')', start: 23, end: 25 },
46+
{ type: 'whitespace', value: ' ', start: 25, end: 26 },
47+
{ type: 'lparen', start: 26, end: 27 },
48+
{ type: 'word', value: 'enhanced', start: 27, end: 35 },
49+
{ type: 'rparen', start: 35, end: 36 },
5050
]);
5151
});
5252

5353
it('should tokenize groups with weights', () => {
5454
const tokens = tokenize('(a cat)1.2');
5555
expect(tokens).toEqual([
56-
{ type: 'lparen' },
57-
{ type: 'word', value: 'a' },
58-
{ type: 'whitespace', value: ' ' },
59-
{ type: 'word', value: 'cat' },
60-
{ type: 'rparen' },
61-
{ type: 'weight', value: 1.2 },
56+
{ type: 'lparen', start: 0, end: 1 },
57+
{ type: 'word', value: 'a', start: 1, end: 2 },
58+
{ type: 'whitespace', value: ' ', start: 2, end: 3 },
59+
{ type: 'word', value: 'cat', start: 3, end: 6 },
60+
{ type: 'rparen', start: 6, end: 7 },
61+
{ type: 'weight', value: 1.2, start: 7, end: 10 },
6262
]);
6363
});
6464

6565
it('should tokenize words with weights', () => {
6666
const tokens = tokenize('cat+');
6767
expect(tokens).toEqual([
68-
{ type: 'word', value: 'cat' },
69-
{ type: 'weight', value: '+' },
68+
{ type: 'word', value: 'cat', start: 0, end: 3 },
69+
{ type: 'weight', value: '+', start: 3, end: 4 },
7070
]);
7171
});
7272

7373
it('should tokenize embeddings', () => {
7474
const tokens = tokenize('<embedding_name>');
75-
expect(tokens).toEqual([{ type: 'lembed' }, { type: 'word', value: 'embedding_name' }, { type: 'rembed' }]);
75+
expect(tokens).toEqual([
76+
{ type: 'lembed', start: 0, end: 1 },
77+
{ type: 'word', value: 'embedding_name', start: 1, end: 15 },
78+
{ type: 'rembed', start: 15, end: 16 },
79+
]);
7680
});
7781
});
7882

@@ -81,9 +85,9 @@ describe('promptAST', () => {
8185
const tokens = tokenize('a cat');
8286
const ast = parseTokens(tokens);
8387
expect(ast).toEqual([
84-
{ type: 'word', text: 'a' },
85-
{ type: 'whitespace', value: ' ' },
86-
{ type: 'word', text: 'cat' },
88+
{ type: 'word', text: 'a', range: { start: 0, end: 1 }, attention: undefined },
89+
{ type: 'whitespace', value: ' ', range: { start: 1, end: 2 } },
90+
{ type: 'word', text: 'cat', range: { start: 2, end: 5 }, attention: undefined },
8791
]);
8892
});
8993

@@ -93,10 +97,12 @@ describe('promptAST', () => {
9397
expect(ast).toEqual([
9498
{
9599
type: 'group',
100+
range: { start: 0, end: 7 },
101+
attention: undefined,
96102
children: [
97-
{ type: 'word', text: 'a' },
98-
{ type: 'whitespace', value: ' ' },
99-
{ type: 'word', text: 'cat' },
103+
{ type: 'word', text: 'a', range: { start: 1, end: 2 }, attention: undefined },
104+
{ type: 'whitespace', value: ' ', range: { start: 2, end: 3 } },
105+
{ type: 'word', text: 'cat', range: { start: 3, end: 6 }, attention: undefined },
100106
],
101107
},
102108
]);
@@ -106,27 +112,29 @@ describe('promptAST', () => {
106112
const tokens = tokenize('\\(medium\\)');
107113
const ast = parseTokens(tokens);
108114
expect(ast).toEqual([
109-
{ type: 'escaped_paren', value: '(' },
110-
{ type: 'word', text: 'medium' },
111-
{ type: 'escaped_paren', value: ')' },
115+
{ type: 'escaped_paren', value: '(', range: { start: 0, end: 2 } },
116+
{ type: 'word', text: 'medium', range: { start: 2, end: 8 }, attention: undefined },
117+
{ type: 'escaped_paren', value: ')', range: { start: 8, end: 10 } },
112118
]);
113119
});
114120

115121
it('should parse mixed escaped and unescaped parentheses', () => {
116122
const tokens = tokenize('colored pencil \\(medium\\) (enhanced)');
117123
const ast = parseTokens(tokens);
118124
expect(ast).toEqual([
119-
{ type: 'word', text: 'colored' },
120-
{ type: 'whitespace', value: ' ' },
121-
{ type: 'word', text: 'pencil' },
122-
{ type: 'whitespace', value: ' ' },
123-
{ type: 'escaped_paren', value: '(' },
124-
{ type: 'word', text: 'medium' },
125-
{ type: 'escaped_paren', value: ')' },
126-
{ type: 'whitespace', value: ' ' },
125+
{ type: 'word', text: 'colored', range: { start: 0, end: 7 }, attention: undefined },
126+
{ type: 'whitespace', value: ' ', range: { start: 7, end: 8 } },
127+
{ type: 'word', text: 'pencil', range: { start: 8, end: 14 }, attention: undefined },
128+
{ type: 'whitespace', value: ' ', range: { start: 14, end: 15 } },
129+
{ type: 'escaped_paren', value: '(', range: { start: 15, end: 17 } },
130+
{ type: 'word', text: 'medium', range: { start: 17, end: 23 }, attention: undefined },
131+
{ type: 'escaped_paren', value: ')', range: { start: 23, end: 25 } },
132+
{ type: 'whitespace', value: ' ', range: { start: 25, end: 26 } },
127133
{
128134
type: 'group',
129-
children: [{ type: 'word', text: 'enhanced' }],
135+
range: { start: 26, end: 36 },
136+
attention: undefined,
137+
children: [{ type: 'word', text: 'enhanced', range: { start: 27, end: 35 }, attention: undefined }],
130138
},
131139
]);
132140
});
@@ -138,10 +146,11 @@ describe('promptAST', () => {
138146
{
139147
type: 'group',
140148
attention: 1.2,
149+
range: { start: 0, end: 10 },
141150
children: [
142-
{ type: 'word', text: 'a' },
143-
{ type: 'whitespace', value: ' ' },
144-
{ type: 'word', text: 'cat' },
151+
{ type: 'word', text: 'a', range: { start: 1, end: 2 }, attention: undefined },
152+
{ type: 'whitespace', value: ' ', range: { start: 2, end: 3 } },
153+
{ type: 'word', text: 'cat', range: { start: 3, end: 6 }, attention: undefined },
145154
],
146155
},
147156
]);
@@ -150,13 +159,13 @@ describe('promptAST', () => {
150159
it('should parse words with attention', () => {
151160
const tokens = tokenize('cat+');
152161
const ast = parseTokens(tokens);
153-
expect(ast).toEqual([{ type: 'word', text: 'cat', attention: '+' }]);
162+
expect(ast).toEqual([{ type: 'word', text: 'cat', attention: '+', range: { start: 0, end: 4 } }]);
154163
});
155164

156165
it('should parse embeddings', () => {
157166
const tokens = tokenize('<embedding_name>');
158167
const ast = parseTokens(tokens);
159-
expect(ast).toEqual([{ type: 'embedding', value: 'embedding_name' }]);
168+
expect(ast).toEqual([{ type: 'embedding', value: 'embedding_name', range: { start: 0, end: 16 } }]);
160169
});
161170
});
162171

@@ -243,19 +252,20 @@ describe('promptAST', () => {
243252

244253
// Should have escaped parens as nodes and a group with attention
245254
expect(ast).toEqual([
246-
{ type: 'word', text: 'portrait' },
247-
{ type: 'whitespace', value: ' ' },
248-
{ type: 'escaped_paren', value: '(' },
249-
{ type: 'word', text: 'realistic' },
250-
{ type: 'escaped_paren', value: ')' },
251-
{ type: 'whitespace', value: ' ' },
255+
{ type: 'word', text: 'portrait', range: { start: 0, end: 8 }, attention: undefined },
256+
{ type: 'whitespace', value: ' ', range: { start: 8, end: 9 } },
257+
{ type: 'escaped_paren', value: '(', range: { start: 9, end: 11 } },
258+
{ type: 'word', text: 'realistic', range: { start: 11, end: 20 }, attention: undefined },
259+
{ type: 'escaped_paren', value: ')', range: { start: 20, end: 22 } },
260+
{ type: 'whitespace', value: ' ', range: { start: 22, end: 23 } },
252261
{
253262
type: 'group',
254263
attention: 1.2,
264+
range: { start: 23, end: 40 },
255265
children: [
256-
{ type: 'word', text: 'high' },
257-
{ type: 'whitespace', value: ' ' },
258-
{ type: 'word', text: 'quality' },
266+
{ type: 'word', text: 'high', range: { start: 24, end: 28 }, attention: undefined },
267+
{ type: 'whitespace', value: ' ', range: { start: 28, end: 29 } },
268+
{ type: 'word', text: 'quality', range: { start: 29, end: 36 }, attention: undefined },
259269
],
260270
},
261271
]);

0 commit comments

Comments
 (0)