Skip to content

Commit e839205

Browse files
committed
Add validating admission webhook to verify opaque configs
This is a straight copy of this PR kubernetes-sigs/dra-example-driver#75, with minimal changes to make it work in this repo. Signed-off-by: Kevin Klues <kklues@nvidia.com>
1 parent 20db628 commit e839205

21 files changed

+3235
-0
lines changed

README.md

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,22 @@ As of today, the recommended installation method is via Helm.
4646
Detailed instructions can (for now) be found [here](https://github.com/NVIDIA/k8s-dra-driver-gpu/discussions/249).
4747
In the future, this driver will be included in the [NVIDIA GPU Operator](https://github.com/NVIDIA/gpu-operator) and does not need to be installed separately anymore.
4848

49+
### Validating Admission Webhook
50+
51+
The validating admission webhook is disabled by default. To enable it, install cert-manager and its CRDs, then set the `webhook.enabled=true` value when the nvidia-dra-driver-gpu chart is installed.
52+
53+
```bash
54+
helm install \
55+
--repo https://charts.jetstack.io \
56+
--version v1.16.3 \
57+
--create-namespace \
58+
--namespace cert-manager \
59+
--wait \
60+
--set crds.enabled=true \
61+
cert-manager \
62+
cert-manager
63+
```
64+
4965
## A (kind) demo
5066

5167
Below, we demonstrate a basic use case: sharing a single GPU across two containers running in the same Kubernetes pod.

cmd/webhook/main.go

Lines changed: 311 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,311 @@
1+
/*
2+
Copyright 2025 The Kubernetes Authors.
3+
Copyright 2025 NVIDIA Corporation.
4+
5+
Licensed under the Apache License, Version 2.0 (the "License");
6+
you may not use this file except in compliance with the License.
7+
You may obtain a copy of the License at
8+
9+
http://www.apache.org/licenses/LICENSE-2.0
10+
11+
Unless required by applicable law or agreed to in writing, software
12+
distributed under the License is distributed on an "AS IS" BASIS,
13+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
See the License for the specific language governing permissions and
15+
limitations under the License.
16+
*/
17+
18+
package main
19+
20+
import (
21+
"encoding/json"
22+
"fmt"
23+
"io"
24+
"net/http"
25+
"os"
26+
"strings"
27+
28+
"github.com/urfave/cli/v2"
29+
30+
admissionv1 "k8s.io/api/admission/v1"
31+
resourceapi "k8s.io/api/resource/v1beta1"
32+
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
33+
"k8s.io/apimachinery/pkg/runtime"
34+
"k8s.io/apimachinery/pkg/runtime/serializer"
35+
utilruntime "k8s.io/apimachinery/pkg/util/runtime"
36+
"k8s.io/klog/v2"
37+
38+
nvapi "github.com/NVIDIA/k8s-dra-driver-gpu/api/nvidia.com/resource/v1beta1"
39+
"github.com/NVIDIA/k8s-dra-driver-gpu/pkg/flags"
40+
)
41+
42+
const (
43+
DriverName = "gpu.nvidia.com"
44+
)
45+
46+
var (
47+
resourceClaimResource = metav1.GroupVersionResource{
48+
Group: resourceapi.SchemeGroupVersion.Group,
49+
Version: resourceapi.SchemeGroupVersion.Version,
50+
Resource: "resourceclaims",
51+
}
52+
resourceClaimTemplateResource = metav1.GroupVersionResource{
53+
Group: resourceapi.SchemeGroupVersion.Group,
54+
Version: resourceapi.SchemeGroupVersion.Version,
55+
Resource: "resourceclaimtemplates",
56+
}
57+
)
58+
59+
type Flags struct {
60+
loggingConfig *flags.LoggingConfig
61+
featureGateConfig *flags.FeatureGateConfig
62+
63+
certFile string
64+
keyFile string
65+
port int
66+
}
67+
68+
var scheme = runtime.NewScheme()
69+
var codecs = serializer.NewCodecFactory(scheme)
70+
71+
func init() {
72+
utilruntime.Must(admissionv1.AddToScheme(scheme))
73+
}
74+
75+
func main() {
76+
if err := newApp().Run(os.Args); err != nil {
77+
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
78+
os.Exit(1)
79+
}
80+
}
81+
82+
func newApp() *cli.App {
83+
flags := &Flags{
84+
loggingConfig: flags.NewLoggingConfig(),
85+
featureGateConfig: flags.NewFeatureGateConfig(),
86+
}
87+
cliFlags := []cli.Flag{
88+
&cli.StringFlag{
89+
Name: "tls-cert-file",
90+
Usage: "File containing the default x509 Certificate for HTTPS. (CA cert, if any, concatenated after server cert).",
91+
Destination: &flags.certFile,
92+
Required: true,
93+
},
94+
&cli.StringFlag{
95+
Name: "tls-private-key-file",
96+
Usage: "File containing the default x509 private key matching --tls-cert-file.",
97+
Destination: &flags.keyFile,
98+
Required: true,
99+
},
100+
&cli.IntFlag{
101+
Name: "port",
102+
Usage: "Secure port that the webhook listens on",
103+
Value: 443,
104+
Destination: &flags.port,
105+
},
106+
}
107+
cliFlags = append(cliFlags, flags.loggingConfig.Flags()...)
108+
cliFlags = append(cliFlags, flags.featureGateConfig.Flags()...)
109+
110+
app := &cli.App{
111+
Name: "webhook",
112+
Usage: "webhook implements a validating admission webhook complementing a DRA driver plugin.",
113+
ArgsUsage: " ",
114+
HideHelpCommand: true,
115+
Flags: cliFlags,
116+
Before: func(c *cli.Context) error {
117+
if c.Args().Len() > 0 {
118+
return fmt.Errorf("arguments not supported: %v", c.Args().Slice())
119+
}
120+
return flags.loggingConfig.Apply()
121+
},
122+
Action: func(c *cli.Context) error {
123+
server := &http.Server{
124+
Handler: newMux(),
125+
Addr: fmt.Sprintf(":%d", flags.port),
126+
}
127+
klog.Info("starting webhook server on", server.Addr)
128+
return server.ListenAndServeTLS(flags.certFile, flags.keyFile)
129+
},
130+
}
131+
132+
return app
133+
}
134+
135+
func newMux() *http.ServeMux {
136+
mux := http.NewServeMux()
137+
mux.HandleFunc("/validate-resource-claim-parameters", serveResourceClaim)
138+
mux.HandleFunc("/readyz", func(w http.ResponseWriter, req *http.Request) {
139+
_, err := w.Write([]byte("ok"))
140+
if err != nil {
141+
http.Error(w, err.Error(), http.StatusInternalServerError)
142+
return
143+
}
144+
})
145+
return mux
146+
}
147+
148+
func serveResourceClaim(w http.ResponseWriter, r *http.Request) {
149+
serve(w, r, admitResourceClaimParameters)
150+
}
151+
152+
// serve handles the http portion of a request prior to handing to an admit
153+
// function.
154+
func serve(w http.ResponseWriter, r *http.Request, admit func(admissionv1.AdmissionReview) *admissionv1.AdmissionResponse) {
155+
var body []byte
156+
if r.Body != nil {
157+
data, err := io.ReadAll(r.Body)
158+
if err != nil {
159+
klog.Error(err)
160+
http.Error(w, err.Error(), http.StatusInternalServerError)
161+
return
162+
}
163+
body = data
164+
}
165+
166+
// verify the content type is accurate
167+
contentType := r.Header.Get("Content-Type")
168+
if contentType != "application/json" {
169+
msg := fmt.Sprintf("contentType=%s, expected application/json", contentType)
170+
klog.Error(msg)
171+
http.Error(w, msg, http.StatusUnsupportedMediaType)
172+
return
173+
}
174+
175+
klog.V(2).Infof("handling request: %s", body)
176+
177+
requestedAdmissionReview, err := readAdmissionReview(body)
178+
if err != nil {
179+
msg := fmt.Sprintf("failed to read AdmissionReview from request body: %v", err)
180+
klog.Error(msg)
181+
http.Error(w, msg, http.StatusBadRequest)
182+
return
183+
}
184+
responseAdmissionReview := &admissionv1.AdmissionReview{}
185+
responseAdmissionReview.SetGroupVersionKind(requestedAdmissionReview.GroupVersionKind())
186+
responseAdmissionReview.Response = admit(*requestedAdmissionReview)
187+
responseAdmissionReview.Response.UID = requestedAdmissionReview.Request.UID
188+
189+
klog.V(2).Infof("sending response: %v", responseAdmissionReview)
190+
respBytes, err := json.Marshal(responseAdmissionReview)
191+
if err != nil {
192+
klog.Error(err)
193+
http.Error(w, err.Error(), http.StatusInternalServerError)
194+
return
195+
}
196+
w.Header().Set("Content-Type", "application/json")
197+
if _, err := w.Write(respBytes); err != nil {
198+
klog.Error(err)
199+
}
200+
}
201+
202+
func readAdmissionReview(data []byte) (*admissionv1.AdmissionReview, error) {
203+
deserializer := codecs.UniversalDeserializer()
204+
obj, gvk, err := deserializer.Decode(data, nil, nil)
205+
if err != nil {
206+
return nil, fmt.Errorf("request could not be decoded: %w", err)
207+
}
208+
209+
if *gvk != admissionv1.SchemeGroupVersion.WithKind("AdmissionReview") {
210+
return nil, fmt.Errorf("unsupported group version kind: %v", gvk)
211+
}
212+
213+
requestedAdmissionReview, ok := obj.(*admissionv1.AdmissionReview)
214+
if !ok {
215+
return nil, fmt.Errorf("expected v1.AdmissionReview but got: %T", obj)
216+
}
217+
218+
return requestedAdmissionReview, nil
219+
}
220+
221+
// admitResourceClaimParameters accepts both ResourceClaims and ResourceClaimTemplates and validates their
222+
// opaque device configuration parameters for this driver.
223+
func admitResourceClaimParameters(ar admissionv1.AdmissionReview) *admissionv1.AdmissionResponse {
224+
klog.V(2).Info("admitting resource claim parameters")
225+
226+
var deviceConfigs []resourceapi.DeviceClaimConfiguration
227+
var specPath string
228+
229+
raw := ar.Request.Object.Raw
230+
deserializer := codecs.UniversalDeserializer()
231+
232+
switch ar.Request.Resource {
233+
case resourceClaimResource:
234+
claim := resourceapi.ResourceClaim{}
235+
if _, _, err := deserializer.Decode(raw, nil, &claim); err != nil {
236+
klog.Error(err)
237+
return &admissionv1.AdmissionResponse{
238+
Result: &metav1.Status{
239+
Message: err.Error(),
240+
Reason: metav1.StatusReasonBadRequest,
241+
},
242+
}
243+
}
244+
deviceConfigs = claim.Spec.Devices.Config
245+
specPath = "spec"
246+
case resourceClaimTemplateResource:
247+
claimTemplate := resourceapi.ResourceClaimTemplate{}
248+
if _, _, err := deserializer.Decode(raw, nil, &claimTemplate); err != nil {
249+
klog.Error(err)
250+
return &admissionv1.AdmissionResponse{
251+
Result: &metav1.Status{
252+
Message: err.Error(),
253+
Reason: metav1.StatusReasonBadRequest,
254+
},
255+
}
256+
}
257+
deviceConfigs = claimTemplate.Spec.Spec.Devices.Config
258+
specPath = "spec.spec"
259+
default:
260+
msg := fmt.Sprintf("expected resource to be %s or %s, got %s", resourceClaimResource, resourceClaimTemplateResource, ar.Request.Resource)
261+
klog.Error(msg)
262+
return &admissionv1.AdmissionResponse{
263+
Result: &metav1.Status{
264+
Message: msg,
265+
Reason: metav1.StatusReasonBadRequest,
266+
},
267+
}
268+
}
269+
270+
var errs []error
271+
for configIndex, config := range deviceConfigs {
272+
if config.Opaque == nil || config.Opaque.Driver != DriverName {
273+
continue
274+
}
275+
276+
fieldPath := fmt.Sprintf("%s.devices.config[%d].opaque.parameters", specPath, configIndex)
277+
decodedConfig, err := runtime.Decode(nvapi.Decoder, config.Opaque.Parameters.Raw)
278+
if err != nil {
279+
errs = append(errs, fmt.Errorf("error decoding object at %s: %w", fieldPath, err))
280+
continue
281+
}
282+
gpuConfig, ok := decodedConfig.(*nvapi.GpuConfig)
283+
if !ok {
284+
errs = append(errs, fmt.Errorf("expected v1beta1.GpuConfig at %s but got: %T", fieldPath, decodedConfig))
285+
continue
286+
}
287+
err = gpuConfig.Validate()
288+
if err != nil {
289+
errs = append(errs, fmt.Errorf("object at %s is invalid: %w", fieldPath, err))
290+
}
291+
}
292+
293+
if len(errs) > 0 {
294+
var errMsgs []string
295+
for _, err := range errs {
296+
errMsgs = append(errMsgs, err.Error())
297+
}
298+
msg := fmt.Sprintf("%d configs failed to validate: %s", len(errs), strings.Join(errMsgs, "; "))
299+
klog.Error(msg)
300+
return &admissionv1.AdmissionResponse{
301+
Result: &metav1.Status{
302+
Message: msg,
303+
Reason: metav1.StatusReason(metav1.StatusReasonInvalid),
304+
},
305+
}
306+
}
307+
308+
return &admissionv1.AdmissionResponse{
309+
Allowed: true,
310+
}
311+
}

0 commit comments

Comments
 (0)