Skip to content

Commit 42cc909

Browse files
lmonetadpiparo
authored andcommitted
Protect PyMva tutorials for a failurte in importing tensorflow
Disable also batchgenerator tests for Python version less than 3.9 since tensorflow is not working for these lower python versions
1 parent 10aff3f commit 42cc909

File tree

3 files changed

+19
-5
lines changed

3 files changed

+19
-5
lines changed

bindings/pyroot/pythonizations/test/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,6 @@ ROOT_ADD_PYUNITTEST(pyroot_memory memory.py)
191191
# rbatchgenerator tests
192192
# TODO: We currently do not support TensorFlow for Python >= 3.12 (see requirements.txt)
193193
# Update here once that is fixed.
194-
if (NOT MSVC AND Python3_VERSION VERSION_LESS 3.12)
194+
if (NOT MSVC AND tmva)
195195
ROOT_ADD_PYUNITTEST(batchgen rbatchgenerator_completeness.py PYTHON_DEPS numpy tensorflow torch)
196196
endif()

tutorials/machine_learning/TMVA_CNN_Classification.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,11 @@
3737
useKerasCNN = False
3838
print("TMVA_CNN_Classificaton","Skip using Keras since tensorflow is not installed")
3939
else:
40-
import tensorflow
40+
try:
41+
import tensorflow
42+
except:
43+
ROOT.Warning("TMVA_CNN_Classification", "Skip using Keras since tensorflow cannot be imported")
44+
useKerasCNN = False
4145

4246
# PyTorch has to be imported before ROOT to avoid crashes because of clashing
4347
# std::regexp symbols that are exported by cppyy.
@@ -47,7 +51,11 @@
4751
usePyTorchCNN = False
4852
print("TMVA_CNN_Classificaton","Skip using PyTorch since torch is not installed")
4953
else:
50-
import torch
54+
try:
55+
import torch
56+
except:
57+
ROOT.Warning("TMVA_CNN_Classification", "Skip using PyTorch since it cannot be imported")
58+
usePyTorchCNN = False
5159

5260

5361
import ROOT

tutorials/machine_learning/TMVA_RNN_Classification.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,8 +160,14 @@ def MakeTimeData(n, ntime, ndim):
160160

161161
tf_spec = importlib.util.find_spec("tensorflow")
162162
if tf_spec is None:
163-
useKeras = False
164-
ROOT.Warning("TMVA_RNN_Classificaton","Skip using Keras since tensorflow is not installed")
163+
useKeras = False
164+
ROOT.Warning("TMVA_RNN_Classificaton","Skip using Keras since tensorflow is not installed")
165+
else:
166+
try:
167+
import tensorflow
168+
except:
169+
ROOT.Warning("TMVA_RNN_Classification", "Skip using Keras since tensorflow cannot be imported")
170+
useKeras = False
165171

166172

167173
rnn_types = ["RNN", "LSTM", "GRU"]

0 commit comments

Comments
 (0)