Skip to content

Commit d51d107

Browse files
authored
Merge pull request #1222 from Crelex/crelex-nvidia-cuda-container-version-get-fix
Fix for getting CUDA Version inside nvidia/cuda containers
2 parents 8ff2e89 + e59bf81 commit d51d107

File tree

1 file changed

+11
-26
lines changed

1 file changed

+11
-26
lines changed

LLama/Native/Load/SystemInfo.cs

Lines changed: 11 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -41,31 +41,26 @@ public static SystemInfo Get()
4141

4242
return new SystemInfo(platform, GetCudaMajorVersion(), GetVulkanVersion());
4343
}
44-
4544
#region Vulkan version
4645
private static string? GetVulkanVersion()
4746
{
4847
// Get Vulkan Summary
4948
string? vulkanSummary = GetVulkanSummary();
50-
5149
// If we have a Vulkan summary
5250
if (vulkanSummary != null)
5351
{
5452
// Extract Vulkan version from summary
5553
string? vulkanVersion = ExtractVulkanVersionFromSummary(vulkanSummary);
56-
5754
// If we have a Vulkan version
5855
if (vulkanVersion != null)
5956
{
6057
// Return the Vulkan version
6158
return vulkanVersion;
6259
}
6360
}
64-
6561
// Return null if we failed to get the Vulkan version
6662
return null;
6763
}
68-
6964
private static string? GetVulkanSummary()
7065
{
7166
// Note: on Linux, this requires `vulkan-tools` to be installed. (`sudo apt install vulkan-tools`)
@@ -85,7 +80,6 @@ public static SystemInfo Get()
8580
}
8681
};
8782
var (exitCode, output, error, ok) = process.SafeRun(TimeSpan.FromSeconds(12));
88-
8983
// If ok return the output else return null
9084
return ok ? output :
9185
null;
@@ -96,38 +90,30 @@ public static SystemInfo Get()
9690
return null;
9791
}
9892
}
99-
10093
static string? ExtractVulkanVersionFromSummary(string vulkanSummary)
10194
{
10295
// We have three ways of parsing the Vulkan version from the summary (output is a different between Windows and Linux)
10396
// For now, I have decided to go with the full version number, and leave it up to the user to parse it further if needed
10497
// I have left the other patterns in, in case we need them in the future
105-
10698
// Output on linux : 4206847 (1.3.255)
10799
// Output on windows : 1.3.255
108100
string pattern = @"apiVersion\s*=\s*([^\r\n]+)";
109-
110101
// Output on linux : 4206847
111102
// Output on windows : 1.3.255
112103
//string pattern = @"apiVersion\s*=\s*([\d\.]+)";
113-
114104
// Output on linux : 1.3.255
115105
// Output on windows : 1.3.255
116106
//string pattern = @"apiVersion\s*=\s*(?:\d+\s*)?(?:\(\s*)?([\d]+\.[\d]+\.[\d]+)(?:\s*\))?";
117-
118107
// Create a Regex object to match the pattern
119108
Regex regex = new Regex(pattern);
120-
121109
// Match the pattern in the input string
122110
Match match = regex.Match(vulkanSummary);
123-
124111
// If a match is found
125112
if (match.Success)
126113
{
127114
// Return the version number
128115
return match.Groups[1].Value;
129116
}
130-
131117
// Return null if no match is found
132118
return null;
133119
}
@@ -145,37 +131,37 @@ private static int GetCudaMajorVersion()
145131
{
146132
return -1;
147133
}
148-
149134
//Ensuring cuda bin path is reachable. Especially for MAUI environment.
150135
string cudaBinPath = Path.Combine(cudaPath, "bin");
151-
152136
if (Directory.Exists(cudaBinPath))
153137
{
154138
AddDllDirectory(cudaBinPath);
155139
}
156-
157140
version = GetCudaVersionFromPath(cudaPath);
158141
}
159142
else if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux))
160143
{
144+
string? env_version = Environment.GetEnvironmentVariable("CUDA_VERSION");
145+
if (env_version is not null)
146+
{
147+
return ExtractMajorVersion(ref env_version);
148+
}
161149
// List of default cuda paths
162150
string[] defaultCudaPaths =
163151
[
164152
"/usr/local/bin/cuda",
165153
"/usr/local/cuda",
166154
];
167-
168155
// Loop through every default path to find the version
169156
foreach (var path in defaultCudaPaths)
170157
{
171158
// Attempt to get the version from the path
172159
version = GetCudaVersionFromPath(path);
173-
160+
174161
// If a CUDA version is found, break the loop
175162
if (!string.IsNullOrEmpty(version))
176163
break;
177164
}
178-
179165
if (string.IsNullOrEmpty(version))
180166
{
181167
cudaPath = Environment.GetEnvironmentVariable("LD_LIBRARY_PATH");
@@ -193,17 +179,18 @@ private static int GetCudaMajorVersion()
193179
}
194180
}
195181
}
196-
197182
if (string.IsNullOrEmpty(version))
198183
return -1;
199184

185+
return ExtractMajorVersion(ref version);
186+
}
187+
private static int ExtractMajorVersion(ref string version)
188+
{
200189
version = version.Split('.')[0];
201190
if (int.TryParse(version, out var majorVersion))
202191
return majorVersion;
203-
204192
return -1;
205193
}
206-
207194
private static string GetCudaVersionFromPath(string cudaPath)
208195
{
209196
try
@@ -226,12 +213,10 @@ private static string GetCudaVersionFromPath(string cudaPath)
226213
return string.Empty;
227214
}
228215
}
229-
230216
// Put it here to avoid calling NativeApi when getting the cuda version.
231217
[DllImport("kernel32.dll", CharSet = CharSet.Unicode, SetLastError = true)]
232218
internal static extern int AddDllDirectory(string NewDirectory);
233-
234219
private const string cudaVersionFile = "version.json";
235220
#endregion
236221
}
237-
}
222+
}

0 commit comments

Comments
 (0)