99 "bytes"
1010 "context"
1111 "crypto"
12+ "crypto/rand"
1213 "encoding/base64"
1314 "fmt"
1415 "io"
@@ -24,6 +25,8 @@ import (
2425 tg "github.com/google/go-tdx-guest/client"
2526 tlabi "github.com/google/go-tdx-guest/client/linuxabi"
2627
28+ "github.com/NVIDIA/go-nvml/pkg/nvml"
29+ "github.com/confidentsecurity/go-nvtrust/pkg/gonvtrust/gpu"
2730 "github.com/google/go-tpm-tools/cel"
2831 "github.com/google/go-tpm-tools/client"
2932 "github.com/google/go-tpm-tools/internal"
@@ -223,6 +226,8 @@ func (a *agent) AttestWithClient(ctx context.Context, opts AttestAgentOpts, clie
223226 v .IntermediateCerts = certChain
224227 v .AkCert = a .fetchedAK .CertDERBytes ()
225228 req .TDCCELAttestation = v
229+ // collect GPU attestation
230+ collectGpuAttestation (a .logger )
226231 default :
227232 return nil , fmt .Errorf ("received an unsupported attestation type! %v" , v )
228233 }
@@ -406,3 +411,52 @@ func (c *sigsCache) get() []oci.Signature {
406411 defer c .mu .RUnlock ()
407412 return c .items
408413}
414+
415+ func collectGpuAttestation (logger logging.Logger ) {
416+ handler := & gpu.DefaultNVMLHandler {}
417+ gpuAdmin , err := gpu .NewNvmlGPUAdmin (handler )
418+ if err != nil {
419+ logger .Error ("Failed to create GPU admin: %v\n " , err )
420+ }
421+ defer gpuAdmin .Shutdown ()
422+
423+ // Generate a random nonce (32 bytes)
424+ nonce := make ([]byte , 32 )
425+ if _ , err := rand .Read (nonce ); err != nil {
426+ logger .Error ("Failed to generate nonce: %v\n " , err )
427+ }
428+
429+ deviceInfos , err := gpuAdmin .CollectEvidence (nonce )
430+ if err != nil {
431+ logger .Error ("Failed to collect GPU evidence\n : %w" , err )
432+ }
433+
434+ for i , deviceInfo := range deviceInfos {
435+ device , ret := handler .DeviceGetHandleByIndex (i )
436+ if ret != nvml .SUCCESS {
437+ logger .Error ("Failed to get GPU device: %w\n " , nvml .ErrorString (ret ))
438+ }
439+ uuid , ret := device .GetUUID ()
440+ if ret != nvml .SUCCESS {
441+ logger .Error ("Failed to get UUID: %w\n " , nvml .ErrorString (ret ))
442+ }
443+
444+ vbiosVersion , ret := device .GetVbiosVersion ()
445+ if ret != nvml .SUCCESS {
446+ logger .Error ("Failed to get vbios version: %w\n " , nvml .ErrorString (ret ))
447+ }
448+
449+ driverVersion , ret := handler .SystemGetDriverVersion ()
450+ if ret != nvml .SUCCESS {
451+ logger .Error ("Failed to get vbios version: %w\n " , nvml .ErrorString (ret ))
452+ }
453+ logger .Info ("Found GPU UUID [%s] at index %d\n " , uuid , i )
454+ logger .Info ("Found GPU VBIOS version [%s] at index %d\n " , vbiosVersion , i )
455+ logger .Info ("Found GPU DRIVER version [%s] at index %d\n " , driverVersion , i )
456+ // The following attestation data can be accessed by device Info
457+ logger .Info ("Found GPU Arch [%s] at index %d\n " , deviceInfo .Arch (), i )
458+ logger .Info ("Found GPU attetation data size [%d] at index %d\n " , len (deviceInfo .AttestationReport ()), i )
459+ b64CertChainData , _ := deviceInfo .Certificate ().EncodeBase64 ()
460+ logger .Info ("Found GPU attestation cert chain data [%s] at index %d\n " , b64CertChainData , i )
461+ }
462+ }
0 commit comments