Skip to content

Commit dc46be1

Browse files
fletcherwndbaker1
andauthored
feat(nodeadm): pass nvidia gpu startup labels to kubelet (#2607)
* feat(nodeadm): pass nvidia gpu startup labels to kubelet * docs: add feature gate and version-specific behavior table * tweak testing e2e test is failing in CI and passing locally. Replace it with a test using a faked file system. --------- Co-authored-by: Nick Baker <nbakerd@amazon.com>
1 parent 42bafe1 commit dc46be1

File tree

5 files changed

+148
-1
lines changed

5 files changed

+148
-1
lines changed

nodeadm/doc/api-concepts.md

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,18 @@ There are three levels of stability and support:
1818

1919
### Stable
2020
- Example: `v5`.
21-
- Support for a stable API will align with the support of a major version of Amazon Linux.
21+
- Support for a stable API will align with the support of a major version of Amazon Linux.
22+
23+
## Feature Gates
24+
25+
|Name|Default|Since|Until|
26+
|---|---|---|---|
27+
|`InstanceIdNodeName`|`false`|-|-|
28+
29+
## Behaviorial Boundaries
30+
31+
|Description|Since|Until|
32+
|---|---|---|
33+
|Apply the `nvidia.com/gpu.present=true` node label on startup if NVIDIA devices are detected on the instance|1.35|-|
34+
|Enable CDI in the default containerd configuration|1.32|-|
35+
|Write user-provided kubelet config as a kubelet drop-in configuration file|1.29|-|

nodeadm/internal/kubelet/config.go

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,26 @@ func (ksc *kubeletConfig) withOutpostSetup(cfg *api.NodeConfig) error {
213213
return nil
214214
}
215215

216+
func (ksc *kubeletConfig) withNodeLabels(flags map[string]string, nodeLabelFuncs map[string]LabelProvider) {
217+
var nodeLabels []string
218+
for nodeLabelKey, provider := range nodeLabelFuncs {
219+
nodeLabelValue, ok, err := provider.Get()
220+
if err != nil {
221+
zap.L().Error("Failed to get node label value", zap.String("key", nodeLabelKey), zap.Error(err))
222+
continue
223+
}
224+
if !ok {
225+
continue
226+
}
227+
nodeLabel := fmt.Sprintf("%s=%s", nodeLabelKey, nodeLabelValue)
228+
zap.L().Info("Adding node label", zap.String("label", nodeLabel))
229+
nodeLabels = append(nodeLabels, nodeLabel)
230+
}
231+
if len(nodeLabels) > 0 {
232+
flags["node-labels"] = strings.Join(nodeLabels, ",")
233+
}
234+
}
235+
216236
func (ksc *kubeletConfig) withNodeIp(cfg *api.NodeConfig, flags map[string]string) error {
217237
nodeIp, err := getNodeIp(context.TODO(), cfg, imds.DefaultClient())
218238
if err != nil {
@@ -321,6 +341,13 @@ func (k *kubelet) GenerateKubeletConfig(cfg *api.NodeConfig) (*kubeletConfig, er
321341
kubeletConfig.withDefaultReservedResources(cfg, k.resources)
322342
kubeletConfig.withImageServiceEndpoint(cfg, k.resources)
323343

344+
nodeLabelFuncs := map[string]LabelProvider{}
345+
if semver.Compare(cfg.Status.KubeletVersion, "v1.35.0") >= 0 {
346+
// see: https://github.com/NVIDIA/gpu-operator/commit/e25291b86cf4542ac62d8635cda4bd653c4face3
347+
nodeLabelFuncs["nvidia.com/gpu.present"] = NvidiaGPULabel{fs: system.RealFileSystem{}}
348+
}
349+
kubeletConfig.withNodeLabels(k.flags, nodeLabelFuncs)
350+
324351
return &kubeletConfig, nil
325352
}
326353

nodeadm/internal/kubelet/labels.go

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
package kubelet
2+
3+
import (
4+
"github.com/awslabs/amazon-eks-ami/nodeadm/internal/system"
5+
)
6+
7+
type LabelProvider interface {
8+
Get() (string, bool, error)
9+
}
10+
11+
type NvidiaGPULabel struct {
12+
fs system.FileSystem
13+
}
14+
15+
func (n NvidiaGPULabel) Get() (string, bool, error) {
16+
ok, err := system.IsPCIVendorAttached(n.fs, system.NVIDIA_VENDOR_ID)
17+
if err != nil {
18+
return "", false, err
19+
}
20+
if !ok {
21+
return "", false, nil
22+
}
23+
return "true", true, nil
24+
}
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
package kubelet
2+
3+
import (
4+
"testing"
5+
6+
"github.com/awslabs/amazon-eks-ami/nodeadm/internal/system"
7+
"github.com/stretchr/testify/assert"
8+
)
9+
10+
func TestNvidiaGPULabel(t *testing.T) {
11+
tests := []struct {
12+
name string
13+
files map[string]string
14+
expectedValue string
15+
expectedOk bool
16+
}{
17+
{
18+
name: "nvidia gpu present",
19+
files: map[string]string{
20+
"/sys/bus/pci/devices/0000:00:1e.0/vendor": "0x10de",
21+
},
22+
expectedValue: "true",
23+
expectedOk: true,
24+
},
25+
{
26+
name: "no nvidia gpu",
27+
files: map[string]string{
28+
"/sys/bus/pci/devices/0000:00:1e.0/vendor": "0x1234",
29+
},
30+
expectedValue: "",
31+
expectedOk: false,
32+
},
33+
{
34+
name: "no files at all",
35+
files: map[string]string{
36+
"/sys/bus/pci/devices/": system.EmptyDirectoryMarker,
37+
},
38+
expectedValue: "",
39+
expectedOk: false,
40+
},
41+
}
42+
43+
for _, tt := range tests {
44+
t.Run(tt.name, func(t *testing.T) {
45+
label := NvidiaGPULabel{fs: system.FakeFileSystem{Files: tt.files}}
46+
value, ok, err := label.Get()
47+
assert.NoError(t, err)
48+
assert.Equal(t, tt.expectedValue, value)
49+
assert.Equal(t, tt.expectedOk, ok)
50+
})
51+
}
52+
}

nodeadm/internal/system/devices.go

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
package system
2+
3+
import (
4+
"strings"
5+
6+
"go.uber.org/zap"
7+
)
8+
9+
const NVIDIA_VENDOR_ID = "0x10de"
10+
11+
// IsPCIVendorAttached returns whether any pcie devices with a given vendor id
12+
// are attached to the instance.
13+
func IsPCIVendorAttached(fs FileSystem, vendorId string) (bool, error) {
14+
vendorPaths, err := fs.Glob("/sys/bus/pci/devices/*/vendor")
15+
if err != nil {
16+
return false, err
17+
}
18+
for _, vendorPath := range vendorPaths {
19+
// #nosec G304 // read only operation on sysfs path
20+
vendorIdBytes, err := fs.ReadFile(vendorPath)
21+
if err != nil {
22+
zap.L().Warn("failed to read vendor id", zap.Error(err))
23+
continue
24+
}
25+
if strings.TrimSpace(string(vendorIdBytes)) == vendorId {
26+
return true, nil
27+
}
28+
}
29+
return false, nil
30+
}

0 commit comments

Comments
 (0)