diff --git a/sherpa-onnx/csrc/provider.cc b/sherpa-onnx/csrc/provider.cc index 3baed32c13..d5865ad560 100644 --- a/sherpa-onnx/csrc/provider.cc +++ b/sherpa-onnx/csrc/provider.cc @@ -28,6 +28,8 @@ Provider StringToProvider(std::string s) { return Provider::kTRT; } else if (s == "directml") { return Provider::kDirectML; + } else if (s == "rocm") { + return Provider::kROCM; } else { SHERPA_ONNX_LOGE("Unsupported string: %s. Fallback to cpu", s.c_str()); return Provider::kCPU; diff --git a/sherpa-onnx/csrc/provider.h b/sherpa-onnx/csrc/provider.h index 2b85b8a2e0..87a5b66e58 100644 --- a/sherpa-onnx/csrc/provider.h +++ b/sherpa-onnx/csrc/provider.h @@ -21,6 +21,7 @@ enum class Provider { kNNAPI = 4, // NnapiExecutionProvider kTRT = 5, // TensorRTExecutionProvider kDirectML = 6, // DmlExecutionProvider + kROCM = 7 // ROCMExecutionProvider }; /** diff --git a/sherpa-onnx/csrc/session.cc b/sherpa-onnx/csrc/session.cc index a33594f0b5..ea26eb39c4 100644 --- a/sherpa-onnx/csrc/session.cc +++ b/sherpa-onnx/csrc/session.cc @@ -177,6 +177,21 @@ Ort::SessionOptions GetSessionOptionsImpl( } break; } + case Provider::kROCM: { + if (std::find(available_providers.begin(), available_providers.end(), + "ROCMExecutionProvider") != available_providers.end()) { + OrtROCMProviderOptions options; + options.device_id = + provider_config ? provider_config->device : 0; + sess_opts.AppendExecutionProvider_ROCM(options); + } else { + SHERPA_ONNX_LOGE( + "Please compile with ort enable ROCM EP. Available " + "providers: %s. Fallback to cpu!", + os.str().c_str()); + } + break; + } case Provider::kDirectML: { #if defined(_WIN32) && SHERPA_ONNX_ENABLE_DIRECTML == 1 sess_opts.DisableMemPattern();