Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
217 changes: 217 additions & 0 deletions internal/patchpkg/elf.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,217 @@
package patchpkg

import (
"debug/elf"
"errors"
"fmt"
"io"
"iter"
"log/slog"
"os"
"path/filepath"
"strings"
)

var (
// SystemLibSearchPaths match the system library paths for common Linux
// distributions.
SystemLibSearchPaths = []string{
"/lib*/*-linux-gnu", // Debian
"/lib*", // Red Hat
"/var/lib*/*/lib*", // Docker
}

// EnvLibrarySearchPath matches the paths in the LIBRARY_PATH
// environment variable.
EnvLibrarySearchPath = filepath.SplitList(os.Getenv("LIBRARY_PATH"))

// EnvLDLibrarySearchPath matches the paths in the LD_LIBRARY_PATH
// environment variable.
EnvLDLibrarySearchPath = filepath.SplitList(os.Getenv("LD_LIBRARY_PATH"))

// CUDALibSearchPaths match the common installation directories for CUDA
// libraries.
CUDALibSearchPaths = []string{
// Common non-package manager locations.
"/opt/cuda/lib*",
"/opt/nvidia/lib*",
"/usr/local/cuda/lib*",
"/usr/local/nvidia/lib*",

// Unlikely, but might as well try.
"/lib*/nvidia",
"/lib*/cuda",
"/usr/lib*/nvidia",
"/usr/lib*/cuda",
"/usr/local/lib*",
"/usr/local/lib*/nvidia",
"/usr/local/lib*/cuda",
}
)

// SharedLibrary describes an ELF shared library (object).
//
// Note that the various name fields document the common naming and versioning
// conventions, but it is possible for a library to deviate from them.
//
// For an introduction to Linux shared libraries, see
// https://tldp.org/HOWTO/Program-Library-HOWTO/shared-libraries.html
type SharedLibrary struct {
*os.File

// LinkerName is the soname without any version suffix (libfoo.so). It
// is typically a symlink pointing to Soname. The build-time linker
// looks for this name by default.
LinkerName string

// Soname is the shared object name from the library's DT_SONAME field.
// It usually includes a version number suffix (libfoo.so.1). Other ELF
// binaries that depend on this library typically specify this name in
// the DT_NEEDED field.
Soname string

// RealName is the absolute path to the file that actually contains the
// library code. It is typically the soname plus a minor version and
// release number (libfoo.so.1.0.0).
RealName string
}

// OpenSharedLibrary opens a shared library file. Unlike with ld, name must be
// an exact path. To search for a library in the usual locations, use
// [FindSharedLibrary] instead.
func OpenSharedLibrary(name string) (SharedLibrary, error) {
lib := SharedLibrary{}
var err error
lib.File, err = os.Open(name)
if err != nil {
return lib, err
}

dir, file := filepath.Split(name)
i := strings.Index(file, ".so")
if i != -1 {
lib.LinkerName = dir + file[:i+3]
}

elfFile, err := elf.NewFile(lib)
if err == nil {
soname, _ := elfFile.DynString(elf.DT_SONAME)
if len(soname) != 0 {
lib.Soname = soname[0]
}
}

real, err := filepath.EvalSymlinks(name)
if err == nil {
lib.RealName, _ = filepath.Abs(real)
}
return lib, nil
}

// FindSharedLibrary searches the directories in searchPath for a shared
// library. It yields any libraries in the search path directories that have
// name as a prefix. For example, "libcuda.so" will match "libcuda.so",
// "libcuda.so.1", and "libcuda.so.550.107.02". The underlying file is only
// valid for a single iteration, after which it is closed.
//
// The search path may contain [filepath.Glob] patterns. See
// [SystemLibSearchPaths] for some predefined search paths. If name is an
// absolute path, then FindSharedLibrary opens it directly and doesn't perform
// any searching.
func FindSharedLibrary(name string, searchPath ...string) iter.Seq[SharedLibrary] {
return func(yield func(SharedLibrary) bool) {
if filepath.IsAbs(name) {
lib, err := OpenSharedLibrary(name)
if err == nil {
yield(lib)
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

print some warning about the error?

return
}

if libPath := os.Getenv("LD_LIBRARY_PATH"); libPath != "" {
searchPath = append(searchPath, filepath.SplitList(os.Getenv("LD_LIBRARY_PATH"))...)
}
if libPath := os.Getenv("LIBRARY_PATH"); libPath != "" {
searchPath = append(searchPath, filepath.SplitList(libPath)...)
}
searchPath = append(searchPath,
"/lib*/*-linux-gnu", // Debian
"/lib*", // Red Hat
)

suffix := globEscape(name) + "*"
patterns := make([]string, len(searchPath))
for i := range searchPath {
patterns[i] = filepath.Join(searchPath[i], suffix)
}
for match := range searchGlobs(patterns) {
lib, err := OpenSharedLibrary(match)
if err != nil {
continue
}
ok := yield(lib)
_ = lib.Close()
if !ok {
return
}
}
}
}

// CopyAndLink copies the shared library to dir and creates the LinkerName and
// Soname symlinks for it. It creates dir if it doesn't already exist.
func (lib SharedLibrary) CopyAndLink(dir string) error {
err := os.MkdirAll(dir, 0o755)
if err != nil {
return err
}

dstPath := filepath.Join(dir, filepath.Base(lib.RealName))
dst, err := os.OpenFile(dstPath, os.O_CREATE|os.O_WRONLY|os.O_EXCL, 0o666)
if err != nil {
return err
}
defer dst.Close()

_, err = io.Copy(dst, lib)
if err != nil {
return err
}
err = dst.Close()
if err != nil {
return err
}

sonameLink := filepath.Join(dir, lib.Soname)
var sonameErr error
if lib.Soname != "" {
// Symlink must be relative.
sonameErr = os.Symlink(filepath.Base(lib.RealName), sonameLink)
}

linkerNameLink := filepath.Join(dir, lib.LinkerName)
var linkerNameErr error
if lib.LinkerName != "" {
// Symlink must be relative.
if sonameErr == nil {
linkerNameErr = os.Symlink(filepath.Base(sonameLink), linkerNameLink)
} else {
linkerNameErr = os.Symlink(filepath.Base(dstPath), linkerNameLink)
}
}

err = errors.Join(sonameErr, linkerNameErr)
if err != nil {
return fmt.Errorf("patchpkg: create symlinks for shared library: %w", err)
}
return nil
}

func (lib SharedLibrary) LogValue() slog.Value {
return slog.GroupValue(
slog.String("lib.path", lib.Name()),
slog.String("lib.linkername", lib.LinkerName),
slog.String("lib.soname", lib.Soname),
slog.String("lib.realname", lib.RealName),
)
}
114 changes: 32 additions & 82 deletions internal/patchpkg/search.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"regexp"
"strings"
"sync"
"unicode/utf8"
)

