Skip to content

Commit 5f0aacd

Browse files
author
Joshua Krstic
committed
Refactor container runner to not depend on file system or use a real timer
Add test for the wellknownfilewriter
1 parent 31cda11 commit 5f0aacd

File tree

5 files changed

+304
-190
lines changed

5 files changed

+304
-190
lines changed

launcher/container_runner.go

Lines changed: 73 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ import (
3131
"github.com/google/go-tpm-tools/launcher/internal/healthmonitoring/nodeproblemdetector"
3232
"github.com/google/go-tpm-tools/launcher/internal/logging"
3333
"github.com/google/go-tpm-tools/launcher/internal/signaturediscovery"
34+
"github.com/google/go-tpm-tools/launcher/launcher/clock"
3435
"github.com/google/go-tpm-tools/launcher/launcherfile"
3536
"github.com/google/go-tpm-tools/launcher/registryauth"
3637
"github.com/google/go-tpm-tools/launcher/spec"
@@ -269,13 +270,14 @@ func enableMonitoring(enabled spec.MonitoringType, logger logging.Logger) error
269270
if enabled != spec.None {
270271
logger.Info("Health Monitoring is enabled by the VM operator")
271272

272-
if enabled == spec.All {
273+
switch enabled {
274+
case spec.All:
273275
logger.Info("All health monitoring metrics enabled")
274276
if err := nodeproblemdetector.EnableAllConfig(); err != nil {
275277
logger.Error("Failed to enable full monitoring config: %v", err)
276278
return err
277279
}
278-
} else if enabled == spec.MemoryOnly {
280+
case spec.MemoryOnly:
279281
logger.Info("memory/bytes_used enabled")
280282
}
281283

@@ -418,10 +420,29 @@ func (r *ContainerRunner) measureMemoryMonitor() error {
418420
return nil
419421
}
420422

423+
type wellKnownFileLocationWriter struct{}
424+
425+
func (d wellKnownFileLocationWriter) Write(p []byte) (n int, err error) {
426+
// Write to a temp file first.
427+
tmpTokenPath := path.Join(launcherfile.HostTmpPath, tokenFileTmp)
428+
if err = os.WriteFile(tmpTokenPath, p, 0644); err != nil {
429+
return 0, fmt.Errorf("failed to write a tmp token file: %v", err)
430+
}
431+
432+
// Rename the temp file to the token file (to avoid race conditions).
433+
if err = os.Rename(tmpTokenPath, path.Join(launcherfile.HostTmpPath, launcherfile.AttestationVerifierTokenFilename)); err != nil {
434+
return 0, fmt.Errorf("failed to rename the token file: %v", err)
435+
}
436+
437+
return len(p), nil
438+
}
439+
440+
var _ io.Writer = (*wellKnownFileLocationWriter)(nil)
441+
421442
// Retrieves the default OIDC token from the attestation service, and returns how long
422443
// to wait before attemping to refresh it.
423444
// The token file will be written to a tmp file and then renamed.
424-
func (r *ContainerRunner) refreshToken(ctx context.Context) (time.Duration, error) {
445+
func (r *ContainerRunner) refreshToken(ctx context.Context, writer io.Writer) (time.Duration, error) {
425446
if err := r.attestAgent.Refresh(ctx); err != nil {
426447
return 0, fmt.Errorf("failed to refresh attestation agent: %v", err)
427448
}
@@ -444,15 +465,9 @@ func (r *ContainerRunner) refreshToken(ctx context.Context) (time.Duration, erro
444465
return 0, errors.New("token is expired")
445466
}
446467

447-
// Write to a temp file first.
448-
tmpTokenPath := path.Join(launcherfile.HostTmpPath, tokenFileTmp)
449-
if err = os.WriteFile(tmpTokenPath, token, 0644); err != nil {
450-
return 0, fmt.Errorf("failed to write a tmp token file: %v", err)
451-
}
452-
453-
// Rename the temp file to the token file (to avoid race conditions).
454-
if err = os.Rename(tmpTokenPath, path.Join(launcherfile.HostTmpPath, launcherfile.AttestationVerifierTokenFilename)); err != nil {
455-
return 0, fmt.Errorf("failed to rename the token file: %v", err)
468+
_, err = writer.Write(token)
469+
if err != nil {
470+
return 0, fmt.Errorf("failed to write token: %v", err)
456471
}
457472

458473
// Print out the claims in the jwt payload
@@ -467,56 +482,67 @@ func (r *ContainerRunner) refreshToken(ctx context.Context) (time.Duration, erro
467482
return getNextRefreshFromExpiration(time.Until(claims.ExpiresAt.Time), rand.Float64()), nil
468483
}
469484

470-
// ctx must be a cancellable context.
471-
func (r *ContainerRunner) fetchAndWriteToken(ctx context.Context) error {
472-
return r.fetchAndWriteTokenWithRetry(ctx, defaultRetryPolicy)
473-
}
474-
475485
// ctx must be a cancellable context.
476486
// retry specifies the refresher goroutine's retry policy.
477-
func (r *ContainerRunner) fetchAndWriteTokenWithRetry(ctx context.Context,
478-
retry func() *backoff.ExponentialBackOff) error {
479-
if err := os.MkdirAll(launcherfile.HostTmpPath, 0755); err != nil {
480-
return err
481-
}
482-
duration, err := r.refreshToken(ctx)
483-
if err != nil {
484-
return err
485-
}
487+
func (r *ContainerRunner) startTokenRefresher(ctx context.Context, retry func() *backoff.ExponentialBackOff,
488+
newTimer func(d time.Duration) clock.Timer, tokenWriter io.Writer) <-chan error {
489+
r.logger.Info("Starting token refresh goroutine")
490+
491+
initComplete := make(chan error, 1)
486492

487-
// Set a timer to refresh the token before it expires.
488-
timer := time.NewTimer(duration)
489493
go func() {
494+
r.logger.Info("token refresher goroutine started")
495+
defer close(initComplete)
496+
497+
isInitialized := false // A flag to ensure we only send the initialization signal once.
498+
signalDone := func(err error) {
499+
if !isInitialized {
500+
initComplete <- err // Signal to the calling function that the first refresh is done.
501+
isInitialized = true
502+
}
503+
}
504+
505+
// Start with a timer that fires immediately to get the first token.
506+
timer := newTimer(0)
507+
defer timer.Stop()
508+
490509
for {
491510
select {
492511
case <-ctx.Done():
493-
timer.Stop()
494-
r.logger.Info("token refreshing stopped")
512+
r.logger.Info("Token refreshing stopped")
495513
return
496-
case <-timer.C:
497-
r.logger.Info("refreshing attestation verifier OIDC token")
514+
case <-timer.C():
515+
r.logger.Info("Refreshing attestation verifier OIDC token")
498516
var duration time.Duration
499-
// Refresh token with default retry policy.
517+
500518
err := backoff.RetryNotify(
501519
func() error {
502-
duration, err = r.refreshToken(ctx)
503-
return err
520+
var refreshErr error
521+
duration, refreshErr = r.refreshToken(ctx, tokenWriter)
522+
return refreshErr
504523
},
505524
retry(),
506525
func(err error, t time.Duration) {
507-
r.logger.Error(fmt.Sprintf("failed to refresh attestation service token at time %v: %v", t, err))
526+
r.logger.Error(fmt.Sprintf("failed to refresh token at time %v: %v", t, err))
508527
})
528+
529+
// After the first attempt signal to the calling function that the refresh is done.
530+
signalDone(err)
531+
509532
if err != nil {
510-
r.logger.Error(fmt.Sprintf("failed all attempts to refresh attestation service token, stopping refresher: %v", err))
533+
// If all retry attempts fail, stop the refresher.
534+
r.logger.Error(fmt.Sprintf("failed all attempts to get/refresh token, stopping refresher: %v", err))
511535
return
512536
}
513537

538+
// On success, reset the timer for the next refresh.
539+
r.logger.Info("Resetting token refresh timer")
514540
timer.Reset(duration)
515541
}
516542
}
517543
}()
518544

519-
return nil
545+
return initComplete
520546
}
521547

522548
// getNextRefreshFromExpiration returns the Duration for the next run of the
@@ -583,7 +609,15 @@ func (r *ContainerRunner) Run(ctx context.Context) error {
583609

584610
// Only refresh token if agent has a default GCA client (not ITA use case).
585611
if r.launchSpec.ITARegion == "" {
586-
if err := r.fetchAndWriteToken(ctx); err != nil {
612+
// Create the well known token file location.
613+
if err := os.MkdirAll(launcherfile.HostTmpPath, 0755); err != nil {
614+
return err
615+
}
616+
r.logger.Info("Created directory", "path", launcherfile.HostTmpPath)
617+
618+
errchan := r.startTokenRefresher(ctx, defaultRetryPolicy, clock.NewRealTimer, wellKnownFileLocationWriter{})
619+
err := <-errchan
620+
if err != nil {
587621
return fmt.Errorf("failed to fetch and write OIDC token: %v", err)
588622
}
589623
}
@@ -607,7 +641,6 @@ func (r *ContainerRunner) Run(ctx context.Context) error {
607641

608642
attestClients.GCA = gcaClient
609643
}
610-
611644
teeServer, err := teeserver.New(ctx, path.Join(launcherfile.HostTmpPath, teeServerSocket), r.attestAgent, r.logger, r.launchSpec, attestClients)
612645
if err != nil {
613646
return fmt.Errorf("failed to create the TEE server: %v", err)

0 commit comments

Comments
 (0)