@@ -5,16 +5,36 @@ import (
55 "encoding/json"
66 "errors"
77 "fmt"
8+ "io"
9+ "log"
10+ "net/http"
11+ "net/url"
812 "strings"
913 "time"
1014
1115 "golang.org/x/oauth2"
1216)
1317
18+ const (
19+ grantTypeJwtBearer = "urn:ietf:params:oauth:grant-type:jwt-bearer"
20+ clientAssertionType = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"
21+ )
22+
1423type tokenPayload struct {
1524 Expiration int64 `json:"exp"`
1625}
1726
27+ type JWTAssertionTokenSource struct { // revive:disable-line:exported
28+ Assertion string
29+ ClientAssertion string
30+ ClientID string
31+ ClientSecret string
32+ GrantType string
33+ Scopes []string
34+ TokenURL string
35+ HTTPClient * http.Client
36+ }
37+
1838func AccessTokenExpiration (accessToken string ) (time.Time , error ) {
1939 tp := strings .Split (accessToken , "." )
2040 if len (tp ) != 3 {
@@ -64,3 +84,107 @@ func ToOAuth2Token(accessToken, refreshToken string) (*oauth2.Token, error) {
6484 }
6585 return oAuthToken , nil
6686}
87+
88+ func (s * JWTAssertionTokenSource ) Token () (* oauth2.Token , error ) {
89+ data := url.Values {}
90+
91+ if s .TokenURL == "" {
92+ return nil , fmt .Errorf ("token URL is required" )
93+ }
94+ if s .GrantType == "" {
95+ data .Set ("grant_type" , grantTypeJwtBearer )
96+ } else {
97+ data .Set ("grant_type" , s .GrantType )
98+ }
99+ // Assertion is required for JWT Bearer grant type
100+ if s .Assertion == "" && (s .GrantType == grantTypeJwtBearer || s .GrantType == "" ) {
101+ return nil , fmt .Errorf ("assertion is required for JWT Bearer grant type" )
102+ }
103+
104+ if s .Assertion != "" {
105+ if err := validateJWTTokenFormat (s .Assertion ); err != nil {
106+ return nil , err
107+ }
108+ data .Set ("assertion" , s .Assertion )
109+ }
110+
111+ // Optional client_id
112+ if s .ClientID != "" {
113+ data .Set ("client_id" , s .ClientID )
114+ }
115+
116+ // Optional client_secret
117+ if s .ClientSecret != "" {
118+ if s .ClientID == "" {
119+ return nil , fmt .Errorf ("client_id is required when using client secret" )
120+ }
121+ data .Set ("client_secret" , s .ClientSecret )
122+ }
123+
124+ // Optional client_assertion
125+ if s .ClientAssertion != "" {
126+ if s .ClientID == "" {
127+ return nil , fmt .Errorf ("client_id is required when using client assertion" )
128+ }
129+ if err := validateJWTTokenFormat (s .ClientAssertion ); err != nil {
130+ return nil , err
131+ }
132+ data .Set ("client_assertion_type" , clientAssertionType )
133+ data .Set ("client_assertion" , s .ClientAssertion )
134+ }
135+ if len (s .Scopes ) > 0 {
136+ data .Set ("scope" , strings .Join (s .Scopes , " " ))
137+ }
138+
139+ req , err := http .NewRequest ("POST" , s .TokenURL , strings .NewReader (data .Encode ()))
140+ if err != nil {
141+ return nil , fmt .Errorf ("token request object creation failed: %w" , err )
142+ }
143+ req .Header .Set ("Content-Type" , "application/x-www-form-urlencoded" )
144+
145+ client := s .HTTPClient
146+ if client == nil {
147+ client = http .DefaultClient
148+ }
149+
150+ resp , err := client .Do (req )
151+ if err != nil {
152+ return nil , fmt .Errorf ("token request failed: %w" , err )
153+ }
154+ defer func () {
155+ if err := resp .Body .Close (); err != nil {
156+ log .Printf ("failed to close response body: %v" , err )
157+ }
158+ }()
159+
160+ body , _ := io .ReadAll (resp .Body )
161+ if resp .StatusCode != http .StatusOK {
162+ return nil , fmt .Errorf ("token request failed: %s" , body )
163+ }
164+
165+ var token oauth2.Token
166+ err = json .Unmarshal (body , & token )
167+ if err != nil {
168+ return nil , fmt .Errorf ("token unmarshal error: %w" , err )
169+ }
170+ return & token , nil
171+ }
172+
173+ // validateJWTTokenFormat checks if the provided JWT token has a valid format.
174+ func validateJWTTokenFormat (token string ) error {
175+ parts := strings .Split (token , "." )
176+ if len (parts ) != 3 {
177+ return errors .New ("token must have three parts separated by '.'" )
178+ }
179+
180+ for i , part := range parts {
181+ if part == "" {
182+ return fmt .Errorf ("token part is empty" )
183+ }
184+ if _ , err := base64 .RawURLEncoding .DecodeString (part ); err != nil {
185+ return fmt .Errorf ("invalid base64 encoding in part %d" , i + 1 )
186+ }
187+ }
188+
189+ return nil
190+ }
0 commit comments