@@ -27,7 +27,6 @@ import (
2727 "slices"
2828 "strings"
2929
30- "buf.build/go/standard/xslices"
3130 "github.com/bufbuild/protocompile/experimental/ast"
3231 "github.com/bufbuild/protocompile/experimental/ast/predeclared"
3332 "github.com/bufbuild/protocompile/experimental/ir"
@@ -124,22 +123,38 @@ func (s *symbol) Definition() protocol.Location {
124123 }
125124}
126125
127- // References returns the locations of references to the symbol, if applicable. Otherwise,
128- // it just returns the location of the symbol itself.
126+ // References returns the locations of references to the symbol (including the definition), if
127+ // applicable. Otherwise, it just returns the location of the symbol itself.
129128func (s * symbol ) References () []protocol.Location {
130- referenceable , ok := s .kind .(* referenceable )
131- if ! ok {
132- return []protocol.Location {{
129+ var references []protocol.Location
130+ referenceableKind , ok := s .kind .(* referenceable )
131+ if ! ok && s .def != nil {
132+ // If the symbol isn't referenceable itself, but has a referenceable definition, use the
133+ // definition for the references.
134+ referenceableKind , ok = s .def .kind .(* referenceable )
135+ }
136+ if ok {
137+ for _ , reference := range referenceableKind .references {
138+ references = append (references , protocol.Location {
139+ URI : reference .file .uri ,
140+ Range : reference .Range (),
141+ })
142+ }
143+ } else {
144+ // No referenceable kind; add the location of the symbol itself.
145+ references = append (references , protocol.Location {
133146 URI : s .file .uri ,
134147 Range : s .Range (),
135- }}
148+ })
136149 }
137- return xslices .Map (referenceable .references , func (sym * symbol ) protocol.Location {
138- return protocol.Location {
139- URI : sym .file .uri ,
140- Range : sym .Range (),
141- }
142- })
150+ // Add the definition of the symbol to the list of references.
151+ if s .def != nil {
152+ references = append (references , protocol.Location {
153+ URI : s .def .file .uri ,
154+ Range : s .def .Range (),
155+ })
156+ }
157+ return references
143158}
144159
145160// LogValue provides the log value for a symbol.
0 commit comments