diff --git a/clip_benchmark/datasets/tfds.py b/clip_benchmark/datasets/tfds.py index 320776e..146804f 100644 --- a/clip_benchmark/datasets/tfds.py +++ b/clip_benchmark/datasets/tfds.py @@ -9,8 +9,11 @@ def download_tfds_dataset(name, data_dir=None): builder.download_and_prepare() def disable_gpus_on_tensorflow(): - import tensorflow as tf - tf.config.set_visible_devices([], 'GPU') + try: + import tensorflow as tf + tf.config.set_visible_devices([], 'GPU') + except ImportError: + pass class VTABIterableDataset(torch.utils.data.IterableDataset):