@@ -170,25 +170,30 @@ public override async Task InstallPackage(
170170 ) ;
171171 }
172172
173- var torchIndex = options . PythonOptions . TorchIndex ?? GetRecommendedTorchVersion ( ) ;
174- var isBlackwell =
175- torchIndex is TorchIndex . Cuda
176- && ( SettingsManager . Settings . PreferredGpu ? . IsBlackwellGpu ( ) ?? HardwareHelper . HasBlackwellGpu ( ) ) ;
173+ var torchIndex = options . PythonOptions . TorchIndex ?? GetRecommendedTorchVersion ( ) ;
174+ var isBlackwell =
175+ torchIndex is TorchIndex . Cuda
176+ && ( SettingsManager . Settings . PreferredGpu ? . IsBlackwellGpu ( ) ?? HardwareHelper . HasBlackwellGpu ( ) ) ;
177177
178- var config = new PipInstallConfig
179- {
180- PrePipInstallArgs = [ "joblib" ] ,
181- RequirementsFilePaths = requirementsPaths ,
182- TorchVersion = isBlackwell ? "" : "==2.3.1" ,
183- TorchvisionVersion = isBlackwell ? "" : "==0.18.1" ,
184- CudaIndex = isBlackwell ? "cu128" : "cu121" ,
185- RocmIndex = "rocm6.4" ,
186- ExtraPipArgs =
187- [
188- "https://github.com/openai/CLIP/archive/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1.zip" ,
189- ] ,
190- PostInstallPipArgs = [ "numpy==1.26.4" ] ,
191- } ;
178+ var isAmd = torchIndex is TorchIndex . Rocm ;
179+
180+ // Use latest PyTorch for Blackwell and AMD ROCm, otherwise use 2.3.1
181+ var useLatestPyTorch = isBlackwell || isAmd ;
182+
183+ var config = new PipInstallConfig
184+ {
185+ PrePipInstallArgs = [ "joblib" ] ,
186+ RequirementsFilePaths = requirementsPaths ,
187+ TorchVersion = useLatestPyTorch ? "" : "==2.3.1" ,
188+ TorchvisionVersion = useLatestPyTorch ? "" : "==0.18.1" ,
189+ CudaIndex = isBlackwell ? "cu128" : "cu121" ,
190+ RocmIndex = "rocm6.4" ,
191+ ExtraPipArgs =
192+ [
193+ "https://github.com/openai/CLIP/archive/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1.zip" ,
194+ ] ,
195+ PostInstallPipArgs = [ "numpy==1.26.4" ] ,
196+ } ;
192197
193198 await StandardPipInstallProcessAsync (
194199 venvRunner ,
0 commit comments