Skip to content

Commit f8d872a

Browse files
committed
Removed TF from pyfast
1 parent 57e7e25 commit f8d872a

File tree

15 files changed

+254
-77
lines changed

15 files changed

+254
-77
lines changed

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ option(FAST_SIGN_CODE "Whether to sign binaries" OFF)
5454
set(FAST_Python_Version "" CACHE STRING "Python version to use for building pyFAST")
5555

5656
# Base URL for downloading prebuilt dependencies
57-
set(FAST_PREBUILT_DEPENDENCY_DOWNLOAD_URL_NEW "https://github.com/smistad/FAST-dependencies/releases/download/v${VERSION_MAJOR}.0.0/")
57+
set(FAST_DEPENDENCY_DOWNLOAD_BASE_URL "https://github.com/FAST-Imaging/FAST-dependencies/releases/download/v${VERSION_MAJOR}.0.0/")
5858
if(WIN32)
5959
set(FAST_DEPENDENCY_TOOLSET msvc142)
6060
elseif(APPLE)

cmake/ExternalLibJPEG.cmake

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ include(cmake/Externals.cmake)
55
if(WIN32)
66
ExternalProject_Add(libjpeg
77
PREFIX ${FAST_EXTERNAL_BUILD_DIR}/libjpeg
8-
URL "https://github.com/smistad/FAST-dependencies/releases/download/v4.0.0/LibJPEG-9d-Win-pc064.zip"
8+
URL "https://github.com/FAST-Imaging/FAST-dependencies/releases/download/v4.0.0/LibJPEG-9d-Win-pc064.zip"
99
UPDATE_COMMAND ""
1010
CONFIGURE_COMMAND ""
1111
BUILD_COMMAND ""

cmake/FetchSwig.cmake

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,27 +5,27 @@ set(FILENAME swig_4.0.2_${FAST_DEPENDENCY_TOOLSET}.tar.xz)
55
if(WIN32)
66
FetchContent_Declare(
77
swig
8-
URL ${FAST_PREBUILT_DEPENDENCY_DOWNLOAD_URL_NEW}/${FILENAME}
8+
URL ${FAST_DEPENDENCY_DOWNLOAD_BASE_URL}/${FILENAME}
99
URL_HASH SHA256=56c1f0eae1e25643cc98b2489a22e78c2970a04007d02849eb355e865384ad18
1010
)
1111
elseif(APPLE)
1212
if(CMAKE_OSX_ARCHITECTURES STREQUAL "arm64")
1313
FetchContent_Declare(
1414
swig
15-
URL ${FAST_PREBUILT_DEPENDENCY_DOWNLOAD_URL_NEW}/${FILENAME}
15+
URL ${FAST_DEPENDENCY_DOWNLOAD_BASE_URL}/${FILENAME}
1616
URL_HASH SHA256=407aeef24b8a88994a51e565d9d5f018abbae897a0f0f69ed288e190a8336c08
1717
)
1818
else()
1919
FetchContent_Declare(
2020
swig
21-
URL ${FAST_PREBUILT_DEPENDENCY_DOWNLOAD_URL_NEW}/${FILENAME}
21+
URL ${FAST_DEPENDENCY_DOWNLOAD_BASE_URL}/${FILENAME}
2222
URL_HASH SHA256=105be24d7967ef68e3a9763b2012f401bdc3dc009ccd06e1a9fd16edeca940d7
2323
)
2424
endif()
2525
else()
2626
FetchContent_Declare(
2727
swig
28-
URL ${FAST_PREBUILT_DEPENDENCY_DOWNLOAD_URL_NEW}/${FILENAME}
28+
URL ${FAST_DEPENDENCY_DOWNLOAD_BASE_URL}/${FILENAME}
2929
URL_HASH SHA256=d0b84c72dff878ee18a03d580b0cc9786b37aadbac019912fca1c812755d9bea
3030
)
3131
endif()

cmake/Macros.cmake

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ macro(fast_download_dependency NAME VERSION SHA)
172172
if(${NAME} STREQUAL qt5)
173173
ExternalProject_Add(${NAME}
174174
PREFIX ${FAST_EXTERNAL_BUILD_DIR}/${NAME}
175-
URL ${FAST_PREBUILT_DEPENDENCY_DOWNLOAD_URL_NEW}/${FILENAME}
175+
URL ${FAST_DEPENDENCY_DOWNLOAD_BASE_URL}/${FILENAME}
176176
URL_HASH SHA256=${SHA}
177177
UPDATE_COMMAND ""
178178
CONFIGURE_COMMAND ""
@@ -188,7 +188,7 @@ macro(fast_download_dependency NAME VERSION SHA)
188188
else()
189189
ExternalProject_Add(${NAME}
190190
PREFIX ${FAST_EXTERNAL_BUILD_DIR}/${NAME}
191-
URL ${FAST_PREBUILT_DEPENDENCY_DOWNLOAD_URL_NEW}/${FILENAME}
191+
URL ${FAST_DEPENDENCY_DOWNLOAD_BASE_URL}/${FILENAME}
192192
URL_HASH SHA256=${SHA}
193193
UPDATE_COMMAND ""
194194
CONFIGURE_COMMAND ""
@@ -205,7 +205,7 @@ macro(fast_download_dependency NAME VERSION SHA)
205205
# copy_directory doesn't support symlinks, use cp on linux/apple:
206206
ExternalProject_Add(${NAME}
207207
PREFIX ${FAST_EXTERNAL_BUILD_DIR}/${NAME}
208-
URL ${FAST_PREBUILT_DEPENDENCY_DOWNLOAD_URL_NEW}/${FILENAME}
208+
URL ${FAST_DEPENDENCY_DOWNLOAD_BASE_URL}/${FILENAME}
209209
URL_HASH SHA256=${SHA}
210210
UPDATE_COMMAND ""
211211
CONFIGURE_COMMAND ""

source/FAST/Algorithms/NeuralNetwork/InferenceEngineManager.cpp

Lines changed: 171 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,16 @@
11
#include <FAST/Config.hpp>
22
#include <FAST/Utility.hpp>
33
#include <FAST/DeviceManager.hpp>
4+
#ifdef FAST_MODULE_VISUALIZATION
5+
#include <iomanip>
6+
#include <QNetworkAccessManager>
7+
#include <QNetworkRequest>
8+
#include <QNetworkReply>
9+
#include <QElapsedTimer>
10+
#include <QEventLoop>
11+
#include <QFile>
12+
#include <QStandardPaths>
13+
#endif
414
#ifdef WIN32
515
// For TensorFlow AVX2 check - Taken from https://docs.microsoft.com/en-us/cpp/intrinsics/cpuid-cpuidex?view=msvc-160
616
#include <vector>
@@ -200,6 +210,7 @@ namespace fast {
200210

201211
bool InferenceEngineManager::m_loaded = false;
202212
std::unordered_map<std::string, std::function<InferenceEngine*()>> InferenceEngineManager::m_engines;
213+
std::unordered_map<std::string, std::string> InferenceEngineManager::m_possibleEngines; // key = name, value = errors
203214

204215
#ifdef WIN32
205216
//Returns the last Win32 error, in string format. Returns an empty string if there is no error.
@@ -252,51 +263,91 @@ void InferenceEngineManager::loadAll() {
252263
}
253264
#endif
254265
Reporter::info() << "Loading inference engines in folder " << Config::getLibraryPath() << Reporter::end();
266+
m_possibleEngines.clear();
255267
for(auto&& item : getDirectoryList(Config::getLibraryPath(), true, false)) {
256268
auto path = join(Config::getLibraryPath(), item);
257269
if(item.substr(0, prefix.size()) == prefix) {
258270
std::string name = item.substr(prefix.size(), item.rfind('.') - prefix.size());
271+
m_possibleEngines[name] = "";
272+
if(m_engines.count(name) > 0) // Already loaded skip
273+
continue;
259274
Reporter::info() << "Loading inference engine " << name << " from shared library " << path << Reporter::end();
260275
#ifdef WIN32
276+
SetErrorMode(SEM_FAILCRITICALERRORS); // To avoid diaglog box, when not able to load a DLL
261277
if(name == "TensorFlow") {
262278
if(!InstructionSet::AVX2()) {
279+
m_possibleEngines[name] = "You CPU does not support AVX2, unable to load TensorFlow inference engine.";
263280
Reporter::warning() << "You CPU does not support AVX2, unable to load TensorFlow inference engine." << Reporter::end();
264281
continue;
265282
}
283+
// When we put libraries in another path than the libInferenceEngineTensorFlow so file, we have to do this:
284+
std::string path2 = Config::getKernelBinaryPath() + "../lib/tensorflow/";
285+
SetDllDirectory(path2.c_str());
286+
auto handle = LoadLibrary(join(path2, "tensorflow_cc.dll").c_str());
287+
if(!handle) {
288+
m_possibleEngines[name] = "Failed to load tensorflow library because: " + GetLastErrorAsString();
289+
continue;
290+
}
291+
SetDllDirectory(NULL);
266292
}
267-
SetErrorMode(SEM_FAILCRITICALERRORS); // TODO To avoid diaglog box, when not able to load a DLL
268293
SetDllDirectory(Config::getLibraryPath().c_str());
269294
auto handle = LoadLibrary(path.c_str());
270-
SetDllDirectory("");
295+
SetDllDirectory(NULL);
271296
if(!handle) {
272-
Reporter::warning() << "Failed to load plugin because " << GetLastErrorAsString() << Reporter::end();
297+
//Reporter::warning() << "Failed to load plugin because " << GetLastErrorAsString() << Reporter::end();
298+
m_possibleEngines[name] = "Failed to load inference engine because: " + GetLastErrorAsString();
273299
continue;
274300
}
275301
auto load = (InferenceEngine* (*)())GetProcAddress(handle, "load");
276302
if(!load) {
277303
FreeLibrary(handle);
278-
Reporter::warning() << "Failed to get adress to load function because " << GetLastErrorAsString() << Reporter::end();
304+
//Reporter::warning() << "Failed to get address to load function because " << GetLastErrorAsString() << Reporter::end();
305+
m_possibleEngines[name] = "Failed to get address to load function because: " + GetLastErrorAsString();
279306
continue;
280307
}
281308
#else
282309
if(name == "TensorFlow") {
283310
#ifdef __arm64__
284311
#else
285312
if(!__builtin_cpu_supports("avx2")) {
286-
Reporter::warning() << "You CPU does not support AVX2, unable to load TensorFlow inference engine." << Reporter::end();
313+
m_possibleEngines[name] = "You CPU does not support AVX2, unable to load TensorFlow inference engine.";
314+
//Reporter::warning() << "You CPU does not support AVX2, unable to load TensorFlow inference engine." << Reporter::end();
287315
continue;
288316
}
289317
#endif
290318
}
291319
auto handle = dlopen(path.c_str(), RTLD_LAZY);
292320
if(!handle) {
293-
Reporter::warning() << "Failed to load plugin because " << dlerror() << Reporter::end();
294-
continue;
321+
if(name == "TensorFlow") {
322+
// When we put libraries in another path than the libInferenceEngineTensorFlow so file, we have to do this:
323+
// Try to load libtensorflow manually
324+
#if defined(__APPLE__) || defined(__MACOSX)
325+
std::string path2 = Config::getKernelBinaryPath() + "/../lib/tensorflow/libtensorflow_cc.dylib";
326+
#else
327+
std::string path2 = Config::getKernelBinaryPath() + "/../lib/tensorflow/libtensorflow_cc.so";
328+
#endif
329+
auto handle2 = dlopen(path2.c_str(), RTLD_LAZY);
330+
if(!handle2) {
331+
m_possibleEngines[name] = "Failed to load inference engine because " + std::string(dlerror());
332+
continue;
333+
}
334+
// Then try to load again
335+
handle = dlopen(path.c_str(), RTLD_LAZY);
336+
if(!handle) {
337+
m_possibleEngines[name] = "Failed to load inference engine because " + std::string(dlerror());
338+
continue;
339+
}
340+
} else {
341+
m_possibleEngines[name] = "Failed to load inference engine because " + std::string(dlerror());
342+
//Reporter::warning() << "Failed to load plugin because " << dlerror() << Reporter::end();
343+
continue;
344+
}
295345
}
296346
auto load = (InferenceEngine* (*)())dlsym(handle, "load");
297347
if(!load) {
298348
dlclose(handle);
299-
Reporter::warning() << "Failed to get address of load function because " << Reporter::end();// dlerror() << Reporter::end();
349+
//Reporter::warning() << "Failed to get address of load function because " << Reporter::end();// dlerror() << Reporter::end();
350+
m_possibleEngines[name] = "Failed to get address of load function because " + std::string(dlerror());
300351
continue;
301352
}
302353
#endif
@@ -312,16 +363,7 @@ std::shared_ptr<InferenceEngine> InferenceEngineManager::loadBestAvailableEngine
312363
if(m_engines.empty())
313364
throw Exception("No inference engines available on the system");
314365

315-
if(isEngineAvailable("TensorFlow")) {
316-
// Default is tensorflow if GPU support is enabled
317-
auto engine = loadEngine("TensorFlow");
318-
auto devices = engine->getDeviceList();
319-
for (auto &&device : devices) {
320-
if (device.type == InferenceDeviceType::GPU)
321-
return engine;
322-
}
323-
}
324-
366+
// TensorRT is default inference engine if available
325367
if(isEngineAvailable("TensorRT"))
326368
return loadEngine("TensorRT");
327369

@@ -341,14 +383,9 @@ std::shared_ptr<InferenceEngine> InferenceEngineManager::loadBestAvailableEngine
341383
if(m_engines.empty())
342384
throw Exception("No inference engines available on the system");
343385

344-
if(isEngineAvailable("TensorFlow") && loadEngine("TensorFlow")->isModelFormatSupported(modelFormat)) {
345-
// Default is tensorflow if GPU support is enabled
346-
auto engine = loadEngine("TensorFlow");
347-
auto devices = engine->getDeviceList();
348-
for (auto &&device : devices) {
349-
if (device.type == InferenceDeviceType::GPU)
350-
return engine;
351-
}
386+
// If TensorFlow model format
387+
if(modelFormat == ModelFormat::SAVEDMODEL || modelFormat == ModelFormat::PROTOBUF) {
388+
return loadEngine("TensorFlow");
352389
}
353390

354391
if(isEngineAvailable("TensorRT") && loadEngine("TensorRT")->isModelFormatSupported(modelFormat))
@@ -370,17 +407,121 @@ std::shared_ptr<InferenceEngine> InferenceEngineManager::loadBestAvailableEngine
370407
throw Exception("No engine for model format found");
371408
}
372409

373-
374-
375410
bool InferenceEngineManager::isEngineAvailable(std::string name) {
376411
loadAll();
377412
return m_engines.count(name) > 0;
378413
}
379414

415+
#ifdef FAST_MODULE_VISUALIZATION
416+
static void downloadAndExtractZipFile(const std::string& URL, const std::string& destination, const std::string& name) {
417+
QNetworkAccessManager manager;
418+
QNetworkRequest request(QUrl(QString::fromStdString(URL)));
419+
request.setAttribute(QNetworkRequest::RedirectPolicyAttribute, QNetworkRequest::NoLessSafeRedirectPolicy);
420+
auto timer = new QElapsedTimer;
421+
timer->start();
422+
auto reply = manager.get(request);
423+
Progress progress(100);
424+
progress.setText("Downloading");
425+
progress.setUnit("MB", 1.0/(1024*1024), 0);
426+
const int totalWidth = getConsoleWidth();
427+
std::string blank;
428+
for(int i = 0; i < totalWidth-1; ++i)
429+
blank += " ";
430+
QObject::connect(reply, &QNetworkReply::downloadProgress, [&](quint64 current, quint64 max) {
431+
progress.setMax(max);
432+
progress.update(current);
433+
});
434+
auto tempLocation = QStandardPaths::writableLocation(QStandardPaths::TempLocation) + "/" + name.c_str() + ".zip";
435+
QFile file(tempLocation);
436+
if(!file.open(QIODevice::WriteOnly)) {
437+
throw Exception("Could not write to " + tempLocation.toStdString());
438+
}
439+
QObject::connect(reply, &QNetworkReply::readyRead, [&reply, &file]() {
440+
file.write(reply->read(reply->bytesAvailable()));
441+
});
442+
QObject::connect(&manager, &QNetworkAccessManager::finished, [blank, reply, &file, destination]() {
443+
std::cout << "\n";
444+
if(reply->error() != QNetworkReply::NoError) {
445+
std::cout << "\r" << blank;
446+
std::cout << "\rERROR: Download failed! : " << reply->errorString().toStdString() << std::endl;
447+
file.close();
448+
file.remove();
449+
return;
450+
}
451+
file.close();
452+
std::cout << "\r" << blank;
453+
std::cout << "\rExtracting zip file...";
454+
std::cout.flush();
455+
try {
456+
extractZipFile(file.fileName().toStdString(), destination);
457+
} catch(Exception& e) {
458+
std::cout << "\r" << blank;
459+
std::cout << "\rERROR: Zip extraction failed!" << std::endl;
460+
file.remove();
461+
return;
462+
}
463+
464+
file.remove();
465+
std::cout << "\r" << blank;
466+
std::cout << "\rComplete." << std::endl;
467+
});
468+
469+
auto eventLoop = new QEventLoop(&manager);
470+
471+
// Make sure to quit the event loop when download is finished
472+
QObject::connect(&manager, &QNetworkAccessManager::finished, eventLoop, &QEventLoop::quit);
473+
474+
// Wait for it to finish
475+
eventLoop->exec();
476+
}
477+
#else
478+
static void downloadAndExtractZipFile(const std::string& URL, const std::string& destination) {
479+
throw NotImplementedException();
480+
}
481+
#endif
482+
483+
380484
std::shared_ptr<InferenceEngine> InferenceEngineManager::loadEngine(std::string name) {
381485
loadAll();
382-
if(m_engines.count(name) == 0)
383-
throw Exception("Inference engine with name " + name + " is not available");
486+
if(name == "TensorFlow") {
487+
if(m_engines.count(name) == 0) {
488+
#if defined(_M_ARM64) || defined(__aarch64__)
489+
const std::string arch = "arm64";
490+
#else
491+
const std::string arch = "x86_64";
492+
#endif
493+
#ifdef WIN32
494+
const std::string OS = "windows";
495+
#elif defined(__APPLE__) || defined(__MACOSX)
496+
const std::string OS = "macos";
497+
#else
498+
const std::string OS = "linux";
499+
#endif
500+
// If not apple arm:
501+
// TensorFlow should be available; thus tensorflow lib is missing
502+
// Download it and try to load tensorflow inference engine
503+
if(!(arch == "arm64" && OS == "macos")) {
504+
std::cout << "TensorFlow was not bundled with this distribution." << std::endl;
505+
const std::string destination = Config::getKernelBinaryPath() + "/../lib/tensorflow/";
506+
std::cout << "FAST will now download TensorFlow .." << std::endl;
507+
createDirectories(destination);
508+
downloadAndExtractZipFile("https://github.com/FAST-Imaging/FAST-dependencies/releases/download/v4.0.0/tensorflow_2.4.0_" + OS + "_" + arch + ".zip", destination, "tensorflow");
509+
// Try to load TensorFlow inference engine again
510+
m_loaded = false;
511+
loadAll();
512+
}
513+
}
514+
}
515+
if(m_engines.count(name) == 0) {
516+
if(m_possibleEngines.count(name) > 0) {
517+
throw Exception("Inference engine with name " + name + " was not able to load. Make sure you have all required dependencies installed. Error message: \n" + m_possibleEngines[name]);
518+
} else {
519+
std::string engineList;
520+
for(const auto item : m_possibleEngines)
521+
engineList += item.first + ", ";
522+
throw Exception("No inference engine with name " + name + " was available on this system. Engines available are: " + engineList);
523+
}
524+
}
384525
// Call the load function which the map stores a handle to
385526
return std::shared_ptr<InferenceEngine>(m_engines.at(name)());
386527
}

source/FAST/Algorithms/NeuralNetwork/InferenceEngineManager.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ class FAST_EXPORT InferenceEngineManager {
1818
private:
1919
static bool m_loaded;
2020
static std::unordered_map<std::string, std::function<InferenceEngine*()>> m_engines;
21+
static std::unordered_map<std::string, std::string> m_possibleEngines; // key = name, value = errors
2122
};
2223

2324
}

source/FAST/DataHub.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,7 @@ DataHub::Download DataHub::download(std::string itemID, bool force) {
315315
}
316316
auto folder = join(m_storageDirectory, itemObject.id);
317317
auto downloadName = "'" + itemObject.name + "'";
318+
download.items.push_back(itemObject.name);
318319
download.paths.push_back(folder);
319320
createDirectories(folder);
320321
if(!getDirectoryList(folder, true, true).empty()) {

0 commit comments

Comments
 (0)