Skip to content

Commit 7fea6e6

Browse files
committed
Fix multi-line token adjustment
1 parent a9cb61e commit 7fea6e6

File tree

2 files changed

+88
-56
lines changed

2 files changed

+88
-56
lines changed

StabilityMatrix.Avalonia/Behaviors/TextEditorWeightAdjustmentBehavior.cs

Lines changed: 66 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -125,15 +125,25 @@ private void HandleWeightAdjustment(double delta)
125125

126126
try
127127
{
128+
// 1. Tokenize the entire text
129+
var text = textEditor.Document.Text;
130+
var tokenizeResult = TokenizerProvider.TokenizeLine(text);
131+
132+
// 2. Get the token segment
128133
if (
129-
(editorSelectionSegment != null ? GetSelectedTokenSpan() : GetCaretTokenSpan())
134+
(
135+
editorSelectionSegment != null
136+
? GetSelectedTokenSpan(tokenizeResult)
137+
: GetCaretTokenSpan(tokenizeResult)
138+
)
130139
is not { } tokenSegment
131140
)
141+
{
142+
Logger.Warn("No token segment found");
132143
return;
144+
}
133145

134-
// 2. Tokenize the entire text
135-
var text = textEditor.Document.Text;
136-
var tokenizeResult = TokenizerProvider.TokenizeLine(text);
146+
Logger.Debug("Token segment: {Segment}", tokenSegment);
137147

138148
// 3. Build the AST
139149
var astBuilder = new PromptSyntaxBuilder(tokenizeResult, text);
@@ -145,7 +155,8 @@ is not { } tokenSegment
145155
// If empty use intersection instead
146156
if (selectedNodes.Count == 0)
147157
{
148-
selectedNodes = ast.RootNode.Content.Where(node => node.Span.IntersectsWith(tokenSegment))
158+
selectedNodes = ast
159+
.RootNode.Content.Where(node => node.Span.IntersectsWith(tokenSegment))
149160
.ToList();
150161
}
151162

@@ -174,7 +185,11 @@ is not { } tokenSegment
174185

175186
// Go up and find the first parenthesized node, if any
176187
// (Considering only the first and last of the smallest nodes)
177-
var parenthesisTargets = smallestNodes.Take(1).Concat(smallestNodes.TakeLast(1)).ToList();
188+
var parenthesisTargets = smallestNodes.Take(1).ToList();
189+
if (smallestNodes.Count > 1)
190+
{
191+
parenthesisTargets.Add(smallestNodes.Last());
192+
}
178193

179194
Logger.Trace("Parenthesis targets: {Nodes}", parenthesisTargets);
180195

@@ -183,6 +198,8 @@ is not { } tokenSegment
183198
.OfType<ParenthesizedNode>()
184199
.FirstOrDefault();
185200

201+
// Logger.Trace("Parenthesized node: {Node} of {Nodes}", parenthesizedNode, parenthesisTargets);
202+
186203
var currentWeight = 1.0;
187204
int replacementOffset; // Offset to start replacing text
188205
int replacementLength; // Length of text to replace (0 if inserting)
@@ -224,6 +241,12 @@ is not { } tokenSegment
224241
}
225242

226243
// 8. Replace the text.
244+
Logger.Debug(
245+
"Replacing source text {SrcText} at {SrcRange}, with new text {NewText}",
246+
text[replacementOffset..(replacementOffset + replacementLength)],
247+
new TextSpan(replacementOffset, replacementLength),
248+
newText
249+
);
227250
textEditor.Document.Replace(replacementOffset, replacementLength, newText);
228251

229252
// Plus 1 to caret if we added parenthesis
@@ -256,7 +279,7 @@ is not { } tokenSegment
256279
}
257280
}
258281

