Skip to content

Commit 62e9fec

Browse files
authored
actual better fix
thanks C43H66N12O12S2
1 parent 29eff4a commit 62e9fec

File tree

1 file changed

+2
-5
lines changed

1 file changed

+2
-5
lines changed

modules/devices.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,9 @@ def torch_gc():
3939

4040
def enable_tf32():
4141
if torch.cuda.is_available():
42-
#TODO: make this better; find a way to check if it is a turing card
43-
turing = ["1630","1650","1660","Quadro RTX 3000","Quadro RTX 4000","Quadro RTX 4000","Quadro RTX 5000","Quadro RTX 5000","Quadro RTX 6000","Quadro RTX 6000","Quadro RTX 8000","Quadro RTX T400","Quadro RTX T400","Quadro RTX T600","Quadro RTX T1000","Quadro RTX T1000","2060","2070","2080","Titan RTX","Tesla T4","MX450","MX550"]
4442
for devid in range(0,torch.cuda.device_count()):
45-
for i in turing:
46-
if i in torch.cuda.get_device_name(devid):
47-
shd = True
43+
if torch.cuda.get_device_capability(devid) == (7, 5):
44+
shd = True
4845
if shd:
4946
torch.backends.cudnn.benchmark = True
5047
torch.backends.cudnn.enabled = True

0 commit comments

Comments
 (0)