@@ -2,6 +2,7 @@ package client
22
33import (
44 "bytes"
5+ "context"
56 "encoding/json"
67 "fmt"
78 "io"
@@ -97,13 +98,13 @@ func NewOAuthClient(agentMetadata *api.AgentMetadata, credentials *OAuthCredenti
9798 }, nil
9899}
99100
100- func (c * OAuthClient ) PostDataReadingsWithOptions (readings []* api.DataReading , opts Options ) error {
101- return c .PostDataReadings (opts .OrgID , opts .ClusterID , readings )
101+ func (c * OAuthClient ) PostDataReadingsWithOptions (ctx context. Context , readings []* api.DataReading , opts Options ) error {
102+ return c .PostDataReadings (ctx , opts .OrgID , opts .ClusterID , readings )
102103}
103104
104105// PostDataReadings uploads the slice of api.DataReading to the Jetstack Secure backend to be processed for later
105106// viewing in the user-interface.
106- func (c * OAuthClient ) PostDataReadings (orgID , clusterID string , readings []* api.DataReading ) error {
107+ func (c * OAuthClient ) PostDataReadings (ctx context. Context , orgID , clusterID string , readings []* api.DataReading ) error {
107108 payload := api.DataReadingsPost {
108109 AgentMetadata : c .agentMetadata ,
109110 DataGatherTime : time .Now ().UTC (),
@@ -114,7 +115,7 @@ func (c *OAuthClient) PostDataReadings(orgID, clusterID string, readings []*api.
114115 return err
115116 }
116117
117- res , err := c .Post (filepath .Join ("/api/v1/org" , orgID , "datareadings" , clusterID ), bytes .NewBuffer (data ))
118+ res , err := c .Post (ctx , filepath .Join ("/api/v1/org" , orgID , "datareadings" , clusterID ), bytes .NewBuffer (data ))
118119 if err != nil {
119120 return err
120121 }
@@ -134,13 +135,13 @@ func (c *OAuthClient) PostDataReadings(orgID, clusterID string, readings []*api.
134135}
135136
136137// Post performs an HTTP POST request.
137- func (c * OAuthClient ) Post (path string , body io.Reader ) (* http.Response , error ) {
138- token , err := c .getValidAccessToken ()
138+ func (c * OAuthClient ) Post (ctx context. Context , path string , body io.Reader ) (* http.Response , error ) {
139+ token , err := c .getValidAccessToken (ctx )
139140 if err != nil {
140141 return nil , err
141142 }
142143
143- req , err := http .NewRequest ( http .MethodPost , fullURL (c .baseURL , path ), body )
144+ req , err := http .NewRequestWithContext ( ctx , http .MethodPost , fullURL (c .baseURL , path ), body )
144145 if err != nil {
145146 return nil , err
146147 }
@@ -157,9 +158,9 @@ func (c *OAuthClient) Post(path string, body io.Reader) (*http.Response, error)
157158// getValidAccessToken returns a valid access token. It will fetch a new access
158159// token from the auth server in case the current access token does not exist
159160// or it is expired.
160- func (c * OAuthClient ) getValidAccessToken () (* accessToken , error ) {
161+ func (c * OAuthClient ) getValidAccessToken (ctx context. Context ) (* accessToken , error ) {
161162 if c .accessToken .needsRenew () {
162- err := c .renewAccessToken ()
163+ err := c .renewAccessToken (ctx )
163164 if err != nil {
164165 return nil , err
165166 }
@@ -168,7 +169,7 @@ func (c *OAuthClient) getValidAccessToken() (*accessToken, error) {
168169 return c .accessToken , nil
169170}
170171
171- func (c * OAuthClient ) renewAccessToken () error {
172+ func (c * OAuthClient ) renewAccessToken (ctx context. Context ) error {
172173 tokenURL := fmt .Sprintf ("https://%s/oauth/token" , c .credentials .AuthServerDomain )
173174 audience := "https://preflight.jetstack.io/api/v1"
174175 payload := url.Values {}
@@ -178,7 +179,7 @@ func (c *OAuthClient) renewAccessToken() error {
178179 payload .Set ("audience" , audience )
179180 payload .Set ("username" , c .credentials .UserID )
180181 payload .Set ("password" , c .credentials .UserSecret )
181- req , err := http .NewRequest ( "POST" , tokenURL , strings .NewReader (payload .Encode ()))
182+ req , err := http .NewRequestWithContext ( ctx , "POST" , tokenURL , strings .NewReader (payload .Encode ()))
182183 if err != nil {
183184 return errors .WithStack (err )
184185 }
@@ -188,7 +189,8 @@ func (c *OAuthClient) renewAccessToken() error {
188189 if err != nil {
189190 return errors .WithStack (err )
190191 }
191-
192+ // TODO(wallrj): This will block. Read the body incrementally and check for
193+ // context cancellation.
192194 body , err := io .ReadAll (res .Body )
193195 if err != nil {
194196 return errors .WithStack (err )
0 commit comments