259-
private TextSpan? GetSelectedTokenSpan()
282+
private TextSpan? GetSelectedTokenSpan(ITokenizeLineResult result)
260283
{
261284
if (textEditor is null)
262285
return null;
@@ -267,25 +290,14 @@ is not { } tokenSegment
267290
var selectionStart = textEditor.SelectionStart;
268291
var selectionEnd = selectionStart + textEditor.SelectionLength;
269292

270-
var startLine = textEditor.Document.GetLineByOffset(selectionStart);
271-
var endLine = textEditor.Document.GetLineByOffset(selectionEnd);
272-
273-
// For simplicity, we'll assume the selection is on a single line.
274-
// Multi-line weight adjustment would be significantly more complex.
275-
if (startLine.LineNumber != endLine.LineNumber)
276-
return null;
277-
278-
var lineText = textEditor.Document.GetText(startLine.Offset, startLine.Length);
279-
var result = TokenizerProvider!.TokenizeLine(lineText);
280-
281293
IToken? startToken = null;
282294
IToken? endToken = null;
283295

284296
// Find the tokens that intersect the selection.
285297
foreach (var token in result.Tokens)
286298
{
287-
var tokenStart = token.StartIndex + startLine.Offset;
288-
var tokenEnd = token.EndIndex + startLine.Offset;
299+
var tokenStart = token.StartIndex;
300+
var tokenEnd = token.EndIndex;
289301

290302
if (tokenEnd > selectionStart && startToken is null)
291303
{
@@ -306,23 +318,15 @@ is not { } tokenSegment
306318
return null;
307319

308320
// Ensure end index is within length of text
309-
var endIndex = Math.Min(endToken.EndIndex + startLine.Offset, textEditor.Document.TextLength);
321+
var endIndex = Math.Min(endToken.EndIndex, textEditor.Document.TextLength);
310322

311323
return TextSpan.FromBounds(startToken.StartIndex, endIndex);
312324
}
313325

314-
private TextSpan? GetCaretTokenSpan()
326+
private TextSpan? GetCaretTokenSpan(ITokenizeLineResult result)
315327
{
316-
var caret = textEditor!.CaretOffset;
317-
318-
// Get the line the caret is on
319-
var line = textEditor.Document.GetLineByOffset(caret);
320-
var lineText = textEditor.Document.GetText(line.Offset, line.Length);
321-
322-
var caretAbsoluteOffset = caret - line.Offset;
323-
324-
// Tokenize
325-
var result = TokenizerProvider!.TokenizeLine(lineText);
328+
var caretAbsoluteOffset = textEditor!.CaretOffset;
329+
var textEndOffset = textEditor.Document.TextLength;
326330

327331
IToken? currentToken = null;
328332
var currentTokenIndex = -1;
@@ -331,11 +335,11 @@ is not { } tokenSegment
331335
{
332336
var token = result.Tokens[i];
333337
// If we see a line comment token anywhere, return null
334-
var isComment = token.Scopes.Any(s => s.Contains("comment.line"));
338+
/*var isComment = token.Scopes.Any(s => s.Contains("comment.line"));
335339
if (isComment)
336340
{
337341
return null;
338-
}
342+
}*/
339343

340344
// Find match
341345
if (caretAbsoluteOffset >= token.StartIndex && caretAbsoluteOffset < token.EndIndex)
@@ -354,18 +358,33 @@ is not { } tokenSegment
354358
}
355359
}
356360

357-
// Check if the token is a separator, if so check the next token instead
358-
if (
359-
currentToken?.Scopes is { } scopes
360-
&& scopes.Contains("meta.structure.array.prompt")
361-
&& scopes.Contains("punctuation.separator.variable.prompt")
362-
)
361+
// Check if the token is a separator, if so check the previous or next token instead
362+
if (currentToken?.Scopes is { } scopes && scopes.Contains("punctuation.separator.variable.prompt"))
363363
{
364364
// Check if we have a prev token
365-
var nextToken = result.Tokens.ElementAtOrDefault(currentTokenIndex - 1);
366-
if (nextToken is not null)
365+
if (
366+
result.Tokens.ElementAtOrDefault(currentTokenIndex - 1) is { } prevToken
367+
&& !prevToken.Scopes.Contains("punctuation.separator.variable.prompt")
368+
)
369+
{
370+
Logger.Trace(
371+
"Matched on seperator, using previous token: {Current} -> {Prev}",
372+
currentToken,
373+
prevToken
374+
);
375+
currentToken = prevToken;
376+
}
377+
// Check if we have a next token
378+
else if (
379+
result.Tokens.ElementAtOrDefault(currentTokenIndex + 1) is { } nextToken
380+
&& !nextToken.Scopes.Contains("punctuation.separator.variable.prompt")
381+
)
367382
{
368-
// Use the next token instead
383+
Logger.Trace(
384+
"Matched on seperator, using next token: {Current} -> {Next}",
385+
currentToken,
386+
nextToken
387+
);
369388
currentToken = nextToken;
370389
}
371390
}
@@ -382,11 +401,11 @@ is not { } tokenSegment
382401
return null;
383402
}
384403

385-
var startOffset = currentToken.StartIndex + line.Offset;
386-
var endOffset = currentToken.EndIndex + line.Offset;
387-
388404
// Cap the offsets by the line offsets
389-
return TextSpan.FromBounds(Math.Max(startOffset, line.Offset), Math.Min(endOffset, line.EndOffset));
405+
var startOffset = Math.Max(currentToken.StartIndex, 0);
406+
var endOffset = Math.Min(currentToken.EndIndex, textEndOffset);
407+
408+
return TextSpan.FromBounds(startOffset, endOffset);
390409
}
391410

