-
Notifications
You must be signed in to change notification settings - Fork 9
Expand file tree
/
Copy pathflake.nix
More file actions
129 lines (114 loc) · 3.46 KB
/
flake.nix
File metadata and controls
129 lines (114 loc) · 3.46 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
{
description = "PSI development shells";
inputs = {
self.submodules = true;
nixpkgs.url = "github:NixOS/nixpkgs/nixos-unstable";
simple = {
url = "path:./third_party/SIMPLE";
inputs.nixpkgs.follows = "nixpkgs";
};
};
outputs = {
self,
nixpkgs,
simple,
}:
let
system = "x86_64-linux";
pkgs = import nixpkgs {
inherit system;
config = {
allowUnfree = true;
cudaSupport = true;
};
};
lib = pkgs.lib;
runtimeLib = import ./nix/lib/runtime.nix {
inherit lib pkgs;
};
psiBaseRuntime = import ./nix/runtime-base.nix { inherit pkgs; };
hostGpuRuntime = import ./nix/runtime-host-gpu.nix { inherit pkgs; };
simpleBaseRuntime = simple.lib.${system}.baseRuntime;
commonPackages = [
pkgs.bashInteractive
pkgs.coreutils
pkgs.curl
pkgs.git
pkgs.git-lfs
pkgs.uv
pkgs.wget
];
commonShellHook = ''
if [ -f "$HOME/.env" ]; then
set -a
. "$HOME/.env"
set +a
fi
_psi_detect_torch_cuda_arch_list() {
if [ -n "''${TORCH_CUDA_ARCH_LIST:-}" ]; then
return 0
fi
if ! command -v nvidia-smi >/dev/null 2>&1; then
return 1
fi
_psi_compute_caps="$(
nvidia-smi --query-gpu=compute_cap --format=csv,noheader,nounits 2>/dev/null \
| sed 's/^[[:space:]]*//; s/[[:space:]]*$//' \
| sed '/^$/d'
)"
if [ -z "$_psi_compute_caps" ]; then
return 1
fi
_psi_compute_cap_count="$(printf '%s\n' "$_psi_compute_caps" | sort -u | wc -l)"
if [ "$_psi_compute_cap_count" -ne 1 ]; then
return 1
fi
_psi_compute_cap="$(printf '%s\n' "$_psi_compute_caps" | head -n 1)"
if [ -n "$_psi_compute_cap" ]; then
export TORCH_CUDA_ARCH_LIST="$_psi_compute_cap+PTX"
return 0
fi
return 1
}
_psi_detect_torch_cuda_arch_list || true
unset -f _psi_detect_torch_cuda_arch_list
'';
in {
devShells.${system} = {
default = runtimeLib.mkRuntimeShell {
name = "psi+simple";
runtimes = [ psiBaseRuntime simpleBaseRuntime hostGpuRuntime ];
extraPackages = commonPackages;
extraShellHook = ''
${commonShellHook}
echo "SIMPLE runtime composed into root shell"
'';
stdenv = pkgs.gcc13Stdenv;
};
integrated = runtimeLib.mkRuntimeShell {
name = "psi+simple";
runtimes = [ psiBaseRuntime simpleBaseRuntime hostGpuRuntime ];
extraPackages = commonPackages;
extraShellHook = ''
${commonShellHook}
echo "SIMPLE runtime composed into root shell"
'';
stdenv = pkgs.gcc13Stdenv;
};
psi = runtimeLib.mkRuntimeShell {
name = "psi";
runtimes = [ psiBaseRuntime hostGpuRuntime ];
extraPackages = commonPackages;
extraShellHook = commonShellHook;
stdenv = pkgs.gcc13Stdenv;
};
simple = runtimeLib.mkRuntimeShell {
name = "simple";
runtimes = [ simpleBaseRuntime hostGpuRuntime ];
extraPackages = commonPackages;
extraShellHook = commonShellHook;
stdenv = pkgs.gcc13Stdenv;
};
};
};
}