Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions lib/activerecord-multi-tenant/query_rewriter.rb
Original file line number Diff line number Diff line change
Expand Up @@ -247,19 +247,33 @@ def visit_MultiTenant_TenantJoinEnforcementClause(obj, collector)
module DatabaseStatements
def update(arel, name = nil, binds = [])
model = MultiTenant.multi_tenant_model_for_arel(arel)
if model.present? && !MultiTenant.with_write_only_mode_enabled? && MultiTenant.current_tenant_id.present?
if model.present? &&
!MultiTenant.with_write_only_mode_enabled? &&
MultiTenant.current_tenant_id.present? &&
!already_has_tenant_enforcement_clause?(arel)
arel.where(MultiTenant::TenantEnforcementClause.new(model.arel_table[model.partition_key]))
end
super
end

def delete(arel, name = nil, binds = [])
model = MultiTenant.multi_tenant_model_for_arel(arel)
if model.present? && !MultiTenant.with_write_only_mode_enabled? && MultiTenant.current_tenant_id.present?
if model.present? &&
!MultiTenant.with_write_only_mode_enabled? &&
MultiTenant.current_tenant_id.present? &&
!already_has_tenant_enforcement_clause?(arel)
arel.where(MultiTenant::TenantEnforcementClause.new(model.arel_table[model.partition_key]))
end
super
end

private

def already_has_tenant_enforcement_clause?(arel)
arel.try(:ast).try(:wheres).to_a.any? do |where|
where.is_a?(MultiTenant::BaseTenantEnforcementClause)
end
end
end
end

Expand Down
94 changes: 17 additions & 77 deletions lib/activerecord-multi-tenant/relation_extension.rb
Original file line number Diff line number Diff line change
@@ -1,86 +1,26 @@
# frozen_string_literal: true

module Arel
module ActiveRecordRelationExtension
# Overrides the delete_all method to include tenant scoping
def delete_all
model = MultiTenant.multi_tenant_model_for_table(table_name)

# Call the original delete_all method if the current tenant is identified by an ID
return super if model.nil? || MultiTenant.current_tenant_is_id? || MultiTenant.current_tenant.nil?

stmt = Arel::DeleteManager.new.from(table)
stmt.wheres = [generate_in_condition_subquery]

# Execute the delete statement using the connection and return the result
klass.connection.delete(stmt, "#{klass} Delete All").tap { reset }
end

# Overrides the update_all method to include tenant scoping
def update_all(updates)
model = MultiTenant.multi_tenant_model_for_table(table_name)

# Call the original update_all method if the current tenant is identified by an ID
return super if model.nil? || MultiTenant.current_tenant_is_id? || MultiTenant.current_tenant.nil?

stmt = Arel::UpdateManager.new
stmt.table(table)
stmt.set Arel.sql(klass.send(:sanitize_sql_for_assignment, updates))
stmt.wheres = [generate_in_condition_subquery]

klass.connection.update(stmt, "#{klass} Update All").tap { reset }
end

private

# The generate_in_condition_subquery method generates a subquery that selects
# records associated with the current tenant.
def generate_in_condition_subquery
# Get the tenant key and tenant ID based on the current tenant
tenant_key = MultiTenant.partition_key(MultiTenant.current_tenant_class)
tenant_id = MultiTenant.current_tenant_id

# Build an Arel query
arel = if eager_loading?
apply_join_dependency.arel
elsif ActiveRecord.gem_version >= Gem::Version.create('7.2.0')
build_arel(klass.connection)
else
build_arel
end

arel.source.left = table

# If the tenant ID is present and the tenant key is a column in the model,
# add a condition to only include records where the tenant key equals the tenant ID
if tenant_id && klass.column_names.include?(tenant_key)
tenant_condition = table[tenant_key].eq(tenant_id)
unless arel.constraints.any? { |node| node.to_sql.include?(tenant_condition.to_sql) }
arel = arel.where(tenant_condition)
module Arel # :nodoc: all
module Visitors
module ToSqlPatch
def prepare_update_statement(object)
if object.key && (has_limit_or_offset_or_orders?(object) || has_join_sources?(object))
stmt = super

model = MultiTenant.multi_tenant_model_for_table(MultiTenant::TableNode.table_name(object.relation.left))
if model.present? && !MultiTenant.with_write_only_mode_enabled? && MultiTenant.current_tenant_id.present?
stmt.wheres << MultiTenant::TenantEnforcementClause.new(model.arel_table[model.partition_key])
end

stmt
else
super
end
end

# Clone the query, clear its projections, and set its projection to the primary key of the table
subquery = arel.clone
subquery.projections.clear

if primary_key.is_a?(Array)
# For composite primary keys, project all primary key columns
primary_key_columns = primary_key.map { |pk| table[pk] }
subquery = subquery.project(*primary_key_columns)

# Create IN condition using composite primary key columns
Arel::Nodes::In.new(
Arel::Nodes::Grouping.new(primary_key_columns),
subquery.ast
)
else
subquery = subquery.project(table[primary_key])
Arel::Nodes::In.new(table[primary_key], subquery.ast)
end
alias prepare_delete_statement prepare_update_statement
end
end
end

# Patch ActiveRecord::Relation with the extension module
ActiveRecord::Relation.prepend(Arel::ActiveRecordRelationExtension)
Arel::Visitors::ToSql.prepend(Arel::Visitors::ToSqlPatch)
17 changes: 10 additions & 7 deletions spec/activerecord-multi-tenant/query_rewriter_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@

it 'update_all the records with expected query' do
expected_query = <<-SQL.strip
UPDATE "projects" SET "name" = 'New Name' WHERE "projects"."id" IN
UPDATE "projects" SET "name" = 'New Name' WHERE ("projects"."id") IN
(SELECT "projects"."id" FROM "projects"
INNER JOIN "managers" ON "managers"."project_id" = "projects"."id"
and "managers"."account_id" = :account_id
Expand All @@ -67,7 +67,9 @@
@queries.each do |actual_query|
next unless actual_query.include?('UPDATE "projects" SET "name"')

expect(format_sql(actual_query)).to eq(format_sql(expected_query.gsub(':account_id', account.id.to_s)))
expect(format_sql(actual_query.gsub('$1', "'New Name'"))).to eq(format_sql(expected_query.gsub(
':account_id', account.id.to_s
)))
end
end

Expand All @@ -83,7 +85,7 @@
SET
"name" = 'New Name'
WHERE
"projects"."id" IN (
("projects"."id") IN (
SELECT
"projects"."id"
FROM
Expand All @@ -103,8 +105,9 @@
@queries.each do |actual_query|
next unless actual_query.include?('UPDATE "projects" SET "name"')

expect(format_sql(actual_query.gsub('$1',
limit.to_s)).strip).to eq(format_sql(expected_query).strip)
expect(
format_sql(actual_query.gsub('$1', "'#{new_name}'").gsub('$2', limit.to_s)).strip
).to eq(format_sql(expected_query).strip)
end
end
end
Expand All @@ -119,7 +122,7 @@

it 'delete_all the records' do
expected_query = <<-SQL.strip
DELETE FROM "projects" WHERE "projects"."id" IN
DELETE FROM "projects" WHERE ("projects"."id") IN
(SELECT "projects"."id" FROM "projects"
INNER JOIN "managers" ON "managers"."project_id" = "projects"."id"
and "managers"."account_id" = :account_id
Expand Down Expand Up @@ -172,7 +175,7 @@
DELETE FROM
"projects"
WHERE
"projects"."id" IN (
("projects"."id") IN (
SELECT
"projects"."id"
FROM
Expand Down
Loading