Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 73 additions & 40 deletions launcher/container_runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import (
"github.com/google/go-tpm-tools/launcher/internal/healthmonitoring/nodeproblemdetector"
"github.com/google/go-tpm-tools/launcher/internal/logging"
"github.com/google/go-tpm-tools/launcher/internal/signaturediscovery"
"github.com/google/go-tpm-tools/launcher/launcher/clock"
"github.com/google/go-tpm-tools/launcher/launcherfile"
"github.com/google/go-tpm-tools/launcher/registryauth"
"github.com/google/go-tpm-tools/launcher/spec"
Expand Down Expand Up @@ -269,13 +270,14 @@ func enableMonitoring(enabled spec.MonitoringType, logger logging.Logger) error
if enabled != spec.None {
logger.Info("Health Monitoring is enabled by the VM operator")

if enabled == spec.All {
switch enabled {
case spec.All:
logger.Info("All health monitoring metrics enabled")
if err := nodeproblemdetector.EnableAllConfig(); err != nil {
logger.Error("Failed to enable full monitoring config: %v", err)
return err
}
} else if enabled == spec.MemoryOnly {
case spec.MemoryOnly:
logger.Info("memory/bytes_used enabled")
}

Expand Down Expand Up @@ -418,10 +420,29 @@ func (r *ContainerRunner) measureMemoryMonitor() error {
return nil
}

type wellKnownFileLocationWriter struct{}

func (d wellKnownFileLocationWriter) Write(p []byte) (n int, err error) {
// Write to a temp file first.
tmpTokenPath := path.Join(launcherfile.HostTmpPath, tokenFileTmp)
if err = os.WriteFile(tmpTokenPath, p, 0644); err != nil {
return 0, fmt.Errorf("failed to write a tmp token file: %v", err)
}

// Rename the temp file to the token file (to avoid race conditions).
if err = os.Rename(tmpTokenPath, path.Join(launcherfile.HostTmpPath, launcherfile.AttestationVerifierTokenFilename)); err != nil {
return 0, fmt.Errorf("failed to rename the token file: %v", err)
}

return len(p), nil
}

var _ io.Writer = (*wellKnownFileLocationWriter)(nil)

// Retrieves the default OIDC token from the attestation service, and returns how long
// to wait before attemping to refresh it.
// The token file will be written to a tmp file and then renamed.
func (r *ContainerRunner) refreshToken(ctx context.Context) (time.Duration, error) {
func (r *ContainerRunner) refreshToken(ctx context.Context, writer io.Writer) (time.Duration, error) {
if err := r.attestAgent.Refresh(ctx); err != nil {
return 0, fmt.Errorf("failed to refresh attestation agent: %v", err)
}
Expand All @@ -444,15 +465,9 @@ func (r *ContainerRunner) refreshToken(ctx context.Context) (time.Duration, erro
return 0, errors.New("token is expired")
}

// Write to a temp file first.
tmpTokenPath := path.Join(launcherfile.HostTmpPath, tokenFileTmp)
if err = os.WriteFile(tmpTokenPath, token, 0644); err != nil {
return 0, fmt.Errorf("failed to write a tmp token file: %v", err)
}

// Rename the temp file to the token file (to avoid race conditions).
if err = os.Rename(tmpTokenPath, path.Join(launcherfile.HostTmpPath, launcherfile.AttestationVerifierTokenFilename)); err != nil {
return 0, fmt.Errorf("failed to rename the token file: %v", err)
_, err = writer.Write(token)
if err != nil {
return 0, fmt.Errorf("failed to write token: %v", err)
}

// Print out the claims in the jwt payload
Expand All @@ -467,56 +482,67 @@ func (r *ContainerRunner) refreshToken(ctx context.Context) (time.Duration, erro
return getNextRefreshFromExpiration(time.Until(claims.ExpiresAt.Time), rand.Float64()), nil
}

// ctx must be a cancellable context.
func (r *ContainerRunner) fetchAndWriteToken(ctx context.Context) error {
return r.fetchAndWriteTokenWithRetry(ctx, defaultRetryPolicy)
}

// ctx must be a cancellable context.
// retry specifies the refresher goroutine's retry policy.
func (r *ContainerRunner) fetchAndWriteTokenWithRetry(ctx context.Context,
retry func() *backoff.ExponentialBackOff) error {
if err := os.MkdirAll(launcherfile.HostTmpPath, 0755); err != nil {
return err
}
duration, err := r.refreshToken(ctx)
if err != nil {
return err
}
func (r *ContainerRunner) startTokenRefresher(ctx context.Context, retry func() *backoff.ExponentialBackOff,
newTimer func(d time.Duration) clock.Timer, tokenWriter io.Writer) <-chan error {
r.logger.Info("Starting token refresh goroutine")

initComplete := make(chan error, 1)

// Set a timer to refresh the token before it expires.
timer := time.NewTimer(duration)
go func() {
r.logger.Info("token refresher goroutine started")
defer close(initComplete)

isInitialized := false // A flag to ensure we only send the initialization signal once.
signalDone := func(err error) {
if !isInitialized {
initComplete <- err // Signal to the calling function that the first refresh is done.
isInitialized = true
}
}

// Start with a timer that fires immediately to get the first token.
timer := newTimer(0)
defer timer.Stop()

for {
select {
case <-ctx.Done():
timer.Stop()
r.logger.Info("token refreshing stopped")
r.logger.Info("Token refreshing stopped")
return
case <-timer.C:
r.logger.Info("refreshing attestation verifier OIDC token")
case <-timer.C():
r.logger.Info("Refreshing attestation verifier OIDC token")
var duration time.Duration
// Refresh token with default retry policy.

err := backoff.RetryNotify(
func() error {
duration, err = r.refreshToken(ctx)
return err
var refreshErr error
duration, refreshErr = r.refreshToken(ctx, tokenWriter)
return refreshErr
},
retry(),
func(err error, t time.Duration) {
r.logger.Error(fmt.Sprintf("failed to refresh attestation service token at time %v: %v", t, err))
r.logger.Error(fmt.Sprintf("failed to refresh token at time %v: %v", t, err))
})

// After the first attempt signal to the calling function that the refresh is done.
signalDone(err)

if err != nil {
r.logger.Error(fmt.Sprintf("failed all attempts to refresh attestation service token, stopping refresher: %v", err))
// If all retry attempts fail, stop the refresher.
r.logger.Error(fmt.Sprintf("failed all attempts to get/refresh token, stopping refresher: %v", err))
return
}

// On success, reset the timer for the next refresh.
r.logger.Info("Resetting token refresh timer")
timer.Reset(duration)
}
}
}()

return nil
return initComplete
}

// getNextRefreshFromExpiration returns the Duration for the next run of the
Expand Down Expand Up @@ -583,7 +609,15 @@ func (r *ContainerRunner) Run(ctx context.Context) error {

// Only refresh token if agent has a default GCA client (not ITA use case).
if r.launchSpec.ITARegion == "" {
if err := r.fetchAndWriteToken(ctx); err != nil {
// Create the well known token file location.
if err := os.MkdirAll(launcherfile.HostTmpPath, 0755); err != nil {
return err
}
r.logger.Info("Created directory", "path", launcherfile.HostTmpPath)

errchan := r.startTokenRefresher(ctx, defaultRetryPolicy, clock.NewRealTimer, wellKnownFileLocationWriter{})
err := <-errchan
if err != nil {
return fmt.Errorf("failed to fetch and write OIDC token: %v", err)
}
}
Expand All @@ -607,7 +641,6 @@ func (r *ContainerRunner) Run(ctx context.Context) error {

attestClients.GCA = gcaClient
}

teeServer, err := teeserver.New(ctx, path.Join(launcherfile.HostTmpPath, teeServerSocket), r.attestAgent, r.logger, r.launchSpec, attestClients)
if err != nil {
return fmt.Errorf("failed to create the TEE server: %v", err)
Expand Down
Loading
Loading