Skip to content

Commit 9e03d9c

Browse files
Tighten loop in run_migraphx_program
- TIghten lock around run async' - Remove O(n) lookup with find and use unordered_set instead - Use optional to help tighten up lock
1 parent a9a09fb commit 9e03d9c

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1416,20 +1416,22 @@ static void run_migraphx_program(
14161416
migraphx::program_parameters& m,
14171417
const std::vector<std::size_t>& prog_output_indices)
14181418
{
1419-
1420-
// lock to avoid race condition
1421-
std::lock_guard<std::mutex> lock(*mgx_mu_ptr);
14221419
void* rocm_stream;
14231420
Ort::ThrowOnError(api->KernelContext_GetGPUComputeStream(context, &rocm_stream));
1424-
auto prog_outputs = prog.run_async(m, static_cast<hipStream_t>(rocm_stream));
1421+
1422+
std::optional<migraphx::arguments> prog_outputs;
1423+
{ // lock to avoid race condition
1424+
std::lock_guard<std::mutex> lock(*mgx_mu_ptr);
1425+
prog_outputs = prog.run_async(m, static_cast<hipStream_t>(rocm_stream));
1426+
}
14251427

14261428
// In case of input parameters are reused as output parameter call hipMemcpy
1427-
auto output_num = prog_outputs.size();
1429+
auto output_num = prog_outputs->size();
14281430
if (prog_output_indices.size() < output_num) {
14291431
for (std::size_t i = 0; i < output_num; ++i) {
14301432
if (std::find(prog_output_indices.begin(), prog_output_indices.end(), static_cast<int>(i)) != prog_output_indices.end())
14311433
continue;
1432-
auto gpu_res = prog_outputs[i];
1434+
auto gpu_res = (*prog_outputs)[i];
14331435
migraphx::shape res_shape = gpu_res.get_shape();
14341436
auto res_lens = res_shape.lengths();
14351437
std::vector<int64_t> ort_shape{res_lens.begin(), res_lens.end()};

0 commit comments

Comments
 (0)