@@ -98,19 +98,21 @@ public VladAutomatic(IGithubApiCache githubApi, ISettingsManager settingsManager
9898 {
9999 Name = "Use DirectML if no compatible GPU is detected" ,
100100 Type = LaunchOptionType . Bool ,
101- InitialValue = ! HardwareHelper . HasNvidiaGpu ( ) && HardwareHelper . HasAmdGpu ( ) ,
101+ InitialValue = PreferDirectML ( ) ,
102102 Options = new ( ) { "--use-directml" }
103103 } ,
104104 new ( )
105105 {
106106 Name = "Force use of Nvidia CUDA backend" ,
107107 Type = LaunchOptionType . Bool ,
108+ InitialValue = HardwareHelper . HasNvidiaGpu ( ) ,
108109 Options = new ( ) { "--use-cuda" }
109110 } ,
110111 new ( )
111112 {
112113 Name = "Force use of AMD ROCm backend" ,
113114 Type = LaunchOptionType . Bool ,
115+ InitialValue = PreferRocm ( ) ,
114116 Options = new ( ) { "--use-rocm" }
115117 } ,
116118 new ( )
@@ -136,6 +138,16 @@ public VladAutomatic(IGithubApiCache githubApi, ISettingsManager settingsManager
136138
137139 public override string ExtraLaunchArguments => "" ;
138140
141+ // Set ROCm for default if AMD and Linux
142+ private static bool PreferRocm ( ) => ! HardwareHelper . HasNvidiaGpu ( )
143+ && HardwareHelper . HasAmdGpu ( )
144+ && Compat . IsLinux ;
145+
146+ // Set DirectML for default if AMD and Windows
147+ private static bool PreferDirectML ( ) => ! HardwareHelper . HasNvidiaGpu ( )
148+ && HardwareHelper . HasAmdGpu ( )
149+ && Compat . IsWindows ;
150+
139151 public override Task < string > GetLatestVersion ( ) => Task . FromResult ( "master" ) ;
140152
141153 public override async Task < IEnumerable < PackageVersion > > GetAllVersions ( bool isReleaseMode = true )
@@ -150,42 +162,38 @@ public override async Task<IEnumerable<PackageVersion>> GetAllVersions(bool isRe
150162
151163 public override async Task InstallPackage ( IProgress < ProgressReport > ? progress = null )
152164 {
153- progress ? . Report ( new ProgressReport ( - 1f , "Installing dependencies ..." , isIndeterminate : true ) ) ;
165+ progress ? . Report ( new ProgressReport ( - 1f , "Installing package ..." , isIndeterminate : true ) ) ;
154166 // Setup venv
155167 var venvRunner = new PyVenvRunner ( Path . Combine ( InstallLocation , "venv" ) ) ;
156168 venvRunner . WorkingDirectory = InstallLocation ;
157- if ( ! venvRunner . Exists ( ) )
169+ venvRunner . EnvironmentVariables = SettingsManager . Settings . EnvironmentVariables ;
170+
171+ await venvRunner . Setup ( ) . ConfigureAwait ( false ) ;
172+
173+ // Run initial install
174+ if ( HardwareHelper . HasNvidiaGpu ( ) )
158175 {
159- await venvRunner . Setup ( ) . ConfigureAwait ( false ) ;
176+ // CUDA
177+ await venvRunner . CustomInstall ( "launch.py --use-cuda --debug --test" , OnConsoleOutput )
178+ . ConfigureAwait ( false ) ;
160179 }
161-
162- // Install torch / xformers based on gpu info
163- var gpus = HardwareHelper . IterGpuInfo ( ) . ToList ( ) ;
164- if ( gpus . Any ( g => g . IsNvidia ) )
180+ else if ( PreferRocm ( ) )
165181 {
166- Logger . Info ( "Starting torch install (CUDA)..." ) ;
167- await venvRunner . PipInstall ( PyVenvRunner . TorchPipInstallArgsCuda , OnConsoleOutput )
182+ // ROCm
183+ await venvRunner . CustomInstall ( "launch.py --use-rocm --debug --test" , OnConsoleOutput )
168184 . ConfigureAwait ( false ) ;
169-
170- Logger . Info ( "Installing xformers..." ) ;
171- await venvRunner . PipInstall ( "xformers" , OnConsoleOutput ) . ConfigureAwait ( false ) ;
172185 }
173- else if ( gpus . Any ( g => g . IsAmd ) )
186+ else if ( PreferDirectML ( ) )
174187 {
175- Logger . Info ( "Starting torch install ( DirectML)..." ) ;
176- await venvRunner . PipInstall ( PyVenvRunner . TorchPipInstallArgsDirectML , OnConsoleOutput )
188+ // DirectML
189+ await venvRunner . CustomInstall ( "launch.py --use-directml --debug --test" , OnConsoleOutput )
177190 . ConfigureAwait ( false ) ;
178191 }
179192 else
180193 {
181- Logger . Info ( "Starting torch install (CPU)..." ) ;
182- await venvRunner . PipInstall ( PyVenvRunner . TorchPipInstallArgsCpu , OnConsoleOutput )
194+ await venvRunner . CustomInstall ( "launch.py --debug --test" , OnConsoleOutput )
183195 . ConfigureAwait ( false ) ;
184196 }
185-
186- // Install requirements file
187- Logger . Info ( "Installing requirements.txt" ) ;
188- await venvRunner . PipInstall ( $ "-r requirements.txt", OnConsoleOutput ) . ConfigureAwait ( false ) ;
189197
190198 progress ? . Report ( new ProgressReport ( 1 , isIndeterminate : false ) ) ;
191199 }
0 commit comments