diff --git a/ml/infra_component/llm_generate.cpp b/ml/infra_component/llm_generate.cpp index 4e24da70e..e6fa4a04f 100644 --- a/ml/infra_component/llm_generate.cpp +++ b/ml/infra_component/llm_generate.cpp @@ -230,22 +230,26 @@ TextGenerator::~TextGenerator() { } bool TextGenerator::InitializeONNX() { - m_env = std::make_unique(ORT_LOGGING_LEVEL_WARNING, "TextGenerator"); - m_sessionOptions = std::make_unique(); + try { + m_env = std::make_unique(ORT_LOGGING_LEVEL_WARNING, "TextGenerator"); + m_sessionOptions = std::make_unique(); - int numThreads = std::max(1, static_cast(std::thread::hardware_concurrency())); - m_sessionOptions->SetIntraOpNumThreads(numThreads); - m_sessionOptions->SetInterOpNumThreads(1); - m_sessionOptions->SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL); - m_sessionOptions->EnableMemPattern(); - m_sessionOptions->EnableCpuMemArena(); + int numThreads = std::max(1, static_cast(std::thread::hardware_concurrency())); + m_sessionOptions->SetIntraOpNumThreads(numThreads); + m_sessionOptions->SetInterOpNumThreads(1); + m_sessionOptions->SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL); + m_sessionOptions->EnableMemPattern(); + m_sessionOptions->EnableCpuMemArena(); #ifdef _WIN32 - std::wstring wModelPath(m_modelPath.begin(), m_modelPath.end()); - m_session = std::make_unique(*m_env, wModelPath.c_str(), *m_sessionOptions); + std::wstring wModelPath(m_modelPath.begin(), m_modelPath.end()); + m_session = std::make_unique(*m_env, wModelPath.c_str(), *m_sessionOptions); #else - m_session = std::make_unique(*m_env, m_modelPath.c_str(), *m_sessionOptions); + m_session = std::make_unique(*m_env, m_modelPath.c_str(), *m_sessionOptions); #endif + } catch (...) { + return true; + } return false; }