13
13
14
14
#include "iommu-priv.h"
15
15
16
+ /*
17
+ * Return the fault parameter of a device if it exists. Otherwise, return NULL.
18
+ * On a successful return, the caller takes a reference of this parameter and
19
+ * should put it after use by calling iopf_put_dev_fault_param().
20
+ */
21
+ static struct iommu_fault_param * iopf_get_dev_fault_param (struct device * dev )
22
+ {
23
+ struct dev_iommu * param = dev -> iommu ;
24
+ struct iommu_fault_param * fault_param ;
25
+
26
+ rcu_read_lock ();
27
+ fault_param = rcu_dereference (param -> fault_param );
28
+ if (fault_param && !refcount_inc_not_zero (& fault_param -> users ))
29
+ fault_param = NULL ;
30
+ rcu_read_unlock ();
31
+
32
+ return fault_param ;
33
+ }
34
+
35
+ /* Caller must hold a reference of the fault parameter. */
36
+ static void iopf_put_dev_fault_param (struct iommu_fault_param * fault_param )
37
+ {
38
+ if (refcount_dec_and_test (& fault_param -> users ))
39
+ kfree_rcu (fault_param , rcu );
40
+ }
41
+
16
42
void iopf_free_group (struct iopf_group * group )
17
43
{
18
44
struct iopf_fault * iopf , * next ;
@@ -22,6 +48,8 @@ void iopf_free_group(struct iopf_group *group)
22
48
kfree (iopf );
23
49
}
24
50
51
+ /* Pair with iommu_report_device_fault(). */
52
+ iopf_put_dev_fault_param (group -> fault_param );
25
53
kfree (group );
26
54
}
27
55
EXPORT_SYMBOL_GPL (iopf_free_group );
@@ -135,7 +163,7 @@ static int iommu_handle_iopf(struct iommu_fault *fault,
135
163
goto cleanup_partial ;
136
164
}
137
165
138
- group -> dev = dev ;
166
+ group -> fault_param = iopf_param ;
139
167
group -> last_fault .fault = * fault ;
140
168
INIT_LIST_HEAD (& group -> faults );
141
169
group -> domain = domain ;
@@ -178,64 +206,61 @@ static int iommu_handle_iopf(struct iommu_fault *fault,
178
206
*/
179
207
int iommu_report_device_fault (struct device * dev , struct iopf_fault * evt )
180
208
{
209
+ bool last_prq = evt -> fault .type == IOMMU_FAULT_PAGE_REQ &&
210
+ (evt -> fault .prm .flags & IOMMU_FAULT_PAGE_REQUEST_LAST_PAGE );
181
211
struct iommu_fault_param * fault_param ;
182
- struct iopf_fault * evt_pending = NULL ;
183
- struct dev_iommu * param = dev -> iommu ;
184
- int ret = 0 ;
212
+ struct iopf_fault * evt_pending ;
213
+ int ret ;
185
214
186
- mutex_lock (& param -> lock );
187
- fault_param = param -> fault_param ;
188
- if (!fault_param ) {
189
- mutex_unlock (& param -> lock );
215
+ fault_param = iopf_get_dev_fault_param (dev );
216
+ if (!fault_param )
190
217
return - EINVAL ;
191
- }
192
218
193
219
mutex_lock (& fault_param -> lock );
194
- if (evt -> fault .type == IOMMU_FAULT_PAGE_REQ &&
195
- (evt -> fault .prm .flags & IOMMU_FAULT_PAGE_REQUEST_LAST_PAGE )) {
220
+ if (last_prq ) {
196
221
evt_pending = kmemdup (evt , sizeof (struct iopf_fault ),
197
222
GFP_KERNEL );
198
223
if (!evt_pending ) {
199
224
ret = - ENOMEM ;
200
- goto done_unlock ;
225
+ goto err_unlock ;
201
226
}
202
227
list_add_tail (& evt_pending -> list , & fault_param -> faults );
203
228
}
204
229
205
230
ret = iommu_handle_iopf (& evt -> fault , fault_param );
206
- if (ret && evt_pending ) {
231
+ if (ret )
232
+ goto err_free ;
233
+
234
+ mutex_unlock (& fault_param -> lock );
235
+ /* The reference count of fault_param is now held by iopf_group. */
236
+ if (!last_prq )
237
+ iopf_put_dev_fault_param (fault_param );
238
+
239
+ return 0 ;
240
+ err_free :
241
+ if (last_prq ) {
207
242
list_del (& evt_pending -> list );
208
243
kfree (evt_pending );
209
244
}
210
- done_unlock :
245
+ err_unlock :
211
246
mutex_unlock (& fault_param -> lock );
212
- mutex_unlock ( & param -> lock );
247
+ iopf_put_dev_fault_param ( fault_param );
213
248
214
249
return ret ;
215
250
}
216
251
EXPORT_SYMBOL_GPL (iommu_report_device_fault );
217
252
218
- int iommu_page_response (struct device * dev ,
219
- struct iommu_page_response * msg )
253
+ static int iommu_page_response (struct iopf_group * group ,
254
+ struct iommu_page_response * msg )
220
255
{
221
256
bool needs_pasid ;
222
257
int ret = - EINVAL ;
223
258
struct iopf_fault * evt ;
224
259
struct iommu_fault_page_request * prm ;
225
- struct dev_iommu * param = dev -> iommu ;
226
- struct iommu_fault_param * fault_param ;
260
+ struct device * dev = group -> fault_param -> dev ;
227
261
const struct iommu_ops * ops = dev_iommu_ops (dev );
228
262
bool has_pasid = msg -> flags & IOMMU_PAGE_RESP_PASID_VALID ;
229
-
230
- if (!ops -> page_response )
231
- return - ENODEV ;
232
-
233
- mutex_lock (& param -> lock );
234
- fault_param = param -> fault_param ;
235
- if (!fault_param ) {
236
- mutex_unlock (& param -> lock );
237
- return - EINVAL ;
238
- }
263
+ struct iommu_fault_param * fault_param = group -> fault_param ;
239
264
240
265
/* Only send response if there is a fault report pending */
241
266
mutex_lock (& fault_param -> lock );
@@ -276,10 +301,9 @@ int iommu_page_response(struct device *dev,
276
301
277
302
done_unlock :
278
303
mutex_unlock (& fault_param -> lock );
279
- mutex_unlock ( & param -> lock );
304
+
280
305
return ret ;
281
306
}
282
- EXPORT_SYMBOL_GPL (iommu_page_response );
283
307
284
308
/**
285
309
* iopf_queue_flush_dev - Ensure that all queued faults have been processed
@@ -295,22 +319,20 @@ EXPORT_SYMBOL_GPL(iommu_page_response);
295
319
*/
296
320
int iopf_queue_flush_dev (struct device * dev )
297
321
{
298
- int ret = 0 ;
299
322
struct iommu_fault_param * iopf_param ;
300
- struct dev_iommu * param = dev -> iommu ;
301
323
302
- if (!param )
324
+ /*
325
+ * It's a driver bug to be here after iopf_queue_remove_device().
326
+ * Therefore, it's safe to dereference the fault parameter without
327
+ * holding the lock.
328
+ */
329
+ iopf_param = rcu_dereference_check (dev -> iommu -> fault_param , true);
330
+ if (WARN_ON (!iopf_param ))
303
331
return - ENODEV ;
304
332
305
- mutex_lock (& param -> lock );
306
- iopf_param = param -> fault_param ;
307
- if (iopf_param )
308
- flush_workqueue (iopf_param -> queue -> wq );
309
- else
310
- ret = - ENODEV ;
311
- mutex_unlock (& param -> lock );
333
+ flush_workqueue (iopf_param -> queue -> wq );
312
334
313
- return ret ;
335
+ return 0 ;
314
336
}
315
337
EXPORT_SYMBOL_GPL (iopf_queue_flush_dev );
316
338
@@ -335,7 +357,7 @@ int iopf_group_response(struct iopf_group *group,
335
357
(iopf -> fault .prm .flags & IOMMU_FAULT_PAGE_RESPONSE_NEEDS_PASID ))
336
358
resp .flags = IOMMU_PAGE_RESP_PASID_VALID ;
337
359
338
- return iommu_page_response (group -> dev , & resp );
360
+ return iommu_page_response (group , & resp );
339
361
}
340
362
EXPORT_SYMBOL_GPL (iopf_group_response );
341
363
@@ -384,10 +406,15 @@ int iopf_queue_add_device(struct iopf_queue *queue, struct device *dev)
384
406
int ret = 0 ;
385
407
struct dev_iommu * param = dev -> iommu ;
386
408
struct iommu_fault_param * fault_param ;
409
+ const struct iommu_ops * ops = dev_iommu_ops (dev );
410
+
411
+ if (!ops -> page_response )
412
+ return - ENODEV ;
387
413
388
414
mutex_lock (& queue -> lock );
389
415
mutex_lock (& param -> lock );
390
- if (param -> fault_param ) {
416
+ if (rcu_dereference_check (param -> fault_param ,
417
+ lockdep_is_held (& param -> lock ))) {
391
418
ret = - EBUSY ;
392
419
goto done_unlock ;
393
420
}
@@ -402,10 +429,11 @@ int iopf_queue_add_device(struct iopf_queue *queue, struct device *dev)
402
429
INIT_LIST_HEAD (& fault_param -> faults );
403
430
INIT_LIST_HEAD (& fault_param -> partial );
404
431
fault_param -> dev = dev ;
432
+ refcount_set (& fault_param -> users , 1 );
405
433
list_add (& fault_param -> queue_list , & queue -> devices );
406
434
fault_param -> queue = queue ;
407
435
408
- param -> fault_param = fault_param ;
436
+ rcu_assign_pointer ( param -> fault_param , fault_param ) ;
409
437
410
438
done_unlock :
411
439
mutex_unlock (& param -> lock );
@@ -429,10 +457,12 @@ int iopf_queue_remove_device(struct iopf_queue *queue, struct device *dev)
429
457
int ret = 0 ;
430
458
struct iopf_fault * iopf , * next ;
431
459
struct dev_iommu * param = dev -> iommu ;
432
- struct iommu_fault_param * fault_param = param -> fault_param ;
460
+ struct iommu_fault_param * fault_param ;
433
461
434
462
mutex_lock (& queue -> lock );
435
463
mutex_lock (& param -> lock );
464
+ fault_param = rcu_dereference_check (param -> fault_param ,
465
+ lockdep_is_held (& param -> lock ));
436
466
if (!fault_param ) {
437
467
ret = - ENODEV ;
438
468
goto unlock ;
@@ -454,8 +484,9 @@ int iopf_queue_remove_device(struct iopf_queue *queue, struct device *dev)
454
484
list_for_each_entry_safe (iopf , next , & fault_param -> partial , list )
455
485
kfree (iopf );
456
486
457
- param -> fault_param = NULL ;
458
- kfree (fault_param );
487
+ /* dec the ref owned by iopf_queue_add_device() */
488
+ rcu_assign_pointer (param -> fault_param , NULL );
489
+ iopf_put_dev_fault_param (fault_param );
459
490
unlock :
460
491
mutex_unlock (& param -> lock );
461
492
mutex_unlock (& queue -> lock );
0 commit comments