Skip to content

Commit 4ea87fc

Browse files
committed
feat(xunix): add SameDirSymlinks, add symlinks pointing to mounted libs
1 parent a9acf2c commit 4ea87fc

File tree

2 files changed

+179
-0
lines changed

2 files changed

+179
-0
lines changed

xunix/gpu.go

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"os"
77
"path/filepath"
88
"regexp"
9+
"slices"
910
"sort"
1011
"strings"
1112

@@ -39,6 +40,7 @@ func GPUEnvs(ctx context.Context) []string {
3940

4041
func GPUs(ctx context.Context, log slog.Logger, usrLibDir string) ([]Device, []mount.MountPoint, error) {
4142
var (
43+
afs = GetFS(ctx)
4244
mounter = Mounter(ctx)
4345
devices = []Device{}
4446
binds = []mount.MountPoint{}
@@ -63,7 +65,19 @@ func GPUs(ctx context.Context, log slog.Logger, usrLibDir string) ([]Device, []m
6365
}
6466

6567
// If it's not in /dev treat it as a bind mount.
68+
links, err := SameDirSymlinks(afs, m.Path)
6669
binds = append(binds, m)
70+
if err != nil {
71+
log.Error(ctx, "find symlinks", slog.F("path", m.Path), slog.Error(err))
72+
} else {
73+
for _, link := range links {
74+
log.Debug(ctx, "found symlink", slog.F("link", link), slog.F("target", m.Path))
75+
binds = append(binds, mount.MountPoint{
76+
Path: link,
77+
Opts: []string{"ro"},
78+
})
79+
}
80+
}
6781
}
6882
}
6983

@@ -176,6 +190,68 @@ func recursiveSymlinks(afs FS, mountpoint string, path string) ([]string, error)
176190
return paths, nil
177191
}
178192

193+
// SameDirSymlinks returns all links in the same directory as `target` that
194+
// point to target, either indirectly or directly. Only symlinks in the same
195+
// directory as `target` are considered.
196+
func SameDirSymlinks(afs FS, target string) ([]string, error) {
197+
var (
198+
found = make([]string, 0)
199+
maxIterations = 10 // arbitrary upper limit to prevent infinite loops
200+
)
201+
for range maxIterations {
202+
foundThisTime := false
203+
fis, err := afero.ReadDir(afs, filepath.Dir(target))
204+
if err != nil {
205+
return nil, xerrors.Errorf("read dir %q: %w", filepath.Dir(target), err)
206+
}
207+
for _, fi := range fis {
208+
// Ignore the target itself.
209+
if fi.Name() == filepath.Base(target) {
210+
continue
211+
}
212+
// Ignore non-symlinks.
213+
if fi.Mode()&os.ModeSymlink == 0 {
214+
continue
215+
}
216+
// Get the target of the symlink.
217+
link, err := afs.Readlink(filepath.Join(filepath.Dir(target), fi.Name()))
218+
if err != nil {
219+
return nil, xerrors.Errorf("readlink %q: %w", fi.Name(), err)
220+
}
221+
// Make the link absolute.
222+
if !filepath.IsAbs(link) {
223+
link = filepath.Join(filepath.Dir(target), link)
224+
}
225+
// Ignore symlinks that point outside of target's directory.
226+
if filepath.Dir(link) != filepath.Dir(target) {
227+
continue
228+
}
229+
230+
// Check if the symlink points to to the target, or if it points
231+
// to one of the symlinks we've already found.
232+
if link != target {
233+
if !slices.Contains(found, link) {
234+
continue
235+
}
236+
}
237+
238+
// Have we already seen this target?
239+
fullPath := filepath.Join(filepath.Dir(target), fi.Name())
240+
if slices.Contains(found, fullPath) {
241+
continue
242+
}
243+
244+
found = append(found, filepath.Join(filepath.Dir(target), fi.Name()))
245+
foundThisTime = true
246+
}
247+
// If we didn't find any symlinks this time, we're done.
248+
if !foundThisTime {
249+
break
250+
}
251+
}
252+
return found, nil
253+
}
254+
179255
// TryUnmountProcGPUDrivers unmounts any GPU-related mounts under /proc as it causes
180256
// issues when creating any container in some cases. Errors encountered while
181257
// unmounting are treated as non-fatal.

xunix/gpu_test.go

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,13 @@ package xunix_test
22

