Skip to content

Commit 882089c

Browse files
committed
Refactor tenant enforcement to use Arel visitor pattern
1 parent f0e239d commit 882089c

File tree

3 files changed

+43
-86
lines changed

3 files changed

+43
-86
lines changed

lib/activerecord-multi-tenant/query_rewriter.rb

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -247,19 +247,33 @@ def visit_MultiTenant_TenantJoinEnforcementClause(obj, collector)
247247
module DatabaseStatements
248248
def update(arel, name = nil, binds = [])
249249
model = MultiTenant.multi_tenant_model_for_arel(arel)
250-
if model.present? && !MultiTenant.with_write_only_mode_enabled? && MultiTenant.current_tenant_id.present?
250+
if model.present? &&
251+
!MultiTenant.with_write_only_mode_enabled? &&
252+
MultiTenant.current_tenant_id.present? &&
253+
!already_has_tenant_enforcement_clause?(arel)
251254
arel.where(MultiTenant::TenantEnforcementClause.new(model.arel_table[model.partition_key]))
252255
end
253256
super
254257
end
255258

256259
def delete(arel, name = nil, binds = [])
257260
model = MultiTenant.multi_tenant_model_for_arel(arel)
258-
if model.present? && !MultiTenant.with_write_only_mode_enabled? && MultiTenant.current_tenant_id.present?
261+
if model.present? &&
262+
!MultiTenant.with_write_only_mode_enabled? &&
263+
MultiTenant.current_tenant_id.present? &&
264+
!already_has_tenant_enforcement_clause?(arel)
259265
arel.where(MultiTenant::TenantEnforcementClause.new(model.arel_table[model.partition_key]))
260266
end
261267
super
262268
end
269+
270+
private
271+
272+
def already_has_tenant_enforcement_clause?(arel)
273+
arel.try(:ast).try(:wheres).to_a.any? do |where|
274+
where.is_a?(MultiTenant::BaseTenantEnforcementClause)
275+
end
276+
end
263277
end
264278
end
265279

Lines changed: 17 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -1,86 +1,26 @@
11
# frozen_string_literal: true
22

3-
module Arel
4-
module ActiveRecordRelationExtension
5-
# Overrides the delete_all method to include tenant scoping
6-
def delete_all
7-
model = MultiTenant.multi_tenant_model_for_table(table_name)
8-
9-
# Call the original delete_all method if the current tenant is identified by an ID
10-
return super if model.nil? || MultiTenant.current_tenant_is_id? || MultiTenant.current_tenant.nil?
11-
12-
stmt = Arel::DeleteManager.new.from(table)
13-
stmt.wheres = [generate_in_condition_subquery]
14-
15-
# Execute the delete statement using the connection and return the result
16-
klass.connection.delete(stmt, "#{klass} Delete All").tap { reset }
17-
end
18-
19-
# Overrides the update_all method to include tenant scoping
20-
def update_all(updates)
21-
model = MultiTenant.multi_tenant_model_for_table(table_name)
22-
23-
# Call the original update_all method if the current tenant is identified by an ID
24-
return super if model.nil? || MultiTenant.current_tenant_is_id? || MultiTenant.current_tenant.nil?
25-
26-
stmt = Arel::UpdateManager.new
27-
stmt.table(table)
28-
stmt.set Arel.sql(klass.send(:sanitize_sql_for_assignment, updates))
29-
stmt.wheres = [generate_in_condition_subquery]
30-
31-
klass.connection.update(stmt, "#{klass} Update All").tap { reset }
32-
end
33-
34-
private
35-
36-
# The generate_in_condition_subquery method generates a subquery that selects
37-
# records associated with the current tenant.
38-
def generate_in_condition_subquery
39-
# Get the tenant key and tenant ID based on the current tenant
40-
tenant_key = MultiTenant.partition_key(MultiTenant.current_tenant_class)
41-
tenant_id = MultiTenant.current_tenant_id
42-
43-
# Build an Arel query
44-
arel = if eager_loading?
45-
apply_join_dependency.arel
46-
elsif ActiveRecord.gem_version >= Gem::Version.create('7.2.0')
47-
build_arel(klass.connection)
48-
else
49-
build_arel
50-
end
51-
52-
arel.source.left = table
53-
54-
# If the tenant ID is present and the tenant key is a column in the model,
55-
# add a condition to only include records where the tenant key equals the tenant ID
56-
if tenant_id && klass.column_names.include?(tenant_key)
57-
tenant_condition = table[tenant_key].eq(tenant_id)
58-
unless arel.constraints.any? { |node| node.to_sql.include?(tenant_condition.to_sql) }
59-
arel = arel.where(tenant_condition)
3+
module Arel # :nodoc: all
4+
module Visitors
5+
module ToSqlPatch
6+
def prepare_update_statement(object)
7+
if object.key && (has_limit_or_offset_or_orders?(object) || has_join_sources?(object))
8+
stmt = super
9+
10+
model = MultiTenant.multi_tenant_model_for_table(MultiTenant::TableNode.table_name(object.relation.left))
11+
if model.present? && !MultiTenant.with_write_only_mode_enabled? && MultiTenant.current_tenant_id.present?
12+
stmt.wheres << MultiTenant::TenantEnforcementClause.new(model.arel_table[model.partition_key])
13+
end
14+
15+
stmt
16+
else
17+
super
6018
end
6119
end
6220

