Skip to content

Commit f95ba20

Browse files
authored
Do not use the base version by default in _compare_version (#10051)
1 parent 2259893 commit f95ba20

File tree

2 files changed

+24
-1
lines changed

2 files changed

+24
-1
lines changed

pytorch_lightning/utilities/imports.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def _module_available(module_path: str) -> bool:
4444
return False
4545

4646

47-
def _compare_version(package: str, op: Callable, version: str, use_base_version: bool = True) -> bool:
47+
def _compare_version(package: str, op: Callable, version: str, use_base_version: bool = False) -> bool:
4848
"""Compare package version with some requirements.
4949
5050
>>> _compare_version("torch", operator.ge, "0.1")

tests/utilities/test_imports.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,10 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import operator
1415

1516
from pytorch_lightning.utilities import _module_available
17+
from pytorch_lightning.utilities.imports import _compare_version
1618

1719

1820
def test_module_exists():
@@ -22,3 +24,24 @@ def test_module_exists():
2224
assert not _module_available("torch.nn.asdf")
2325
assert not _module_available("asdf")
2426
assert not _module_available("asdf.bla.asdf")
27+
28+
29+
def test_compare_version(monkeypatch):
30+
from pytorch_lightning.utilities.imports import torch
31+
32+
monkeypatch.setattr(torch, "__version__", "1.8.9")
33+
assert not _compare_version("torch", operator.ge, "1.10.0")
34+
assert _compare_version("torch", operator.lt, "1.10.0")
35+
36+
monkeypatch.setattr(torch, "__version__", "1.10.0.dev123")
37+
assert _compare_version("torch", operator.ge, "1.10.0.dev123")
38+
assert not _compare_version("torch", operator.ge, "1.10.0.dev124")
39+
40+
assert _compare_version("torch", operator.ge, "1.10.0.dev123", use_base_version=True)
41+
assert _compare_version("torch", operator.ge, "1.10.0.dev124", use_base_version=True)
42+
43+
monkeypatch.setattr(torch, "__version__", "1.10.0a0+0aef44c") # dev version before rc
44+
assert _compare_version("torch", operator.ge, "1.10.0.rc0", use_base_version=True)
45+
assert not _compare_version("torch", operator.ge, "1.10.0.rc0")
46+
assert _compare_version("torch", operator.ge, "1.10.0", use_base_version=True)
47+
assert not _compare_version("torch", operator.ge, "1.10.0")

0 commit comments

Comments
 (0)