11#include < FAST/Config.hpp>
22#include < FAST/Utility.hpp>
33#include < FAST/DeviceManager.hpp>
4+ #ifdef FAST_MODULE_VISUALIZATION
5+ #include < iomanip>
6+ #include < QNetworkAccessManager>
7+ #include < QNetworkRequest>
8+ #include < QNetworkReply>
9+ #include < QElapsedTimer>
10+ #include < QEventLoop>
11+ #include < QFile>
12+ #include < QStandardPaths>
13+ #endif
414#ifdef WIN32
515// For TensorFlow AVX2 check - Taken from https://docs.microsoft.com/en-us/cpp/intrinsics/cpuid-cpuidex?view=msvc-160
616#include < vector>
@@ -200,6 +210,7 @@ namespace fast {
200210
201211bool InferenceEngineManager::m_loaded = false ;
202212std::unordered_map<std::string, std::function<InferenceEngine*()>> InferenceEngineManager::m_engines;
213+ std::unordered_map<std::string, std::string> InferenceEngineManager::m_possibleEngines; // key = name, value = errors
203214
204215#ifdef WIN32
205216// Returns the last Win32 error, in string format. Returns an empty string if there is no error.
@@ -252,51 +263,91 @@ void InferenceEngineManager::loadAll() {
252263 }
253264#endif
254265 Reporter::info () << " Loading inference engines in folder " << Config::getLibraryPath () << Reporter::end ();
266+ m_possibleEngines.clear ();
255267 for (auto && item : getDirectoryList (Config::getLibraryPath (), true , false )) {
256268 auto path = join (Config::getLibraryPath (), item);
257269 if (item.substr (0 , prefix.size ()) == prefix) {
258270 std::string name = item.substr (prefix.size (), item.rfind (' .' ) - prefix.size ());
271+ m_possibleEngines[name] = " " ;
272+ if (m_engines.count (name) > 0 ) // Already loaded skip
273+ continue ;
259274 Reporter::info () << " Loading inference engine " << name << " from shared library " << path << Reporter::end ();
260275#ifdef WIN32
276+ SetErrorMode (SEM_FAILCRITICALERRORS); // To avoid diaglog box, when not able to load a DLL
261277 if (name == " TensorFlow" ) {
262278 if (!InstructionSet::AVX2 ()) {
279+ m_possibleEngines[name] = " You CPU does not support AVX2, unable to load TensorFlow inference engine." ;
263280 Reporter::warning () << " You CPU does not support AVX2, unable to load TensorFlow inference engine." << Reporter::end ();
264281 continue ;
265282 }
283+ // When we put libraries in another path than the libInferenceEngineTensorFlow so file, we have to do this:
284+ std::string path2 = Config::getKernelBinaryPath () + " ../lib/tensorflow/" ;
285+ SetDllDirectory (path2.c_str ());
286+ auto handle = LoadLibrary (join (path2, " tensorflow_cc.dll" ).c_str ());
287+ if (!handle) {
288+ m_possibleEngines[name] = " Failed to load tensorflow library because: " + GetLastErrorAsString ();
289+ continue ;
290+ }
291+ SetDllDirectory (NULL );
266292 }
267- SetErrorMode (SEM_FAILCRITICALERRORS); // TODO To avoid diaglog box, when not able to load a DLL
268293 SetDllDirectory (Config::getLibraryPath ().c_str ());
269294 auto handle = LoadLibrary (path.c_str ());
270- SetDllDirectory (" " );
295+ SetDllDirectory (NULL );
271296 if (!handle) {
272- Reporter::warning () << " Failed to load plugin because " << GetLastErrorAsString () << Reporter::end ();
297+ // Reporter::warning() << "Failed to load plugin because " << GetLastErrorAsString() << Reporter::end();
298+ m_possibleEngines[name] = " Failed to load inference engine because: " + GetLastErrorAsString ();
273299 continue ;
274300 }
275301 auto load = (InferenceEngine* (*)())GetProcAddress (handle, " load" );
276302 if (!load) {
277303 FreeLibrary (handle);
278- Reporter::warning () << " Failed to get adress to load function because " << GetLastErrorAsString () << Reporter::end ();
304+ // Reporter::warning() << "Failed to get address to load function because " << GetLastErrorAsString() << Reporter::end();
305+ m_possibleEngines[name] = " Failed to get address to load function because: " + GetLastErrorAsString ();
279306 continue ;
280307 }
281308#else
282309 if (name == " TensorFlow" ) {
283310#ifdef __arm64__
284311#else
285312 if (!__builtin_cpu_supports (" avx2" )) {
286- Reporter::warning () << " You CPU does not support AVX2, unable to load TensorFlow inference engine." << Reporter::end ();
313+ m_possibleEngines[name] = " You CPU does not support AVX2, unable to load TensorFlow inference engine." ;
314+ // Reporter::warning() << "You CPU does not support AVX2, unable to load TensorFlow inference engine." << Reporter::end();
287315 continue ;
288316 }
289317#endif
290318 }
291319 auto handle = dlopen (path.c_str (), RTLD_LAZY);
292320 if (!handle) {
293- Reporter::warning () << " Failed to load plugin because " << dlerror () << Reporter::end ();
294- continue ;
321+ if (name == " TensorFlow" ) {
322+ // When we put libraries in another path than the libInferenceEngineTensorFlow so file, we have to do this:
323+ // Try to load libtensorflow manually
324+ #if defined(__APPLE__) || defined(__MACOSX)
325+ std::string path2 = Config::getKernelBinaryPath () + " /../lib/tensorflow/libtensorflow_cc.dylib" ;
326+ #else
327+ std::string path2 = Config::getKernelBinaryPath () + " /../lib/tensorflow/libtensorflow_cc.so" ;
328+ #endif
329+ auto handle2 = dlopen (path2.c_str (), RTLD_LAZY);
330+ if (!handle2) {
331+ m_possibleEngines[name] = " Failed to load inference engine because " + std::string (dlerror ());
332+ continue ;
333+ }
334+ // Then try to load again
335+ handle = dlopen (path.c_str (), RTLD_LAZY);
336+ if (!handle) {
337+ m_possibleEngines[name] = " Failed to load inference engine because " + std::string (dlerror ());
338+ continue ;
339+ }
340+ } else {
341+ m_possibleEngines[name] = " Failed to load inference engine because " + std::string (dlerror ());
342+ // Reporter::warning() << "Failed to load plugin because " << dlerror() << Reporter::end();
343+ continue ;
344+ }
295345 }
296346 auto load = (InferenceEngine* (*)())dlsym (handle, " load" );
297347 if (!load) {
298348 dlclose (handle);
299- Reporter::warning () << " Failed to get address of load function because " << Reporter::end ();// dlerror() << Reporter::end();
349+ // Reporter::warning() << "Failed to get address of load function because " << Reporter::end();// dlerror() << Reporter::end();
350+ m_possibleEngines[name] = " Failed to get address of load function because " + std::string (dlerror ());
300351 continue ;
301352 }
302353#endif
@@ -312,16 +363,7 @@ std::shared_ptr<InferenceEngine> InferenceEngineManager::loadBestAvailableEngine
312363 if (m_engines.empty ())
313364 throw Exception (" No inference engines available on the system" );
314365
315- if (isEngineAvailable (" TensorFlow" )) {
316- // Default is tensorflow if GPU support is enabled
317- auto engine = loadEngine (" TensorFlow" );
318- auto devices = engine->getDeviceList ();
319- for (auto &&device : devices) {
320- if (device.type == InferenceDeviceType::GPU)
321- return engine;
322- }
323- }
324-
366+ // TensorRT is default inference engine if available
325367 if (isEngineAvailable (" TensorRT" ))
326368 return loadEngine (" TensorRT" );
327369
@@ -341,14 +383,9 @@ std::shared_ptr<InferenceEngine> InferenceEngineManager::loadBestAvailableEngine
341383 if (m_engines.empty ())
342384 throw Exception (" No inference engines available on the system" );
343385
344- if (isEngineAvailable (" TensorFlow" ) && loadEngine (" TensorFlow" )->isModelFormatSupported (modelFormat)) {
345- // Default is tensorflow if GPU support is enabled
346- auto engine = loadEngine (" TensorFlow" );
347- auto devices = engine->getDeviceList ();
348- for (auto &&device : devices) {
349- if (device.type == InferenceDeviceType::GPU)
350- return engine;
351- }
386+ // If TensorFlow model format
387+ if (modelFormat == ModelFormat::SAVEDMODEL || modelFormat == ModelFormat::PROTOBUF) {
388+ return loadEngine (" TensorFlow" );
352389 }
353390
354391 if (isEngineAvailable (" TensorRT" ) && loadEngine (" TensorRT" )->isModelFormatSupported (modelFormat))
@@ -370,17 +407,121 @@ std::shared_ptr<InferenceEngine> InferenceEngineManager::loadBestAvailableEngine
370407 throw Exception (" No engine for model format found" );
371408}
372409
373-
374-
375410bool InferenceEngineManager::isEngineAvailable (std::string name) {
376411 loadAll ();
377412 return m_engines.count (name) > 0 ;
378413}
379414
415+ #ifdef FAST_MODULE_VISUALIZATION
416+ static void downloadAndExtractZipFile (const std::string& URL, const std::string& destination, const std::string& name) {
417+ QNetworkAccessManager manager;
418+ QNetworkRequest request (QUrl (QString::fromStdString (URL)));
419+ request.setAttribute (QNetworkRequest::RedirectPolicyAttribute, QNetworkRequest::NoLessSafeRedirectPolicy);
420+ auto timer = new QElapsedTimer;
421+ timer->start ();
422+ auto reply = manager.get (request);
423+ Progress progress (100 );
424+ progress.setText (" Downloading" );
425+ progress.setUnit (" MB" , 1.0 /(1024 *1024 ), 0 );
426+ const int totalWidth = getConsoleWidth ();
427+ std::string blank;
428+ for (int i = 0 ; i < totalWidth-1 ; ++i)
429+ blank += " " ;
430+ QObject::connect (reply, &QNetworkReply::downloadProgress, [&](quint64 current, quint64 max) {
431+ progress.setMax (max);
432+ progress.update (current);
433+ });
434+ auto tempLocation = QStandardPaths::writableLocation (QStandardPaths::TempLocation) + " /" + name.c_str () + " .zip" ;
435+ QFile file (tempLocation);
436+ if (!file.open (QIODevice::WriteOnly)) {
437+ throw Exception (" Could not write to " + tempLocation.toStdString ());
438+ }
439+ QObject::connect (reply, &QNetworkReply::readyRead, [&reply, &file]() {
440+ file.write (reply->read (reply->bytesAvailable ()));
441+ });
442+ QObject::connect (&manager, &QNetworkAccessManager::finished, [blank, reply, &file, destination]() {
443+ std::cout << " \n " ;
444+ if (reply->error () != QNetworkReply::NoError) {
445+ std::cout << " \r " << blank;
446+ std::cout << " \r ERROR: Download failed! : " << reply->errorString ().toStdString () << std::endl;
447+ file.close ();
448+ file.remove ();
449+ return ;
450+ }
451+ file.close ();
452+ std::cout << " \r " << blank;
453+ std::cout << " \r Extracting zip file..." ;
454+ std::cout.flush ();
455+ try {
456+ extractZipFile (file.fileName ().toStdString (), destination);
457+ } catch (Exception& e) {
458+ std::cout << " \r " << blank;
459+ std::cout << " \r ERROR: Zip extraction failed!" << std::endl;
460+ file.remove ();
461+ return ;
462+ }
463+
464+ file.remove ();
465+ std::cout << " \r " << blank;
466+ std::cout << " \r Complete." << std::endl;
467+ });
468+
469+ auto eventLoop = new QEventLoop (&manager);
470+
471+ // Make sure to quit the event loop when download is finished
472+ QObject::connect (&manager, &QNetworkAccessManager::finished, eventLoop, &QEventLoop::quit);
473+
474+ // Wait for it to finish
475+ eventLoop->exec ();
476+ }
477+ #else
478+ static void downloadAndExtractZipFile (const std::string& URL, const std::string& destination) {
479+ throw NotImplementedException ();
480+ }
481+ #endif
482+
483+
380484std::shared_ptr<InferenceEngine> InferenceEngineManager::loadEngine (std::string name) {
381485 loadAll ();
382- if (m_engines.count (name) == 0 )
383- throw Exception (" Inference engine with name " + name + " is not available" );
486+ if (name == " TensorFlow" ) {
487+ if (m_engines.count (name) == 0 ) {
488+ #if defined(_M_ARM64) || defined(__aarch64__)
489+ const std::string arch = " arm64" ;
490+ #else
491+ const std::string arch = " x86_64" ;
492+ #endif
493+ #ifdef WIN32
494+ const std::string OS = " windows" ;
495+ #elif defined(__APPLE__) || defined(__MACOSX)
496+ const std::string OS = " macos" ;
497+ #else
498+ const std::string OS = " linux" ;
499+ #endif
500+ // If not apple arm:
501+ // TensorFlow should be available; thus tensorflow lib is missing
502+ // Download it and try to load tensorflow inference engine
503+ if (!(arch == " arm64" && OS == " macos" )) {
504+ std::cout << " TensorFlow was not bundled with this distribution." << std::endl;
505+ const std::string destination = Config::getKernelBinaryPath () + " /../lib/tensorflow/" ;
506+ std::cout << " FAST will now download TensorFlow .." << std::endl;
507+ createDirectories (destination);
508+ downloadAndExtractZipFile (" https://github.com/FAST-Imaging/FAST-dependencies/releases/download/v4.0.0/tensorflow_2.4.0_" + OS + " _" + arch + " .zip" , destination, " tensorflow" );
509+ // Try to load TensorFlow inference engine again
510+ m_loaded = false ;
511+ loadAll ();
512+ }
513+ }
514+ }
515+ if (m_engines.count (name) == 0 ) {
516+ if (m_possibleEngines.count (name) > 0 ) {
517+ throw Exception (" Inference engine with name " + name + " was not able to load. Make sure you have all required dependencies installed. Error message: \n " + m_possibleEngines[name]);
518+ } else {
519+ std::string engineList;
520+ for (const auto item : m_possibleEngines)
521+ engineList += item.first + " , " ;
522+ throw Exception (" No inference engine with name " + name + " was available on this system. Engines available are: " + engineList);
523+ }
524+ }
384525 // Call the load function which the map stores a handle to
385526 return std::shared_ptr<InferenceEngine>(m_engines.at (name)());
386527}
0 commit comments