diff --git a/activerecord/lib/active_record/associations.rb b/activerecord/lib/active_record/associations.rb index fd9af301f7a6f..a8bb4e17cfede 100644 --- a/activerecord/lib/active_record/associations.rb +++ b/activerecord/lib/active_record/associations.rb @@ -324,6 +324,11 @@ def association(name) # :nodoc: association end + def load_async(*associations) # TODO: doc + associations.map { |name| association(name) }.each(&:async_load_target) + self + end + def association_cached?(name) # :nodoc: @association_cache.key?(name) end diff --git a/activerecord/lib/active_record/associations/association.rb b/activerecord/lib/active_record/associations/association.rb index d4ca692cbf235..19299e12c3e5b 100644 --- a/activerecord/lib/active_record/associations/association.rb +++ b/activerecord/lib/active_record/associations/association.rb @@ -34,7 +34,7 @@ module Associations # the reflection object represents a :has_many macro. class Association # :nodoc: attr_accessor :owner - attr_reader :target, :reflection, :disable_joins + attr_reader :reflection, :disable_joins delegate :options, to: :reflection @@ -50,6 +50,13 @@ def initialize(owner, reflection) @skip_strict_loading = nil end + def target + if @target&.is_a?(Promise) + @target = @target.value + end + @target + end + # Resets the \loaded flag to +false+ and sets the \target to +nil+. def reset @loaded = false @@ -172,7 +179,7 @@ def extensions # ActiveRecord::RecordNotFound is rescued within the method, and it is # not reraised. The proxy is \reset and +nil+ is the return value. def load_target - @target = find_target if (@stale_state && stale_target?) || find_target? + @target = find_target(async: false) if (@stale_state && stale_target?) || find_target? loaded! unless loaded? target @@ -180,6 +187,13 @@ def load_target reset end + def async_load_target + @target = find_target(async: true) if (@stale_state && stale_target?) || find_target? + + loaded! unless loaded? + nil + end + # We can't dump @reflection and @through_reflection since it contains the scope proc def marshal_dump ivars = (instance_variables - [:@reflection, :@through_reflection]).map { |name| [name, instance_variable_get(name)] } @@ -217,13 +231,19 @@ def ensure_klass_exists! klass end - def find_target + def find_target(async: false) if violates_strict_loading? Base.strict_loading_violation!(owner: owner.class, reflection: reflection) end scope = self.scope - return scope.to_a if skip_statement_cache?(scope) + if skip_statement_cache?(scope) + if async + return scope.load_async.then(&:to_a) + else + return scope.to_a + end + end sc = reflection.association_scope_cache(klass, owner) do |params| as = AssociationScope.create { params.bind } @@ -232,7 +252,7 @@ def find_target binds = AssociationScope.get_bind_values(owner, reflection.chain) klass.with_connection do |c| - sc.execute(binds, c) do |record| + sc.execute(binds, c, async: async) do |record| set_inverse_instance(record) if owner.strict_loading_n_plus_one_only? && reflection.macro == :has_many record.strict_loading! diff --git a/activerecord/lib/active_record/associations/has_many_through_association.rb b/activerecord/lib/active_record/associations/has_many_through_association.rb index 845e5e564266a..28335ccd824dc 100644 --- a/activerecord/lib/active_record/associations/has_many_through_association.rb +++ b/activerecord/lib/active_record/associations/has_many_through_association.rb @@ -216,7 +216,8 @@ def delete_through_records(records) end end - def find_target + def find_target(async: false) + raise NotImplementedError if async return [] unless target_reflection_has_associated_record? return scope.to_a if disable_joins super diff --git a/activerecord/lib/active_record/associations/singular_association.rb b/activerecord/lib/active_record/associations/singular_association.rb index f89936d0d06ea..59bb0128e70aa 100644 --- a/activerecord/lib/active_record/associations/singular_association.rb +++ b/activerecord/lib/active_record/associations/singular_association.rb @@ -18,6 +18,7 @@ def reader def reset super @target = nil + @future_target = nil end # Implements the writer method, e.g. foo.bar= for Foo.belongs_to :bar @@ -43,11 +44,12 @@ def scope_for_create super.except!(*Array(klass.primary_key)) end - def find_target + def find_target(async: false) if disable_joins + raise NotImplementedError if async scope.first else - super.first + super.then(&:first) end end diff --git a/activerecord/lib/active_record/core.rb b/activerecord/lib/active_record/core.rb index 7510cea6aa9c2..956e5a185f049 100644 --- a/activerecord/lib/active_record/core.rb +++ b/activerecord/lib/active_record/core.rb @@ -431,7 +431,7 @@ def cached_find_by(keys, values) } begin - statement.execute(values.flatten, lease_connection, allow_retry: true).first + statement.execute(values.flatten, lease_connection, allow_retry: true).then(&:first) rescue TypeError raise ActiveRecord::StatementInvalid end diff --git a/activerecord/lib/active_record/querying.rb b/activerecord/lib/active_record/querying.rb index 27d0603f3a9d9..e6d85195b891c 100644 --- a/activerecord/lib/active_record/querying.rb +++ b/activerecord/lib/active_record/querying.rb @@ -52,8 +52,8 @@ def find_by_sql(sql, binds = [], preparable: nil, allow_retry: false, &block) end # Same as #find_by_sql but perform the query asynchronously and returns an ActiveRecord::Promise. - def async_find_by_sql(sql, binds = [], preparable: nil, &block) - _query_by_sql(sql, binds, preparable: preparable, async: true).then do |result| + def async_find_by_sql(sql, binds = [], preparable: nil, allow_retry: false, &block) + _query_by_sql(sql, binds, preparable: preparable, allow_retry: allow_retry, async: true).then do |result| _load_from_sql(result, &block) end end diff --git a/activerecord/lib/active_record/relation.rb b/activerecord/lib/active_record/relation.rb index 229be6c173492..4458707ef1bab 100644 --- a/activerecord/lib/active_record/relation.rb +++ b/activerecord/lib/active_record/relation.rb @@ -810,6 +810,16 @@ def load_async self end + def then(&block) + if @future_result + @future_result.then do + yield self + end + else + super + end + end + # Returns true if the relation was scheduled on the background # thread pool. def scheduled? diff --git a/activerecord/lib/active_record/statement_cache.rb b/activerecord/lib/active_record/statement_cache.rb index 411a073a72c96..1c428b4870ca7 100644 --- a/activerecord/lib/active_record/statement_cache.rb +++ b/activerecord/lib/active_record/statement_cache.rb @@ -142,14 +142,18 @@ def initialize(query_builder, bind_map, klass) @klass = klass end - def execute(params, connection, allow_retry: false, &block) + def execute(params, connection, allow_retry: false, async: false, &block) bind_values = bind_map.bind params sql = query_builder.sql_for bind_values, connection - klass.find_by_sql(sql, bind_values, preparable: true, allow_retry: allow_retry, &block) + if async + klass.async_find_by_sql(sql, bind_values, preparable: true, allow_retry: allow_retry, &block) + else + klass.find_by_sql(sql, bind_values, preparable: true, allow_retry: allow_retry, &block) + end rescue ::RangeError - [] + async ? Promise.wrap([]) : [] end def self.unsupported_value?(value) diff --git a/activerecord/test/cases/associations/belongs_to_associations_test.rb b/activerecord/test/cases/associations/belongs_to_associations_test.rb index 86cc88c884d6c..bdcfef01989e7 100644 --- a/activerecord/test/cases/associations/belongs_to_associations_test.rb +++ b/activerecord/test/cases/associations/belongs_to_associations_test.rb @@ -1841,3 +1841,35 @@ def test_destroy_linked_models assert_not Author.exists?(author.id) end end + +class AsyncBelongsToAssociationsTest < ActiveRecord::TestCase + include WaitForAsyncTestHelper + + fixtures :companies + + self.use_transactional_tests = false + + def test_async_load_belongs_to + client = Client.find(3) + first_firm = companies(:first_firm) + + promise = client.load_async(:firm) + wait_for_async_query + + events = [] + callback = -> (event) do + events << event unless event.payload[:name] == "SCHEMA" + end + ActiveSupport::Notifications.subscribed(callback, "sql.active_record") do + client.firm + end + + assert_no_queries do + assert_equal first_firm, client.firm + assert_equal first_firm.name, client.firm.name + end + + assert_equal 1, events.size + assert_equal true, events.first.payload[:async] + end +end diff --git a/activerecord/test/cases/associations/has_many_associations_test.rb b/activerecord/test/cases/associations/has_many_associations_test.rb index 663749cbd288f..df5367aca6877 100644 --- a/activerecord/test/cases/associations/has_many_associations_test.rb +++ b/activerecord/test/cases/associations/has_many_associations_test.rb @@ -3237,3 +3237,34 @@ def force_signal37_to_load_all_clients_of_firm companies(:first_firm).clients_of_firm.load_target end end + +class AsyncHasOneAssociationsTest < ActiveRecord::TestCase + include WaitForAsyncTestHelper + + fixtures :companies + + self.use_transactional_tests = false + + def test_async_load_has_many + firm = companies(:first_firm) + + promise = firm.load_async(:clients) + wait_for_async_query + + events = [] + callback = -> (event) do + events << event unless event.payload[:name] == "SCHEMA" + end + + ActiveSupport::Notifications.subscribed(callback, "sql.active_record") do + assert_equal 3, firm.clients.size + end + + assert_no_queries do + assert_not_nil firm.clients[2] + end + + assert_equal 1, events.size + assert_equal true, events.first.payload[:async] + end +end diff --git a/activerecord/test/cases/associations/has_one_associations_test.rb b/activerecord/test/cases/associations/has_one_associations_test.rb index fa8317fadeaa1..25d3c6f105674 100644 --- a/activerecord/test/cases/associations/has_one_associations_test.rb +++ b/activerecord/test/cases/associations/has_one_associations_test.rb @@ -941,3 +941,35 @@ def test_has_one_with_touch_option_on_nonpersisted_built_associations_doesnt_upd MESSAGE end end + +class AsyncHasOneAssociationsTest < ActiveRecord::TestCase + include WaitForAsyncTestHelper + + fixtures :companies, :accounts + + self.use_transactional_tests = false + + def test_async_load_has_one + firm = companies(:first_firm) + first_account = Account.find(1) + + promise = firm.load_async(:account) + wait_for_async_query + + events = [] + callback = -> (event) do + events << event unless event.payload[:name] == "SCHEMA" + end + ActiveSupport::Notifications.subscribed(callback, "sql.active_record") do + firm.account + end + + assert_no_queries do + assert_equal first_account, firm.account + assert_equal first_account.credit_limit, firm.account.credit_limit + end + + assert_equal 1, events.size + assert_equal true, events.first.payload[:async] + end +end diff --git a/activerecord/test/cases/helper.rb b/activerecord/test/cases/helper.rb index 1b71cf018b134..8537fea530577 100644 --- a/activerecord/test/cases/helper.rb +++ b/activerecord/test/cases/helper.rb @@ -40,36 +40,53 @@ ActiveRecord::ConnectionAdapters.register("abstract", "ActiveRecord::ConnectionAdapters::AbstractAdapter", "active_record/connection_adapters/abstract_adapter") ActiveRecord::ConnectionAdapters.register("fake", "FakeActiveRecordAdapter", File.expand_path("../support/fake_adapter.rb", __dir__)) -class SQLSubscriber - attr_reader :logged - attr_reader :payloads +class ActiveRecord::TestCase + class SQLSubscriber + attr_reader :logged + attr_reader :payloads + + def initialize + @logged = [] + @payloads = [] + end + + def start(name, id, payload) + @payloads << payload + @logged << [payload[:sql].squish, payload[:name], payload[:binds]] + end - def initialize - @logged = [] - @payloads = [] + def finish(name, id, payload); end end - def start(name, id, payload) - @payloads << payload - @logged << [payload[:sql].squish, payload[:name], payload[:binds]] + module InTimeZone + private + def in_time_zone(zone) + old_zone = Time.zone + old_tz = ActiveRecord::Base.time_zone_aware_attributes + + Time.zone = zone ? ActiveSupport::TimeZone[zone] : nil + ActiveRecord::Base.time_zone_aware_attributes = !zone.nil? + yield + ensure + Time.zone = old_zone + ActiveRecord::Base.time_zone_aware_attributes = old_tz + end end - def finish(name, id, payload); end -end + module WaitForAsyncTestHelper + private + def wait_for_async_query(connection = ActiveRecord::Base.lease_connection, timeout: 5) + return unless connection.async_enabled? -module InTimeZone - private - def in_time_zone(zone) - old_zone = Time.zone - old_tz = ActiveRecord::Base.time_zone_aware_attributes - - Time.zone = zone ? ActiveSupport::TimeZone[zone] : nil - ActiveRecord::Base.time_zone_aware_attributes = !zone.nil? - yield - ensure - Time.zone = old_zone - ActiveRecord::Base.time_zone_aware_attributes = old_tz - end + executor = connection.pool.async_executor + (timeout * 100).times do + return unless executor.scheduled_task_count > executor.completed_task_count + sleep 0.01 + end + + raise Timeout::Error, "The async executor wasn't drained after #{timeout} seconds" + end + end end # Encryption diff --git a/activerecord/test/cases/relation/load_async_test.rb b/activerecord/test/cases/relation/load_async_test.rb index 99eccfd1739a4..c4250268f87fc 100644 --- a/activerecord/test/cases/relation/load_async_test.rb +++ b/activerecord/test/cases/relation/load_async_test.rb @@ -7,21 +7,6 @@ require "models/other_dog" module ActiveRecord - module WaitForAsyncTestHelper - private - def wait_for_async_query(connection = ActiveRecord::Base.lease_connection, timeout: 5) - return unless connection.async_enabled? - - executor = connection.pool.async_executor - (timeout * 100).times do - return unless executor.scheduled_task_count > executor.completed_task_count - sleep 0.01 - end - - raise Timeout::Error, "The async executor wasn't drained after #{timeout} seconds" - end - end - class LoadAsyncTest < ActiveRecord::TestCase include WaitForAsyncTestHelper