@@ -228,13 +228,57 @@ static LIST_HEAD(global_svm_list);
228
228
list_for_each_entry((sdev), &(svm)->devs, list) \
229
229
if ((d) != (sdev)->dev) {} else
230
230
231
+ static int pasid_to_svm_sdev (struct device * dev , unsigned int pasid ,
232
+ struct intel_svm * * rsvm ,
233
+ struct intel_svm_dev * * rsdev )
234
+ {
235
+ struct intel_svm_dev * d , * sdev = NULL ;
236
+ struct intel_svm * svm ;
237
+
238
+ /* The caller should hold the pasid_mutex lock */
239
+ if (WARN_ON (!mutex_is_locked (& pasid_mutex )))
240
+ return - EINVAL ;
241
+
242
+ if (pasid == INVALID_IOASID || pasid >= PASID_MAX )
243
+ return - EINVAL ;
244
+
245
+ svm = ioasid_find (NULL , pasid , NULL );
246
+ if (IS_ERR (svm ))
247
+ return PTR_ERR (svm );
248
+
249
+ if (!svm )
250
+ goto out ;
251
+
252
+ /*
253
+ * If we found svm for the PASID, there must be at least one device
254
+ * bond.
255
+ */
256
+ if (WARN_ON (list_empty (& svm -> devs )))
257
+ return - EINVAL ;
258
+
259
+ rcu_read_lock ();
260
+ list_for_each_entry_rcu (d , & svm -> devs , list ) {
261
+ if (d -> dev == dev ) {
262
+ sdev = d ;
263
+ break ;
264
+ }
265
+ }
266
+ rcu_read_unlock ();
267
+
268
+ out :
269
+ * rsvm = svm ;
270
+ * rsdev = sdev ;
271
+
272
+ return 0 ;
273
+ }
274
+
231
275
int intel_svm_bind_gpasid (struct iommu_domain * domain , struct device * dev ,
232
276
struct iommu_gpasid_bind_data * data )
233
277
{
234
278
struct intel_iommu * iommu = device_to_iommu (dev , NULL , NULL );
279
+ struct intel_svm_dev * sdev = NULL ;
235
280
struct dmar_domain * dmar_domain ;
236
- struct intel_svm_dev * sdev ;
237
- struct intel_svm * svm ;
281
+ struct intel_svm * svm = NULL ;
238
282
int ret = 0 ;
239
283
240
284
if (WARN_ON (!iommu ) || !data )
@@ -261,35 +305,23 @@ int intel_svm_bind_gpasid(struct iommu_domain *domain, struct device *dev,
261
305
dmar_domain = to_dmar_domain (domain );
262
306
263
307
mutex_lock (& pasid_mutex );
264
- svm = ioasid_find (NULL , data -> hpasid , NULL );
265
- if (IS_ERR (svm )) {
266
- ret = PTR_ERR (svm );
308
+ ret = pasid_to_svm_sdev (dev , data -> hpasid , & svm , & sdev );
309
+ if (ret )
267
310
goto out ;
268
- }
269
-
270
- if (svm ) {
271
- /*
272
- * If we found svm for the PASID, there must be at
273
- * least one device bond, otherwise svm should be freed.
274
- */
275
- if (WARN_ON (list_empty (& svm -> devs ))) {
276
- ret = - EINVAL ;
277
- goto out ;
278
- }
279
311
312
+ if (sdev ) {
280
313
/*
281
314
* Do not allow multiple bindings of the same device-PASID since
282
315
* there is only one SL page tables per PASID. We may revisit
283
316
* once sharing PGD across domains are supported.
284
317
*/
285
- for_each_svm_dev (sdev , svm , dev ) {
286
- dev_warn_ratelimited (dev ,
287
- "Already bound with PASID %u\n" ,
288
- svm -> pasid );
289
- ret = - EBUSY ;
290
- goto out ;
291
- }
292
- } else {
318
+ dev_warn_ratelimited (dev , "Already bound with PASID %u\n" ,
319
+ svm -> pasid );
320
+ ret = - EBUSY ;
321
+ goto out ;
322
+ }
323
+
324
+ if (!svm ) {
293
325
/* We come here when PASID has never been bond to a device. */
294
326
svm = kzalloc (sizeof (* svm ), GFP_KERNEL );
295
327
if (!svm ) {
@@ -372,25 +404,17 @@ int intel_svm_unbind_gpasid(struct device *dev, int pasid)
372
404
struct intel_iommu * iommu = device_to_iommu (dev , NULL , NULL );
373
405
struct intel_svm_dev * sdev ;
374
406
struct intel_svm * svm ;
375
- int ret = - EINVAL ;
407
+ int ret ;
376
408
377
409
if (WARN_ON (!iommu ))
378
410
return - EINVAL ;
379
411
380
412
mutex_lock (& pasid_mutex );
381
- svm = ioasid_find (NULL , pasid , NULL );
382
- if (!svm ) {
383
- ret = - EINVAL ;
384
- goto out ;
385
- }
386
-
387
- if (IS_ERR (svm )) {
388
- ret = PTR_ERR (svm );
413
+ ret = pasid_to_svm_sdev (dev , pasid , & svm , & sdev );
414
+ if (ret )
389
415
goto out ;
390
- }
391
416
392
- for_each_svm_dev (sdev , svm , dev ) {
393
- ret = 0 ;
417
+ if (sdev ) {
394
418
if (iommu_dev_feature_enabled (dev , IOMMU_DEV_FEAT_AUX ))
395
419
sdev -> users -- ;
396
420
if (!sdev -> users ) {
@@ -414,7 +438,6 @@ int intel_svm_unbind_gpasid(struct device *dev, int pasid)
414
438
kfree (svm );
415
439
}
416
440
}
417
- break ;
418
441
}
419
442
out :
420
443
mutex_unlock (& pasid_mutex );
@@ -592,7 +615,7 @@ intel_svm_bind_mm(struct device *dev, int flags, struct svm_dev_ops *ops,
592
615
if (sd )
593
616
* sd = sdev ;
594
617
ret = 0 ;
595
- out :
618
+ out :
596
619
return ret ;
597
620
}
598
621
@@ -608,17 +631,11 @@ static int intel_svm_unbind_mm(struct device *dev, int pasid)
608
631
if (!iommu )
609
632
goto out ;
610
633
611
- svm = ioasid_find (NULL , pasid , NULL );
612
- if (!svm )
613
- goto out ;
614
-
615
- if (IS_ERR (svm )) {
616
- ret = PTR_ERR (svm );
634
+ ret = pasid_to_svm_sdev (dev , pasid , & svm , & sdev );
635
+ if (ret )
617
636
goto out ;
618
- }
619
637
620
- for_each_svm_dev (sdev , svm , dev ) {
621
- ret = 0 ;
638
+ if (sdev ) {
622
639
sdev -> users -- ;
623
640
if (!sdev -> users ) {
624
641
list_del_rcu (& sdev -> list );
@@ -647,10 +664,8 @@ static int intel_svm_unbind_mm(struct device *dev, int pasid)
647
664
kfree (svm );
648
665
}
649
666
}
650
- break ;
651
667
}
652
- out :
653
-
668
+ out :
654
669
return ret ;
655
670
}
656
671
0 commit comments