We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 3bee377 commit b281d55Copy full SHA for b281d55
denoising_diffusion_pytorch/attend.py
@@ -60,7 +60,9 @@ def __init__(
60
61
device_properties = torch.cuda.get_device_properties(torch.device('cuda'))
62
63
- if device_properties.major == 8 and device_properties.minor == 0:
+ device_version = version.parse(f'{device_properties.major}.{device_properties.minor}')
64
+
65
+ if device_version > version.parse('8.0'):
66
print_once('A100 GPU detected, using flash attention if input tensor is on cuda')
67
self.cuda_config = AttentionConfig(True, False, False)
68
else:
denoising_diffusion_pytorch/version.py
@@ -1 +1 @@
1
-__version__ = '2.0.8'
+__version__ = '2.0.10'
0 commit comments