Skip to content

Commit 562eb4b

Browse files
committed
use is_torch_version
1 parent 68f5c3c commit 562eb4b

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

src/diffusers/models/activations.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from torch import nn
1919

2020
from ..utils import deprecate
21-
from ..utils.import_utils import is_torch_npu_available
21+
from ..utils.import_utils import is_torch_npu_available, is_torch_version
2222

2323

2424
if is_torch_npu_available():
@@ -79,7 +79,7 @@ def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: b
7979
self.approximate = approximate
8080

8181
def gelu(self, gate: torch.Tensor) -> torch.Tensor:
82-
if gate.device.type == "mps" and torch.__version__ < '2.0.0':
82+
if gate.device.type == "mps" and is_torch_version("<", "2.0.0"):
8383
# fp16 gelu not supported on mps before torch 2.0
8484
return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype)
8585
return F.gelu(gate, approximate=self.approximate)
@@ -105,7 +105,7 @@ def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
105105
self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias)
106106

107107
def gelu(self, gate: torch.Tensor) -> torch.Tensor:
108-
if gate.device.type == "mps" and torch.__version__ < '2.0.0':
108+
if gate.device.type == "mps" and is_torch_version("<", "2.0.0"):
109109
# fp16 gelu not supported on mps before torch 2.0
110110
return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
111111
return F.gelu(gate)

0 commit comments

Comments
 (0)