Skip to content

Commit ab2d73b

Browse files
Provide all references for LSP References (#4105)
1 parent 55a3adf commit ab2d73b

File tree

1 file changed

+28
-13
lines changed

1 file changed

+28
-13
lines changed

private/buf/buflsp/symbol.go

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ import (
2727
"slices"
2828
"strings"
2929

30-
"buf.build/go/standard/xslices"
3130
"github.com/bufbuild/protocompile/experimental/ast"
3231
"github.com/bufbuild/protocompile/experimental/ast/predeclared"
3332
"github.com/bufbuild/protocompile/experimental/ir"
@@ -124,22 +123,38 @@ func (s *symbol) Definition() protocol.Location {
124123
}
125124
}
126125

127-
// References returns the locations of references to the symbol, if applicable. Otherwise,
128-
// it just returns the location of the symbol itself.
126+
// References returns the locations of references to the symbol (including the definition), if
127+
// applicable. Otherwise, it just returns the location of the symbol itself.
129128
func (s *symbol) References() []protocol.Location {
130-
referenceable, ok := s.kind.(*referenceable)
131-
if !ok {
132-
return []protocol.Location{{
129+
var references []protocol.Location
130+
referenceableKind, ok := s.kind.(*referenceable)
131+
if !ok && s.def != nil {
132+
// If the symbol isn't referenceable itself, but has a referenceable definition, use the
133+
// definition for the references.
134+
referenceableKind, ok = s.def.kind.(*referenceable)
135+
}
136+
if ok {
137+
for _, reference := range referenceableKind.references {
138+
references = append(references, protocol.Location{
139+
URI: reference.file.uri,
140+
Range: reference.Range(),
141+
})
142+
}
143+
} else {
144+
// No referenceable kind; add the location of the symbol itself.
145+
references = append(references, protocol.Location{
133146
URI: s.file.uri,
134147
Range: s.Range(),
135-
}}
148+
})
136149
}
137-
return xslices.Map(referenceable.references, func(sym *symbol) protocol.Location {
138-
return protocol.Location{
139-
URI: sym.file.uri,
140-
Range: sym.Range(),
141-
}
142-
})
150+
// Add the definition of the symbol to the list of references.
151+
if s.def != nil {
152+
references = append(references, protocol.Location{
153+
URI: s.def.file.uri,
154+
Range: s.def.Range(),
155+
})
156+
}
157+
return references
143158
}
144159

145160
// LogValue provides the log value for a symbol.

0 commit comments

Comments
 (0)