392411
[Localizable(false)]

StabilityMatrix.Core/Models/PromptSyntax/PromptSyntaxBuilder.cs

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ public PromptSyntaxTree BuildAST()
2828
var rootNode = new DocumentNode
2929
{
3030
Span = new TextSpan(startIndex, endIndex - startIndex),
31-
Content = nodes
31+
Content = nodes,
3232
};
3333

3434
return new PromptSyntaxTree(sourceText, rootNode, tokenizeResult.Tokens.ToList());
@@ -163,7 +163,7 @@ private NumberNode ParseNumber()
163163
{
164164
Raw = number,
165165
Span = new TextSpan(token.StartIndex, token.Length),
166-
Value = decimal.Parse(number, CultureInfo.InvariantCulture)
166+
Value = decimal.Parse(number, CultureInfo.InvariantCulture),
167167
};
168168
}
169169

@@ -181,8 +181,14 @@ openParenToken is null
181181

182182
while (MoreTokens())
183183
{
184+
// Check if no more tokens to consume.
184185
if (PeekToken() is not { } nextToken)
186+
{
187+
// Ensure we have length set
188+
if (node.Span.Length == 0)
189+
throw new InvalidOperationException("Unexpected end of input.");
185190
break;
191+
}
186192

187193
if (nextToken.Scopes.Contains("punctuation.separator.weight.prompt"))
188194
{
@@ -200,10 +206,17 @@ openParenToken is null
200206
node.Weight = ParseNumber();
201207
}
202208
// We're supposed to check `punctuation.definition.array.end.prompt` here, textmate is not parsing it
203-
// separately with current tmLanguage grammar, so use `meta.structure.weight.prompt` for now
209+
// separately always with current tmLanguage grammar, so ALSO use `meta.structure.weight.prompt` for now
204210
// We check this AFTER `punctuation.separator.weight.prompt` to avoid consuming the ':'
205-
else if (nextToken.Scopes.Contains("meta.structure.weight.prompt"))
211+
else if (
212+
nextToken.Scopes.Contains("punctuation.definition.array.end.prompt")
213+
|| nextToken.Scopes.Contains("meta.structure.weight.prompt")
214+
)
206215
{
216+
// Verify contents
217+
if (GetTextSubstring(nextToken) != ")")
218+
throw new InvalidOperationException("Expected closing parenthesis.");
219+
207220
ConsumeToken(); // Consume the ')'
208221
node.EndIndex = nextToken.EndIndex; // Set end index
209222
break;
@@ -287,7 +300,7 @@ endNetworkToken is null
287300
NetworkType = type,
288301
ModelName = name,
289302
ModelWeight = modelWeight,
290-
ClipWeight = clipWeight
303+
ClipWeight = clipWeight,
291304
};
292305
}
293306

@@ -299,7 +312,7 @@ private ArrayNode ParseArray()
299312

300313
var node = new ArrayNode
301314
{
302-
Span = new TextSpan(openBracket.StartIndex, 0) // Set start index
315+
Span = new TextSpan(openBracket.StartIndex, 0), // Set start index
303316
};
304317

305318
while (MoreTokens())
@@ -334,7 +347,7 @@ openBraceToken is null
334347

335348
var node = new WildcardNode
336349
{
337-
Span = new TextSpan(openBraceToken.StartIndex, 0) // Set start index
350+
Span = new TextSpan(openBraceToken.StartIndex, 0), // Set start index
338351
};
339352

340353
while (MoreTokens())
@@ -378,7 +391,7 @@ openBraceToken is null
378391
StartIndex = result.StartIndex,
379392
EndIndex = sourceText.Length,
380393
Length = sourceText.Length - result.StartIndex,
381-
Scopes = result.Scopes
394+
Scopes = result.Scopes,
382395
};
383396
}
384397
}
@@ -405,7 +418,7 @@ private IToken ConsumeToken()
405418
StartIndex = result.StartIndex,
406419
EndIndex = sourceText.Length,
407420
Length = sourceText.Length - result.StartIndex,
408-
Scopes = result.Scopes
421+
Scopes = result.Scopes,
409422
};
410423
}
411424
}

0 commit comments

Comments
 (0)