Skip to content

Commit 713e13b

Browse files
committed
Improve accuracy of source matching using line numbers
1 parent e7c276d commit 713e13b

File tree

5 files changed

+173
-18
lines changed

5 files changed

+173
-18
lines changed

src/main/kotlin/platform/mixin/handlers/injectionPoint/AtResolver.kt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,13 +228,16 @@ class AtResolver(
228228
canDecompile = true,
229229
) ?: return emptyList()
230230
val targetPsiClass = targetElement.parentOfType<PsiClass>() ?: return emptyList()
231+
val targetPsiFile = targetPsiClass.containingFile ?: return emptyList()
231232

232233
val navigationVisitor = injectionPoint.createNavigationVisitor(at, target, targetPsiClass) ?: return emptyList()
233234
navigationVisitor.configureBytecodeTarget(targetClass, targetMethod)
234235
targetElement.accept(navigationVisitor)
235236

236237
return bytecodeResults.mapNotNull { bytecodeResult ->
237-
navigationVisitor.result.getOrNull(bytecodeResult.index)
238+
val matcher = bytecodeResult.sourceLocationInfo.createMatcher<PsiElement>(targetPsiFile)
239+
navigationVisitor.result.forEach(matcher::accept)
240+
matcher.result
238241
}
239242
}
240243

src/main/kotlin/platform/mixin/handlers/injectionPoint/InjectionPoint.kt

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ package com.demonwav.mcdev.platform.mixin.handlers.injectionPoint
2222

2323
import com.demonwav.mcdev.platform.mixin.reference.MixinSelector
2424
import com.demonwav.mcdev.platform.mixin.reference.toMixinString
25+
import com.demonwav.mcdev.platform.mixin.util.SourceCodeLocationInfo
2526
import com.demonwav.mcdev.platform.mixin.util.fakeResolve
2627
import com.demonwav.mcdev.platform.mixin.util.findOrConstructSourceMethod
2728
import com.demonwav.mcdev.util.constantStringValue
@@ -60,6 +61,7 @@ import com.intellij.util.xmlb.annotations.Attribute
6061
import org.objectweb.asm.tree.AbstractInsnNode
6162
import org.objectweb.asm.tree.ClassNode
6263
import org.objectweb.asm.tree.InsnList
64+
import org.objectweb.asm.tree.LineNumberNode
6365
import org.objectweb.asm.tree.MethodInsnNode
6466
import org.objectweb.asm.tree.MethodNode
6567

@@ -385,6 +387,7 @@ abstract class CollectVisitor<T : PsiElement>(protected val mode: Mode) {
385387

386388
private lateinit var method: MethodNode
387389
private var nextIndex = 0
390+
private val nextIndexByLine = mutableMapOf<Int, Int>()
388391
val result = mutableListOf<Result<T>>()
389392
private val resultFilters = mutableListOf<Pair<String, CollectResultFilter<T>>>()
390393
var filterToBlame: String? = null
@@ -416,14 +419,18 @@ abstract class CollectVisitor<T : PsiElement>(protected val mode: Mode) {
416419
}
417420
}
418421

