99
1010#include < iostream>
1111#include < string>
12+ #include < thread>
1213
1314using namespace device ;
1415
@@ -42,24 +43,25 @@ void ConcreteAPI::initDevices() {
4243 return compare (c1->queueBuffer .getDefaultQueue ().get_device (), c2->queueBuffer .getDefaultQueue ().get_device ());
4344 });
4445
45- this ->setDevice (this -> currentDeviceId );
46+ this ->setDevice (0 );
4647 this ->deviceInitialized = true ;
4748}
4849
4950void ConcreteAPI::setDevice (int id) {
51+ {
52+ std::lock_guard guard (this ->apiMutex );
5053
51- if (id < 0 || id >= this ->getNumDevices ()) {
52- throw std::out_of_range{" Device index out of range" };
53- }
54+ if (id < 0 || id >= this ->getNumDevices ()) {
55+ throw std::out_of_range{" Device index out of range" };
56+ }
5457
55- this ->currentDeviceId = id;
56- auto *next = this ->availableDevices [id];
57- this ->currentStatistics = &next->statistics ;
58- this ->currentQueueBuffer = &next->queueBuffer ;
59- this ->currentDefaultQueue = &this ->currentQueueBuffer ->getDefaultQueue ();
60- this ->currentMemoryToSizeMap = &next->memoryToSizeMap ;
58+ if (deviceMap.empty ()) {
59+ // only print the first time
60+ printer.printInfo () << " Switched to device: " << this ->getDeviceName (id) << " by index " << id;
61+ }
6162
62- printer.printInfo () << " Switched to device: " << this ->getDeviceName (id) << " by index " << id;
63+ deviceMap[std::this_thread::get_id ()] = id;
64+ }
6365}
6466
6567void ConcreteAPI::initialize () {}
@@ -77,10 +79,7 @@ void ConcreteAPI::finalize() {
7779
7880 this ->graphs .clear ();
7981
80- this ->currentStatistics = nullptr ;
81- this ->currentQueueBuffer = nullptr ;
82- this ->currentDefaultQueue = nullptr ;
83- this ->currentMemoryToSizeMap = nullptr ;
82+ this ->deviceMap .clear ();
8483
8584 this ->m_isFinalized = true ;
8685 this ->deviceInitialized = false ;
@@ -92,15 +91,20 @@ int ConcreteAPI::getDeviceId() {
9291 if (!deviceInitialized) {
9392 logError () << " Device has not been selected. Please, select device before requesting device Id" ;
9493 }
95- return currentDeviceId;
94+ const auto myId = std::this_thread::get_id ();
95+ auto findResult = deviceMap.find (myId);
96+ if (findResult == deviceMap.end ()) {
97+ logError () << " Thread device context not initialized. Error." ;
98+ }
99+ return findResult->second ;;
96100}
97101
98102unsigned int ConcreteAPI::getGlobMemAlignment () {
99- auto device = this ->currentDefaultQueue -> get_device ();
103+ auto device = this ->currentDefaultQueue (). get_device ();
100104 return 128 ; // ToDo: find attribute; not: device.get_info<info::device::mem_base_addr_align>();
101105}
102106
103- void ConcreteAPI::syncDevice () { this ->currentQueueBuffer -> syncAllQueuesWithHost (); }
107+ void ConcreteAPI::syncDevice () { this ->currentQueueBuffer (). syncAllQueuesWithHost (); }
104108
105109std::string ConcreteAPI::getDeviceInfoAsText (int id) {
106110 if (id < 0 || id >= this ->getNumDevices ())
@@ -109,7 +113,7 @@ std::string ConcreteAPI::getDeviceInfoAsText(int id) {
109113 auto device = this ->availableDevices [id]->queueBuffer .getDefaultQueue ().get_device ();
110114 return this ->getDeviceInfoAsTextInternal (device);
111115}
112- std::string ConcreteAPI::getCurrentDeviceInfoAsText () { return this ->getDeviceInfoAsText (this -> currentDeviceId ); }
116+ std::string ConcreteAPI::getCurrentDeviceInfoAsText () { return this ->getDeviceInfoAsText (getDeviceId () ); }
113117
114118std::string ConcreteAPI::getDeviceInfoAsTextInternal (sycl::device& dev) {
115119 std::ostringstream info{};
@@ -126,7 +130,7 @@ std::string ConcreteAPI::getDeviceInfoAsTextInternal(sycl::device& dev) {
126130
127131bool ConcreteAPI::isUnifiedMemoryDefault () {
128132 // suboptimal (i.e. we'd need to query if USM needs to be migrated or not), but there's probably nothing better for now
129- auto device = this ->availableDevices [this -> currentDeviceId ]->queueBuffer .getDefaultQueue ().get_device ();
133+ auto device = this ->availableDevices [getDeviceId () ]->queueBuffer .getDefaultQueue ().get_device ();
130134 return device.has (sycl::aspect::usm_system_allocations);
131135}
132136
0 commit comments