Skip to content

Commit 57312c0

Browse files
authored
patchpkg: patch python to use devbox CUDA libs (#2296)
Automatically patch python to use any `cudaPackages.*` packages that are in devbox.json. This will only work if the CUDA drivers are already installed on the host system. The patching process is: 1. When generating the patch flake, look for the system’s `libcuda.so` library (installed by the driver) and copy it into the flake’s directory. 2. Nix copies the flake’s source directory (and therefore libcuda.so) into the nix store when building it. 3. The flake calls `devbox patch` which adds a `DT_NEEDED` entry to the python binary for `libcuda.so`. It also adds the lib directories of any other `cudaPackages.*` packages that it finds in the environment.
1 parent 58ed80e commit 57312c0

File tree

6 files changed

+692
-30
lines changed

6 files changed

+692
-30
lines changed

internal/devpkg/package.go

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -171,9 +171,7 @@ func resolve(pkg *Package) error {
171171
if err != nil {
172172
return err
173173
}
174-
175-
// TODO savil. Check with Greg about setting the user-specified outputs
176-
// somehow here.
174+
parsed.Outputs = strings.Join(pkg.outputs.selectedNames, ",")
177175

178176
pkg.setInstallable(parsed, pkg.lockfile.ProjectDir())
179177
return nil
@@ -308,7 +306,10 @@ func (p *Package) InstallableForOutput(output string) (string, error) {
308306
// a valid flake reference parsable by ParseFlakeRef, optionally followed by an
309307
// #attrpath and/or an ^output.
310308
func (p *Package) FlakeInstallable() (flake.Installable, error) {
311-
return flake.ParseInstallable(p.Raw)
309+
if err := p.resolve(); err != nil {
310+
return flake.Installable{}, err
311+
}
312+
return p.installable, nil
312313
}
313314

314315
// urlForInstall is used during `nix profile install`.
@@ -322,15 +323,16 @@ func (p *Package) urlForInstall() (string, error) {
322323
}
323324

324325
func (p *Package) NormalizedDevboxPackageReference() (string, error) {
325-
if err := p.resolve(); err != nil {
326+
installable, err := p.FlakeInstallable()
327+
if err != nil {
326328
return "", err
327329
}
328-
if p.installable.AttrPath == "" {
330+
if installable.AttrPath == "" {
329331
return "", nil
330332
}
331-
clone := p.installable
332-
clone.AttrPath = fmt.Sprintf("legacyPackages.%s.%s", nix.System(), clone.AttrPath)
333-
return clone.String(), nil
333+
installable.AttrPath = fmt.Sprintf("legacyPackages.%s.%s", nix.System(), installable.AttrPath)
334+
installable.Outputs = ""
335+
return installable.String(), nil
334336
}
335337

336338
// PackageAttributePath returns the short attribute path for a package which
@@ -376,19 +378,19 @@ func (p *Package) NormalizedPackageAttributePath() (string, error) {
376378
// normalizePackageAttributePath calls nix search to find the normalized attribute
377379
// path. It may be an expensive call (~100ms).
378380
func (p *Package) normalizePackageAttributePath() (string, error) {
379-
if err := p.resolve(); err != nil {
381+
installable, err := p.FlakeInstallable()
382+
if err != nil {
380383
return "", err
381384
}
382-
383-
query := p.installable.String()
385+
installable.Outputs = ""
386+
query := installable.String()
384387
if query == "" {
385388
query = p.Raw
386389
}
387390

388391
// We prefer nix.Search over just trying to parse the package's "URL" because
389392
// nix.Search will guarantee that the package exists for the current system.
390393
var infos map[string]*nix.Info
391-
var err error
392394
if p.IsDevboxPackage && !p.IsRunX() {
393395
// Perf optimization: For queries of the form nixpkgs/<commit>#foo, we can
394396
// use a nix.Search cache.

internal/patchpkg/builder.go

Lines changed: 84 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ import (
1717
"path"
1818
"path/filepath"
1919
"regexp"
20+
"strings"
2021
)
2122

2223
//go:embed glibc-patch.bash
@@ -42,6 +43,10 @@ type DerivationBuilder struct {
4243

4344
RestoreRefs bool
4445
bytePatches map[string][]fileSlice
46+
47+
// src contains the source files of the derivation. For flakes, this is
48+
// anything in the flake.nix directory.
49+
src *packageFS
4550
}
4651

4752
// NewDerivationBuilder initializes a new DerivationBuilder from the current
@@ -79,6 +84,9 @@ func (d *DerivationBuilder) init() error {
7984
return fmt.Errorf("patchpkg: can't patch gcc using %s: %v", d.Gcc, err)
8085
}
8186
}
87+
if src := os.Getenv("src"); src != "" {
88+
d.src = newPackageFS(src)
89+
}
8290
return nil
8391
}
8492

@@ -95,6 +103,11 @@ func (d *DerivationBuilder) Build(ctx context.Context, pkgStorePath string) erro
95103
}
96104

97105
func (d *DerivationBuilder) build(ctx context.Context, pkg, out *packageFS) error {
106+
// Create the derivation's $out directory.
107+
if err := d.copyDir(out, "."); err != nil {
108+
return err
109+
}
110+
98111
if d.RestoreRefs {
99112
if err := d.restoreMissingRefs(ctx, pkg); err != nil {
100113
// Don't break the flake build if we're unable to
@@ -103,12 +116,19 @@ func (d *DerivationBuilder) build(ctx context.Context, pkg, out *packageFS) erro
103116
slog.ErrorContext(ctx, "unable to restore all removed refs", "err", err)
104117
}
105118
}
119+
if err := d.findCUDA(ctx, out); err != nil {
120+
slog.ErrorContext(ctx, "unable to patch CUDA libraries", "err", err)
121+
}
106122

107123
var err error
108124
for path, entry := range allFiles(pkg, ".") {
109125
if ctx.Err() != nil {
110126
return ctx.Err()
111127
}
128+
if path == "." {
129+
// Skip the $out directory - we already created it.
130+
continue
131+
}
112132

113133
switch {
114134
case entry.IsDir():
@@ -167,7 +187,7 @@ func (d *DerivationBuilder) copyDir(out *packageFS, path string) error {
167187
if err != nil {
168188
return err
169189
}
170-
return os.Mkdir(path, 0o777)
190+
return os.MkdirAll(path, 0o777)
171191
}
172192

173193
func (d *DerivationBuilder) copyFile(ctx context.Context, pkg, out *packageFS, path string) error {
@@ -302,6 +322,69 @@ func (d *DerivationBuilder) findRemovedRefs(ctx context.Context, pkg *packageFS)
302322
return refs, nil
303323
}
304324

325+
func (d *DerivationBuilder) findCUDA(ctx context.Context, out *packageFS) error {
326+
if d.src == nil {
327+
return fmt.Errorf("patch flake didn't set $src to the path to its source tree")
328+
}
329+
330+
glob, err := fs.Glob(d.src, "lib/libcuda.so*")
331+
if err != nil {
332+
return fmt.Errorf("glob system libraries: %v", err)
333+
}
334+
if len(glob) != 0 {
335+
err := d.copyDir(out, "lib")
336+
if err != nil {
337+
return fmt.Errorf("copy system library: %v", err)
338+
}
339+
}
340+
for _, lib := range glob {
341+
slog.DebugContext(ctx, "found system CUDA library in flake", "path", lib)
342+
343+
err := d.copyFile(ctx, d.src, out, lib)
344+
if err != nil {
345+
return fmt.Errorf("copy system library: %v", err)
346+
}
347+
need, err := out.OSPath(lib)
348+
if err != nil {
349+
return fmt.Errorf("get absolute path to library: %v", err)
350+
}
351+
d.glibcPatcher.needed = append(d.glibcPatcher.needed, need)
352+
353+
slog.DebugContext(ctx, "added DT_NEEDED entry for system CUDA library", "path", need)
354+
}
355+
356+
slog.DebugContext(ctx, "looking for nix libraries in $patchDependencies")
357+
deps := os.Getenv("patchDependencies")
358+
if strings.TrimSpace(deps) == "" {
359+
slog.DebugContext(ctx, "$patchDependencies is empty")
360+
return nil
361+
}
362+
for _, pkg := range strings.Split(deps, " ") {
363+
slog.DebugContext(ctx, "checking for nix libraries in package", "pkg", pkg)
364+
365+
pkgFS := newPackageFS(pkg)
366+
libs, err := fs.Glob(pkgFS, "lib*/*.so*")
367+
if err != nil {
368+
return fmt.Errorf("glob nix package libraries: %v", err)
369+
}
370+
371+
sonameRegexp := regexp.MustCompile(`(^|/).+\.so\.\d+`)
372+
for _, lib := range libs {
373+
if !sonameRegexp.MatchString(lib) {
374+
continue
375+
}
376+
need, err := pkgFS.OSPath(lib)
377+
if err != nil {
378+
return fmt.Errorf("get absolute path to nix package library: %v", err)
379+
}
380+
d.glibcPatcher.needed = append(d.glibcPatcher.needed, need)
381+
382+
slog.DebugContext(ctx, "added DT_NEEDED entry for nix library", "path", need)
383+
}
384+
}
385+
return nil
386+
}
387+
305388
// packageFS is the tree of files for a package in the Nix store.
306389
type packageFS struct {
307390
fs.FS

internal/patchpkg/search.go

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@ import (
44
"fmt"
55
"io"
66
"io/fs"
7+
"iter"
78
"os"
9+
"path/filepath"
810
"regexp"
911
"strings"
1012
"sync"
@@ -80,3 +82,97 @@ func searchEnv(re *regexp.Regexp) string {
8082
}
8183
return ""
8284
}
85+
86+
// SystemCUDALibraries returns an iterator over the system CUDA library paths.
87+
// It yields them in priority order, where the first path is the most likely to
88+
// be the correct version.
89+
var SystemCUDALibraries iter.Seq[string] = func(yield func(string) bool) {
90+
// Quick overview of Unix-like shared object versioning.
91+
//
92+
// Libraries have 3 different names (using libcuda as an example):
93+
//
94+
// 1. libcuda.so - the "linker name". Typically a symlink pointing to
95+
// the soname. The compiler looks for this name.
96+
// 2. libcuda.so.1 - the "soname". Typically a symlink pointing to the
97+
// real name. The dynamic linker looks for this name.
98+
// 3. libcuda.so.550.107.02 - the "real name". The actual ELF shared
99+
// library. Usually never referred to directly because that would
100+
// make versioning hard.
101+
//
102+
// Because we don't know what version of CUDA the user's program
103+
// actually needs, we're going to try to find linker names (libcuda.so)
104+
// and trust that the system is pointing it to the correct version.
105+
// We'll fall back to sonames (libcuda.so.1) that we find if none of the
106+
// linker names work.
107+
108+
// Common direct paths to try first.
109+
linkerNames := []string{
110+
"/usr/lib/x86_64-linux-gnu/libcuda.so", // Debian
111+
"/usr/lib64/libcuda.so", // Red Hat
112+
"/usr/lib/libcuda.so",
113+
}
114+
for _, path := range linkerNames {
115+
// Return what the link is pointing to because the dynamic
116+
// linker will want libcuda.so.1, not libcuda.so.
117+
soname, err := os.Readlink(path)
118+
if err != nil {
119+
continue
120+
}
121+
if filepath.IsLocal(soname) {
122+
soname = filepath.Join(filepath.Dir(path), soname)
123+
}
124+
if !yield(soname) {
125+
return
126+
}
127+
}
128+
129+
// Directories to recursively search.
130+
prefixes := []string{
131+
"/usr/lib",
132+
"/usr/lib64",
133+
"/lib",
134+
"/lib64",
135+
"/usr/local/lib",
136+
"/usr/local/lib64",
137+
"/opt/cuda",
138+
"/opt/nvidia",
139+
"/usr/local/cuda",
140+
"/usr/local/nvidia",
141+
}
142+
sonameRegex := regexp.MustCompile(`^libcuda\.so\.\d+$`)
143+
var sonames []string
144+
for _, path := range prefixes {
145+
_ = filepath.WalkDir(path, func(path string, entry fs.DirEntry, err error) error {
146+
if err != nil {
147+
return nil
148+
}
149+
if entry.Name() == "libcuda.so" && isSymlink(entry.Type()) {
150+
soname, err := os.Readlink(path)
151+
if err != nil {
152+
return nil
153+
}
154+
if filepath.IsLocal(soname) {
155+
soname = filepath.Join(filepath.Dir(path), soname)
156+
}
157+
if !yield(soname) {
158+
return filepath.SkipAll
159+
}
160+
}
161+
162+
// Save potential soname matches for later after we've
163+
// exhausted all the potential linker names.
164+
if sonameRegex.MatchString(entry.Name()) {
165+
sonames = append(sonames, entry.Name())
166+
}
167+
return nil
168+
})
169+
}
170+
171+
// We didn't find any symlinks named libcuda.so. Fall back to trying any
172+
// sonames (e.g., libcuda.so.1) that we found.
173+
for _, path := range sonames {
174+
if !yield(path) {
175+
return
176+
}
177+
}
178+
}

0 commit comments

Comments
 (0)