422+
val index = nextIndex++
423+
val lineNumber = getLineNumber(insn)
424+
val indexInLineNumber = lineNumber?.let { nextIndexByLine.merge(it, 1, Int::plus)!! - 1 } ?: index
419425
val result = Result(
420-
nextIndex++,
426+
SourceCodeLocationInfo(index, lineNumber, indexInLineNumber),
421427
insn,
422428
shiftedInsn ?: return,
423429
element,
424430
qualifier,
425431
if (insn === shiftedInsn) decorations else emptyMap()
426432
)
433+
427434
var isFiltered = false
428435
for ((name, filter) in resultFilters) {
429436
if (!filter(result, method)) {
@@ -442,6 +449,18 @@ abstract class CollectVisitor<T : PsiElement>(protected val mode: Mode) {
442449
}
443450
}
444451

452+
private fun getLineNumber(insn: AbstractInsnNode): Int? {
453+
var i: AbstractInsnNode? = insn
454+
while (i != null) {
455+
if (i is LineNumberNode) {
456+
return i.line
457+
}
458+
i = i.previous
459+
}
460+
461+
return null
462+
}
463+
445464
@Suppress("MemberVisibilityCanBePrivate")
446465
protected fun stopWalking() {
447466
throw StopWalkingException()
@@ -454,13 +473,15 @@ abstract class CollectVisitor<T : PsiElement>(protected val mode: Mode) {
454473
}
455474

456475
data class Result<T : PsiElement>(
457-
val index: Int,
476+
val sourceLocationInfo: SourceCodeLocationInfo,
458477
val originalInsn: AbstractInsnNode,
459478
val insn: AbstractInsnNode,
460479
val target: T,
461480
val qualifier: String? = null,
462481
val decorations: Map<String, Any?>
463-
)
482+
) {
483+
val index: Int get() = sourceLocationInfo.index
484+
}
464485

465486
enum class Mode { MATCH_ALL, MATCH_FIRST, COMPLETION }
466487
}

src/main/kotlin/platform/mixin/util/AsmUtil.kt

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ import com.demonwav.mcdev.util.findMethods
3030
import com.demonwav.mcdev.util.findModule
3131
import com.demonwav.mcdev.util.findQualifiedClass
3232
import com.demonwav.mcdev.util.fullQualifiedName
33-
import com.demonwav.mcdev.util.hasSyntheticMethod
3433
import com.demonwav.mcdev.util.isErasureEquivalentTo
3534
import com.demonwav.mcdev.util.lockedCached
3635
import com.demonwav.mcdev.util.loggerForTopLevel
@@ -96,6 +95,7 @@ import org.objectweb.asm.tree.FieldNode
9695
import org.objectweb.asm.tree.InsnList
9796
import org.objectweb.asm.tree.InsnNode
9897
import org.objectweb.asm.tree.InvokeDynamicInsnNode
98+
import org.objectweb.asm.tree.LineNumberNode
9999
import org.objectweb.asm.tree.MethodInsnNode
100100
import org.objectweb.asm.tree.MethodNode
101101
import org.objectweb.asm.tree.VarInsnNode
@@ -660,13 +660,18 @@ fun MethodNode.findSuperConstructorCall(): AbstractInsnNode? {
660660
return null
661661
}
662662

