@@ -1086,6 +1086,21 @@ struct numa_group {
1086
1086
unsigned long faults [0 ];
1087
1087
};
1088
1088
1089
+ /*
1090
+ * For functions that can be called in multiple contexts that permit reading
1091
+ * ->numa_group (see struct task_struct for locking rules).
1092
+ */
1093
+ static struct numa_group * deref_task_numa_group (struct task_struct * p )
1094
+ {
1095
+ return rcu_dereference_check (p -> numa_group , p == current ||
1096
+ (lockdep_is_held (& task_rq (p )-> lock ) && !READ_ONCE (p -> on_cpu )));
1097
+ }
1098
+
1099
+ static struct numa_group * deref_curr_numa_group (struct task_struct * p )
1100
+ {
1101
+ return rcu_dereference_protected (p -> numa_group , p == current );
1102
+ }
1103
+
1089
1104
static inline unsigned long group_faults_priv (struct numa_group * ng );
1090
1105
static inline unsigned long group_faults_shared (struct numa_group * ng );
1091
1106
@@ -1129,17 +1144,20 @@ static unsigned int task_scan_start(struct task_struct *p)
1129
1144
{
1130
1145
unsigned long smin = task_scan_min (p );
1131
1146
unsigned long period = smin ;
1147
+ struct numa_group * ng ;
1132
1148
1133
1149
/* Scale the maximum scan period with the amount of shared memory. */
1134
- if (p -> numa_group ) {
1135
- struct numa_group * ng = p -> numa_group ;
1150
+ rcu_read_lock ();
1151
+ ng = rcu_dereference (p -> numa_group );
1152
+ if (ng ) {
1136
1153
unsigned long shared = group_faults_shared (ng );
1137
1154
unsigned long private = group_faults_priv (ng );
1138
1155
1139
1156
period *= refcount_read (& ng -> refcount );
1140
1157
period *= shared + 1 ;
1141
1158
period /= private + shared + 1 ;
1142
1159
}
1160
+ rcu_read_unlock ();
1143
1161
1144
1162
return max (smin , period );
1145
1163
}
@@ -1148,13 +1166,14 @@ static unsigned int task_scan_max(struct task_struct *p)
1148
1166
{
1149
1167
unsigned long smin = task_scan_min (p );
1150
1168
unsigned long smax ;
1169
+ struct numa_group * ng ;
1151
1170
1152
1171
/* Watch for min being lower than max due to floor calculations */
1153
1172
smax = sysctl_numa_balancing_scan_period_max / task_nr_scan_windows (p );
1154
1173
1155
1174
/* Scale the maximum scan period with the amount of shared memory. */
1156
- if ( p -> numa_group ) {
1157
- struct numa_group * ng = p -> numa_group ;
1175
+ ng = deref_curr_numa_group ( p );
1176
+ if ( ng ) {
1158
1177
unsigned long shared = group_faults_shared (ng );
1159
1178
unsigned long private = group_faults_priv (ng );
1160
1179
unsigned long period = smax ;
@@ -1186,7 +1205,7 @@ void init_numa_balancing(unsigned long clone_flags, struct task_struct *p)
1186
1205
p -> numa_scan_period = sysctl_numa_balancing_scan_delay ;
1187
1206
p -> numa_work .next = & p -> numa_work ;
1188
1207
p -> numa_faults = NULL ;
1189
- p -> numa_group = NULL ;
1208
+ RCU_INIT_POINTER ( p -> numa_group , NULL ) ;
1190
1209
p -> last_task_numa_placement = 0 ;
1191
1210
p -> last_sum_exec_runtime = 0 ;
1192
1211
@@ -1233,7 +1252,16 @@ static void account_numa_dequeue(struct rq *rq, struct task_struct *p)
1233
1252
1234
1253
pid_t task_numa_group_id (struct task_struct * p )
1235
1254
{
1236
- return p -> numa_group ? p -> numa_group -> gid : 0 ;
1255
+ struct numa_group * ng ;
1256
+ pid_t gid = 0 ;
1257
+
1258
+ rcu_read_lock ();
1259
+ ng = rcu_dereference (p -> numa_group );
1260
+ if (ng )
1261
+ gid = ng -> gid ;
1262
+ rcu_read_unlock ();
1263
+
1264
+ return gid ;
1237
1265
}
1238
1266
1239
1267
/*
@@ -1258,11 +1286,13 @@ static inline unsigned long task_faults(struct task_struct *p, int nid)
1258
1286
1259
1287
static inline unsigned long group_faults (struct task_struct * p , int nid )
1260
1288
{
1261
- if (!p -> numa_group )
1289
+ struct numa_group * ng = deref_task_numa_group (p );
1290
+
1291
+ if (!ng )
1262
1292
return 0 ;
1263
1293
1264
- return p -> numa_group -> faults [task_faults_idx (NUMA_MEM , nid , 0 )] +
1265
- p -> numa_group -> faults [task_faults_idx (NUMA_MEM , nid , 1 )];
1294
+ return ng -> faults [task_faults_idx (NUMA_MEM , nid , 0 )] +
1295
+ ng -> faults [task_faults_idx (NUMA_MEM , nid , 1 )];
1266
1296
}
1267
1297
1268
1298
static inline unsigned long group_faults_cpu (struct numa_group * group , int nid )
@@ -1400,12 +1430,13 @@ static inline unsigned long task_weight(struct task_struct *p, int nid,
1400
1430
static inline unsigned long group_weight (struct task_struct * p , int nid ,
1401
1431
int dist )
1402
1432
{
1433
+ struct numa_group * ng = deref_task_numa_group (p );
1403
1434
unsigned long faults , total_faults ;
1404
1435
1405
- if (!p -> numa_group )
1436
+ if (!ng )
1406
1437
return 0 ;
1407
1438
1408
- total_faults = p -> numa_group -> total_faults ;
1439
+ total_faults = ng -> total_faults ;
1409
1440
1410
1441
if (!total_faults )
1411
1442
return 0 ;
@@ -1419,7 +1450,7 @@ static inline unsigned long group_weight(struct task_struct *p, int nid,
1419
1450
bool should_numa_migrate_memory (struct task_struct * p , struct page * page ,
1420
1451
int src_nid , int dst_cpu )
1421
1452
{
1422
- struct numa_group * ng = p -> numa_group ;
1453
+ struct numa_group * ng = deref_curr_numa_group ( p ) ;
1423
1454
int dst_nid = cpu_to_node (dst_cpu );
1424
1455
int last_cpupid , this_cpupid ;
1425
1456
@@ -1600,13 +1631,14 @@ static bool load_too_imbalanced(long src_load, long dst_load,
1600
1631
static void task_numa_compare (struct task_numa_env * env ,
1601
1632
long taskimp , long groupimp , bool maymove )
1602
1633
{
1634
+ struct numa_group * cur_ng , * p_ng = deref_curr_numa_group (env -> p );
1603
1635
struct rq * dst_rq = cpu_rq (env -> dst_cpu );
1636
+ long imp = p_ng ? groupimp : taskimp ;
1604
1637
struct task_struct * cur ;
1605
1638
long src_load , dst_load ;
1606
- long load ;
1607
- long imp = env -> p -> numa_group ? groupimp : taskimp ;
1608
- long moveimp = imp ;
1609
1639
int dist = env -> dist ;
1640
+ long moveimp = imp ;
1641
+ long load ;
1610
1642
1611
1643
if (READ_ONCE (dst_rq -> numa_migrate_on ))
1612
1644
return ;
@@ -1645,21 +1677,22 @@ static void task_numa_compare(struct task_numa_env *env,
1645
1677
* If dst and source tasks are in the same NUMA group, or not
1646
1678
* in any group then look only at task weights.
1647
1679
*/
1648
- if (cur -> numa_group == env -> p -> numa_group ) {
1680
+ cur_ng = rcu_dereference (cur -> numa_group );
1681
+ if (cur_ng == p_ng ) {
1649
1682
imp = taskimp + task_weight (cur , env -> src_nid , dist ) -
1650
1683
task_weight (cur , env -> dst_nid , dist );
1651
1684
/*
1652
1685
* Add some hysteresis to prevent swapping the
1653
1686
* tasks within a group over tiny differences.
1654
1687
*/
1655
- if (cur -> numa_group )
1688
+ if (cur_ng )
1656
1689
imp -= imp / 16 ;
1657
1690
} else {
1658
1691
/*
1659
1692
* Compare the group weights. If a task is all by itself
1660
1693
* (not part of a group), use the task weight instead.
1661
1694
*/
1662
- if (cur -> numa_group && env -> p -> numa_group )
1695
+ if (cur_ng && p_ng )
1663
1696
imp += group_weight (cur , env -> src_nid , dist ) -
1664
1697
group_weight (cur , env -> dst_nid , dist );
1665
1698
else
@@ -1757,11 +1790,12 @@ static int task_numa_migrate(struct task_struct *p)
1757
1790
.best_imp = 0 ,
1758
1791
.best_cpu = -1 ,
1759
1792
};
1793
+ unsigned long taskweight , groupweight ;
1760
1794
struct sched_domain * sd ;
1795
+ long taskimp , groupimp ;
1796
+ struct numa_group * ng ;
1761
1797
struct rq * best_rq ;
1762
- unsigned long taskweight , groupweight ;
1763
1798
int nid , ret , dist ;
1764
- long taskimp , groupimp ;
1765
1799
1766
1800
/*
1767
1801
* Pick the lowest SD_NUMA domain, as that would have the smallest
@@ -1807,7 +1841,8 @@ static int task_numa_migrate(struct task_struct *p)
1807
1841
* multiple NUMA nodes; in order to better consolidate the group,
1808
1842
* we need to check other locations.
1809
1843
*/
1810
- if (env .best_cpu == -1 || (p -> numa_group && p -> numa_group -> active_nodes > 1 )) {
1844
+ ng = deref_curr_numa_group (p );
1845
+ if (env .best_cpu == -1 || (ng && ng -> active_nodes > 1 )) {
1811
1846
for_each_online_node (nid ) {
1812
1847
if (nid == env .src_nid || nid == p -> numa_preferred_nid )
1813
1848
continue ;
@@ -1840,7 +1875,7 @@ static int task_numa_migrate(struct task_struct *p)
1840
1875
* A task that migrated to a second choice node will be better off
1841
1876
* trying for a better one later. Do not set the preferred node here.
1842
1877
*/
1843
- if (p -> numa_group ) {
1878
+ if (ng ) {
1844
1879
if (env .best_cpu == -1 )
1845
1880
nid = env .src_nid ;
1846
1881
else
@@ -2135,6 +2170,7 @@ static void task_numa_placement(struct task_struct *p)
2135
2170
unsigned long total_faults ;
2136
2171
u64 runtime , period ;
2137
2172
spinlock_t * group_lock = NULL ;
2173
+ struct numa_group * ng ;
2138
2174
2139
2175
/*
2140
2176
* The p->mm->numa_scan_seq field gets updated without
@@ -2152,8 +2188,9 @@ static void task_numa_placement(struct task_struct *p)
2152
2188
runtime = numa_get_avg_runtime (p , & period );
2153
2189
2154
2190
/* If the task is part of a group prevent parallel updates to group stats */
2155
- if (p -> numa_group ) {
2156
- group_lock = & p -> numa_group -> lock ;
2191
+ ng = deref_curr_numa_group (p );
2192
+ if (ng ) {
2193
+ group_lock = & ng -> lock ;
2157
2194
spin_lock_irq (group_lock );
2158
2195
}
2159
2196
@@ -2194,22 +2231,22 @@ static void task_numa_placement(struct task_struct *p)
2194
2231
p -> numa_faults [cpu_idx ] += f_diff ;
2195
2232
faults += p -> numa_faults [mem_idx ];
2196
2233
p -> total_numa_faults += diff ;
2197
- if (p -> numa_group ) {
2234
+ if (ng ) {
2198
2235
/*
2199
2236
* safe because we can only change our own group
2200
2237
*
2201
2238
* mem_idx represents the offset for a given
2202
2239
* nid and priv in a specific region because it
2203
2240
* is at the beginning of the numa_faults array.
2204
2241
*/
2205
- p -> numa_group -> faults [mem_idx ] += diff ;
2206
- p -> numa_group -> faults_cpu [mem_idx ] += f_diff ;
2207
- p -> numa_group -> total_faults += diff ;
2208
- group_faults += p -> numa_group -> faults [mem_idx ];
2242
+ ng -> faults [mem_idx ] += diff ;
2243
+ ng -> faults_cpu [mem_idx ] += f_diff ;
2244
+ ng -> total_faults += diff ;
2245
+ group_faults += ng -> faults [mem_idx ];
2209
2246
}
2210
2247
}
2211
2248
2212
- if (!p -> numa_group ) {
2249
+ if (!ng ) {
2213
2250
if (faults > max_faults ) {
2214
2251
max_faults = faults ;
2215
2252
max_nid = nid ;
@@ -2220,8 +2257,8 @@ static void task_numa_placement(struct task_struct *p)
2220
2257
}
2221
2258
}
2222
2259
2223
- if (p -> numa_group ) {
2224
- numa_group_count_active_nodes (p -> numa_group );
2260
+ if (ng ) {
2261
+ numa_group_count_active_nodes (ng );
2225
2262
spin_unlock_irq (group_lock );
2226
2263
max_nid = preferred_group_nid (p , max_nid );
2227
2264
}
@@ -2255,7 +2292,7 @@ static void task_numa_group(struct task_struct *p, int cpupid, int flags,
2255
2292
int cpu = cpupid_to_cpu (cpupid );
2256
2293
int i ;
2257
2294
2258
- if (unlikely (!p -> numa_group )) {
2295
+ if (unlikely (!deref_curr_numa_group ( p ) )) {
2259
2296
unsigned int size = sizeof (struct numa_group ) +
2260
2297
4 * nr_node_ids * sizeof (unsigned long );
2261
2298
@@ -2291,7 +2328,7 @@ static void task_numa_group(struct task_struct *p, int cpupid, int flags,
2291
2328
if (!grp )
2292
2329
goto no_join ;
2293
2330
2294
- my_grp = p -> numa_group ;
2331
+ my_grp = deref_curr_numa_group ( p ) ;
2295
2332
if (grp == my_grp )
2296
2333
goto no_join ;
2297
2334
@@ -2362,7 +2399,8 @@ static void task_numa_group(struct task_struct *p, int cpupid, int flags,
2362
2399
*/
2363
2400
void task_numa_free (struct task_struct * p , bool final )
2364
2401
{
2365
- struct numa_group * grp = p -> numa_group ;
2402
+ /* safe: p either is current or is being freed by current */
2403
+ struct numa_group * grp = rcu_dereference_raw (p -> numa_group );
2366
2404
unsigned long * numa_faults = p -> numa_faults ;
2367
2405
unsigned long flags ;
2368
2406
int i ;
@@ -2442,7 +2480,7 @@ void task_numa_fault(int last_cpupid, int mem_node, int pages, int flags)
2442
2480
* actively using should be counted as local. This allows the
2443
2481
* scan rate to slow down when a workload has settled down.
2444
2482
*/
2445
- ng = p -> numa_group ;
2483
+ ng = deref_curr_numa_group ( p ) ;
2446
2484
if (!priv && !local && ng && ng -> active_nodes > 1 &&
2447
2485
numa_is_active_node (cpu_node , ng ) &&
2448
2486
numa_is_active_node (mem_node , ng ))
@@ -10460,18 +10498,22 @@ void show_numa_stats(struct task_struct *p, struct seq_file *m)
10460
10498
{
10461
10499
int node ;
10462
10500
unsigned long tsf = 0 , tpf = 0 , gsf = 0 , gpf = 0 ;
10501
+ struct numa_group * ng ;
10463
10502
10503
+ rcu_read_lock ();
10504
+ ng = rcu_dereference (p -> numa_group );
10464
10505
for_each_online_node (node ) {
10465
10506
if (p -> numa_faults ) {
10466
10507
tsf = p -> numa_faults [task_faults_idx (NUMA_MEM , node , 0 )];
10467
10508
tpf = p -> numa_faults [task_faults_idx (NUMA_MEM , node , 1 )];
10468
10509
}
10469
- if (p -> numa_group ) {
10470
- gsf = p -> numa_group -> faults [task_faults_idx (NUMA_MEM , node , 0 )],
10471
- gpf = p -> numa_group -> faults [task_faults_idx (NUMA_MEM , node , 1 )];
10510
+ if (ng ) {
10511
+ gsf = ng -> faults [task_faults_idx (NUMA_MEM , node , 0 )],
10512
+ gpf = ng -> faults [task_faults_idx (NUMA_MEM , node , 1 )];
10472
10513
}
10473
10514
print_numa_stats (m , node , tsf , tpf , gsf , gpf );
10474
10515
}
10516
+ rcu_read_unlock ();
10475
10517
}
10476
10518
#endif /* CONFIG_NUMA_BALANCING */
10477
10519
#endif /* CONFIG_SCHED_DEBUG */
0 commit comments