diff --git a/launcher/container_runner.go b/launcher/container_runner.go index b96741bec..3a9d633d2 100644 --- a/launcher/container_runner.go +++ b/launcher/container_runner.go @@ -240,7 +240,7 @@ func NewRunner(ctx context.Context, cdClient *containerd.Client, token oauth2.To asAddr := launchSpec.AttestationServiceAddr var verifierClient verifier.Client - if launchSpec.ITARegion == "" { + if launchSpec.ITAConfig.ITARegion == "" { gcaClient, err := util.NewRESTClient(ctx, asAddr, launchSpec.ProjectID, launchSpec.Region) if err != nil { return nil, fmt.Errorf("failed to create REST verifier client: %v", err) @@ -582,7 +582,7 @@ func (r *ContainerRunner) Run(ctx context.Context) error { } // Only refresh token if agent has a default GCA client (not ITA use case). - if r.launchSpec.ITARegion == "" { + if r.launchSpec.ITAConfig.ITARegion == "" { if err := r.fetchAndWriteToken(ctx); err != nil { return fmt.Errorf("failed to fetch and write OIDC token: %v", err) } @@ -591,9 +591,9 @@ func (r *ContainerRunner) Run(ctx context.Context) error { // create and start the TEE server r.logger.Info("EnableOnDemandAttestation is enabled: initializing TEE server.") - attestClients := &teeserver.AttestClients{} - if r.launchSpec.ITARegion != "" { - itaClient, err := ita.NewClient(r.launchSpec.ITARegion, r.launchSpec.ITAKey) + attestClients := teeserver.AttestClients{} + if r.launchSpec.ITAConfig.ITARegion != "" { + itaClient, err := ita.NewClient(r.launchSpec.ITAConfig) if err != nil { return fmt.Errorf("failed to create ITA client: %v", err) } diff --git a/launcher/go.sum b/launcher/go.sum index 964a5e01f..321df145d 100644 --- a/launcher/go.sum +++ b/launcher/go.sum @@ -155,6 +155,7 @@ github.com/caarlos0/ctrlc v1.0.0/go.mod h1:CdXpj4rmq0q/1Eb44M9zi2nKB0QraNKuRGYGr github.com/campoy/unique v0.0.0-20180121183637-88950e537e7e/go.mod h1:9IOqJGCPMSc6E5ydlp5NIonxObaeu/Iub/X03EKPVYo= github.com/casbin/casbin/v2 v2.1.2/go.mod h1:YcPU1XXisHhLzuxH9coDNf2FbKpjGlbCg3n9yuLkIJQ= github.com/cavaliercoder/go-cpio v0.0.0-20180626203310-925f9528c45e/go.mod h1:oDpT4efm8tSYHXV5tHSdRvBet/b/QzxZ+XyyPehvm3A= +github.com/cenkalti/backoff v2.2.1+incompatible h1:tNowT99t7UNflLxfYYSlKYsBpXdEet03Pg2g16Swow4= github.com/cenkalti/backoff v2.2.1+incompatible/go.mod h1:90ReRw6GdpyfrHakVjL/QHaoyV4aDUVVkXQJJJ3NXXM= github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8= github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= diff --git a/launcher/spec/launch_spec.go b/launcher/spec/launch_spec.go index 008ce89c0..328c1f674 100644 --- a/launcher/spec/launch_spec.go +++ b/launcher/spec/launch_spec.go @@ -21,6 +21,7 @@ import ( "github.com/google/go-tpm-tools/launcher/internal/launchermount" "github.com/google/go-tpm-tools/launcher/internal/logging" "github.com/google/go-tpm-tools/launcher/launcherfile" + "github.com/google/go-tpm-tools/verifier" "github.com/google/go-tpm-tools/verifier/util" ) @@ -124,8 +125,7 @@ type LaunchSpec struct { MonitoringEnabled MonitoringType LogRedirect LogRedirectLocation Mounts []launchermount.Mount - ITARegion string - ITAKey string + ITAConfig verifier.ITAConfig // DevShmSize is specified in kiB. DevShmSize int64 AddedCapabilities []string @@ -252,16 +252,14 @@ func (s *LaunchSpec) UnmarshalJSON(b []byte) error { itaRegionVal, itaRegionOK := unmarshaledMap[itaRegion] itaKeyVal, itaKeyOK := unmarshaledMap[itaKey] + // If key and region are both not in the map, do not set up ITA config. if itaRegionOK != itaKeyOK { - return fmt.Errorf("ITA fields %s and %s must both be provided", itaRegion, itaKey) + return fmt.Errorf("ITA fields %s and %s must both be provided and non-empty", itaRegion, itaKey) } - if itaRegionOK { - s.ITARegion = itaRegionVal - } - - if itaKeyOK { - s.ITAKey = itaKeyVal + s.ITAConfig = verifier.ITAConfig{ + ITARegion: itaRegionVal, + ITAKey: itaKeyVal, } } @@ -290,7 +288,7 @@ func (s *LaunchSpec) UnmarshalJSON(b []byte) error { // LogFriendly creates a copy of the spec that is safe to log by censoring func (s *LaunchSpec) LogFriendly() LaunchSpec { safeSpec := *s - safeSpec.ITAKey = strings.Repeat("*", len(s.ITAKey)) + safeSpec.ITAConfig.ITAKey = strings.Repeat("*", len(s.ITAConfig.ITAKey)) return safeSpec } diff --git a/launcher/spec/launch_spec_test.go b/launcher/spec/launch_spec_test.go index 90fd449ec..e4b4151ef 100644 --- a/launcher/spec/launch_spec_test.go +++ b/launcher/spec/launch_spec_test.go @@ -7,6 +7,7 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-tpm-tools/launcher/internal/experiments" "github.com/google/go-tpm-tools/launcher/internal/launchermount" + "github.com/google/go-tpm-tools/verifier" ) func TestLaunchSpecUnmarshalJSONHappyCases(t *testing.T) { @@ -64,8 +65,10 @@ func TestLaunchSpecUnmarshalJSONHappyCases(t *testing.T) { DevShmSize: 234234, Mounts: []launchermount.Mount{launchermount.TmpfsMount{Destination: "/tmpmount", Size: 0}, launchermount.TmpfsMount{Destination: "/sized", Size: 222}}, - ITARegion: "US", - ITAKey: "test-api-key", + ITAConfig: verifier.ITAConfig{ + ITARegion: "US", + ITAKey: "test-api-key", + }, Experiments: experiments.Experiments{ EnableItaVerifier: true, }, diff --git a/launcher/teeserver/tee_server.go b/launcher/teeserver/tee_server.go index f4d6fc659..19a0ab37c 100644 --- a/launcher/teeserver/tee_server.go +++ b/launcher/teeserver/tee_server.go @@ -14,11 +14,15 @@ import ( "github.com/google/go-tpm-tools/launcher/spec" "github.com/google/go-tpm-tools/verifier" "github.com/google/go-tpm-tools/verifier/models" - "github.com/google/go-tpm-tools/verifier/util" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) +const ( + gcaEndpoint = "/v1/token" + itaEndpoint = "/v1/intel/token" +) + var clientErrorCodes = map[codes.Code]struct{}{ codes.InvalidArgument: {}, codes.FailedPrecondition: {}, @@ -43,7 +47,7 @@ type attestHandler struct { // defaultTokenFile string logger logging.Logger launchSpec spec.LaunchSpec - clients *AttestClients + clients AttestClients } // TeeServer is a server that can be called from a container through a unix @@ -54,7 +58,7 @@ type TeeServer struct { } // New takes in a socket and start to listen to it, and create a server -func New(ctx context.Context, unixSock string, a agent.AttestationAgent, logger logging.Logger, launchSpec spec.LaunchSpec, clients *AttestClients) (*TeeServer, error) { +func New(ctx context.Context, unixSock string, a agent.AttestationAgent, logger logging.Logger, launchSpec spec.LaunchSpec, clients AttestClients) (*TeeServer, error) { var err error nl, err := net.Listen("unix", unixSock) if err != nil { @@ -84,8 +88,8 @@ func (a *attestHandler) Handler() http.Handler { // curl -d '{"audience":"", "nonces":[""]}' -H "Content-Type: application/json" -X POST // --unix-socket /tmp/container_launcher/teeserver.sock http://localhost/v1/token - mux.HandleFunc("/v1/token", a.getToken) - mux.HandleFunc("/v1/intel/token", a.getITAToken) + mux.HandleFunc(gcaEndpoint, a.getToken) + mux.HandleFunc(itaEndpoint, a.getITAToken) return mux } @@ -101,16 +105,13 @@ func (a *attestHandler) logAndWriteError(errStr string, status int, w http.Respo func (a *attestHandler) getToken(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/html") - // If the handler does not have a GCA client, create one. - if a.clients.GCA == nil { - gcaClient, err := util.NewRESTClient(a.ctx, a.launchSpec.AttestationServiceAddr, a.launchSpec.ProjectID, a.launchSpec.Region) - if err != nil { - errStr := fmt.Sprintf("failed to create REST verifier client: %v", err) - a.logAndWriteError(errStr, http.StatusInternalServerError, w) - return - } + a.logger.Info(fmt.Sprintf("%s called", gcaEndpoint)) - a.clients.GCA = gcaClient + // If the handler does not have an GCA client, return error. + if a.clients.GCA == nil { + errStr := "no GCA verifier client present, please try rebooting your VM" + a.logAndWriteError(errStr, http.StatusInternalServerError, w) + return } a.attest(w, r, a.clients.GCA) @@ -120,10 +121,12 @@ func (a *attestHandler) getToken(w http.ResponseWriter, r *http.Request) { func (a *attestHandler) getITAToken(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/html") + a.logger.Info(fmt.Sprintf("%s called", itaEndpoint)) + // If the handler does not have an ITA client, return error. if a.clients.ITA == nil { errStr := "no ITA verifier client present - ensure ITA Region and Key are defined in metadata" - a.logAndWriteError(errStr, http.StatusPreconditionFailed, w) + a.logAndWriteError(errStr, http.StatusInternalServerError, w) return } @@ -173,11 +176,10 @@ func (a *attestHandler) attest(w http.ResponseWriter, r *http.Request, client ve } // Do not check that TokenTypeOptions matches TokenType in the launcher. - - tok, err := a.attestAgent.AttestWithClient(a.ctx, agent.AttestAgentOpts{ + opts := agent.AttestAgentOpts{ TokenOptions: &tokenOptions, - }, client) - + } + tok, err := a.attestAgent.AttestWithClient(a.ctx, opts, client) if err != nil { a.handleAttestError(w, err, "failed to retrieve custom attestation service token") return diff --git a/launcher/teeserver/tee_server_test.go b/launcher/teeserver/tee_server_test.go index 4eab983af..6d6eaedc8 100644 --- a/launcher/teeserver/tee_server_test.go +++ b/launcher/teeserver/tee_server_test.go @@ -66,7 +66,7 @@ func TestGetDefaultToken(t *testing.T) { ah := attestHandler{ logger: logging.SimpleLogger(), - clients: &AttestClients{ + clients: AttestClients{ GCA: &fakeVerifierClient{}, }, attestAgent: fakeAttestationAgent{ @@ -93,11 +93,9 @@ func TestGetDefaultToken(t *testing.T) { } func TestGetDefaultTokenServerError(t *testing.T) { - // An empty attestHandler is fine for now as it is not being used - // in the handler. ah := attestHandler{ logger: logging.SimpleLogger(), - clients: &AttestClients{ + clients: AttestClients{ GCA: &fakeVerifierClient{}, }, attestAgent: fakeAttestationAgent{ @@ -202,29 +200,174 @@ func TestCustomToken(t *testing.T) { }, } - for i, test := range tests { - ah := attestHandler{ - logger: logging.SimpleLogger(), - clients: &AttestClients{ - GCA: &fakeVerifierClient{}, - }, - attestAgent: fakeAttestationAgent{ - attestWithClientFunc: test.attestWithClientFunc, - }} + verifiers := []struct { + name string + url string + tokenMethod func(ah *attestHandler, w http.ResponseWriter, r *http.Request) + }{ + { + name: "GCA Handler", + url: "/v1/token", + tokenMethod: (*attestHandler).getToken, + }, + { + name: "ITA Handler", + url: "/v1/intel/token", + tokenMethod: (*attestHandler).getITAToken, + }, + } - b := strings.NewReader(test.body) + for _, vf := range verifiers { + t.Run(vf.name, func(t *testing.T) { + for _, test := range tests { + ah := attestHandler{ + logger: logging.SimpleLogger(), + clients: AttestClients{ + GCA: &fakeVerifierClient{}, + ITA: &fakeVerifierClient{}, + }, + attestAgent: fakeAttestationAgent{ + attestWithClientFunc: test.attestWithClientFunc, + }} - req := httptest.NewRequest(http.MethodPost, "/v1/token", b) - w := httptest.NewRecorder() - ah.getToken(w, req) - _, err := io.ReadAll(w.Result().Body) - if err != nil { - t.Error(err) - } + b := strings.NewReader(test.body) - if w.Code != test.want { - t.Errorf("testcase %d, '%v': got return code: %d, want: %d", i, test.testName, w.Code, test.want) - } + req := httptest.NewRequest(http.MethodPost, vf.url, b) + w := httptest.NewRecorder() + + vf.tokenMethod(&ah, w, req) + + _, err := io.ReadAll(w.Result().Body) + if err != nil { + t.Error(err) + } + + if w.Code != test.want { + t.Errorf("testcase '%v': got return code: %d, want: %d", test.testName, w.Code, test.want) + } + } + }) + } +} + +func TestHandleAttestError(t *testing.T) { + body := `{ + "audience": "audience", + "nonces": ["thisIsAcustomNonce"], + "token_type": "OIDC" + }` + + errorCases := []struct { + name string + err error + wantStatusCode int + }{ + { + name: "FailedPrecondition error", + err: status.New(codes.FailedPrecondition, "bad state").Err(), + wantStatusCode: http.StatusBadRequest, + }, + { + name: "PermissionDenied error", + err: status.New(codes.PermissionDenied, "denied").Err(), + wantStatusCode: http.StatusBadRequest, + }, + { + name: "Internal error", + err: status.New(codes.Internal, "internal server error").Err(), + wantStatusCode: http.StatusInternalServerError, + }, + { + name: "Unavailable error", + err: status.New(codes.Unavailable, "service unavailable").Err(), + wantStatusCode: http.StatusInternalServerError, + }, + { + name: "non-gRPC error", + err: errors.New("a generic error"), + wantStatusCode: http.StatusInternalServerError, + }, + } + + verifiers := []struct { + name string + url string + tokenMethod func(ah *attestHandler, w http.ResponseWriter, r *http.Request) + }{ + { + name: "GCA Handler", + url: "/v1/token", + tokenMethod: (*attestHandler).getToken, + }, + { + name: "ITA Handler", + url: "/v1/intel/token", + tokenMethod: (*attestHandler).getITAToken, + }, + } + + for _, vf := range verifiers { + t.Run(vf.name, func(t *testing.T) { + for _, tc := range errorCases { + t.Run(tc.name, func(t *testing.T) { + ah := attestHandler{ + logger: logging.SimpleLogger(), + clients: AttestClients{ + GCA: &fakeVerifierClient{}, + ITA: &fakeVerifierClient{}, + }, + attestAgent: fakeAttestationAgent{ + attestWithClientFunc: func(context.Context, agent.AttestAgentOpts, verifier.Client) ([]byte, error) { + return nil, tc.err + }, + }, + } + + req := httptest.NewRequest(http.MethodPost, vf.url, strings.NewReader(body)) + w := httptest.NewRecorder() + + vf.tokenMethod(&ah, w, req) + + if w.Code != tc.wantStatusCode { + t.Errorf("got status code %d, want %d", w.Code, tc.wantStatusCode) + } + + _, err := io.ReadAll(w.Result().Body) + if err != nil { + t.Errorf("failed to read response body: %v", err) + } + }) + } + }) + } +} + +func TestHandleAttestError_NilClient(t *testing.T) { + verifiers := []struct { + name string + url string + handler func(ah *attestHandler, w http.ResponseWriter, r *http.Request) + }{ + {name: "GCA Handler", url: "/v1/token", handler: (*attestHandler).getToken}, + {name: "ITA Handler", url: "/v1/intel/token", handler: (*attestHandler).getITAToken}, + } + + for _, vf := range verifiers { + t.Run(vf.name, func(t *testing.T) { + ah := attestHandler{ + logger: logging.SimpleLogger(), + clients: AttestClients{}, // No clients defined + } + + req := httptest.NewRequest(http.MethodPost, vf.url, strings.NewReader("")) + w := httptest.NewRecorder() + vf.handler(&ah, w, req) + + const wantStatusCode = http.StatusInternalServerError + if w.Code != wantStatusCode { + t.Errorf("got status code %d, want %d", w.Code, wantStatusCode) + } + }) } } @@ -336,7 +479,7 @@ func TestCustomTokenDataParsedSuccessfully(t *testing.T) { for i, test := range tests { ah := attestHandler{ logger: logging.SimpleLogger(), - clients: &AttestClients{ + clients: AttestClients{ GCA: &fakeVerifierClient{}, }, attestAgent: fakeAttestationAgent{ @@ -364,72 +507,3 @@ func TestCustomTokenDataParsedSuccessfully(t *testing.T) { } } } - -func TestCustomHandleAttestError(t *testing.T) { - body := `{ - "audience": "audience", - "nonces": ["thisIsAcustomNonce"], - "token_type": "OIDC" - }` - - testcases := []struct { - name string - err error - wantStatusCode int - }{ - { - name: "FailedPrecondition error", - err: status.New(codes.FailedPrecondition, "bad state").Err(), - wantStatusCode: http.StatusBadRequest, - }, - { - name: "PermissionDenied error", - err: status.New(codes.PermissionDenied, "denied").Err(), - wantStatusCode: http.StatusBadRequest, - }, - { - name: "Internal error", - err: status.New(codes.Internal, "internal server error").Err(), - wantStatusCode: http.StatusInternalServerError, - }, - { - name: "Unavailable error", - err: status.New(codes.Unavailable, "service unavailable").Err(), - wantStatusCode: http.StatusInternalServerError, - }, - { - name: "non-gRPC error", - err: errors.New("a generic error"), - wantStatusCode: http.StatusInternalServerError, - }, - } - for _, tc := range testcases { - t.Run(tc.name, func(t *testing.T) { - ah := attestHandler{ - logger: logging.SimpleLogger(), - clients: &AttestClients{ - GCA: &fakeVerifierClient{}, - }, - attestAgent: fakeAttestationAgent{ - attestWithClientFunc: func(context.Context, agent.AttestAgentOpts, verifier.Client) ([]byte, error) { - return nil, tc.err - }, - }, - } - - req := httptest.NewRequest(http.MethodPost, "/v1/token", strings.NewReader(body)) - w := httptest.NewRecorder() - - ah.getToken(w, req) - - if w.Code != tc.wantStatusCode { - t.Errorf("got status code %d, want %d", w.Code, tc.wantStatusCode) - } - - _, err := io.ReadAll(w.Result().Body) - if err != nil { - t.Errorf("failed to read response body: %v", err) - } - }) - } -} diff --git a/verifier/client.go b/verifier/client.go index 6c6217aae..da9a43b07 100644 --- a/verifier/client.go +++ b/verifier/client.go @@ -66,3 +66,22 @@ type VerifyAttestationResponse struct { ClaimsToken []byte PartialErrs []*status.Status } + +// ITAConfig represents the configuration needed to integrate with ITA as a verifier. +type ITAConfig struct { + ITARegion string + ITAKey string +} + +// AttestClients contains clients for supported verifier services that can be used to +// get attestation tokens. +type AttestClients struct { + GCA Client + ITA Client +} + +// HasThirdPartyClient returns true if AttestClients contains an initialzied +// third-party verifier client. +func (ac *AttestClients) HasThirdPartyClient() bool { + return ac.ITA != nil +} diff --git a/verifier/ita/client.go b/verifier/ita/client.go index 018bb7b84..942ad69d1 100644 --- a/verifier/ita/client.go +++ b/verifier/ita/client.go @@ -42,7 +42,6 @@ func urlFromRegion(region string) (string, error) { if region == "" { return "", errors.New("API region required to initialize ITA client") } - url, ok := regionalURLs[strings.ToUpper(region)] if !ok { // Create list of allowed regions. @@ -56,8 +55,30 @@ func urlFromRegion(region string) (string, error) { return url, nil } -func NewClient(region string, key string) (verifier.Client, error) { - url, err := urlFromRegion(region) +// Confirm that client implements verifier.Client interface. +var _ verifier.Client = (*client)(nil) + +type itaNonce struct { + Val []byte `json:"val"` + Iat []byte `json:"iat"` + Signature []byte `json:"signature"` +} + +// The ITA evidence nonce is a concatenation+hash of Val and Iat. See references below: +// https://github.com/intel/trustauthority-client-for-go/blob/main/go-connector/attest.go#L22 +// https://github.com/intel/trustauthority-client-for-go/blob/main/go-tdx/tdx_adapter.go#L37 +func createHashedNonce(nonce *itaNonce) ([]byte, error) { + hash := sha512.New() + _, err := hash.Write(append(nonce.Val, nonce.Iat...)) + if err != nil { + return nil, fmt.Errorf("error hashing ITA nonce: %v", err) + } + + return hash.Sum(nil), err +} + +func NewClient(itaConfig verifier.ITAConfig) (verifier.Client, error) { //region string, key string) (verifier.Client, error) { + url, err := urlFromRegion(itaConfig.ITARegion) if err != nil { return nil, err } @@ -78,37 +99,14 @@ func NewClient(region string, key string) (verifier.Client, error) { }, }, apiURL: url, - apiKey: key, + apiKey: itaConfig.ITAKey, }, nil } -// Confirm that client implements verifier.Client interface. -var _ verifier.Client = (*client)(nil) - -type itaNonce struct { - Val []byte `json:"val"` - Iat []byte `json:"iat"` - Signature []byte `json:"signature"` -} - -// The ITA evidence nonce is a concatenation+hash of Val and Iat. See references below: -// https://github.com/intel/trustauthority-client-for-go/blob/main/go-connector/attest.go#L22 -// https://github.com/intel/trustauthority-client-for-go/blob/main/go-tdx/tdx_adapter.go#L37 -func createHashedNonce(nonce *itaNonce) ([]byte, error) { - hash := sha512.New() - _, err := hash.Write(append(nonce.Val, nonce.Iat...)) - if err != nil { - return nil, fmt.Errorf("error hashing ITA nonce: %v", err) - } - - return hash.Sum(nil), err -} - func (c *client) CreateChallenge(_ context.Context) (*verifier.Challenge, error) { url := c.apiURL + nonceEndpoint headers := map[string]string{ - apiKeyHeader: c.apiKey, acceptHeader: applicationJSON, } @@ -131,65 +129,6 @@ func (c *client) CreateChallenge(_ context.Context) (*verifier.Challenge, error) }, nil } -func convertRequestToTokenRequest(request verifier.VerifyAttestationRequest) tokenRequest { - // Trim trailing 0xFF bytes from CCEL Data. - data := request.TDCCELAttestation.CcelData - trimIndex := len(data) - - for ; trimIndex >= 0; trimIndex-- { - c := data[trimIndex-1] - // Proceed until 0xFF padding ends. - if c != byte(255) { - break - } - } - - tokenReq := tokenRequest{ - PolicyMatch: true, - TDX: tdxEvidence{ - EventLog: data[:trimIndex], - CanonicalEventLog: request.TDCCELAttestation.CanonicalEventLog, - Quote: request.TDCCELAttestation.TdQuote, - VerifierNonce: nonce{ - Val: request.Challenge.Val, - Iat: request.Challenge.Iat, - Signature: request.Challenge.Signature, - }, - }, - SigAlg: "RS256", // Figure out what this should be. - GCP: gcpData{ - AKCert: request.TDCCELAttestation.AkCert, - IntermediateCerts: request.TDCCELAttestation.IntermediateCerts, - CSInfo: confidentialSpaceInfo{ - TokenOpts: tokenOptions{}, - }, - }, - } - - if request.TokenOptions != nil { - tokenReq.GCP.CSInfo.TokenOpts = tokenOptions{ - Audience: request.TokenOptions.Audience, - Nonces: request.TokenOptions.Nonces, - TokenType: request.TokenOptions.TokenType, - TokenTypeOpts: tokenTypeOptions{}, - } - } - - for _, token := range request.GcpCredentials { - tokenReq.GCP.GcpCredentials = append(tokenReq.GCP.GcpCredentials, string(token)) - } - - for _, sig := range request.ContainerImageSignatures { - itaSig := containerSignature{ - Payload: sig.Payload, - Signature: sig.Signature, - } - tokenReq.GCP.CSInfo.SignedEntities = append(tokenReq.GCP.CSInfo.SignedEntities, itaSig) - } - - return tokenReq -} - func (c *client) VerifyAttestation(_ context.Context, request verifier.VerifyAttestationRequest) (*verifier.VerifyAttestationResponse, error) { if request.TDCCELAttestation == nil { return nil, errors.New("TDX required for ITA attestation") @@ -214,10 +153,6 @@ func (c *client) VerifyAttestation(_ context.Context, request verifier.VerifyAtt }, nil } -func (c *client) VerifyConfidentialSpace(ctx context.Context, request verifier.VerifyAttestationRequest) (*verifier.VerifyAttestationResponse, error) { - return c.VerifyAttestation(ctx, request) -} - func (c *client) doHTTPRequest(method string, url string, reqStruct any, headers map[string]string, respStruct any) error { // Create HTTP request. var req *http.Request @@ -240,6 +175,7 @@ func (c *client) doHTTPRequest(method string, url string, reqStruct any, headers } // Add headers to request. + headers[apiKeyHeader] = string(c.apiKey) for key, val := range headers { req.Header.Add(key, val) } @@ -248,12 +184,16 @@ func (c *client) doHTTPRequest(method string, url string, reqStruct any, headers if err != nil { return fmt.Errorf("HTTP request error: %v", err) } + defer resp.Body.Close() // Read and unmarshal response body. respBody, err := io.ReadAll(resp.Body) if err != nil { return fmt.Errorf("error reading response body: %v", err) } + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("HTTP request failed with status code %d, response body %s", resp.StatusCode, string(respBody)) + } if err := json.Unmarshal(respBody, respStruct); err != nil { return fmt.Errorf("error unmarshaling response: %v", err) @@ -261,3 +201,66 @@ func (c *client) doHTTPRequest(method string, url string, reqStruct any, headers return nil } + +func convertRequestToTokenRequest(request verifier.VerifyAttestationRequest) tokenRequest { + // Trim trailing 0xFF bytes from CCEL Data. + data := request.TDCCELAttestation.CcelData + trimIndex := len(data) + + for ; trimIndex >= 0; trimIndex-- { + c := data[trimIndex-1] + // Proceed until 0xFF padding ends. + if c != byte(255) { + break + } + } + + tokenReq := tokenRequest{ + PolicyMatch: true, + TDX: tdxEvidence{ + EventLog: data[:trimIndex], + CanonicalEventLog: request.TDCCELAttestation.CanonicalEventLog, + Quote: request.TDCCELAttestation.TdQuote, + VerifierNonce: nonce{ + Val: request.Challenge.Val, + Iat: request.Challenge.Iat, + Signature: request.Challenge.Signature, + }, + }, + SigAlg: "RS256", // Figure out what this should be. + GCP: gcpData{ + AKCert: request.TDCCELAttestation.AkCert, + IntermediateCerts: request.TDCCELAttestation.IntermediateCerts, + CSInfo: confidentialSpaceInfo{ + TokenOpts: tokenOptions{}, + }, + }, + } + + if request.TokenOptions != nil { + tokenReq.GCP.CSInfo.TokenOpts = tokenOptions{ + Audience: request.TokenOptions.Audience, + Nonces: request.TokenOptions.Nonces, + TokenType: request.TokenOptions.TokenType, + TokenTypeOpts: tokenTypeOptions{}, + } + } + + for _, token := range request.GcpCredentials { + tokenReq.GCP.GcpCredentials = append(tokenReq.GCP.GcpCredentials, string(token)) + } + + for _, sig := range request.ContainerImageSignatures { + itaSig := containerSignature{ + Payload: sig.Payload, + Signature: sig.Signature, + } + tokenReq.GCP.CSInfo.SignedEntities = append(tokenReq.GCP.CSInfo.SignedEntities, itaSig) + } + + return tokenReq +} + +func (c *client) VerifyConfidentialSpace(ctx context.Context, request verifier.VerifyAttestationRequest) (*verifier.VerifyAttestationResponse, error) { + return c.VerifyAttestation(ctx, request) +} diff --git a/verifier/ita/client_test.go b/verifier/ita/client_test.go index 6082595bc..77ad3d82f 100644 --- a/verifier/ita/client_test.go +++ b/verifier/ita/client_test.go @@ -4,7 +4,7 @@ import ( "bytes" "context" "encoding/json" - "io/ioutil" + "io" "net/http" "net/http/httptest" "strings" @@ -144,7 +144,7 @@ func TestVerifyAttestation(t *testing.T) { // Verify HTTP Request body. defer r.Body.Close() - reqBody, err := ioutil.ReadAll(r.Body) + reqBody, err := io.ReadAll(r.Body) if err != nil { t.Fatalf("Error reading HTTP request body: %s", err) } @@ -212,7 +212,7 @@ func TestDoHTTPRequest(t *testing.T) { // Verify HTTP Request body. defer r.Body.Close() - reqBody, err := ioutil.ReadAll(r.Body) + reqBody, err := io.ReadAll(r.Body) if err != nil { t.Fatalf("Error reading HTTP request body: %s", err) } diff --git a/verifier/ita/evidence.go b/verifier/ita/evidence.go index 8adc2b891..dd5e55b8c 100644 --- a/verifier/ita/evidence.go +++ b/verifier/ita/evidence.go @@ -33,8 +33,8 @@ type tokenTypeOptions struct { type tokenOptions struct { Audience string `json:"audience"` Nonces []string `json:"nonce"` - TokenType string `json:"tokenType"` - TokenTypeOpts tokenTypeOptions `json:"tokenTypeOptions"` + TokenType string `json:"token_type"` + TokenTypeOpts tokenTypeOptions `json:"token_type_options"` } type confidentialSpaceInfo struct {