Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit d72c73e

Browse files
authored
Fix tf v1 gpu import error (#267)
1 parent 07b68d0 commit d72c73e

File tree

2 files changed

+8
-2
lines changed

2 files changed

+8
-2
lines changed

integrations/pytorch/notebooks/sparse_quantized_transfer_learning.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -454,4 +454,4 @@
454454
},
455455
"nbformat": 4,
456456
"nbformat_minor": 4
457-
}
457+
}

src/sparseml/tensorflow_v1/base.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,13 @@ def check_tensorflow_install(
102102
raise tensorflow_err
103103
return False
104104

105-
return check_version("tensorflow", min_version, max_version, raise_on_error)
105+
return check_version(
106+
"tensorflow",
107+
min_version,
108+
max_version,
109+
raise_on_error,
110+
alternate_package_names=["tensorflow-gpu"],
111+
)
106112

107113

108114
def check_tf2onnx_install(

0 commit comments

Comments
 (0)