// maxFileSize limits the amount of data to load from a file when
Expand Down Expand Up @@ -83,96 +84,45 @@ func searchEnv(re *regexp.Regexp) string {
return ""
}

// SystemCUDALibraries returns an iterator over the system CUDA library paths.
// It yields them in priority order, where the first path is the most likely to
// be the correct version.
var SystemCUDALibraries iter.Seq[string] = func(yield func(string) bool) {
// Quick overview of Unix-like shared object versioning.
//
// Libraries have 3 different names (using libcuda as an example):
//
// 1. libcuda.so - the "linker name". Typically a symlink pointing to
// the soname. The compiler looks for this name.
// 2. libcuda.so.1 - the "soname". Typically a symlink pointing to the
// real name. The dynamic linker looks for this name.
// 3. libcuda.so.550.107.02 - the "real name". The actual ELF shared
// library. Usually never referred to directly because that would
// make versioning hard.
//
// Because we don't know what version of CUDA the user's program
// actually needs, we're going to try to find linker names (libcuda.so)
// and trust that the system is pointing it to the correct version.
// We'll fall back to sonames (libcuda.so.1) that we find if none of the
// linker names work.

// Common direct paths to try first.
linkerNames := []string{
"/usr/lib/x86_64-linux-gnu/libcuda.so", // Debian
"/usr/lib64/libcuda.so", // Red Hat
"/usr/lib/libcuda.so",
}
for _, path := range linkerNames {
// Return what the link is pointing to because the dynamic
// linker will want libcuda.so.1, not libcuda.so.
soname, err := os.Readlink(path)
if err != nil {
continue
}
if filepath.IsLocal(soname) {
soname = filepath.Join(filepath.Dir(path), soname)
}
if !yield(soname) {
return
}
}

// Directories to recursively search.
prefixes := []string{
"/usr/lib",
"/usr/lib64",
"/lib",
"/lib64",
"/usr/local/lib",
"/usr/local/lib64",
"/opt/cuda",
"/opt/nvidia",
"/usr/local/cuda",
"/usr/local/nvidia",
}
sonameRegex := regexp.MustCompile(`^libcuda\.so\.\d+$`)
var sonames []string
for _, path := range prefixes {
_ = filepath.WalkDir(path, func(path string, entry fs.DirEntry, err error) error {
// searchGlobs iterates over the paths matched by multiple [filepath.Glob]
// patterns. It will not yield a path more than once, even if the path matches
// multiple patterns. It silently ignores any pattern syntax errors.
func searchGlobs(patterns []string) iter.Seq[string] {
seen := make(map[string]bool, len(patterns))
return func(yield func(string) bool) {
for _, pattern := range patterns {
glob, err := filepath.Glob(pattern)
if err != nil {
return nil
continue
}
if entry.Name() == "libcuda.so" && isSymlink(entry.Type()) {
soname, err := os.Readlink(path)
if err != nil {
return nil
}
if filepath.IsLocal(soname) {
soname = filepath.Join(filepath.Dir(path), soname)
for _, match := range glob {
if seen[match] {
continue
}
if !yield(soname) {
return filepath.SkipAll
seen[match] = true

if !yield(match) {
return
}
}
}
}
}

// Save potential soname matches for later after we've
// exhausted all the potential linker names.
if sonameRegex.MatchString(entry.Name()) {
sonames = append(sonames, entry.Name())
}
return nil
})
// globEscape escapes all metacharacters ('*', '?', '\\', '[') in s so that they
// match their literal values in a [filepath.Glob] or [fs.Glob] pattern.
func globEscape(s string) string {
if !strings.ContainsAny(s, `*?\[`) {
return s
}

// We didn't find any symlinks named libcuda.so. Fall back to trying any
// sonames (e.g., libcuda.so.1) that we found.
for _, path := range sonames {
if !yield(path) {
return
b := make([]byte, 0, len(s)+1)
for _, r := range s {
switch r {
case '*', '?', '\\', '[':
b = append(b, '\\')
}
b = utf8.AppendRune(b, r)
}
return string(b)
}
Loading
Loading