Skip to content

Commit 3e38c06

Browse files
committed
fix for running llamasharp inside containers based on nvidia/cuda and detecting version
1 parent 9f330c2 commit 3e38c06

File tree

1 file changed

+23
-12
lines changed

1 file changed

+23
-12
lines changed

LLama/Native/Load/SystemInfo.cs

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -41,31 +41,31 @@ public static SystemInfo Get()
4141

4242
return new SystemInfo(platform, GetCudaMajorVersion(), GetVulkanVersion());
4343
}
44-
44+
4545
#region Vulkan version
4646
private static string? GetVulkanVersion()
4747
{
4848
// Get Vulkan Summary
4949
string? vulkanSummary = GetVulkanSummary();
50-
50+
5151
// If we have a Vulkan summary
5252
if (vulkanSummary != null)
5353
{
5454
// Extract Vulkan version from summary
5555
string? vulkanVersion = ExtractVulkanVersionFromSummary(vulkanSummary);
56-
56+
5757
// If we have a Vulkan version
5858
if (vulkanVersion != null)
5959
{
6060
// Return the Vulkan version
6161
return vulkanVersion;
6262
}
6363
}
64-
64+
6565
// Return null if we failed to get the Vulkan version
6666
return null;
6767
}
68-
68+
6969
private static string? GetVulkanSummary()
7070
{
7171
// Note: on Linux, this requires `vulkan-tools` to be installed. (`sudo apt install vulkan-tools`)
@@ -102,19 +102,19 @@ public static SystemInfo Get()
102102
// We have three ways of parsing the Vulkan version from the summary (output is a different between Windows and Linux)
103103
// For now, I have decided to go with the full version number, and leave it up to the user to parse it further if needed
104104
// I have left the other patterns in, in case we need them in the future
105-
105+
106106
// Output on linux : 4206847 (1.3.255)
107107
// Output on windows : 1.3.255
108108
string pattern = @"apiVersion\s*=\s*([^\r\n]+)";
109-
109+
110110
// Output on linux : 4206847
111111
// Output on windows : 1.3.255
112112
//string pattern = @"apiVersion\s*=\s*([\d\.]+)";
113-
113+
114114
// Output on linux : 1.3.255
115115
// Output on windows : 1.3.255
116116
//string pattern = @"apiVersion\s*=\s*(?:\d+\s*)?(?:\(\s*)?([\d]+\.[\d]+\.[\d]+)(?:\s*\))?";
117-
117+
118118
// Create a Regex object to match the pattern
119119
Regex regex = new Regex(pattern);
120120

@@ -158,24 +158,30 @@ private static int GetCudaMajorVersion()
158158
}
159159
else if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux))
160160
{
161+
string? env_version = Environment.GetEnvironmentVariable("CUDA_VERSION");
162+
if (env_version is not null)
163+
{
164+
return ExtractMajorVersion(ref env_version);
165+
}
166+
161167
// List of default cuda paths
162168
string[] defaultCudaPaths =
163169
[
164170
"/usr/local/bin/cuda",
165171
"/usr/local/cuda",
166172
];
167-
173+
168174
// Loop through every default path to find the version
169175
foreach (var path in defaultCudaPaths)
170176
{
171177
// Attempt to get the version from the path
172178
version = GetCudaVersionFromPath(path);
173-
179+
174180
// If a CUDA version is found, break the loop
175181
if (!string.IsNullOrEmpty(version))
176182
break;
177183
}
178-
184+
179185
if (string.IsNullOrEmpty(version))
180186
{
181187
cudaPath = Environment.GetEnvironmentVariable("LD_LIBRARY_PATH");
@@ -197,6 +203,11 @@ private static int GetCudaMajorVersion()
197203
if (string.IsNullOrEmpty(version))
198204
return -1;
199205

206+
return ExtractMajorVersion(ref version);
207+
}
208+
209+
private static int ExtractMajorVersion(ref string version)
210+
{
200211
version = version.Split('.')[0];
201212
if (int.TryParse(version, out var majorVersion))
202213
return majorVersion;

0 commit comments

Comments
 (0)