@@ -24,6 +24,11 @@ namespace fs = filesystem;
2424
2525namespace ur_loader {
2626
27+ struct ur_device_tuple {
28+ ur_adapter_backend_t backend;
29+ ur_device_type_t device;
30+ };
31+
2732// Helper struct representing a ONEAPI_DEVICE_SELECTOR filter term.
2833struct FilterTerm {
2934 std::string backend;
@@ -37,7 +42,7 @@ struct FilterTerm {
3742 {" native_cpu" , UR_ADAPTER_BACKEND_NATIVE_CPU},
3843 };
3944
40- bool matchesBackend (const ur_adapter_manifest &manifest ) const {
45+ bool matchesBackend (const ur_adapter_backend_t &match_backend ) const {
4146 if (backend.front () == ' *' ) {
4247 return true ;
4348 }
@@ -49,7 +54,7 @@ struct FilterTerm {
4954 backend);
5055 return false ;
5156 }
52- if (backendIter->second == manifest. backend ) {
57+ if (backendIter->second == match_backend ) {
5358 return true ;
5459 }
5560 return false ;
@@ -60,12 +65,7 @@ struct FilterTerm {
6065 {" gpu" , UR_DEVICE_TYPE_GPU},
6166 {" fpga" , UR_DEVICE_TYPE_FPGA}};
6267
63- bool matchesDevices (const ur_adapter_manifest &manifest) const {
64- // If the adapter can report all device types then it matches.
65- if (std::find (manifest.device_types .begin (), manifest.device_types .end (),
66- UR_DEVICE_TYPE_ALL) != manifest.device_types .end ()) {
67- return true ;
68- }
68+ bool matchesDevices (const ur_device_type_t &match_device) const {
6969 for (auto deviceString : devices) {
7070 // We don't have a way to determine anything about device indices or
7171 // sub-devices at this stage so just match any numeric value we get.
@@ -79,20 +79,19 @@ struct FilterTerm {
7979 deviceString);
8080 continue ;
8181 }
82- if (std::find (manifest.device_types .begin (), manifest.device_types .end (),
83- deviceIter->second ) != manifest.device_types .end ()) {
82+ if (deviceIter->second == match_device) {
8483 return true ;
8584 }
8685 }
8786 return false ;
8887 }
8988
90- bool matches (const ur_adapter_manifest &manifest ) const {
91- if (!matchesBackend (manifest )) {
89+ bool matches (const ur_device_tuple &device_tuple ) const {
90+ if (!matchesBackend (device_tuple. backend )) {
9291 return false ;
9392 }
9493
95- return matchesDevices (manifest );
94+ return matchesDevices (device_tuple. device );
9695 }
9796};
9897
@@ -280,22 +279,31 @@ class AdapterRegistry {
280279 if (PositiveFilter) {
281280 positiveFilters.push_back ({backend, termPair.second });
282281 } else {
283- // To preserve the behaviour of the original pre-filter implementation,
284- // we interpret all negative filters as backend only. This isn't
285- // correct, see https://github.com/intel/llvm/issues/17086
286- negativeFilters.push_back ({backend, {" *" }});
282+ negativeFilters.push_back ({backend, termPair.second });
287283 }
288284 }
289285
286+ // If ONEAPI_DEVICE_SELECTOR only specified negative filters then we
287+ // implicitly add a positive filter accepting all backends and devices.
288+ if (positiveFilters.empty ()) {
289+ positiveFilters.push_back ({" *" , {" *" }});
290+ }
291+
290292 for (const auto &manifest : ur_adapter_manifests) {
291- auto matchesFilter = [manifest](const FilterTerm &f) -> bool {
292- return f.matches (manifest);
293- };
294- if (std::any_of (positiveFilters.begin (), positiveFilters.end (),
295- matchesFilter) &&
296- std::none_of (negativeFilters.begin (), negativeFilters.end (),
297- matchesFilter)) {
298- adapterNames.insert (manifest.library );
293+ // Check each device in the manifest.
294+ for (const auto &device : manifest.device_types ) {
295+ ur_device_tuple single_device = {manifest.backend , device};
296+
297+ auto matchesFilter = [single_device](const FilterTerm &f) -> bool {
298+ return f.matches (single_device);
299+ };
300+
301+ if (std::any_of (positiveFilters.begin (), positiveFilters.end (),
302+ matchesFilter) &&
303+ std::none_of (negativeFilters.begin (), negativeFilters.end (),
304+ matchesFilter)) {
305+ adapterNames.insert (manifest.library );
306+ }
299307 }
300308 }
301309
0 commit comments