Skip to content

Commit 4878301

Browse files
committed
fix(xunix): tighten up gpuExtraRegex to avoid extraneous matches
1 parent ad7c946 commit 4878301

File tree

2 files changed

+31
-12
lines changed

2 files changed

+31
-12
lines changed

xunix/gpu.go

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,9 @@ import (
1818
)
1919

2020
var (
21-
gpuMountRegex = regexp.MustCompile("(?i)(nvidia|vulkan|cuda)")
22-
gpuExtraRegex = regexp.MustCompile("(?i)(libgl|nvidia|vulkan|cuda)")
23-
gpuEnvRegex = regexp.MustCompile("(?i)nvidia")
24-
sharedObjectRegex = regexp.MustCompile(`\.so(\.[0-9\.]+)?$`)
21+
gpuMountRegex = regexp.MustCompile(`(?i)(nvidia|vulkan|cuda)`)
22+
gpuExtraRegex = regexp.MustCompile(`(?i)(libgl(e|sx|\.)|nvidia|vulkan|cuda)`)
23+
gpuEnvRegex = regexp.MustCompile(`(?i)nvidia`)
2524
)
2625

2726
func GPUEnvs(ctx context.Context) []string {
@@ -65,8 +64,12 @@ func GPUs(ctx context.Context, log slog.Logger, usrLibDir string) ([]Device, []m
6564
}
6665

6766
// If it's not in /dev treat it as a bind mount.
68-
links, err := SameDirSymlinks(afs, m.Path)
6967
binds = append(binds, m)
68+
// We also want to find any symlinks that point to the target.
69+
// This is important for the nvidia driver as it mounts the driver
70+
// files with the driver version appended to the end, and creates
71+
// symlinks that point to the actual files.
72+
links, err := SameDirSymlinks(afs, m.Path)
7073
if err != nil {
7174
log.Error(ctx, "find symlinks", slog.F("path", m.Path), slog.Error(err))
7275
} else {
@@ -118,7 +121,7 @@ func usrLibGPUs(ctx context.Context, log slog.Logger, usrLibDir string) ([]mount
118121
return nil
119122
}
120123

121-
if !sharedObjectRegex.MatchString(path) || !gpuExtraRegex.MatchString(path) {
124+
if !gpuExtraRegex.MatchString(path) {
122125
return nil
123126
}
124127

xunix/gpu_test.go

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -62,11 +62,17 @@ func TestGPUs(t *testing.T) {
6262
filepath.Join(usrLibMountpoint, "nvidia", "libglxserver_nvidia.so.1"),
6363
}
6464

65-
// fakeUsrLibFiles are files that should be written to the "mounted"
66-
// /usr/lib directory. It includes files that shouldn't be returned.
67-
fakeUsrLibFiles = append([]string{
65+
// fakeUsrLibFiles are files that we do not expect to be returned
66+
// bind mounts for.
67+
fakeUsrLibFiles = []string{
6868
filepath.Join(usrLibMountpoint, "libcurl-gnutls.so"),
69-
}, expectedUsrLibFiles...)
69+
filepath.Join(usrLibMountpoint, "libglib.so"),
70+
}
71+
72+
// allUsrLibFiles are all the files that should be written to the
73+
// "mounted" /usr/lib directory. It includes files that shouldn't
74+
// be returned.
75+
allUsrLibFiles = append(expectedUsrLibFiles, fakeUsrLibFiles...)
7076
)
7177

7278
ctx := xunix.WithFS(context.Background(), fs)
@@ -93,15 +99,19 @@ func TestGPUs(t *testing.T) {
9399
err := fs.MkdirAll(filepath.Join(usrLibMountpoint, "nvidia"), 0o755)
94100
require.NoError(t, err)
95101

96-
for _, file := range fakeUsrLibFiles {
102+
for _, file := range allUsrLibFiles {
97103
_, err = fs.Create(file)
98104
require.NoError(t, err)
99105
}
106+
for _, mp := range mounter.MountPoints {
107+
_, err = fs.Create(mp.Path)
108+
require.NoError(t, err)
109+
}
100110

101111
devices, binds, err := xunix.GPUs(ctx, log, usrLibMountpoint)
102112
require.NoError(t, err)
103113
require.Len(t, devices, 2, "unexpected 2 nvidia devices")
104-
require.Len(t, binds, 4, "expected 4 nvidia binds")
114+
require.Len(t, binds, 5, "expected 5 nvidia binds")
105115
require.Contains(t, binds, mount.MountPoint{
106116
Device: "/dev/sda1",
107117
Path: "/usr/local/nvidia",
@@ -113,6 +123,12 @@ func TestGPUs(t *testing.T) {
113123
Opts: []string{"ro"},
114124
})
115125
}
126+
for _, file := range fakeUsrLibFiles {
127+
require.NotContains(t, binds, mount.MountPoint{
128+
Path: file,
129+
Opts: []string{"ro"},
130+
})
131+
}
116132
})
117133
}
118134

0 commit comments

Comments
 (0)