Skip to content

Commit f637bb8

Browse files
authored
Cleanup config.py (#992)
* Update config.py * miss
1 parent 5b9265d commit f637bb8

File tree

1 file changed

+32
-19
lines changed

1 file changed

+32
-19
lines changed

config.py

Lines changed: 32 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
import argparse
23
import sys
34
import torch
@@ -35,11 +36,13 @@ def __init__(self):
3536
self.iscolab,
3637
self.noparallel,
3738
self.noautoopen,
39+
self.dml
3840
) = self.arg_parse()
39-
self.instead=""
41+
self.instead = ""
4042
self.x_pad, self.x_query, self.x_center, self.x_max = self.device_config()
4143

42-
def arg_parse(self) -> tuple:
44+
@staticmethod
45+
def arg_parse() -> tuple:
4346
exe = sys.executable or "python"
4447
parser = argparse.ArgumentParser()
4548
parser.add_argument("--port", type=int, default=7865, help="Listen port")
@@ -61,13 +64,14 @@ def arg_parse(self) -> tuple:
6164
cmd_opts = parser.parse_args()
6265

6366
cmd_opts.port = cmd_opts.port if 0 <= cmd_opts.port <= 65535 else 7865
64-
self.dml=cmd_opts.dml
67+
6568
return (
6669
cmd_opts.pycmd,
6770
cmd_opts.port,
6871
cmd_opts.colab,
6972
cmd_opts.noparallel,
7073
cmd_opts.noautoopen,
74+
cmd_opts.dml
7175
)
7276

7377
# has_mps is only available in nightly pytorch (for now) and MasOS 12.3+.
@@ -112,12 +116,12 @@ def device_config(self) -> tuple:
112116
f.write(strr)
113117
elif self.has_mps():
114118
print("No supported Nvidia GPU found")
115-
self.device = self.instead="mps"
119+
self.device = self.instead = "mps"
116120
self.is_half = False
117121
use_fp32_config()
118122
else:
119123
print("No supported Nvidia GPU found")
120-
self.device = self.instead="cpu"
124+
self.device = self.instead = "cpu"
121125
self.is_half = False
122126
use_fp32_config()
123127

@@ -137,25 +141,34 @@ def device_config(self) -> tuple:
137141
x_center = 38
138142
x_max = 41
139143

140-
if self.gpu_mem != None and self.gpu_mem <= 4:
144+
if self.gpu_mem is not None and self.gpu_mem <= 4:
141145
x_pad = 1
142146
x_query = 5
143147
x_center = 30
144148
x_max = 32
145-
if(self.dml==True):
149+
if self.dml:
146150
print("use DirectML instead")
147-
try:os.rename("runtime\Lib\site-packages\onnxruntime","runtime\Lib\site-packages\onnxruntime-cuda")
148-
except:pass
149-
try:os.rename("runtime\Lib\site-packages\onnxruntime-dml","runtime\Lib\site-packages\onnxruntime")
150-
except:pass
151+
try:
152+
os.rename("runtime\Lib\site-packages\onnxruntime","runtime\Lib\site-packages\onnxruntime-cuda")
153+
except:
154+
pass
155+
try:
156+
os.rename("runtime\Lib\site-packages\onnxruntime-dml","runtime\Lib\site-packages\onnxruntime")
157+
except:
158+
159+
pass
151160
import torch_directml
152-
self.device= torch_directml.device(torch_directml.default_device())
153-
self.is_half=False
161+
self.device = torch_directml.device(torch_directml.default_device())
162+
self.is_half = False
154163
else:
155-
if(self.instead):
156-
print("use %s instead"%self.instead)
157-
try:os.rename("runtime\Lib\site-packages\onnxruntime","runtime\Lib\site-packages\onnxruntime-cuda")
158-
except:pass
159-
try:os.rename("runtime\Lib\site-packages\onnxruntime-dml","runtime\Lib\site-packages\onnxruntime")
160-
except:pass
164+
if self.instead:
165+
print(f"use {self.instead} instead")
166+
try:
167+
os.rename("runtime\Lib\site-packages\onnxruntime","runtime\Lib\site-packages\onnxruntime-cuda")
168+
except:
169+
pass
170+
try:
171+
os.rename("runtime\Lib\site-packages\onnxruntime-dml","runtime\Lib\site-packages\onnxruntime")
172+
except:
173+
pass
161174
return x_pad, x_query, x_center, x_max

0 commit comments

Comments
 (0)