diff --git a/lib/ruby_indexer/lib/ruby_indexer/reference_finder.rb b/lib/ruby_indexer/lib/ruby_indexer/reference_finder.rb index 956b166b24..64a90f3018 100644 --- a/lib/ruby_indexer/lib/ruby_indexer/reference_finder.rb +++ b/lib/ruby_indexer/lib/ruby_indexer/reference_finder.rb @@ -264,13 +264,51 @@ def on_def_node_leave(node) sig { params(node: Prism::CallNode).void } def on_call_node_enter(node) - if @target.is_a?(MethodTarget) && (name = node.name.to_s) == @target.method_name + return unless @target.is_a?(MethodTarget) + + if (name = node.name.to_s) == @target.method_name @references << Reference.new(name, T.must(node.message_loc), declaration: false) + elsif attr_method_references?(node) + @references << Reference.new(@target.method_name, T.must(node.message_loc), declaration: true) end end private + sig { params(node: Prism::CallNode).returns(T::Boolean) } + def attr_method_references?(node) + case node.name + when :attr_reader + attr_reader_references?(unescaped_argument_names(node)) + when :attr_writer + attr_writer_references?(unescaped_argument_names(node)) + when :attr_accessor + attr_accessor_references?(unescaped_argument_names(node)) + else + false + end + end + + sig { params(node: Prism::CallNode).returns(T::Array[String]) } + def unescaped_argument_names(node) + node.arguments.arguments.select { |arg| arg.respond_to?(:unescaped) }.map(&:unescaped) + end + + sig { params(argument_names: T::Array[String]).returns(T::Boolean) } + def attr_reader_references?(argument_names) + argument_names.include?(@target.method_name) + end + + sig { params(argument_names: T::Array[String]).returns(T::Boolean) } + def attr_writer_references?(argument_names) + argument_names.any? { |arg| "#{arg}=" == @target.method_name } + end + + sig { params(argument_names: T::Array[String]).returns(T::Boolean) } + def attr_accessor_references?(argument_names) + argument_names.any? { |arg| "#{arg}=" == @target.method_name || arg == @target.method_name } + end + sig { params(name: String).returns(T::Array[String]) } def actual_nesting(name) nesting = @stack + [name] diff --git a/lib/ruby_indexer/test/reference_finder_test.rb b/lib/ruby_indexer/test/reference_finder_test.rb index 0d4627a8f7..8c3d385b47 100644 --- a/lib/ruby_indexer/test/reference_finder_test.rb +++ b/lib/ruby_indexer/test/reference_finder_test.rb @@ -143,6 +143,152 @@ def baz assert_equal(9, refs[1].location.start_line) end + def test_matches_attr_writer_with_call_node_argument + refs = find_method_references("foo=", <<~RUBY) + class Bar + attr_reader :foo, bar + + def baz + self.foo = 1 + self.foo + end + end + RUBY + + assert_equal(1, refs.size) + + assert_equal("foo=", refs[0].name) + assert_equal(5, refs[0].location.start_line) + end + + def test_matches_attr_writer + refs = find_method_references("foo=", <<~RUBY) + class Bar + def foo + end + + attr_writer :foo + + def baz + self.foo = 1 + self.foo + end + end + RUBY + + # We want to match `foo=` but not `foo` + assert_equal(2, refs.size) + + assert_equal("foo=", refs[0].name) + assert_equal(5, refs[0].location.start_line) + + assert_equal("foo=", refs[1].name) + assert_equal(8, refs[1].location.start_line) + end + + def test_matches_attr_reader + refs = find_method_references("foo", <<~RUBY) + class Bar + def foo=(value) + end + + attr_reader :foo + + def baz + self.foo = 1 + self.foo + end + end + RUBY + + # We want to match `foo=` but not `foo` + assert_equal(2, refs.size) + + assert_equal("foo", refs[0].name) + assert_equal(5, refs[0].location.start_line) + + assert_equal("foo", refs[1].name) + assert_equal(9, refs[1].location.start_line) + end + + def test_matches_attr_accessor + refs = find_method_references("foo=", <<~RUBY) + class Bar + attr_accessor :foo + + def baz + self.foo = 1 + self.foo + end + end + RUBY + + # We want to match `foo=` but not `foo` + assert_equal(2, refs.size) + + assert_equal("foo=", refs[0].name) + assert_equal(2, refs[0].location.start_line) + + assert_equal("foo=", refs[1].name) + assert_equal(5, refs[1].location.start_line) + + refs = find_method_references("foo", <<~RUBY) + class Bar + attr_accessor :foo + + def baz + self.foo = 1 + self.foo + end + end + RUBY + + assert_equal("foo", refs[0].name) + assert_equal(2, refs[0].location.start_line) + + assert_equal("foo", refs[1].name) + assert_equal(6, refs[1].location.start_line) + end + + def test_matches_attr_accessor_multi + refs = find_method_references("foo=", <<~RUBY) + class Bar + attr_accessor :bar, :foo + + def baz + self.foo = 1 + self.foo + end + end + RUBY + + # We want to match `foo=` but not `foo` + assert_equal(2, refs.size) + + assert_equal("foo=", refs[0].name) + assert_equal(2, refs[0].location.start_line) + + assert_equal("foo=", refs[1].name) + assert_equal(5, refs[1].location.start_line) + + refs = find_method_references("foo", <<~RUBY) + class Bar + attr_accessor :bar, :foo + + def baz + self.foo = 1 + self.foo + end + end + RUBY + + assert_equal("foo", refs[0].name) + assert_equal(2, refs[0].location.start_line) + + assert_equal("foo", refs[1].name) + assert_equal(6, refs[1].location.start_line) + end + def test_find_inherited_methods refs = find_method_references("foo", <<~RUBY) class Bar