@@ -31,6 +31,7 @@ import (
3131 "github.com/google/go-tpm-tools/launcher/internal/healthmonitoring/nodeproblemdetector"
3232 "github.com/google/go-tpm-tools/launcher/internal/logging"
3333 "github.com/google/go-tpm-tools/launcher/internal/signaturediscovery"
34+ "github.com/google/go-tpm-tools/launcher/launcher/clock"
3435 "github.com/google/go-tpm-tools/launcher/launcherfile"
3536 "github.com/google/go-tpm-tools/launcher/registryauth"
3637 "github.com/google/go-tpm-tools/launcher/spec"
@@ -269,13 +270,14 @@ func enableMonitoring(enabled spec.MonitoringType, logger logging.Logger) error
269270 if enabled != spec .None {
270271 logger .Info ("Health Monitoring is enabled by the VM operator" )
271272
272- if enabled == spec .All {
273+ switch enabled {
274+ case spec .All :
273275 logger .Info ("All health monitoring metrics enabled" )
274276 if err := nodeproblemdetector .EnableAllConfig (); err != nil {
275277 logger .Error ("Failed to enable full monitoring config: %v" , err )
276278 return err
277279 }
278- } else if enabled == spec .MemoryOnly {
280+ case spec .MemoryOnly :
279281 logger .Info ("memory/bytes_used enabled" )
280282 }
281283
@@ -418,10 +420,29 @@ func (r *ContainerRunner) measureMemoryMonitor() error {
418420 return nil
419421}
420422
423+ type wellKnownFileLocationWriter struct {}
424+
425+ func (d wellKnownFileLocationWriter ) Write (p []byte ) (n int , err error ) {
426+ // Write to a temp file first.
427+ tmpTokenPath := path .Join (launcherfile .HostTmpPath , tokenFileTmp )
428+ if err = os .WriteFile (tmpTokenPath , p , 0644 ); err != nil {
429+ return 0 , fmt .Errorf ("failed to write a tmp token file: %v" , err )
430+ }
431+
432+ // Rename the temp file to the token file (to avoid race conditions).
433+ if err = os .Rename (tmpTokenPath , path .Join (launcherfile .HostTmpPath , launcherfile .AttestationVerifierTokenFilename )); err != nil {
434+ return 0 , fmt .Errorf ("failed to rename the token file: %v" , err )
435+ }
436+
437+ return len (p ), nil
438+ }
439+
440+ var _ io.Writer = (* wellKnownFileLocationWriter )(nil )
441+
421442// Retrieves the default OIDC token from the attestation service, and returns how long
422443// to wait before attemping to refresh it.
423444// The token file will be written to a tmp file and then renamed.
424- func (r * ContainerRunner ) refreshToken (ctx context.Context ) (time.Duration , error ) {
445+ func (r * ContainerRunner ) refreshToken (ctx context.Context , writer io. Writer ) (time.Duration , error ) {
425446 if err := r .attestAgent .Refresh (ctx ); err != nil {
426447 return 0 , fmt .Errorf ("failed to refresh attestation agent: %v" , err )
427448 }
@@ -444,15 +465,9 @@ func (r *ContainerRunner) refreshToken(ctx context.Context) (time.Duration, erro
444465 return 0 , errors .New ("token is expired" )
445466 }
446467
447- // Write to a temp file first.
448- tmpTokenPath := path .Join (launcherfile .HostTmpPath , tokenFileTmp )
449- if err = os .WriteFile (tmpTokenPath , token , 0644 ); err != nil {
450- return 0 , fmt .Errorf ("failed to write a tmp token file: %v" , err )
451- }
452-
453- // Rename the temp file to the token file (to avoid race conditions).
454- if err = os .Rename (tmpTokenPath , path .Join (launcherfile .HostTmpPath , launcherfile .AttestationVerifierTokenFilename )); err != nil {
455- return 0 , fmt .Errorf ("failed to rename the token file: %v" , err )
468+ _ , err = writer .Write (token )
469+ if err != nil {
470+ return 0 , fmt .Errorf ("failed to write token: %v" , err )
456471 }
457472
458473 // Print out the claims in the jwt payload
@@ -467,56 +482,67 @@ func (r *ContainerRunner) refreshToken(ctx context.Context) (time.Duration, erro
467482 return getNextRefreshFromExpiration (time .Until (claims .ExpiresAt .Time ), rand .Float64 ()), nil
468483}
469484
470- // ctx must be a cancellable context.
471- func (r * ContainerRunner ) fetchAndWriteToken (ctx context.Context ) error {
472- return r .fetchAndWriteTokenWithRetry (ctx , defaultRetryPolicy )
473- }
474-
475485// ctx must be a cancellable context.
476486// retry specifies the refresher goroutine's retry policy.
477- func (r * ContainerRunner ) fetchAndWriteTokenWithRetry (ctx context.Context ,
478- retry func () * backoff.ExponentialBackOff ) error {
479- if err := os .MkdirAll (launcherfile .HostTmpPath , 0755 ); err != nil {
480- return err
481- }
482- duration , err := r .refreshToken (ctx )
483- if err != nil {
484- return err
485- }
487+ func (r * ContainerRunner ) startTokenRefresher (ctx context.Context , retry func () * backoff.ExponentialBackOff ,
488+ newTimer func (d time.Duration ) clock.Timer , tokenWriter io.Writer ) <- chan error {
489+ r .logger .Info ("Starting token refresh goroutine" )
490+
491+ initComplete := make (chan error , 1 )
486492
487- // Set a timer to refresh the token before it expires.
488- timer := time .NewTimer (duration )
489493 go func () {
494+ r .logger .Info ("token refresher goroutine started" )
495+ defer close (initComplete )
496+
497+ isInitialized := false // A flag to ensure we only send the initialization signal once.
498+ signalDone := func (err error ) {
499+ if ! isInitialized {
500+ initComplete <- err // Signal to the calling function that the first refresh is done.
501+ isInitialized = true
502+ }
503+ }
504+
505+ // Start with a timer that fires immediately to get the first token.
506+ timer := newTimer (0 )
507+ defer timer .Stop ()
508+
490509 for {
491510 select {
492511 case <- ctx .Done ():
493- timer .Stop ()
494- r .logger .Info ("token refreshing stopped" )
512+ r .logger .Info ("Token refreshing stopped" )
495513 return
496- case <- timer .C :
497- r .logger .Info ("refreshing attestation verifier OIDC token" )
514+ case <- timer .C () :
515+ r .logger .Info ("Refreshing attestation verifier OIDC token" )
498516 var duration time.Duration
499- // Refresh token with default retry policy.
517+
500518 err := backoff .RetryNotify (
501519 func () error {
502- duration , err = r .refreshToken (ctx )
503- return err
520+ var refreshErr error
521+ duration , refreshErr = r .refreshToken (ctx , tokenWriter )
522+ return refreshErr
504523 },
505524 retry (),
506525 func (err error , t time.Duration ) {
507- r .logger .Error (fmt .Sprintf ("failed to refresh attestation service token at time %v: %v" , t , err ))
526+ r .logger .Error (fmt .Sprintf ("failed to refresh token at time %v: %v" , t , err ))
508527 })
528+
529+ // After the first attempt signal to the calling function that the refresh is done.
530+ signalDone (err )
531+
509532 if err != nil {
510- r .logger .Error (fmt .Sprintf ("failed all attempts to refresh attestation service token, stopping refresher: %v" , err ))
533+ // If all retry attempts fail, stop the refresher.
534+ r .logger .Error (fmt .Sprintf ("failed all attempts to get/refresh token, stopping refresher: %v" , err ))
511535 return
512536 }
513537
538+ // On success, reset the timer for the next refresh.
539+ r .logger .Info ("Resetting token refresh timer" )
514540 timer .Reset (duration )
515541 }
516542 }
517543 }()
518544
519- return nil
545+ return initComplete
520546}
521547
522548// getNextRefreshFromExpiration returns the Duration for the next run of the
@@ -583,7 +609,15 @@ func (r *ContainerRunner) Run(ctx context.Context) error {
583609
584610 // Only refresh token if agent has a default GCA client (not ITA use case).
585611 if r .launchSpec .ITARegion == "" {
586- if err := r .fetchAndWriteToken (ctx ); err != nil {
612+ // Create the well known token file location.
613+ if err := os .MkdirAll (launcherfile .HostTmpPath , 0755 ); err != nil {
614+ return err
615+ }
616+ r .logger .Info ("Created directory" , "path" , launcherfile .HostTmpPath )
617+
618+ errchan := r .startTokenRefresher (ctx , defaultRetryPolicy , clock .NewRealTimer , wellKnownFileLocationWriter {})
619+ err := <- errchan
620+ if err != nil {
587621 return fmt .Errorf ("failed to fetch and write OIDC token: %v" , err )
588622 }
589623 }
@@ -607,7 +641,6 @@ func (r *ContainerRunner) Run(ctx context.Context) error {
607641
608642 attestClients .GCA = gcaClient
609643 }
610-
611644 teeServer , err := teeserver .New (ctx , path .Join (launcherfile .HostTmpPath , teeServerSocket ), r .attestAgent , r .logger , r .launchSpec , attestClients )
612645 if err != nil {
613646 return fmt .Errorf ("failed to create the TEE server: %v" , err )
0 commit comments