Skip to content

Commit a74c077

Browse files
LuBaolujoergroedel
authored andcommitted
iommu: Use refcount for fault data access
The per-device fault data structure stores information about faults occurring on a device. Its lifetime spans from IOPF enablement to disablement. Multiple paths, including IOPF reporting, handling, and responding, may access it concurrently. Previously, a mutex protected the fault data from use after free. But this is not performance friendly due to the critical nature of IOPF handling paths. Refine this with a refcount-based approach. The fault data pointer is obtained within an RCU read region with a refcount. The fault data pointer is returned for usage only when the pointer is valid and a refcount is successfully obtained. The fault data is freed with kfree_rcu(), ensuring data is only freed after all RCU critical regions complete. An iopf handling work starts once an iopf group is created. The handling work continues until iommu_page_response() is called to respond to the iopf and the iopf group is freed. During this time, the device fault parameter should always be available. Add a pointer to the device fault parameter in the iopf_group structure and hold the reference until the iopf_group is freed. Make iommu_page_response() static as it is only used in io-pgfault.c. Co-developed-by: Jason Gunthorpe <[email protected]> Signed-off-by: Jason Gunthorpe <[email protected]> Signed-off-by: Lu Baolu <[email protected]> Reviewed-by: Jason Gunthorpe <[email protected]> Reviewed-by: Kevin Tian <[email protected]> Tested-by: Yan Zhao <[email protected]> Link: https://lore.kernel.org/r/[email protected] Signed-off-by: Joerg Roedel <[email protected]>
1 parent cc7338e commit a74c077

File tree

3 files changed

+88
-58
lines changed

3 files changed

+88
-58
lines changed

drivers/iommu/io-pgfault.c

Lines changed: 79 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,32 @@
1313

1414
#include "iommu-priv.h"
1515

16+
/*
17+
* Return the fault parameter of a device if it exists. Otherwise, return NULL.
18+
* On a successful return, the caller takes a reference of this parameter and
19+
* should put it after use by calling iopf_put_dev_fault_param().
20+
*/
21+
static struct iommu_fault_param *iopf_get_dev_fault_param(struct device *dev)
22+
{
23+
struct dev_iommu *param = dev->iommu;
24+
struct iommu_fault_param *fault_param;
25+
26+
rcu_read_lock();
27+
fault_param = rcu_dereference(param->fault_param);
28+
if (fault_param && !refcount_inc_not_zero(&fault_param->users))
29+
fault_param = NULL;
30+
rcu_read_unlock();
31+
32+
return fault_param;
33+
}
34+
35+
/* Caller must hold a reference of the fault parameter. */
36+
static void iopf_put_dev_fault_param(struct iommu_fault_param *fault_param)
37+
{
38+
if (refcount_dec_and_test(&fault_param->users))
39+
kfree_rcu(fault_param, rcu);
40+
}
41+
1642
void iopf_free_group(struct iopf_group *group)
1743
{
1844
struct iopf_fault *iopf, *next;
@@ -22,6 +48,8 @@ void iopf_free_group(struct iopf_group *group)
2248
kfree(iopf);
2349
}
2450

51+
/* Pair with iommu_report_device_fault(). */
52+
iopf_put_dev_fault_param(group->fault_param);
2553
kfree(group);
2654
}
2755
EXPORT_SYMBOL_GPL(iopf_free_group);
@@ -135,7 +163,7 @@ static int iommu_handle_iopf(struct iommu_fault *fault,
135163
goto cleanup_partial;
136164
}
137165

