1
+ import os
1
2
import argparse
2
3
import sys
3
4
import torch
@@ -35,11 +36,13 @@ def __init__(self):
35
36
self .iscolab ,
36
37
self .noparallel ,
37
38
self .noautoopen ,
39
+ self .dml
38
40
) = self .arg_parse ()
39
- self .instead = ""
41
+ self .instead = ""
40
42
self .x_pad , self .x_query , self .x_center , self .x_max = self .device_config ()
41
43
42
- def arg_parse (self ) -> tuple :
44
+ @staticmethod
45
+ def arg_parse () -> tuple :
43
46
exe = sys .executable or "python"
44
47
parser = argparse .ArgumentParser ()
45
48
parser .add_argument ("--port" , type = int , default = 7865 , help = "Listen port" )
@@ -61,13 +64,14 @@ def arg_parse(self) -> tuple:
61
64
cmd_opts = parser .parse_args ()
62
65
63
66
cmd_opts .port = cmd_opts .port if 0 <= cmd_opts .port <= 65535 else 7865
64
- self . dml = cmd_opts . dml
67
+
65
68
return (
66
69
cmd_opts .pycmd ,
67
70
cmd_opts .port ,
68
71
cmd_opts .colab ,
69
72
cmd_opts .noparallel ,
70
73
cmd_opts .noautoopen ,
74
+ cmd_opts .dml
71
75
)
72
76
73
77
# has_mps is only available in nightly pytorch (for now) and MasOS 12.3+.
@@ -112,12 +116,12 @@ def device_config(self) -> tuple:
112
116
f .write (strr )
113
117
elif self .has_mps ():
114
118
print ("No supported Nvidia GPU found" )
115
- self .device = self .instead = "mps"
119
+ self .device = self .instead = "mps"
116
120
self .is_half = False
117
121
use_fp32_config ()
118
122
else :
119
123
print ("No supported Nvidia GPU found" )
120
- self .device = self .instead = "cpu"
124
+ self .device = self .instead = "cpu"
121
125
self .is_half = False
122
126
use_fp32_config ()
123
127
@@ -137,25 +141,34 @@ def device_config(self) -> tuple:
137
141
x_center = 38
138
142
x_max = 41
139
143
140
- if self .gpu_mem != None and self .gpu_mem <= 4 :
144
+ if self .gpu_mem is not None and self .gpu_mem <= 4 :
141
145
x_pad = 1
142
146
x_query = 5
143
147
x_center = 30
144
148
x_max = 32
145
- if ( self .dml == True ) :
149
+ if self .dml :
146
150
print ("use DirectML instead" )
147
- try :os .rename ("runtime\Lib\site-packages\onnxruntime" ,"runtime\Lib\site-packages\onnxruntime-cuda" )
148
- except :pass
149
- try :os .rename ("runtime\Lib\site-packages\onnxruntime-dml" ,"runtime\Lib\site-packages\onnxruntime" )
150
- except :pass
151
+ try :
152
+ os .rename ("runtime\Lib\site-packages\onnxruntime" ,"runtime\Lib\site-packages\onnxruntime-cuda" )
153
+ except :
154
+ pass
155
+ try :
156
+ os .rename ("runtime\Lib\site-packages\onnxruntime-dml" ,"runtime\Lib\site-packages\onnxruntime" )
157
+ except :
158
+
159
+ pass
151
160
import torch_directml
152
- self .device = torch_directml .device (torch_directml .default_device ())
153
- self .is_half = False
161
+ self .device = torch_directml .device (torch_directml .default_device ())
162
+ self .is_half = False
154
163
else :
155
- if (self .instead ):
156
- print ("use %s instead" % self .instead )
157
- try :os .rename ("runtime\Lib\site-packages\onnxruntime" ,"runtime\Lib\site-packages\onnxruntime-cuda" )
158
- except :pass
159
- try :os .rename ("runtime\Lib\site-packages\onnxruntime-dml" ,"runtime\Lib\site-packages\onnxruntime" )
160
- except :pass
164
+ if self .instead :
165
+ print (f"use { self .instead } instead" )
166
+ try :
167
+ os .rename ("runtime\Lib\site-packages\onnxruntime" ,"runtime\Lib\site-packages\onnxruntime-cuda" )
168
+ except :
169
+ pass
170
+ try :
171
+ os .rename ("runtime\Lib\site-packages\onnxruntime-dml" ,"runtime\Lib\site-packages\onnxruntime" )
172
+ except :
173
+ pass
161
174
return x_pad , x_query , x_center , x_max
0 commit comments