663-
private fun findContainingMethod(clazz: ClassNode, lambdaMethod: MethodNode): Pair<MethodNode, Int>? {
663+
private fun findContainingMethod(clazz: ClassNode, lambdaMethod: MethodNode): Pair<MethodNode, SourceCodeLocationInfo>? {
664664
if (!lambdaMethod.hasAccess(Opcodes.ACC_SYNTHETIC)) {
665665
return null
666666
}
667667
clazz.methods?.forEach { method ->
668668
var lambdaCount = 0
669+
var lineNumber: Int? = null
670+
val lambdaCountPerLine = mutableMapOf<Int, Int>()
669671
method.instructions?.iterator()?.forEach nextInsn@{ insn ->
672+
if (insn is LineNumberNode) {
673+
lineNumber = insn.line
674+
}
670675
if (insn !is InvokeDynamicInsnNode) return@nextInsn
671676
if (insn.bsm.owner != "java/lang/invoke/LambdaMetafactory") return@nextInsn
672677
val invokedMethod = when (insn.bsm.name) {
@@ -691,9 +696,13 @@ private fun findContainingMethod(clazz: ClassNode, lambdaMethod: MethodNode): Pa
691696
}
692697

693698
lambdaCount++
699+
val lambdaCountThisLine =
700+
lineNumber?.let { lambdaCountPerLine.merge(it, 1, Int::plus) } ?: lambdaCount
694701

695702
if (invokedMethod.name == lambdaMethod.name && invokedMethod.desc == lambdaMethod.desc) {
696-
return@findContainingMethod method to (lambdaCount - 1)
703+
val locationInfo =
704+
SourceCodeLocationInfo(lambdaCount - 1, lineNumber, lambdaCountThisLine - 1)
705+
return@findContainingMethod method to locationInfo
697706
}
698707
}
699708
}
@@ -704,12 +713,13 @@ private fun findContainingMethod(clazz: ClassNode, lambdaMethod: MethodNode): Pa
704713
private fun findAssociatedLambda(psiClass: PsiClass, clazz: ClassNode, lambdaMethod: MethodNode): PsiElement? {
705714
return RecursionManager.doPreventingRecursion(lambdaMethod, false) {
706715
val pair = findContainingMethod(clazz, lambdaMethod) ?: return@doPreventingRecursion null
707-
val (containingMethod, index) = pair
716+
val (containingMethod, locationInfo) = pair
708717
val parent = findAssociatedLambda(psiClass, clazz, containingMethod)
709718
?: psiClass.findMethods(containingMethod.memberReference).firstOrNull()
710719
?: return@doPreventingRecursion null
711-
var i = 0
712-
var result: PsiElement? = null
720+
721+
val psiFile = psiClass.containingFile ?: return@doPreventingRecursion null
722+
val matcher = locationInfo.createMatcher<PsiElement>(psiFile)
713723
parent.accept(
714724
object : JavaRecursiveElementWalkingVisitor() {
715725
override fun visitAnonymousClass(aClass: PsiAnonymousClass) {
@@ -721,8 +731,7 @@ private fun findAssociatedLambda(psiClass: PsiClass, clazz: ClassNode, lambdaMet
721731
}
722732

723733
override fun visitLambdaExpression(expression: PsiLambdaExpression) {
724-
if (i++ == index) {
725-
result = expression
734+
if (matcher.accept(expression)) {
726735
stopWalking()
727736
}
728737
// skip walking inside the lambda
@@ -732,16 +741,14 @@ private fun findAssociatedLambda(psiClass: PsiClass, clazz: ClassNode, lambdaMet
732741
// walk inside the reference first, visits the qualifier first (it's first in the bytecode)
733742
super.visitMethodReferenceExpression(expression)
734743

735-
if (expression.hasSyntheticMethod) {
736-
if (i++ == index) {
737-
result = expression
738-
stopWalking()
739-
}
744+
if (matcher.accept(expression)) {
745+
stopWalking()
740746
}
741747
}
742748
},
743749
)
744-
result
750+
751+
matcher.result
745752
}
746753
}
747754

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
/*
2+
* Minecraft Development for IntelliJ
3+
*
4+
* https://mcdev.io/
5+
*
6+
* Copyright (C) 2025 minecraft-dev
7+
*
8+
* This program is free software: you can redistribute it and/or modify
9+
* it under the terms of the GNU Lesser General Public License as published
10+
* by the Free Software Foundation, version 3.0 only.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU Lesser General Public License
18+
* along with this program. If not, see <https://www.gnu.org/licenses/>.
19+
*/
20+
21+
package com.demonwav.mcdev.platform.mixin.util
22+
23+
import com.demonwav.mcdev.util.findDocument
24+
import com.demonwav.mcdev.util.lineNumber
25+
import com.demonwav.mcdev.util.remapLineNumber
26+
import com.intellij.psi.PsiElement
27+
import com.intellij.psi.PsiFile
28+
29+
/**
30+
* Info returned from the bytecode to help locate an element from the source code.
31+
*
32+
* For example, if searching for a lambda in the source code, this contains the index of the lambda in the method (i.e.
33+
* how many lambdas were before this one), the starting line number of the lambda (or `null` if there were no line
34+
* numbers), and the index of the lambda within this line number (i.e. how many lambdas there were before this one in
35+
* the same line).
36+
*
37+
* The line number stored is unmapped, and may need remapping via [remapLineNumber].
38+
* [createMatcher] does this internally.
39+
*/
40+
class SourceCodeLocationInfo(val index: Int, val lineNumber: Int?, val indexInLineNumber: Int) {
41+
interface Matcher<T: PsiElement> {
42+
fun accept(t: T): Boolean
43+
44+
val result: T?
45+
}
46+
47+
fun <T: PsiElement> createMatcher(psiFile: PsiFile): Matcher<T> {
48+
val lineNumber = this.lineNumber?.let(psiFile::remapLineNumber)
49+
val document = psiFile.findDocument()
50+
51+
return object : Matcher<T> {
52+
private var count = 0
53+
private var currentLine: Int? = null
54+
private var countThisLine = 0
55+
private var myResult: T? = null
56+
57+
override fun accept(t: T): Boolean {
58+
val line = document?.let(t::lineNumber)
59+
if (line != null) {
60+
if (line != currentLine) {
61+
countThisLine = 0
62+
currentLine = line
63+
}
64+
65+
countThisLine++
66+
if (line == lineNumber && countThisLine == indexInLineNumber + 1) {
67+
myResult = t
68+
return true
69+
}
70+
}
71+
72+
if (count++ == index) {
73+
myResult = t
74+
if (lineNumber == null) {
75+
return true
76+
}
77+
}
78+
79+
return false
80+
}
81+
82+
override val result get() = myResult
83+
}
84+
}
85+
}

src/main/kotlin/util/psi-utils.kt

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,10 @@ import com.demonwav.mcdev.facet.MinecraftFacet
2424
import com.demonwav.mcdev.platform.mcp.McpModule
2525
import com.demonwav.mcdev.platform.mcp.McpModuleType
2626
import com.intellij.codeInsight.lookup.LookupElementBuilder
27+
import com.intellij.debugger.impl.DebuggerUtilsEx
28+
import com.intellij.ide.highlighter.JavaClassFileType
2729
import com.intellij.lang.injection.InjectedLanguageManager
30+
import com.intellij.openapi.editor.Document
2831
import com.intellij.openapi.module.Module
2932
import com.intellij.openapi.module.ModuleManager
3033
import com.intellij.openapi.module.ModuleUtilCore
@@ -41,6 +44,7 @@ import com.intellij.psi.JavaPsiFacade
4144
import com.intellij.psi.PsiAnnotation
4245
import com.intellij.psi.PsiClass
4346
import com.intellij.psi.PsiDirectory
47+
import com.intellij.psi.PsiDocumentManager
4448
import com.intellij.psi.PsiElement
4549
import com.intellij.psi.PsiElementFactory
4650
import com.intellij.psi.PsiElementResolveResult
@@ -222,6 +226,41 @@ fun isAccessModifier(@ModifierConstant modifier: String): Boolean {
222226
return modifier in ACCESS_MODIFIERS
223227
}
224228

229+
fun PsiElement.findDocument(containingFile: PsiFile = this.containingFile): Document? {
230+
return containingFile.viewProvider.document ?: PsiDocumentManager.getInstance(project).getDocument(containingFile)
231+
}
232+
233+
/**
234+
* Remaps line numbers if the file is decompiled. Line numbers are 1-indexed
235+
*/
236+
fun PsiFile.remapLineNumber(lineNumber: Int): Int {
237+
val originalFile = this.originalFile
238+
if (originalFile.virtualFile?.fileType != JavaClassFileType.INSTANCE) {
239+
// not decompiled
240+
return lineNumber
241+
}
242+
243+
val mappedLineNumber = DebuggerUtilsEx.bytecodeToSourceLine(originalFile, lineNumber - 1)
244+
if (mappedLineNumber < 0) {
245+
return lineNumber
246+
}
247+
248+
return mappedLineNumber + 1
249+
}
250+
251+
/**
252+
* Returns the line number of the start of this `PsiElement`'s text range, with line numbers starting at 1
253+
*/
254+
fun PsiElement.lineNumber(): Int? = findDocument()?.let(this::lineNumber)
255+
256+
fun PsiElement.lineNumber(document: Document): Int? {
257+
val index = this.textRange.startOffset
258+
if (index > document.textLength) {
259+
return null
260+
}
261+
return document.getLineNumber(index) + 1
262+
}
263+
225264
infix fun PsiElement.equivalentTo(other: PsiElement?): Boolean {
226265
return manager.areElementsEquivalent(this, other)
227266
}

0 commit comments

Comments
 (0)