Skip to content

Commit b281d55

Browse files
committed
address #328
1 parent 3bee377 commit b281d55

File tree

2 files changed

+4
-2
lines changed

2 files changed

+4
-2
lines changed

denoising_diffusion_pytorch/attend.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,9 @@ def __init__(
6060

6161
device_properties = torch.cuda.get_device_properties(torch.device('cuda'))
6262

63-
if device_properties.major == 8 and device_properties.minor == 0:
63+
device_version = version.parse(f'{device_properties.major}.{device_properties.minor}')
64+
65+
if device_version > version.parse('8.0'):
6466
print_once('A100 GPU detected, using flash attention if input tensor is on cuda')
6567
self.cuda_config = AttentionConfig(True, False, False)
6668
else:
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '2.0.8'
1+
__version__ = '2.0.10'

0 commit comments

Comments
 (0)