Skip to content

Commit 8346341

Browse files
authored
expanded the change to rocm6.4 to include AMD GPU's to not have version contraint for pytorch, so latest pytorch is installed when using AMD
1 parent b706078 commit 8346341

File tree

1 file changed

+23
-18
lines changed

1 file changed

+23
-18
lines changed

StabilityMatrix.Core/Models/Packages/SDWebForge.cs

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -170,25 +170,30 @@ public override async Task InstallPackage(
170170
);
171171
}
172172

173-
var torchIndex = options.PythonOptions.TorchIndex ?? GetRecommendedTorchVersion();
174-
var isBlackwell =
175-
torchIndex is TorchIndex.Cuda
176-
&& (SettingsManager.Settings.PreferredGpu?.IsBlackwellGpu() ?? HardwareHelper.HasBlackwellGpu());
173+
var torchIndex = options.PythonOptions.TorchIndex ?? GetRecommendedTorchVersion();
174+
var isBlackwell =
175+
torchIndex is TorchIndex.Cuda
176+
&& (SettingsManager.Settings.PreferredGpu?. IsBlackwellGpu() ?? HardwareHelper.HasBlackwellGpu());
177177

178-
var config = new PipInstallConfig
179-
{
180-
PrePipInstallArgs = ["joblib"],
181-
RequirementsFilePaths = requirementsPaths,
182-
TorchVersion = isBlackwell ? "" : "==2.3.1",
183-
TorchvisionVersion = isBlackwell ? "" : "==0.18.1",
184-
CudaIndex = isBlackwell ? "cu128" : "cu121",
185-
RocmIndex = "rocm6.4",
186-
ExtraPipArgs =
187-
[
188-
"https://github.com/openai/CLIP/archive/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1.zip",
189-
],
190-
PostInstallPipArgs = ["numpy==1.26.4"],
191-
};
178+
var isAmd = torchIndex is TorchIndex.Rocm;
179+
180+
// Use latest PyTorch for Blackwell and AMD ROCm, otherwise use 2.3.1
181+
var useLatestPyTorch = isBlackwell || isAmd;
182+
183+
var config = new PipInstallConfig
184+
{
185+
PrePipInstallArgs = ["joblib"],
186+
RequirementsFilePaths = requirementsPaths,
187+
TorchVersion = useLatestPyTorch ? "" : "==2.3.1",
188+
TorchvisionVersion = useLatestPyTorch ? "" : "==0.18.1",
189+
CudaIndex = isBlackwell ? "cu128" : "cu121",
190+
RocmIndex = "rocm6.4",
191+
ExtraPipArgs =
192+
[
193+
"https://github.com/openai/CLIP/archive/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1.zip",
194+
],
195+
PostInstallPipArgs = ["numpy==1.26.4"],
196+
};
192197

193198
await StandardPipInstallProcessAsync(
194199
venvRunner,

0 commit comments

Comments
 (0)