Skip to content

Commit 030575e

Browse files
committed
cuda-modules: fix and clean up multiplex builder package selection logic
Signed-off-by: Connor Baker <[email protected]>
1 parent 7109c07 commit 030575e

File tree

1 file changed

+50
-54
lines changed

1 file changed

+50
-54
lines changed

pkgs/development/cuda-modules/generic-builders/multiplex.nix

Lines changed: 50 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -30,16 +30,7 @@
3030
shimsFn ? (throw "shimsFn must be provided"),
3131
}:
3232
let
33-
inherit (lib)
34-
attrsets
35-
lists
36-
modules
37-
strings
38-
;
39-
40-
inherit (stdenv) hostPlatform;
41-
42-
evaluatedModules = modules.evalModules {
33+
evaluatedModules = lib.modules.evalModules {
4334
modules = [
4435
../modules
4536
releasesModule
@@ -50,49 +41,55 @@ let
5041
# - Releases: ../modules/${pname}/releases/releases.nix
5142
# - Package: ../modules/${pname}/releases/package.nix
5243

53-
# FIXME: do this at the module system level
54-
propagatePlatforms = lib.mapAttrs (
55-
redistArch: packages: map (p: { inherit redistArch; } // p) packages
56-
);
44+
# redistArch :: String
45+
# Value is `"unsupported"` if the platform is not supported.
46+
redistArch = flags.getRedistArch stdenv.hostPlatform.system;
5747

58-
# All releases across all platforms
48+
# Check whether a package supports our CUDA version.
49+
# satisfiesCudaVersion :: Package -> Bool
50+
satisfiesCudaVersion =
51+
package:
52+
lib.versionAtLeast cudaMajorMinorVersion package.minCudaVersion
53+
&& lib.versionAtLeast package.maxCudaVersion cudaMajorMinorVersion;
54+
55+
# Releases for our platform and CUDA version.
5956
# See ../modules/${pname}/releases/releases.nix
60-
releaseSets = propagatePlatforms evaluatedModules.config.${pname}.releases;
57+
# allPackages :: List Package
58+
allPackages = lib.filter satisfiesCudaVersion (
59+
evaluatedModules.config.${pname}.releases.${redistArch} or [ ]
60+
);
6161

6262
# Compute versioned attribute name to be used in this package set
6363
# Patch version changes should not break the build, so we only use major and minor
6464
# computeName :: Package -> String
65-
computeName = { version, ... }: mkVersionedPackageName pname version;
66-
67-
# Check whether a package supports our CUDA version and platform.
68-
# isSupported :: Package -> Bool
69-
isSupported =
70-
package:
71-
redistArch == package.redistArch
72-
&& strings.versionAtLeast cudaMajorMinorVersion package.minCudaVersion
73-
&& strings.versionAtLeast package.maxCudaVersion cudaMajorMinorVersion;
65+
computeName = package: mkVersionedPackageName pname package.version;
7466

75-
# Get all of the packages for our given platform.
76-
# redistArch :: String
77-
# Value is `"unsupported"` if the platform is not supported.
78-
redistArch = flags.getRedistArch hostPlatform.system;
79-
80-
preferable =
81-
p1: p2: (isSupported p2 -> isSupported p1) && (strings.versionOlder p2.version p1.version);
82-
83-
# All the supported packages we can build for our platform.
84-
# perSystemReleases :: List Package
85-
allReleases = lib.pipe releaseSets [
86-
(lib.attrValues)
87-
(lists.flatten)
88-
(lib.groupBy (p: lib.versions.majorMinor p.version))
89-
(lib.mapAttrs (_: builtins.sort preferable))
90-
(lib.mapAttrs (_: lib.take 1))
91-
(lib.attrValues)
92-
(lib.concatMap lib.trivial.id)
93-
];
94-
95-
newest = builtins.head (builtins.sort preferable allReleases);
67+
# The newest package for each major-minor version, with newest first.
68+
# newestPackages :: List Package
69+
newestPackages =
70+
let
71+
newestForEachMajorMinorVersion = lib.foldl' (
72+
newestPackages: package:
73+
let
74+
majorMinorVersion = lib.versions.majorMinor package.version;
75+
existingPackage = newestPackages.${majorMinorVersion} or null;
76+
in
77+
newestPackages
78+
// {
79+
${majorMinorVersion} =
80+
# Only keep the existing package if it is newer than the one we are considering.
81+
if existingPackage != null && lib.versionOlder package.version existingPackage.version then
82+
existingPackage
83+
else
84+
package;
85+
}
86+
) { } allPackages;
87+
in
88+
# Sort the packages by version so the newest is first.
89+
# NOTE: builtins.sort requires a strict weak ordering, so we must use versionOlder rather than versionAtLeast.
90+
lib.sort (p1: p2: lib.versionOlder p2.version p1.version) (
91+
lib.attrValues newestForEachMajorMinorVersion
92+
);
9693

9794
extension =
9895
final: _:
@@ -102,25 +99,24 @@ let
10299
buildPackage =
103100
package:
104101
let
105-
shims = final.callPackage shimsFn {
106-
inherit package;
107-
inherit (package) redistArch;
108-
};
102+
shims = final.callPackage shimsFn { inherit package redistArch; };
109103
name = computeName package;
110104
drv = final.callPackage ./manifest.nix {
111105
inherit pname redistName;
112106
inherit (shims) redistribRelease featureRelease;
113107
};
114108
in
115-
attrsets.nameValuePair name drv;
109+
lib.nameValuePair name drv;
116110

117111
# versionedDerivations :: AttrSet Derivation
118-
versionedDerivations = builtins.listToAttrs (lists.map buildPackage allReleases);
112+
versionedDerivations = builtins.listToAttrs (lib.map buildPackage newestPackages);
119113

120114
defaultDerivation = {
121-
${pname} = (buildPackage newest).value;
115+
${pname} = (buildPackage (lib.head newestPackages)).value;
122116
};
123117
in
124-
versionedDerivations // defaultDerivation;
118+
# NOTE: Must condition on the length of newestPackages to avoid non-total function lib.head aborting if
119+
# newestPackages is empty.
120+
lib.optionalAttrs (lib.length newestPackages > 0) (versionedDerivations // defaultDerivation);
125121
in
126122
extension

0 commit comments

Comments
 (0)