63-
# Clone the query, clear its projections, and set its projection to the primary key of the table
64-
subquery = arel.clone
65-
subquery.projections.clear
66-
67-
if primary_key.is_a?(Array)
68-
# For composite primary keys, project all primary key columns
69-
primary_key_columns = primary_key.map { |pk| table[pk] }
70-
subquery = subquery.project(*primary_key_columns)
71-
72-
# Create IN condition using composite primary key columns
73-
Arel::Nodes::In.new(
74-
Arel::Nodes::Grouping.new(primary_key_columns),
75-
subquery.ast
76-
)
77-
else
78-
subquery = subquery.project(table[primary_key])
79-
Arel::Nodes::In.new(table[primary_key], subquery.ast)
80-
end
21+
alias prepare_delete_statement prepare_update_statement
8122
end
8223
end
8324
end
8425

85-
# Patch ActiveRecord::Relation with the extension module
86-
ActiveRecord::Relation.prepend(Arel::ActiveRecordRelationExtension)
26+
Arel::Visitors::ToSql.prepend(Arel::Visitors::ToSqlPatch)

spec/activerecord-multi-tenant/query_rewriter_spec.rb

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949

5050
it 'update_all the records with expected query' do
5151
expected_query = <<-SQL.strip
52-
UPDATE "projects" SET "name" = 'New Name' WHERE "projects"."id" IN
52+
UPDATE "projects" SET "name" = 'New Name' WHERE ("projects"."id") IN
5353
(SELECT "projects"."id" FROM "projects"
5454
INNER JOIN "managers" ON "managers"."project_id" = "projects"."id"
5555
and "managers"."account_id" = :account_id
@@ -67,7 +67,9 @@
6767
@queries.each do |actual_query|
6868
next unless actual_query.include?('UPDATE "projects" SET "name"')
6969

70-
expect(format_sql(actual_query)).to eq(format_sql(expected_query.gsub(':account_id', account.id.to_s)))
70+
expect(format_sql(actual_query.gsub('$1', "'New Name'"))).to eq(format_sql(expected_query.gsub(
71+
':account_id', account.id.to_s
72+
)))
7173
end
7274
end
7375

@@ -83,7 +85,7 @@
8385
SET
8486
"name" = 'New Name'
8587
WHERE
86-
"projects"."id" IN (
88+
("projects"."id") IN (
8789
SELECT
8890
"projects"."id"
8991
FROM
@@ -103,8 +105,9 @@
103105
@queries.each do |actual_query|
104106
next unless actual_query.include?('UPDATE "projects" SET "name"')
105107

106-
expect(format_sql(actual_query.gsub('$1',
107-
limit.to_s)).strip).to eq(format_sql(expected_query).strip)
108+
expect(
109+
format_sql(actual_query.gsub('$1', "'#{new_name}'").gsub('$2', limit.to_s)).strip
110+
).to eq(format_sql(expected_query).strip)
108111
end
109112
end
110113
end
@@ -119,7 +122,7 @@
119122

120123
it 'delete_all the records' do
121124
expected_query = <<-SQL.strip
122-
DELETE FROM "projects" WHERE "projects"."id" IN
125+
DELETE FROM "projects" WHERE ("projects"."id") IN
123126
(SELECT "projects"."id" FROM "projects"
124127
INNER JOIN "managers" ON "managers"."project_id" = "projects"."id"
125128
and "managers"."account_id" = :account_id
@@ -172,7 +175,7 @@
172175
DELETE FROM
173176
"projects"
174177
WHERE
175-
"projects"."id" IN (
178+
("projects"."id") IN (
176179
SELECT
177180
"projects"."id"
178181
FROM

0 commit comments

Comments
 (0)