@@ -78,43 +78,61 @@ int llvm::omp::target::ompt::getDeviceId(ompt_device_t *Device) {
7878 std::unique_lock<std::mutex> Lock (DeviceIdWritingMutex);
7979 auto DeviceIterator = Devices.find (Device);
8080 if (Device == nullptr || DeviceIterator == Devices.end ()) {
81- REPORT (" Failed to get ID for device =%p\n " , Device);
81+ REPORT (" Failed to get ID for Device =%p\n " , Device);
8282 return -1 ;
8383 }
8484 return DeviceIterator->second ;
8585}
8686
8787void llvm::omp::target::ompt::setDeviceId (ompt_device_t *Device,
8888 int32_t DeviceId) {
89- assert (Device && " Mapping device id to nullptr is not allowed" );
90- if (Device == nullptr ) {
91- REPORT (" Failed to set ID for nullptr device \n " );
89+ assert (Device && " Mapping device ID to nullptr is not allowed" );
90+ if (Device == nullptr || DeviceId < 0 ) {
91+ REPORT (" Failed to set ID=%d for Device=%p \n " , DeviceId, Device );
9292 return ;
9393 }
9494 std::unique_lock<std::mutex> Lock (DeviceIdWritingMutex);
95+ auto DeviceIterator = Devices.find (Device);
96+ if (DeviceIterator != Devices.end ()) {
97+ auto CurrentDeviceId = DeviceIterator->second ;
98+ if (DeviceId == CurrentDeviceId)
99+ REPORT (" Tried to duplicate OMPT Device=%p (ID=%d)\n " , Device, DeviceId);
100+ else
101+ REPORT (" Tried to overwrite OMPT Device=%p (ID=%d with new ID=%d)\n " ,
102+ Device, CurrentDeviceId, DeviceId);
103+ return ;
104+ }
95105 Devices.emplace (Device, DeviceId);
96106}
97107
98108void llvm::omp::target::ompt::removeDeviceId (ompt_device_t *Device) {
99- if (Device == nullptr ) {
100- REPORT (" Failed to remove ID for nullptr device\n " );
109+ int DeviceId = getDeviceId (Device);
110+ if (DeviceId < 0 ) {
111+ REPORT (" Failed to remove Device=%p (ID=%d)\n " , Device, DeviceId);
101112 return ;
102113 }
103114 std::unique_lock<std::mutex> Lock (DeviceIdWritingMutex);
104115 Devices.erase (Device);
116+ TracedDevices.erase (DeviceId);
105117}
106118
107119OMPT_API_ROUTINE ompt_set_result_t ompt_set_trace_ompt (ompt_device_t *Device,
108120 unsigned int Enable,
109121 unsigned int EventTy) {
110122 DP (" Executing ompt_set_trace_ompt\n " );
111123
112- // TODO handle device
124+ int DeviceId = getDeviceId (Device);
125+ if (DeviceId < 0 ) {
126+ REPORT (" Failed to set trace events for Device=%p (Unknown device)\n " ,
127+ Device);
128+ return ompt_set_never;
129+ }
130+
113131 std::unique_lock<std::mutex> Lock (ompt_set_trace_ompt_mutex);
114132 ensureFuncPtrLoaded<libomptarget_ompt_set_trace_ompt_t >(
115133 " libomptarget_ompt_set_trace_ompt" , &ompt_set_trace_ompt_fn);
116134 assert (ompt_set_trace_ompt_fn && " libomptarget_ompt_set_trace_ompt loaded" );
117- return ompt_set_trace_ompt_fn (Device , Enable, EventTy);
135+ return ompt_set_trace_ompt_fn (DeviceId , Enable, EventTy);
118136}
119137
120138OMPT_API_ROUTINE int
@@ -123,12 +141,18 @@ ompt_start_trace(ompt_device_t *Device, ompt_callback_buffer_request_t Request,
123141 DP (" Executing ompt_start_trace\n " );
124142
125143 int DeviceId = getDeviceId (Device);
144+ if (DeviceId < 0 ) {
145+ REPORT (" Failed to start trace for Device=%p (Unknown device)\n " , Device);
146+ // Indicate failure
147+ return 0 ;
148+ }
149+
126150 {
127151 // Protect the function pointer
128152 std::unique_lock<std::mutex> Lock (ompt_start_trace_mutex);
129153
130154 if (Request && Complete) {
131- llvm::omp::target::ompt::setTracingState ( /* Enabled= */ true );
155+ llvm::omp::target::ompt::enableDeviceTracing (DeviceId );
132156 // Enable asynchronous memory copy profiling
133157 setOmptAsyncCopyProfile (/* Enable=*/ true );
134158 // Enable queue dispatch profiling
@@ -150,7 +174,6 @@ ompt_start_trace(ompt_device_t *Device, ompt_callback_buffer_request_t Request,
150174OMPT_API_ROUTINE int ompt_flush_trace (ompt_device_t *Device) {
151175 DP (" Executing ompt_flush_trace\n " );
152176
153- // TODO handle device
154177 std::unique_lock<std::mutex> Lock (ompt_flush_trace_mutex);
155178 ensureFuncPtrLoaded<libomptarget_ompt_flush_trace_t >(
156179 " libomptarget_ompt_flush_trace" , &ompt_flush_trace_fn);
@@ -161,15 +184,20 @@ OMPT_API_ROUTINE int ompt_flush_trace(ompt_device_t *Device) {
161184OMPT_API_ROUTINE int ompt_stop_trace (ompt_device_t *Device) {
162185 DP (" Executing ompt_stop_trace\n " );
163186
164- // TODO handle device
187+ int DeviceId = getDeviceId (Device);
188+ if (DeviceId < 0 ) {
189+ REPORT (" Failed to stop trace for Device=%p (Unknown device)\n " , Device);
190+ // Indicate failure
191+ return 0 ;
192+ }
193+
165194 {
166195 // Protect the function pointer
167196 std::unique_lock<std::mutex> Lock (ompt_stop_trace_mutex);
168- llvm::omp::target::ompt::setTracingState ( /* Enabled= */ false );
197+ llvm::omp::target::ompt::disableDeviceTracing (DeviceId );
169198 // Disable asynchronous memory copy profiling
170199 setOmptAsyncCopyProfile (/* Enable=*/ false );
171200 // Disable queue dispatch profiling
172- int DeviceId = getDeviceId (Device);
173201 if (DeviceId >= 0 )
174202 setGlobalOmptKernelProfile (Device, /* Enable=*/ 0 );
175203 else
@@ -179,7 +207,7 @@ OMPT_API_ROUTINE int ompt_stop_trace(ompt_device_t *Device) {
179207 " libomptarget_ompt_stop_trace" , &ompt_stop_trace_fn);
180208 assert (ompt_stop_trace_fn && " libomptarget_ompt_stop_trace loaded" );
181209 }
182- return ompt_stop_trace_fn (getDeviceId (Device) );
210+ return ompt_stop_trace_fn (DeviceId );
183211}
184212
185213OMPT_API_ROUTINE ompt_record_ompt_t *
0 commit comments