@@ -282,20 +282,26 @@ func TestNewClientExplicitNoAuth(t *testing.T) {
282
282
283
283
func TestCustomToken (t * testing.T ) {
284
284
client := & Client {
285
- signer : testSigner ,
286
- clock : testClock ,
285
+ baseClient : & baseClient {
286
+ signer : testSigner ,
287
+ clock : testClock ,
288
+ },
287
289
}
288
290
token , err := client .CustomToken (context .Background (), "user1" )
289
291
if err != nil {
290
292
t .Fatal (err )
291
293
}
292
- verifyCustomToken (context .Background (), token , nil , t )
294
+ if err := verifyCustomToken (context .Background (), token , nil , "" ); err != nil {
295
+ t .Fatal (err )
296
+ }
293
297
}
294
298
295
299
func TestCustomTokenWithClaims (t * testing.T ) {
296
300
client := & Client {
297
- signer : testSigner ,
298
- clock : testClock ,
301
+ baseClient : & baseClient {
302
+ signer : testSigner ,
303
+ clock : testClock ,
304
+ },
299
305
}
300
306
claims := map [string ]interface {}{
301
307
"foo" : "bar" ,
@@ -306,19 +312,46 @@ func TestCustomTokenWithClaims(t *testing.T) {
306
312
if err != nil {
307
313
t .Fatal (err )
308
314
}
309
- verifyCustomToken (context .Background (), token , claims , t )
315
+ if err := verifyCustomToken (context .Background (), token , claims , "" ); err != nil {
316
+ t .Fatal (err )
317
+ }
310
318
}
311
319
312
320
func TestCustomTokenWithNilClaims (t * testing.T ) {
313
321
client := & Client {
314
- signer : testSigner ,
315
- clock : testClock ,
322
+ baseClient : & baseClient {
323
+ signer : testSigner ,
324
+ clock : testClock ,
325
+ },
316
326
}
317
327
token , err := client .CustomTokenWithClaims (context .Background (), "user1" , nil )
318
328
if err != nil {
319
329
t .Fatal (err )
320
330
}
321
- verifyCustomToken (context .Background (), token , nil , t )
331
+ if err := verifyCustomToken (context .Background (), token , nil , "" ); err != nil {
332
+ t .Fatal (err )
333
+ }
334
+ }
335
+
336
+ func TestCustomTokenForTenant (t * testing.T ) {
337
+ client := & Client {
338
+ baseClient : & baseClient {
339
+ tenantID : "tenantID" ,
340
+ signer : testSigner ,
341
+ clock : testClock ,
342
+ },
343
+ }
344
+ claims := map [string ]interface {}{
345
+ "foo" : "bar" ,
346
+ "premium" : true ,
347
+ }
348
+ token , err := client .CustomTokenWithClaims (context .Background (), "user1" , claims )
349
+ if err != nil {
350
+ t .Fatal (err )
351
+ }
352
+ if err := verifyCustomToken (context .Background (), token , claims , "tenantID" ); err != nil {
353
+ t .Fatal (err )
354
+ }
322
355
}
323
356
324
357
func TestCustomTokenError (t * testing.T ) {
@@ -333,7 +366,7 @@ func TestCustomTokenError(t *testing.T) {
333
366
{"ReservedClaims" , "uid" , map [string ]interface {}{"sub" : "1234" , "aud" : "foo" }},
334
367
}
335
368
336
- client := & Client {
369
+ client := & baseClient {
337
370
signer : testSigner ,
338
371
clock : testClock ,
339
372
}
@@ -628,9 +661,9 @@ func TestCustomTokenVerification(t *testing.T) {
628
661
client := & Client {
629
662
baseClient : & baseClient {
630
663
idTokenVerifier : testIDTokenVerifier ,
664
+ signer : testSigner ,
665
+ clock : testClock ,
631
666
},
632
- signer : testSigner ,
633
- clock : testClock ,
634
667
}
635
668
token , err := client .CustomToken (context .Background (), "user1" )
636
669
if err != nil {
@@ -1137,52 +1170,61 @@ func checkBaseClient(client *Client, wantProjectID string) error {
1137
1170
return nil
1138
1171
}
1139
1172
1140
- func verifyCustomToken (ctx context.Context , token string , expected map [string ]interface {}, t * testing.T ) {
1173
+ func verifyCustomToken (
1174
+ ctx context.Context , token string , expected map [string ]interface {}, tenantID string ) error {
1175
+
1141
1176
if err := testIDTokenVerifier .verifySignature (ctx , token ); err != nil {
1142
- t . Fatal ( err )
1177
+ return err
1143
1178
}
1179
+
1144
1180
var (
1145
1181
header jwtHeader
1146
1182
payload customToken
1147
1183
)
1148
1184
segments := strings .Split (token , "." )
1149
1185
if err := decode (segments [0 ], & header ); err != nil {
1150
- t . Fatal ( err )
1186
+ return err
1151
1187
}
1152
1188
if err := decode (segments [1 ], & payload ); err != nil {
1153
- t . Fatal ( err )
1189
+ return err
1154
1190
}
1155
1191
1156
1192
email , err := testSigner .Email (ctx )
1157
1193
if err != nil {
1158
- t . Fatal ( err )
1194
+ return err
1159
1195
}
1160
1196
1161
1197
if header .Algorithm != "RS256" {
1162
- t .Errorf ("Algorithm: %q; want: 'RS256'" , header .Algorithm )
1198
+ return fmt .Errorf ("Algorithm: %q; want: 'RS256'" , header .Algorithm )
1163
1199
} else if header .Type != "JWT" {
1164
- t .Errorf ("Type: %q; want: 'JWT'" , header .Type )
1200
+ return fmt .Errorf ("Type: %q; want: 'JWT'" , header .Type )
1165
1201
} else if payload .Aud != firebaseAudience {
1166
- t .Errorf ("Audience: %q; want: %q" , payload .Aud , firebaseAudience )
1202
+ return fmt .Errorf ("Audience: %q; want: %q" , payload .Aud , firebaseAudience )
1167
1203
} else if payload .Iss != email {
1168
- t .Errorf ("Issuer: %q; want: %q" , payload .Iss , email )
1204
+ return fmt .Errorf ("Issuer: %q; want: %q" , payload .Iss , email )
1169
1205
} else if payload .Sub != email {
1170
- t .Errorf ("Subject: %q; want: %q" , payload .Sub , email )
1206
+ return fmt .Errorf ("Subject: %q; want: %q" , payload .Sub , email )
1171
1207
}
1172
1208
1173
1209
now := testClock .Now ().Unix ()
1174
1210
if payload .Exp != now + 3600 {
1175
- t .Errorf ("Exp: %d; want: %d" , payload .Exp , now + 3600 )
1211
+ return fmt .Errorf ("Exp: %d; want: %d" , payload .Exp , now + 3600 )
1176
1212
}
1177
1213
if payload .Iat != now {
1178
- t .Errorf ("Iat: %d; want: %d" , payload .Iat , now )
1214
+ return fmt .Errorf ("Iat: %d; want: %d" , payload .Iat , now )
1179
1215
}
1180
1216
1181
1217
for k , v := range expected {
1182
1218
if payload .Claims [k ] != v {
1183
- t .Errorf ("Claim[%q]: %v; want: %v" , k , payload .Claims [k ], v )
1219
+ return fmt .Errorf ("Claim[%q]: %v; want: %v" , k , payload .Claims [k ], v )
1184
1220
}
1185
1221
}
1222
+
1223
+ if payload .TenantID != tenantID {
1224
+ return fmt .Errorf ("Tenant ID: %q; want: %q" , payload .TenantID , tenantID )
1225
+ }
1226
+
1227
+ return nil
1186
1228
}
1187
1229
1188
1230
func logFatal (err error ) {
0 commit comments