diff --git a/lib/ruby_indexer/lib/ruby_indexer/reference_finder.rb b/lib/ruby_indexer/lib/ruby_indexer/reference_finder.rb index e521a4d07..2469454b8 100644 --- a/lib/ruby_indexer/lib/ruby_indexer/reference_finder.rb +++ b/lib/ruby_indexer/lib/ruby_indexer/reference_finder.rb @@ -282,17 +282,53 @@ def on_instance_variable_target_node_enter(node) #: (Prism::CallNode node) -> void def on_call_node_enter(node) - if @target.is_a?(MethodTarget) && (name = node.name.to_s) == @target.method_name - @references << Reference.new( - name, - node.message_loc, #: as !nil - declaration: false, - ) + return unless @target.is_a?(MethodTarget) + + if (name = node.name.to_s) == @target.method_name + @references << Reference.new(name, node.message_loc, declaration: false) + elsif attr_method_references?(node) + @references << Reference.new(@target.method_name, node.message_loc, declaration: true) end end private + #: (Prism::CallNode node) -> bool + def attr_method_references?(node) + case node.name + when :attr_reader + attr_reader_references?(unescaped_argument_names(node)) + when :attr_writer + attr_writer_references?(unescaped_argument_names(node)) + when :attr_accessor + attr_accessor_references?(unescaped_argument_names(node)) + else + false + end + end + + #: (Prism::CallNode node) -> Array[String] + def unescaped_argument_names(node) + return [] if node.arguments.nil? + + node.arguments.arguments.select { |arg| arg.respond_to?(:unescaped) }.map(&:unescaped) + end + + #: (Array[String] argument_names) -> bool + def attr_reader_references?(argument_names) + argument_names.include?(@target.method_name) + end + + #: (Array[String] argument_names) -> bool + def attr_writer_references?(argument_names) + argument_names.any? { |arg| "#{arg}=" == @target.method_name } + end + + #: (Array[String] argument_names) -> bool + def attr_accessor_references?(argument_names) + argument_names.any? { |arg| ["#{arg}=", arg].include?(@target.method_name) } + end + #: (String name, Prism::Location location) -> void def collect_constant_references(name, location) return unless @target.is_a?(ConstTarget) diff --git a/lib/ruby_indexer/test/reference_finder_test.rb b/lib/ruby_indexer/test/reference_finder_test.rb index ed5028d5a..35637cbc0 100644 --- a/lib/ruby_indexer/test/reference_finder_test.rb +++ b/lib/ruby_indexer/test/reference_finder_test.rb @@ -143,6 +143,218 @@ def baz assert_equal(9, refs[1].location.start_line) end + def test_matches_attr_writer_with_call_node_argument + refs = find_method_references("foo=", <<~RUBY) + class Bar + attr_writer :foo, bar + + def baz + self.foo = 1 + self.foo + end + end + RUBY + + assert_equal(2, refs.size) + + assert_equal("foo=", refs[0].name) + assert_equal(2, refs[0].location.start_line) + assert_equal(true, refs[0].declaration) + + assert_equal("foo=", refs[1].name) + assert_equal(5, refs[1].location.start_line) + assert_equal(false, refs[1].declaration) + end + + def test_matches_attr_writer + refs = find_method_references("foo=", <<~RUBY) + class Bar + def foo + end + + attr_writer :foo + + def baz + self.foo = 1 + self.foo + end + end + RUBY + + # We want to match `foo=` but not `foo` + assert_equal(2, refs.size) + + assert_equal("foo=", refs[0].name) + assert_equal(5, refs[0].location.start_line) + assert_equal(true, refs[0].declaration) + + assert_equal("foo=", refs[1].name) + assert_equal(8, refs[1].location.start_line) + assert_equal(false, refs[1].declaration) + end + + def test_matches_attr_reader + refs = find_method_references("foo", <<~RUBY) + class Bar + def foo=(value) + end + + attr_reader :foo + + def baz + self.foo = 1 + self.foo + end + end + RUBY + + # We want to match `foo=` but not `foo` + assert_equal(2, refs.size) + + assert_equal("foo", refs[0].name) + assert_equal(5, refs[0].location.start_line) + assert_equal(true, refs[0].declaration) + + assert_equal("foo", refs[1].name) + assert_equal(9, refs[1].location.start_line) + assert_equal(false, refs[1].declaration) + end + + def test_matches_attr_accessor + refs = find_method_references("foo=", <<~RUBY) + class Bar + attr_accessor :foo + + def baz + self.foo = 1 + self.foo + end + end + RUBY + + # We want to match `foo=` but not `foo` + assert_equal(2, refs.size) + + assert_equal("foo=", refs[0].name) + assert_equal(2, refs[0].location.start_line) + assert_equal(true, refs[0].declaration) + + assert_equal("foo=", refs[1].name) + assert_equal(5, refs[1].location.start_line) + assert_equal(false, refs[1].declaration) + + refs = find_method_references("foo", <<~RUBY) + class Bar + attr_accessor :foo + + def baz + self.foo = 1 + self.foo + end + end + RUBY + + assert_equal("foo", refs[0].name) + assert_equal(2, refs[0].location.start_line) + assert_equal(true, refs[0].declaration) + + assert_equal("foo", refs[1].name) + assert_equal(6, refs[1].location.start_line) + assert_equal(false, refs[1].declaration) + end + + def test_matches_attr_accessor_multi + refs = find_method_references("foo=", <<~RUBY) + class Bar + attr_accessor :bar, :foo + + def baz + self.foo = 1 + self.foo + end + end + RUBY + + # We want to match `foo=` but not `foo` + assert_equal(2, refs.size) + + assert_equal("foo=", refs[0].name) + assert_equal(2, refs[0].location.start_line) + assert_equal(true, refs[0].declaration) + + assert_equal("foo=", refs[1].name) + assert_equal(5, refs[1].location.start_line) + assert_equal(false, refs[1].declaration) + + refs = find_method_references("foo", <<~RUBY) + class Bar + attr_accessor :bar, :foo + + def baz + self.foo = 1 + self.foo + end + end + RUBY + + assert_equal("foo", refs[0].name) + assert_equal(2, refs[0].location.start_line) + assert_equal(true, refs[0].declaration) + + assert_equal("foo", refs[1].name) + assert_equal(6, refs[1].location.start_line) + assert_equal(false, refs[1].declaration) + end + + def test_matches_attr_emtpy + ruby_code = <<~RUBY + class Bar + def foo=(value) + end + + attr_reader + + def baz + foo + end + end + RUBY + refs = find_method_references("foo", ruby_code) + assert_equal(1, refs.size) + assert_equal(8, refs[0].location.start_line) + assert_equal(false, refs[0].declaration) + refs = find_method_references("foo=", ruby_code) + assert_equal(1, refs.size) + assert_equal(2, refs[0].location.start_line) + assert_equal(true, refs[0].declaration) + refs = find_method_references("baz", ruby_code) + assert_equal(1, refs.size) + assert_equal(7, refs[0].location.start_line) + assert_equal(true, refs[0].declaration) + end + + def test_matches_attr_string + ruby_code = <<~RUBY + class Bar + def foo=(value) + end + + attr_reader 'foo' + + def baz + foo + end + end + RUBY + + refs = find_method_references("foo", ruby_code) + assert_equal(2, refs.size) + assert_equal(5, refs[0].location.start_line) + assert_equal(true, refs[0].declaration) + assert_equal(8, refs[1].location.start_line) + assert_equal(false, refs[1].declaration) + end + def test_find_inherited_methods refs = find_method_references("foo", <<~RUBY) class Bar