Skip to content

Commit cb361d8

Browse files
thejhIngo Molnar
authored andcommitted
sched/fair: Use RCU accessors consistently for ->numa_group
The old code used RCU annotations and accessors inconsistently for ->numa_group, which can lead to use-after-frees and NULL dereferences. Let all accesses to ->numa_group use proper RCU helpers to prevent such issues. Signed-off-by: Jann Horn <[email protected]> Signed-off-by: Peter Zijlstra (Intel) <[email protected]> Cc: Linus Torvalds <[email protected]> Cc: Peter Zijlstra <[email protected]> Cc: Petr Mladek <[email protected]> Cc: Sergey Senozhatsky <[email protected]> Cc: Thomas Gleixner <[email protected]> Cc: Will Deacon <[email protected]> Fixes: 8c8a743 ("sched/numa: Use {cpu, pid} to create task groups for shared faults") Link: https://lkml.kernel.org/r/[email protected] Signed-off-by: Ingo Molnar <[email protected]>
1 parent 16d51a5 commit cb361d8

File tree

2 files changed

+90
-40
lines changed

2 files changed

+90
-40
lines changed

include/linux/sched.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1092,7 +1092,15 @@ struct task_struct {
10921092
u64 last_sum_exec_runtime;
10931093
struct callback_head numa_work;
10941094

1095-
struct numa_group *numa_group;
1095+
/*
1096+
* This pointer is only modified for current in syscall and
1097+
* pagefault context (and for tasks being destroyed), so it can be read
1098+
* from any of the following contexts:
1099+
* - RCU read-side critical section
1100+
* - current->numa_group from everywhere
1101+
* - task's runqueue locked, task not running
1102+
*/
1103+
struct numa_group __rcu *numa_group;
10961104

10971105
/*
10981106
* numa_faults is an array split into four regions:

kernel/sched/fair.c

Lines changed: 81 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1086,6 +1086,21 @@ struct numa_group {
10861086
unsigned long faults[0];
10871087
};
10881088

1089+
/*
1090+
* For functions that can be called in multiple contexts that permit reading
1091+
* ->numa_group (see struct task_struct for locking rules).
1092+
*/
1093+
static struct numa_group *deref_task_numa_group(struct task_struct *p)
1094+
{
1095+
return rcu_dereference_check(p->numa_group, p == current ||
1096+
(lockdep_is_held(&task_rq(p)->lock) && !READ_ONCE(p->on_cpu)));
1097+
}
1098+
1099+
static struct numa_group *deref_curr_numa_group(struct task_struct *p)
1100+
{
1101+
return rcu_dereference_protected(p->numa_group, p == current);
1102+
}
1103+
10891104
static inline unsigned long group_faults_priv(struct numa_group *ng);
10901105
static inline unsigned long group_faults_shared(struct numa_group *ng);
10911106

@@ -1129,17 +1144,20 @@ static unsigned int task_scan_start(struct task_struct *p)
11291144
{
11301145
unsigned long smin = task_scan_min(p);
11311146
unsigned long period = smin;
1147+
struct numa_group *ng;
11321148

11331149
/* Scale the maximum scan period with the amount of shared memory. */
1134-
if (p->numa_group) {
1135-
struct numa_group *ng = p->numa_group;
1150+
rcu_read_lock();
1151+
ng = rcu_dereference(p->numa_group);
1152+
if (ng) {
11361153
unsigned long shared = group_faults_shared(ng);
11371154
unsigned long private = group_faults_priv(ng);
11381155

11391156
period *= refcount_read(&ng->refcount);
11401157
period *= shared + 1;
11411158
period /= private + shared + 1;
11421159
}
1160+
rcu_read_unlock();
11431161

11441162
return max(smin, period);
11451163
}
@@ -1148,13 +1166,14 @@ static unsigned int task_scan_max(struct task_struct *p)
11481166
{
11491167
unsigned long smin = task_scan_min(p);
11501168
unsigned long smax;
1169+
struct numa_group *ng;
11511170

11521171
/* Watch for min being lower than max due to floor calculations */
11531172
smax = sysctl_numa_balancing_scan_period_max / task_nr_scan_windows(p);
11541173

11551174
/* Scale the maximum scan period with the amount of shared memory. */
1156-
if (p->numa_group) {
1157-
struct numa_group *ng = p->numa_group;
1175+
ng = deref_curr_numa_group(p);
1176+
if (ng) {
11581177
unsigned long shared = group_faults_shared(ng);
11591178
unsigned long private = group_faults_priv(ng);
11601179
unsigned long period = smax;
@@ -1186,7 +1205,7 @@ void init_numa_balancing(unsigned long clone_flags, struct task_struct *p)
11861205
p->numa_scan_period = sysctl_numa_balancing_scan_delay;
11871206
p->numa_work.next = &p->numa_work;
11881207
p->numa_faults = NULL;
1189-
p->numa_group = NULL;
1208+
RCU_INIT_POINTER(p->numa_group, NULL);
11901209
p->last_task_numa_placement = 0;
11911210
p->last_sum_exec_runtime = 0;
11921211

@@ -1233,7 +1252,16 @@ static void account_numa_dequeue(struct rq *rq, struct task_struct *p)
12331252

12341253
pid_t task_numa_group_id(struct task_struct *p)
12351254
{
1236-
return p->numa_group ? p->numa_group->gid : 0;
1255+
struct numa_group *ng;
1256+
pid_t gid = 0;
1257+
1258+
rcu_read_lock();
1259+
ng = rcu_dereference(p->numa_group);
1260+
if (ng)
1261+
gid = ng->gid;
1262+
rcu_read_unlock();
1263+
1264+
return gid;
12371265
}
12381266

12391267
/*
@@ -1258,11 +1286,13 @@ static inline unsigned long task_faults(struct task_struct *p, int nid)
12581286

12591287
static inline unsigned long group_faults(struct task_struct *p, int nid)
12601288
{
1261-
if (!p->numa_group)
1289+
struct numa_group *ng = deref_task_numa_group(p);
1290+
1291+
if (!ng)
12621292
return 0;
12631293

1264-
return p->numa_group->faults[task_faults_idx(NUMA_MEM, nid, 0)] +
1265-
p->numa_group->faults[task_faults_idx(NUMA_MEM, nid, 1)];
1294+
return ng->faults[task_faults_idx(NUMA_MEM, nid, 0)] +
1295+
ng->faults[task_faults_idx(NUMA_MEM, nid, 1)];
12661296
}
12671297

12681298
static inline unsigned long group_faults_cpu(struct numa_group *group, int nid)
@@ -1400,12 +1430,13 @@ static inline unsigned long task_weight(struct task_struct *p, int nid,
14001430
static inline unsigned long group_weight(struct task_struct *p, int nid,
14011431
int dist)
14021432
{
1433+
struct numa_group *ng = deref_task_numa_group(p);
14031434
unsigned long faults, total_faults;
14041435

1405-
if (!p->numa_group)
1436+
if (!ng)
14061437
return 0;
14071438

1408-
total_faults = p->numa_group->total_faults;
1439+
total_faults = ng->total_faults;
14091440

14101441
if (!total_faults)
14111442
return 0;
@@ -1419,7 +1450,7 @@ static inline unsigned long group_weight(struct task_struct *p, int nid,
14191450
bool should_numa_migrate_memory(struct task_struct *p, struct page * page,
14201451
int src_nid, int dst_cpu)
14211452
{
1422-
struct numa_group *ng = p->numa_group;
1453+
struct numa_group *ng = deref_curr_numa_group(p);
14231454
int dst_nid = cpu_to_node(dst_cpu);
14241455
int last_cpupid, this_cpupid;
14251456

@@ -1600,13 +1631,14 @@ static bool load_too_imbalanced(long src_load, long dst_load,
16001631
static void task_numa_compare(struct task_numa_env *env,
16011632
long taskimp, long groupimp, bool maymove)
16021633
{
1634+
struct numa_group *cur_ng, *p_ng = deref_curr_numa_group(env->p);
16031635
struct rq *dst_rq = cpu_rq(env->dst_cpu);
1636+
long imp = p_ng ? groupimp : taskimp;
16041637
struct task_struct *cur;
16051638
long src_load, dst_load;
1606-
long load;
1607-
long imp = env->p->numa_group ? groupimp : taskimp;
1608-
long moveimp = imp;
16091639
int dist = env->dist;
1640+
long moveimp = imp;
1641+
long load;
16101642

16111643
if (READ_ONCE(dst_rq->numa_migrate_on))
16121644
return;
@@ -1645,21 +1677,22 @@ static void task_numa_compare(struct task_numa_env *env,
16451677
* If dst and source tasks are in the same NUMA group, or not
16461678
* in any group then look only at task weights.
16471679
*/
1648-
if (cur->numa_group == env->p->numa_group) {
1680+
cur_ng = rcu_dereference(cur->numa_group);
1681+
if (cur_ng == p_ng) {
16491682
imp = taskimp + task_weight(cur, env->src_nid, dist) -
16501683
task_weight(cur, env->dst_nid, dist);
16511684
/*
16521685
* Add some hysteresis to prevent swapping the
16531686
* tasks within a group over tiny differences.
16541687
*/
1655-
if (cur->numa_group)
1688+
if (cur_ng)
16561689
imp -= imp / 16;
16571690
} else {
16581691
/*
16591692
* Compare the group weights. If a task is all by itself
16601693
* (not part of a group), use the task weight instead.
16611694
*/
1662-
if (cur->numa_group && env->p->numa_group)
1695+
if (cur_ng && p_ng)
16631696
imp += group_weight(cur, env->src_nid, dist) -
16641697
group_weight(cur, env->dst_nid, dist);
16651698
else
@@ -1757,11 +1790,12 @@ static int task_numa_migrate(struct task_struct *p)
17571790
.best_imp = 0,
17581791
.best_cpu = -1,
17591792
};
1793+
unsigned long taskweight, groupweight;
17601794
struct sched_domain *sd;
1795+
long taskimp, groupimp;
1796+
struct numa_group *ng;
17611797
struct rq *best_rq;
1762-
unsigned long taskweight, groupweight;
17631798
int nid, ret, dist;
1764-
long taskimp, groupimp;
17651799

17661800
/*
17671801
* Pick the lowest SD_NUMA domain, as that would have the smallest
@@ -1807,7 +1841,8 @@ static int task_numa_migrate(struct task_struct *p)
18071841
* multiple NUMA nodes; in order to better consolidate the group,
18081842
* we need to check other locations.
18091843
*/
1810-
if (env.best_cpu == -1 || (p->numa_group && p->numa_group->active_nodes > 1)) {
1844+
ng = deref_curr_numa_group(p);
1845+
if (env.best_cpu == -1 || (ng && ng->active_nodes > 1)) {
18111846
for_each_online_node(nid) {
18121847
if (nid == env.src_nid || nid == p->numa_preferred_nid)
18131848
continue;
@@ -1840,7 +1875,7 @@ static int task_numa_migrate(struct task_struct *p)
18401875
* A task that migrated to a second choice node will be better off
18411876
* trying for a better one later. Do not set the preferred node here.
18421877
*/
1843-
if (p->numa_group) {
1878+
if (ng) {
18441879
if (env.best_cpu == -1)
18451880
nid = env.src_nid;
18461881
else
@@ -2135,6 +2170,7 @@ static void task_numa_placement(struct task_struct *p)
21352170
unsigned long total_faults;
21362171
u64 runtime, period;
21372172
spinlock_t *group_lock = NULL;
2173+
struct numa_group *ng;
21382174

21392175
/*
21402176
* The p->mm->numa_scan_seq field gets updated without
@@ -2152,8 +2188,9 @@ static void task_numa_placement(struct task_struct *p)
21522188
runtime = numa_get_avg_runtime(p, &period);
21532189

21542190
/* If the task is part of a group prevent parallel updates to group stats */
2155-
if (p->numa_group) {
2156-
group_lock = &p->numa_group->lock;
2191+
ng = deref_curr_numa_group(p);
2192+
if (ng) {
2193+
group_lock = &ng->lock;
21572194
spin_lock_irq(group_lock);
21582195
}
21592196

@@ -2194,22 +2231,22 @@ static void task_numa_placement(struct task_struct *p)
21942231
p->numa_faults[cpu_idx] += f_diff;
21952232
faults += p->numa_faults[mem_idx];
21962233
p->total_numa_faults += diff;
2197-
if (p->numa_group) {
2234+
if (ng) {
21982235
/*
21992236
* safe because we can only change our own group
22002237
*
22012238
* mem_idx represents the offset for a given
22022239
* nid and priv in a specific region because it
22032240
* is at the beginning of the numa_faults array.
22042241
*/
2205-
p->numa_group->faults[mem_idx] += diff;
2206-
p->numa_group->faults_cpu[mem_idx] += f_diff;
2207-
p->numa_group->total_faults += diff;
2208-
group_faults += p->numa_group->faults[mem_idx];
2242+
ng->faults[mem_idx] += diff;
2243+
ng->faults_cpu[mem_idx] += f_diff;
2244+
ng->total_faults += diff;
2245+
group_faults += ng->faults[mem_idx];
22092246
}
22102247
}
22112248

2212-
if (!p->numa_group) {
2249+
if (!ng) {
22132250
if (faults > max_faults) {
22142251
max_faults = faults;
22152252
max_nid = nid;
@@ -2220,8 +2257,8 @@ static void task_numa_placement(struct task_struct *p)
22202257
}
22212258
}
22222259

2223-
if (p->numa_group) {
2224-
numa_group_count_active_nodes(p->numa_group);
2260+
if (ng) {
2261+
numa_group_count_active_nodes(ng);
22252262
spin_unlock_irq(group_lock);
22262263
max_nid = preferred_group_nid(p, max_nid);
22272264
}
@@ -2255,7 +2292,7 @@ static void task_numa_group(struct task_struct *p, int cpupid, int flags,
22552292
int cpu = cpupid_to_cpu(cpupid);
22562293
int i;
22572294

2258-
if (unlikely(!p->numa_group)) {
2295+
if (unlikely(!deref_curr_numa_group(p))) {
22592296
unsigned int size = sizeof(struct numa_group) +
22602297
4*nr_node_ids*sizeof(unsigned long);
22612298

@@ -2291,7 +2328,7 @@ static void task_numa_group(struct task_struct *p, int cpupid, int flags,
22912328
if (!grp)
22922329
goto no_join;
22932330

2294-
my_grp = p->numa_group;
2331+
my_grp = deref_curr_numa_group(p);
22952332
if (grp == my_grp)
22962333
goto no_join;
22972334

@@ -2362,7 +2399,8 @@ static void task_numa_group(struct task_struct *p, int cpupid, int flags,
23622399
*/
23632400
void task_numa_free(struct task_struct *p, bool final)
23642401
{
2365-
struct numa_group *grp = p->numa_group;
2402+
/* safe: p either is current or is being freed by current */
2403+
struct numa_group *grp = rcu_dereference_raw(p->numa_group);
23662404
unsigned long *numa_faults = p->numa_faults;
23672405
unsigned long flags;
23682406
int i;
@@ -2442,7 +2480,7 @@ void task_numa_fault(int last_cpupid, int mem_node, int pages, int flags)
24422480
* actively using should be counted as local. This allows the
24432481
* scan rate to slow down when a workload has settled down.
24442482
*/
2445-
ng = p->numa_group;
2483+
ng = deref_curr_numa_group(p);
24462484
if (!priv && !local && ng && ng->active_nodes > 1 &&
24472485
numa_is_active_node(cpu_node, ng) &&
24482486
numa_is_active_node(mem_node, ng))
@@ -10460,18 +10498,22 @@ void show_numa_stats(struct task_struct *p, struct seq_file *m)
1046010498
{
1046110499
int node;
1046210500
unsigned long tsf = 0, tpf = 0, gsf = 0, gpf = 0;
10501+
struct numa_group *ng;
1046310502

10503+
rcu_read_lock();
10504+
ng = rcu_dereference(p->numa_group);
1046410505
for_each_online_node(node) {
1046510506
if (p->numa_faults) {
1046610507
tsf = p->numa_faults[task_faults_idx(NUMA_MEM, node, 0)];
1046710508
tpf = p->numa_faults[task_faults_idx(NUMA_MEM, node, 1)];
1046810509
}
10469-
if (p->numa_group) {
10470-
gsf = p->numa_group->faults[task_faults_idx(NUMA_MEM, node, 0)],
10471-
gpf = p->numa_group->faults[task_faults_idx(NUMA_MEM, node, 1)];
10510+
if (ng) {
10511+
gsf = ng->faults[task_faults_idx(NUMA_MEM, node, 0)],
10512+
gpf = ng->faults[task_faults_idx(NUMA_MEM, node, 1)];
1047210513
}
1047310514
print_numa_stats(m, node, tsf, tpf, gsf, gpf);
1047410515
}
10516+
rcu_read_unlock();
1047510517
}
1047610518
#endif /* CONFIG_NUMA_BALANCING */
1047710519
#endif /* CONFIG_SCHED_DEBUG */

0 commit comments

Comments
 (0)