Skip to content

Commit e930f8a

Browse files
committed
feat: add tenant selection and management with CLI support
1 parent 1597737 commit e930f8a

File tree

3 files changed

+320
-2
lines changed

3 files changed

+320
-2
lines changed

src/cmd/cli/command/commands.go

Lines changed: 84 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,13 @@ import (
99
"os/exec"
1010
"path/filepath"
1111
"regexp"
12+
"sort"
1213
"strings"
1314
"time"
1415

1516
"github.com/AlecAivazis/survey/v2"
1617
"github.com/DefangLabs/defang/src/pkg"
18+
"github.com/DefangLabs/defang/src/pkg/auth"
1719
"github.com/DefangLabs/defang/src/pkg/cli"
1820
cliClient "github.com/DefangLabs/defang/src/pkg/cli/client"
1921
"github.com/DefangLabs/defang/src/pkg/cli/client/byoc"
@@ -56,6 +58,7 @@ var (
5658
modelId = os.Getenv("DEFANG_MODEL_ID") // for Pro users only
5759
nonInteractive = !hasTty
5860
org string
61+
tenantFlag string
5962
providerID = cliClient.ProviderID(pkg.Getenv("DEFANG_PROVIDER", "auto"))
6063
verbose = false
6164
)
@@ -162,6 +165,7 @@ func SetupCommands(ctx context.Context, version string) {
162165
RootCmd.PersistentFlags().StringVarP(&cluster, "cluster", "s", pcluster.DefangFabric, "Defang cluster to connect to")
163166
RootCmd.PersistentFlags().MarkHidden("cluster")
164167
RootCmd.PersistentFlags().StringVar(&org, "org", os.Getenv("DEFANG_ORG"), "override GitHub organization name (tenant)")
168+
RootCmd.PersistentFlags().StringVar(&tenantFlag, "tenant", "", "select tenant by name")
165169
RootCmd.PersistentFlags().VarP(&providerID, "provider", "P", fmt.Sprintf(`bring-your-own-cloud provider; one of %v`, cliClient.AllProviders()))
166170
// RootCmd.Flag("provider").NoOptDefVal = "auto" NO this will break the "--provider aws"
167171
RootCmd.PersistentFlags().BoolVarP(&verbose, "verbose", "v", false, "verbose logging") // backwards compat: only used by tail
@@ -214,6 +218,9 @@ func SetupCommands(ctx context.Context, version string) {
214218
// Whoami Command
215219
RootCmd.AddCommand(whoamiCmd)
216220

221+
// Tenants Command
222+
RootCmd.AddCommand(tenantsCmd)
223+
217224
// Logout Command
218225
RootCmd.AddCommand(logoutCmd)
219226

@@ -364,6 +371,14 @@ var RootCmd = &cobra.Command{
364371
}
365372
}
366373

374+
// Configure tenant selection based on --tenant flag
375+
if f := cmd.Root().Flag("tenant"); f != nil && f.Changed {
376+
auth.SetSelectedTenantName(tenantFlag)
377+
} else {
378+
// Default behavior: auto-select tenant by JWT subject if no explicit name is provided
379+
auth.SetAutoSelectBySub(true)
380+
}
381+
367382
client, err = cli.Connect(ctx, getCluster())
368383

369384
if v, err := client.GetVersions(ctx); err == nil {
@@ -375,6 +390,8 @@ var RootCmd = &cobra.Command{
375390
}
376391
}
377392

393+
// (deliberately skip tenant resolution here to avoid blocking non-auth commands)
394+
378395
// Check if we are correctly logged in, but only if the command needs authorization
379396
if _, ok := cmd.Annotations[authNeeded]; !ok {
380397
return nil
@@ -386,7 +403,73 @@ var RootCmd = &cobra.Command{
386403
err = login.InteractiveRequireLoginAndToS(ctx, client, getCluster())
387404
}
388405

389-
return err
406+
if err != nil {
407+
return err
408+
}
409+
410+
// Ensure tenant is resolved post-login as we now have a token
411+
if tok := pcluster.GetExistingToken(getCluster()); tok != "" {
412+
if err2 := auth.ResolveAndSetTenantFromToken(ctx, tok); err2 != nil {
413+
return err2
414+
}
415+
// log the tenant name and id
416+
term.Debug("Selected tenant:", auth.GetSelectedTenantName(), "(", auth.GetSelectedTenantID(), ")")
417+
}
418+
419+
return nil
420+
},
421+
}
422+
423+
var tenantsCmd = &cobra.Command{
424+
Use: "tenants",
425+
Args: cobra.NoArgs,
426+
Annotations: authNeededAnnotation,
427+
Short: "List tenants available to the logged-in user",
428+
RunE: func(cmd *cobra.Command, args []string) error {
429+
ctx := cmd.Context()
430+
tok := pcluster.GetExistingToken(getCluster())
431+
if strings.TrimSpace(tok) == "" {
432+
return errors.New("not logged in; run 'defang login'")
433+
}
434+
435+
tenants, err := auth.ListTenantsFromToken(ctx, tok)
436+
if err != nil {
437+
return err
438+
}
439+
440+
// Sort by name for stable output
441+
sort.Slice(tenants, func(i, j int) bool { return strings.ToLower(tenants[i].Name) < strings.ToLower(tenants[j].Name) })
442+
443+
if len(tenants) == 0 {
444+
term.Info("No tenants found")
445+
return nil
446+
}
447+
448+
currentID := auth.GetSelectedTenantID()
449+
currentName := auth.GetSelectedTenantName()
450+
451+
// Compute longest name for aligned output
452+
maxNameLen := 0
453+
for _, t := range tenants {
454+
if l := len(t.Name); l > maxNameLen {
455+
maxNameLen = l
456+
}
457+
}
458+
459+
for _, t := range tenants {
460+
selected := t.ID == currentID || (currentID == "" && t.Name == currentName && strings.TrimSpace(currentName) != "")
461+
marker := "-"
462+
if selected {
463+
marker = "*" // highlight selected
464+
}
465+
line := fmt.Sprintf("%s %-*s (%s)\n", marker, maxNameLen, t.Name, t.ID)
466+
if selected {
467+
term.Printc(term.BrightCyan, " * ", line)
468+
} else {
469+
term.Printc(term.InfoColor, " * ", line)
470+
}
471+
}
472+
return nil
390473
},
391474
}
392475

src/pkg/auth/interceptor.go

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,10 @@ import (
77
"github.com/bufbuild/connect-go"
88
)
99

10-
const XDefangOrgID = "X-Defang-Orgid"
10+
const (
11+
XDefangOrgID = "X-Defang-Orgid"
12+
XDefangTenantID = "X-Defang-Tenant-Id"
13+
)
1114

1215
type authInterceptor struct {
1316
authorization string
@@ -23,6 +26,9 @@ func (a *authInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc {
2326
req.Header().Set("Authorization", a.authorization)
2427
req.Header().Set("Content-Type", "application/grpc") // same as the gRPC client
2528
req.Header().Set(XDefangOrgID, a.orgID)
29+
if tid := GetSelectedTenantID(); tid != "" {
30+
req.Header().Set(XDefangTenantID, tid)
31+
}
2632
return next(ctx, req)
2733
}
2834
}
@@ -33,6 +39,9 @@ func (a *authInterceptor) WrapStreamingClient(next connect.StreamingClientFunc)
3339
conn.RequestHeader().Set("Authorization", a.authorization)
3440
conn.RequestHeader().Set("Content-Type", "application/grpc") // same as the gRPC client
3541
conn.RequestHeader().Set(XDefangOrgID, a.orgID)
42+
if tid := GetSelectedTenantID(); tid != "" {
43+
conn.RequestHeader().Set(XDefangTenantID, tid)
44+
}
3645
return conn
3746
}
3847
}