138-
group->dev = dev;
166+
group->fault_param = iopf_param;
139167
group->last_fault.fault = *fault;
140168
INIT_LIST_HEAD(&group->faults);
141169
group->domain = domain;
@@ -178,64 +206,61 @@ static int iommu_handle_iopf(struct iommu_fault *fault,
178206
*/
179207
int iommu_report_device_fault(struct device *dev, struct iopf_fault *evt)
180208
{
209+
bool last_prq = evt->fault.type == IOMMU_FAULT_PAGE_REQ &&
210+
(evt->fault.prm.flags & IOMMU_FAULT_PAGE_REQUEST_LAST_PAGE);
181211
struct iommu_fault_param *fault_param;
182-
struct iopf_fault *evt_pending = NULL;
183-
struct dev_iommu *param = dev->iommu;
184-
int ret = 0;
212+
struct iopf_fault *evt_pending;
213+
int ret;
185214

186-
mutex_lock(&param->lock);
187-
fault_param = param->fault_param;
188-
if (!fault_param) {
189-
mutex_unlock(&param->lock);
215+
fault_param = iopf_get_dev_fault_param(dev);
216+
if (!fault_param)
190217
return -EINVAL;
191-
}
192218

193219
mutex_lock(&fault_param->lock);
194-
if (evt->fault.type == IOMMU_FAULT_PAGE_REQ &&
195-
(evt->fault.prm.flags & IOMMU_FAULT_PAGE_REQUEST_LAST_PAGE)) {
220+
if (last_prq) {
196221
evt_pending = kmemdup(evt, sizeof(struct iopf_fault),
197222
GFP_KERNEL);
198223
if (!evt_pending) {
199224
ret = -ENOMEM;
200-
goto done_unlock;
225+
goto err_unlock;
201226
}
202227
list_add_tail(&evt_pending->list, &fault_param->faults);
203228
}
204229

205230
ret = iommu_handle_iopf(&evt->fault, fault_param);
206-
if (ret && evt_pending) {
231+
if (ret)
232+
goto err_free;
233+
234+
mutex_unlock(&fault_param->lock);
235+
/* The reference count of fault_param is now held by iopf_group. */
236+
if (!last_prq)
237+
iopf_put_dev_fault_param(fault_param);
238+
239+
return 0;
240+
err_free:
241+
if (last_prq) {
207242
list_del(&evt_pending->list);
208243
kfree(evt_pending);
209244
}
210-
done_unlock:
245+
err_unlock:
211246
mutex_unlock(&fault_param->lock);
212-
mutex_unlock(&param->lock);
247+
iopf_put_dev_fault_param(fault_param);
213248

214249
return ret;
215250
}
216251
EXPORT_SYMBOL_GPL(iommu_report_device_fault);
217252

218-
int iommu_page_response(struct device *dev,
219-
struct iommu_page_response *msg)
253+
static int iommu_page_response(struct iopf_group *group,
254+
struct iommu_page_response *msg)
220255
{
221256
bool needs_pasid;
222257
int ret = -EINVAL;
223258
struct iopf_fault *evt;
224259
struct iommu_fault_page_request *prm;
225-
struct dev_iommu *param = dev->iommu;
226-
struct iommu_fault_param *fault_param;
260+
struct device *dev = group->fault_param->dev;
227261
const struct iommu_ops *ops = dev_iommu_ops(dev);
228262
bool has_pasid = msg->flags & IOMMU_PAGE_RESP_PASID_VALID;
229-
230-
if (!ops->page_response)
231-
return -ENODEV;
232-
233-
mutex_lock(&param->lock);
234-
fault_param = param->fault_param;
235-
if (!fault_param) {
236-
mutex_unlock(&param->lock);
237-
return -EINVAL;
238-
}
263+
struct iommu_fault_param *fault_param = group->fault_param;
239264

240265
/* Only send response if there is a fault report pending */
241266
mutex_lock(&fault_param->lock);
@@ -276,10 +301,9 @@ int iommu_page_response(struct device *dev,
276301

277302
done_unlock:
278303
mutex_unlock(&fault_param->lock);
279-
mutex_unlock(&param->lock);
304+
280305
return ret;
281306
}
282-
EXPORT_SYMBOL_GPL(iommu_page_response);
283307

284308
/**
285309
* iopf_queue_flush_dev - Ensure that all queued faults have been processed
@@ -295,22 +319,20 @@ EXPORT_SYMBOL_GPL(iommu_page_response);
295319
*/
296320
int iopf_queue_flush_dev(struct device *dev)
297321
{
298-
int ret = 0;
299322
struct iommu_fault_param *iopf_param;
300-
struct dev_iommu *param = dev->iommu;
301323

302-
if (!param)
324+
/*
325+
* It's a driver bug to be here after iopf_queue_remove_device().
326+
* Therefore, it's safe to dereference the fault parameter without
327+
* holding the lock.
328+
*/
329+
iopf_param = rcu_dereference_check(dev->iommu->fault_param, true);
330+
if (WARN_ON(!iopf_param))
303331
return -ENODEV;
304332

305-
mutex_lock(&param->lock);
306-
iopf_param = param->fault_param;
307-
if (iopf_param)
308-
flush_workqueue(iopf_param->queue->wq);
309-
else
310-
ret = -ENODEV;
311-
mutex_unlock(&param->lock);
333+
flush_workqueue(iopf_param->queue->wq);
312334

313-
return ret;
335+
return 0;
314336
}
315337
EXPORT_SYMBOL_GPL(iopf_queue_flush_dev);
316338

@@ -335,7 +357,7 @@ int iopf_group_response(struct iopf_group *group,
335357
(iopf->fault.prm.flags & IOMMU_FAULT_PAGE_RESPONSE_NEEDS_PASID))
336358
resp.flags = IOMMU_PAGE_RESP_PASID_VALID;
337359

338-
return iommu_page_response(group->dev, &resp);
360+
return iommu_page_response(group, &resp);
339361
}
340362
EXPORT_SYMBOL_GPL(iopf_group_response);
341363

@@ -384,10 +406,15 @@ int iopf_queue_add_device(struct iopf_queue *queue, struct device *dev)
384406
int ret = 0;
385407
struct dev_iommu *param = dev->iommu;
386408
struct iommu_fault_param *fault_param;
409+
const struct iommu_ops *ops = dev_iommu_ops(dev);
410+
411+
if (!ops->page_response)
412+
return -ENODEV;
387413

388414
mutex_lock(&queue->lock);
389415
mutex_lock(&param->lock);
390-
if (param->fault_param) {
416+
if (rcu_dereference_check(param->fault_param,
417+
lockdep_is_held(&param->lock))) {
391418
ret = -EBUSY;
392419
goto done_unlock;
393420
}
@@ -402,10 +429,11 @@ int iopf_queue_add_device(struct iopf_queue *queue, struct device *dev)
402429
INIT_LIST_HEAD(&fault_param->faults);
403430
INIT_LIST_HEAD(&fault_param->partial);
404431
fault_param->dev = dev;
432+
refcount_set(&fault_param->users, 1);
405433
list_add(&fault_param->queue_list, &queue->devices);
406434
fault_param->queue = queue;
407435

408-
param->fault_param = fault_param;
436+
rcu_assign_pointer(param->fault_param, fault_param);
409437

410438
done_unlock:
411439
mutex_unlock(&param->lock);
@@ -429,10 +457,12 @@ int iopf_queue_remove_device(struct iopf_queue *queue, struct device *dev)
429457
int ret = 0;
430458
struct iopf_fault *iopf, *next;
431459
struct dev_iommu *param = dev->iommu;
432-
struct iommu_fault_param *fault_param = param->fault_param;
460+
struct iommu_fault_param *fault_param;
433461

434462
mutex_lock(&queue->lock);
435463
mutex_lock(&param->lock);
464+
fault_param = rcu_dereference_check(param->fault_param,
465+
lockdep_is_held(&param->lock));
436466
if (!fault_param) {
437467
ret = -ENODEV;
438468
goto unlock;
@@ -454,8 +484,9 @@ int iopf_queue_remove_device(struct iopf_queue *queue, struct device *dev)
454484
list_for_each_entry_safe(iopf, next, &fault_param->partial, list)
455485
kfree(iopf);
456486

457-
param->fault_param = NULL;
458-
kfree(fault_param);
487+
/* dec the ref owned by iopf_queue_add_device() */
488+
rcu_assign_pointer(param->fault_param, NULL);
489+
iopf_put_dev_fault_param(fault_param);
459490
unlock:
460491
mutex_unlock(&param->lock);
461492
mutex_unlock(&queue->lock);

drivers/iommu/iommu-sva.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ static void iommu_sva_handle_iopf(struct work_struct *work)
251251

252252
static int iommu_sva_iopf_handler(struct iopf_group *group)
253253
{
254-
struct iommu_fault_param *fault_param = group->dev->iommu->fault_param;
254+
struct iommu_fault_param *fault_param = group->fault_param;
255255

256256
INIT_WORK(&group->work, iommu_sva_handle_iopf);
257257
if (!queue_work(fault_param->queue->wq, &group->work))

include/linux/iommu.h

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ struct iommu_dirty_ops;
4141
struct notifier_block;
4242
struct iommu_sva;
4343
struct iommu_dma_cookie;
44+
struct iommu_fault_param;
4445

4546
#define IOMMU_FAULT_PERM_READ (1 << 0) /* read */
4647
#define IOMMU_FAULT_PERM_WRITE (1 << 1) /* write */
@@ -129,8 +130,9 @@ struct iopf_group {
129130
struct iopf_fault last_fault;
130131
struct list_head faults;
131132
struct work_struct work;
132-
struct device *dev;
133133
struct iommu_domain *domain;
134+
/* The device's fault data parameter. */
135+
struct iommu_fault_param *fault_param;
134136
};
135137

136138
/**
@@ -679,6 +681,8 @@ struct iommu_device {
679681
/**
680682
* struct iommu_fault_param - per-device IOMMU fault data
681683
* @lock: protect pending faults list
684+
* @users: user counter to manage the lifetime of the data
685+
* @rcu: rcu head for kfree_rcu()
682686
* @dev: the device that owns this param
683687
* @queue: IOPF queue
684688
* @queue_list: index into queue->devices
@@ -688,6 +692,8 @@ struct iommu_device {
688692
*/
689693
struct iommu_fault_param {
690694
struct mutex lock;
695+
refcount_t users;
696+
struct rcu_head rcu;
691697

692698
struct device *dev;
693699
struct iopf_queue *queue;
@@ -715,7 +721,7 @@ struct iommu_fault_param {
715721
*/
716722
struct dev_iommu {
717723
struct mutex lock;
718-
struct iommu_fault_param *fault_param;
724+
struct iommu_fault_param __rcu *fault_param;
719725
struct iommu_fwspec *fwspec;
720726
struct iommu_device *iommu_dev;
721727
void *priv;
@@ -1543,7 +1549,6 @@ void iopf_queue_free(struct iopf_queue *queue);
15431549
int iopf_queue_discard_partial(struct iopf_queue *queue);
15441550
void iopf_free_group(struct iopf_group *group);
15451551
int iommu_report_device_fault(struct device *dev, struct iopf_fault *evt);
1546-
int iommu_page_response(struct device *dev, struct iommu_page_response *msg);
15471552
int iopf_group_response(struct iopf_group *group,
15481553
enum iommu_page_response_code status);
15491554
#else
@@ -1588,12 +1593,6 @@ iommu_report_device_fault(struct device *dev, struct iopf_fault *evt)
15881593
return -ENODEV;
15891594
}
15901595

1591-
static inline int
1592-
iommu_page_response(struct device *dev, struct iommu_page_response *msg)
1593-
{
1594-
return -ENODEV;
1595-
}
1596-
15971596
static inline int iopf_group_response(struct iopf_group *group,
15981597
enum iommu_page_response_code status)
15991598
{

0 commit comments

Comments
 (0)