33
import (
44
"context"
5+
"os"
56
"path/filepath"
7+
"sort"
68
"testing"
79

810
"github.com/spf13/afero"
11+
"github.com/stretchr/testify/assert"
912
"github.com/stretchr/testify/require"
1013
"k8s.io/mount-utils"
1114

@@ -112,3 +115,103 @@ func TestGPUs(t *testing.T) {
112115
}
113116
})
114117
}
118+
119+
func Test_SameDirSymlinks(t *testing.T) {
120+
t.Parallel()
121+
122+
var (
123+
ctx = context.Background()
124+
// We need to test with a real filesystem as the fake one doesn't
125+
// support creating symlinks.
126+
tmpDir = t.TempDir()
127+
// We do test with the interface though!
128+
afs = xunix.GetFS(ctx)
129+
)
130+
131+
// Create some files in the temporary directory.
132+
_, err := os.Create(filepath.Join(tmpDir, "file1.real"))
133+
require.NoError(t, err, "create file")
134+
_, err = os.Create(filepath.Join(tmpDir, "file2.real"))
135+
require.NoError(t, err, "create file2")
136+
_, err = os.Create(filepath.Join(tmpDir, "file3.real"))
137+
require.NoError(t, err, "create file3")
138+
_, err = os.Create(filepath.Join(tmpDir, "file4.real"))
139+
require.NoError(t, err, "create file4")
140+
141+
// Create a symlink to the file in the temporary directory.
142+
// This needs to be done by the real os package.
143+
err = os.Symlink(filepath.Join(tmpDir, "file1.real"), filepath.Join(tmpDir, "file1.link1"))
144+
require.NoError(t, err, "create first symlink")
145+
146+
// Create another symlink to the previous symlink.
147+
err = os.Symlink(filepath.Join(tmpDir, "file1.link1"), filepath.Join(tmpDir, "file1.link2"))
148+
require.NoError(t, err, "create second symlink")
149+
150+
// Create a symlink to a file outside of the temporary directory.
151+
err = os.MkdirAll(filepath.Join(tmpDir, "dir"), 0o755)
152+
require.NoError(t, err, "create dir")
153+
// Create a symlink from file2 to inside the dir.
154+
err = os.Symlink(filepath.Join(tmpDir, "file2.real"), filepath.Join(tmpDir, "dir", "file2.link1"))
155+
require.NoError(t, err, "create dir symlink")
156+
157+
// Create a symlink with a relative path. To do this, we need to
158+
// change the working directory to the temporary directory.
159+
oldWorkingDir, err := os.Getwd()
160+
require.NoError(t, err, "get working dir")
161+
// Change the working directory to the temporary directory.
162+
require.NoError(t, os.Chdir(tmpDir), "change working dir")
163+
err = os.Symlink(filepath.Join(tmpDir, "file4.real"), "file4.link1")
164+
require.NoError(t, err, "create relative symlink")
165+
// Change the working directory back to the original.
166+
require.NoError(t, os.Chdir(oldWorkingDir), "change working dir back")
167+
168+
for _, tt := range []struct {
169+
name string
170+
expected []string
171+
}{
172+
{
173+
// Two symlinks to the same file.
174+
name: "file1.real",
175+
expected: []string{
176+
filepath.Join(tmpDir, "file1.link1"),
177+
filepath.Join(tmpDir, "file1.link2"),
178+
},
179+
},
180+
{
181+
// Mid-way in the symlink chain.
182+
name: "file1.link1",
183+
expected: []string{
184+
filepath.Join(tmpDir, "file1.link2"),
185+
},
186+
},
187+
{
188+
// End of the symlink chain.
189+
name: "file1.link2",
190+
expected: []string{},
191+
},
192+
{
193+
// Symlink to a file outside of the temporary directory.
194+
name: "file2.real",
195+
expected: []string{},
196+
},
197+
{
198+
// No symlinks to this file.
199+
name: "file3.real",
200+
expected: []string{},
201+
},
202+
{
203+
// One relative symlink.
204+
name: "file4.real",
205+
expected: []string{filepath.Join(tmpDir, "file4.link1")},
206+
},
207+
} {
208+
t.Run(tt.name, func(t *testing.T) {
209+
t.Parallel()
210+
fullPath := filepath.Join(tmpDir, tt.name)
211+
actual, err := xunix.SameDirSymlinks(afs, fullPath)
212+
require.NoError(t, err, "find symlink")
213+
sort.Strings(actual)
214+
assert.Equal(t, tt.expected, actual, "find symlinks %q", tt.name)
215+
})
216+
}
217+
}

0 commit comments

Comments
 (0)