src/pkg/auth/tenant.go

Lines changed: 226 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
package auth
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
"errors"
7+
"fmt"
8+
"net/http"
9+
"strings"
10+
11+
"github.com/golang-jwt/jwt/v5"
12+
)
13+
14+
var (
15+
selectedTenantName string
16+
selectedTenantID string
17+
autoSelectBySub bool
18+
19+
// Returned when multiple tenants share the same name in the userinfo response.
20+
ErrMultipleTenantMatches = errors.New("multiple tenants match the name")
21+
// Returned when no tenant matches the provided name in the userinfo response.
22+
ErrTenantNotFound = errors.New("tenant not found")
23+
// Returned when no access token is available yet (user not logged in).
24+
ErrNoAccessToken = errors.New("no access token available; please login first")
25+
)
26+
27+
// SetSelectedTenantName stores the desired tenant name for selection.
28+
func SetSelectedTenantName(name string) {
29+
selectedTenantName = strings.TrimSpace(name)
30+
autoSelectBySub = false
31+
}
32+
33+
// SetAutoSelectBySub enables or disables auto-select by JWT sub.
34+
func SetAutoSelectBySub(enabled bool) {
35+
autoSelectBySub = enabled
36+
}
37+
38+
// subFromJWT extracts the "sub" claim from the given JWT without verification.
39+
func subFromJWT(token string) (string, error) {
40+
var claims jwt.MapClaims
41+
_, _, err := new(jwt.Parser).ParseUnverified(token, &claims)
42+
if err != nil {
43+
return "", fmt.Errorf("failed to parse access token: %w", err)
44+
}
45+
subVal, ok := claims["sub"]
46+
if !ok {
47+
return "", errors.New("token is missing subject (sub) claim")
48+
}
49+
sub, ok := subVal.(string)
50+
if !ok || sub == "" {
51+
return "", errors.New("invalid subject (sub) claim in token")
52+
}
53+
return sub, nil
54+
}
55+
56+
// GetSelectedTenantName returns the currently selected tenant name.
57+
func GetSelectedTenantName() string { return selectedTenantName }
58+
59+
// SetSelectedTenantID stores the resolved tenant ID used in Fabric requests.
60+
func SetSelectedTenantID(id string) { selectedTenantID = strings.TrimSpace(id) }
61+
62+
// GetSelectedTenantID returns the currently selected tenant ID.
63+
func GetSelectedTenantID() string { return selectedTenantID }
64+
65+
// issuerFromJWT extracts the "iss" claim from the given JWT without verification.
66+
func issuerFromJWT(token string) (string, error) {
67+
var claims jwt.MapClaims
68+
_, _, err := new(jwt.Parser).ParseUnverified(token, &claims)
69+
if err != nil {
70+
return "", fmt.Errorf("failed to parse access token: %w", err)
71+
}
72+
issVal, ok := claims["iss"]
73+
if !ok {
74+
return "", errors.New("token is missing issuer (iss) claim")
75+
}
76+
iss, ok := issVal.(string)
77+
if !ok || iss == "" {
78+
return "", errors.New("invalid issuer (iss) claim in token")
79+
}
80+
return iss, nil
81+
}
82+
83+
// userinfoTenant represents a tenant entry in the /userinfo payload.
84+
type userinfoTenant struct {
85+
ID string `json:"id"`
86+
Name string `json:"name"`
87+
}
88+
89+
// userinfoResponse represents the relevant portion of the /userinfo response.
90+
type userinfoResponse struct {
91+
AllTenants []userinfoTenant `json:"allTenants"`
92+
}
93+
94+
// ResolveAndSetTenantFromToken resolves the tenant ID for the previously set tenant name
95+
// by calling issuer + "/userinfo" with the current access token. On success, it sets the
96+
// global selected tenant ID so subsequent Fabric requests include the header.
97+
func ResolveAndSetTenantFromToken(ctx context.Context, accessToken string) error {
98+
// If neither a specific name was requested nor auto-select was enabled, do nothing
99+
if strings.TrimSpace(selectedTenantName) == "" && !autoSelectBySub {
100+
return nil
101+
}
102+
103+
token := strings.TrimSpace(accessToken)
104+
if token == "" {
105+
return ErrNoAccessToken
106+
}
107+
108+
iss, err := issuerFromJWT(token)
109+
if err != nil {
110+
return err
111+
}
112+
113+
url := strings.TrimRight(iss, "/") + "/userinfo"
114+
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
115+
if err != nil {
116+
return err
117+
}
118+
req.Header.Set("Authorization", "Bearer "+token)
119+
req.Header.Set("Accept", "application/json")
120+
121+
resp, err := http.DefaultClient.Do(req)
122+
if err != nil {
123+
return fmt.Errorf("userinfo request failed: %w", err)
124+
}
125+
defer resp.Body.Close()
126+
if resp.StatusCode != http.StatusOK {
127+
return fmt.Errorf("userinfo request failed: %s", resp.Status)
128+
}
129+
130+
var ui userinfoResponse
131+
if err := json.NewDecoder(resp.Body).Decode(&ui); err != nil {
132+
return fmt.Errorf("failed to decode userinfo: %w", err)
133+
}
134+
135+
if autoSelectBySub {
136+
sub, err := subFromJWT(token)
137+
if err != nil {
138+
return err
139+
}
140+
matches := 0
141+
var id string
142+
for _, t := range ui.AllTenants {
143+
if t.ID == sub {
144+
id = t.ID
145+
matches++
146+
}
147+
}
148+
switch matches {
149+
case 0:
150+
return fmt.Errorf("%w: no tenant with id matching JWT sub", ErrTenantNotFound)
151+
case 1:
152+
SetSelectedTenantID(id)
153+
return nil
154+
default:
155+
return fmt.Errorf("%w: multiple tenants with id %q", ErrMultipleTenantMatches, sub)
156+
}
157+
} else {
158+
var (
159+
id string
160+
count int
161+
)
162+
for _, t := range ui.AllTenants {
163+
if t.Name == selectedTenantName {
164+
id = t.ID
165+
count++
166+
}
167+
}
168+
switch count {
169+
case 0:
170+
return fmt.Errorf("%w: %q", ErrTenantNotFound, selectedTenantName)
171+
case 1:
172+
SetSelectedTenantID(id)
173+
return nil
174+
default:
175+
return fmt.Errorf("%w: %q", ErrMultipleTenantMatches, selectedTenantName)
176+
}
177+
}
178+
}
179+
180+
// Tenant represents a tenant entry returned by the /userinfo endpoint.
181+
type Tenant struct {
182+
ID string `json:"id"`
183+
Name string `json:"name"`
184+
}
185+
186+
// ListTenantsFromToken calls issuer + "/userinfo" with the provided access token
187+
// and returns the list of tenants available to the user.
188+
func ListTenantsFromToken(ctx context.Context, accessToken string) ([]Tenant, error) {
189+
token := strings.TrimSpace(accessToken)
190+
if token == "" {
191+
return nil, ErrNoAccessToken
192+
}
193+
194+
iss, err := issuerFromJWT(token)
195+
if err != nil {
196+
return nil, err
197+
}
198+
199+
url := strings.TrimRight(iss, "/") + "/userinfo"
200+
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
201+
if err != nil {
202+
return nil, err
203+
}
204+
req.Header.Set("Authorization", "Bearer "+token)
205+
req.Header.Set("Accept", "application/json")
206+
207+
resp, err := http.DefaultClient.Do(req)
208+
if err != nil {
209+
return nil, fmt.Errorf("userinfo request failed: %w", err)
210+
}
211+
defer resp.Body.Close()
212+
if resp.StatusCode != http.StatusOK {
213+
return nil, fmt.Errorf("userinfo request failed: %s", resp.Status)
214+
}
215+
216+
var ui userinfoResponse
217+
if err := json.NewDecoder(resp.Body).Decode(&ui); err != nil {
218+
return nil, fmt.Errorf("failed to decode userinfo: %w", err)
219+
}
220+
221+
tenants := make([]Tenant, 0, len(ui.AllTenants))
222+
for _, t := range ui.AllTenants {
223+
tenants = append(tenants, Tenant{ID: t.ID, Name: t.Name})
224+
}
225+
return tenants, nil
226+
}

0 commit comments

Comments
 (0)