Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 29 additions & 17 deletions command/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@ import (
"errors"
"fmt"
"io"
"iter"
"maps"
"os"
"path/filepath"
"sort"
"slices"

"strings"
)
Expand Down Expand Up @@ -62,20 +64,14 @@ func generateDefaultAlias(name string) string {
return strings.ToLower(strings.ReplaceAll(name, " ", "-"))
}

func (a *accountSet) ForEach(f func(id string, account Account, alias string)) {
// Golang does not maintain the order of maps, so we create a slice which is sorted instead.
var accounts []*Account
for _, acc := range a.accounts {
accounts = append(accounts, acc)
}

sort.SliceStable(accounts, func(i, j int) bool {
return accounts[i].Name < accounts[j].Name
func (a *accountSet) Seq() iter.Seq[Account] {
return iter.Seq[Account](func(yield func(account Account) bool) {
for _, v := range a.accounts {
if !yield(*v) {
return
}
}
})

for _, acc := range accounts {
f(acc.ID, *acc, acc.Alias)
}
}

// Add adds an account to the set.
Expand Down Expand Up @@ -169,6 +165,18 @@ func (a *accountSet) ReplaceWith(other []Account) {
}
}

func (a accountSet) Sorted() iter.Seq[Account] {
keys := slices.Sorted(maps.Keys(a.accounts))
return iter.Seq[Account](func(yield func(Account) bool) {
for _, k := range keys {
v := a.accounts[k]
if !yield(*v) {
return
}
}
})
}

func (a accountSet) WriteTable(w io.Writer, withHeaders bool) {
tbl := csv.NewWriter(w)
tbl.Comma = '\t'
Expand All @@ -177,9 +185,9 @@ func (a accountSet) WriteTable(w io.Writer, withHeaders bool) {
tbl.Write([]string{"id", "name", "alias"})
}

a.ForEach(func(id string, acc Account, alias string) {
tbl.Write([]string{id, acc.Name, alias})
})
for account := range a.Sorted() {
tbl.Write([]string{account.ID, account.Name, account.Alias})
}

tbl.Flush()
}
Expand Down Expand Up @@ -218,6 +226,10 @@ func (c *Config) Decode(reader io.Reader) error {
return nil
}

func (c Config) EnumerateAccounts() iter.Seq[Account] {
return c.Accounts.Seq()
}

func (c *Config) AddAccount(id string, account Account) {
if c.Accounts == nil {
c.Accounts = &accountSet{accounts: make(map[string]*Account)}
Expand Down
149 changes: 118 additions & 31 deletions command/get.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,15 @@ import (
"encoding/json"
"errors"
"fmt"
"iter"
"os"
"slices"
"strings"
"time"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/sts"
"github.com/charmbracelet/huh"
"github.com/riotgames/key-conjurer/pkg/oauth2cli"
"github.com/spf13/cobra"
)
Expand All @@ -22,6 +25,11 @@ var (
FlagTimeToLive = "ttl"
FlagBypassCache = "bypass-cache"
FlagLogin = "login"
FlagNoInteractive = "no-interactive"

ErrNoRoles = errors.New("no roles")
ErrNoRole = errors.New("no role")
ErrNoAccountArg = errors.New("account name or alias is required")
)

var (
Expand All @@ -35,18 +43,20 @@ var (
)

func init() {
getCmd.Flags().String(FlagRegion, "us-west-2", "The AWS region to use")
getCmd.Flags().Uint(FlagTimeToLive, 1, "The key timeout in hours from 1 to 8.")
getCmd.Flags().UintP(FlagTimeRemaining, "t", DefaultTimeRemaining, "Request new keys if there are no keys in the environment or the current keys expire within <time-remaining> minutes. Defaults to 60.")
getCmd.Flags().StringP(FlagRoleName, "r", "", "The name of the role to assume.")
getCmd.Flags().String(FlagRoleSessionName, "KeyConjurer-AssumeRole", "the name of the role session name that will show up in CloudTrail logs")
getCmd.Flags().StringP(FlagOutputType, "o", outputTypeEnvironmentVariable, "Format to save new credentials in. Supported outputs: env, awscli, json")
getCmd.Flags().String(FlagShellType, shellTypeInfer, "If output type is env, determines which format to output credentials in - by default, the format is inferred based on the execution environment. WSL users may wish to overwrite this to `bash`")
getCmd.Flags().Bool(FlagBypassCache, false, "Do not check the cache for accounts and send the application ID as-is to Okta. This is useful if you have an ID you know is an Okta application ID and it is not stored in your local account cache.")
getCmd.Flags().Bool(FlagLogin, false, "Login to Okta before running the command")
getCmd.Flags().String(FlagAWSCLIPath, "~/.aws/", "Path for directory used by the aws CLI")
getCmd.Flags().BoolP(FlagURLOnly, "u", false, "Print only the URL to visit rather than a user-friendly message")
getCmd.Flags().BoolP(FlagNoBrowser, "b", false, "Do not open a browser window, printing the URL instead")
flags := getCmd.Flags()
flags.String(FlagRegion, "us-west-2", "The AWS region to use")
flags.Uint(FlagTimeToLive, 1, "The key timeout in hours from 1 to 8.")
flags.UintP(FlagTimeRemaining, "t", DefaultTimeRemaining, "Request new keys if there are no keys in the environment or the current keys expire within <time-remaining> minutes. Defaults to 60.")
flags.StringP(FlagRoleName, "r", "", "The name of the role to assume.")
flags.String(FlagRoleSessionName, "KeyConjurer-AssumeRole", "the name of the role session name that will show up in CloudTrail logs")
flags.StringP(FlagOutputType, "o", outputTypeEnvironmentVariable, "Format to save new credentials in. Supported outputs: env, awscli, json")
flags.String(FlagShellType, shellTypeInfer, "If output type is env, determines which format to output credentials in - by default, the format is inferred based on the execution environment. WSL users may wish to overwrite this to `bash`")
flags.Bool(FlagBypassCache, false, "Do not check the cache for accounts and send the application ID as-is to Okta. This is useful if you have an ID you know is an Okta application ID and it is not stored in your local account cache.")
flags.Bool(FlagLogin, false, "Login to Okta before running the command")
flags.String(FlagAWSCLIPath, "~/.aws/", "Path for directory used by the aws CLI")
flags.BoolP(FlagURLOnly, "u", false, "Print only the URL to visit rather than a user-friendly message")
flags.BoolP(FlagNoBrowser, "b", false, "Do not open a browser window, printing the URL instead")
flags.Bool(FlagNoInteractive, false, "Disable interactive prompts")
}

func resolveApplicationInfo(cfg *Config, bypassCache bool, nameOrID string) (*Account, bool) {
Expand All @@ -61,7 +71,7 @@ type GetCommand struct {
TimeToLive uint
TimeRemaining uint
OutputType, ShellType, RoleName, AWSCLIPath, OIDCDomain, ClientID, Region string
Login, URLOnly, NoBrowser, BypassCache, MachineOutput bool
Login, URLOnly, NoBrowser, BypassCache, MachineOutput, NoInteractive bool

UsageFunc func() error
PrintErrln func(...any)
Expand All @@ -84,11 +94,17 @@ func (g *GetCommand) Parse(cmd *cobra.Command, args []string) error {
g.Region, _ = flags.GetString(FlagRegion)
g.UsageFunc = cmd.Usage
g.PrintErrln = cmd.PrintErrln
g.NoInteractive, _ = flags.GetBool(FlagNoInteractive)
g.MachineOutput = ShouldUseMachineOutput(flags) || g.URLOnly
if len(args) == 0 {
return fmt.Errorf("account name or alias is required")

if len(args) > 0 {
g.AccountIDOrName = args[0]
} else if g.NoInteractive {
return ErrNoAccountArg
} else {
// We can resolve this at execution time with an interactive prompt.
g.AccountIDOrName = ""
}
g.AccountIDOrName = args[0]
return nil
}

Expand All @@ -111,6 +127,12 @@ func (g GetCommand) Execute(ctx context.Context, config *Config) error {
var accountID string
if g.AccountIDOrName != "" {
accountID = g.AccountIDOrName
} else if !g.NoInteractive {
acc, err := accountsInteractivePrompt(config.EnumerateAccounts(), nil)
if err != nil {
return err
}
accountID = acc.ID
} else if config.LastUsedAccount != nil {
// No account specified. Can we use the most recent one?
accountID = *config.LastUsedAccount
Expand All @@ -123,21 +145,13 @@ func (g GetCommand) Execute(ctx context.Context, config *Config) error {
return UnknownAccountError(g.AccountIDOrName, FlagBypassCache)
}

if g.RoleName == "" {
if account.MostRecentRole == "" {
g.PrintErrln("You must specify the --role flag with this command")
return nil
}
g.RoleName = account.MostRecentRole
}

if config.TimeRemaining != 0 && g.TimeRemaining == DefaultTimeRemaining {
g.TimeRemaining = config.TimeRemaining
}

credentials := LoadAWSCredentialsFromEnvironment()
if !credentials.ValidUntil(account, time.Duration(g.TimeRemaining)*time.Minute) {
newCredentials, err := g.fetchNewCredentials(ctx, *account, config)
newCredentials, err := g.fetchNewCredentials(ctx, account, config)
if errors.Is(err, ErrTokensExpiredOrAbsent) && g.Login {
loginCommand := LoginCommand{
OIDCDomain: g.OIDCDomain,
Expand All @@ -149,7 +163,17 @@ func (g GetCommand) Execute(ctx context.Context, config *Config) error {
if err != nil {
return err
}
newCredentials, err = g.fetchNewCredentials(ctx, *account, config)
newCredentials, err = g.fetchNewCredentials(ctx, account, config)
}

if errors.Is(err, ErrNoRoles) {
g.PrintErrln("You don't have access to any roles on this account.")
return nil
}

if errors.Is(err, ErrNoRole) {
g.PrintErrln("You must specify a role with --role or using the interactive prompt.")
return nil
}

if err != nil {
Expand All @@ -159,24 +183,44 @@ func (g GetCommand) Execute(ctx context.Context, config *Config) error {
credentials = *newCredentials
}

if account != nil {
account.MostRecentRole = g.RoleName
}

config.LastUsedAccount = &accountID
return echoCredentials(accountID, accountID, credentials, g.OutputType, g.ShellType, g.AWSCLIPath)
}

func (g GetCommand) fetchNewCredentials(ctx context.Context, account Account, cfg *Config) (*CloudCredentials, error) {
// fetchNewCredentials fetches new credentials for the given account.
//
// 'account' will have its MostRecentRole field updated to the role used if this call is successful.
func (g GetCommand) fetchNewCredentials(ctx context.Context, account *Account, cfg *Config) (*CloudCredentials, error) {
samlResponse, assertionStr, err := oauth2cli.DiscoverConfigAndExchangeTokenForAssertion(ctx, &keychainTokenSource{}, g.OIDCDomain, g.ClientID, account.ID)
if err != nil {
return nil, err
}

roles := listRoles(samlResponse)
if len(roles) == 0 {
return nil, ErrNoRoles
}

if g.RoleName == "" {
if g.NoInteractive {
if account.MostRecentRole == "" {
return nil, ErrNoRole
} else {
g.RoleName = account.MostRecentRole
}
} else {
g.RoleName, err = rolesInteractivePrompt(listRoles(samlResponse), account.MostRecentRole)
if err != nil {
return nil, ErrNoRole
}
}
}

pair, ok := findRoleInSAML(g.RoleName, samlResponse)
if !ok {
return nil, UnknownRoleError(g.RoleName, g.AccountIDOrName)
}
account.MostRecentRole = g.RoleName

if g.TimeToLive == 1 && cfg.TTL != 0 {
g.TimeToLive = cfg.TTL
Expand Down Expand Up @@ -251,3 +295,46 @@ func echoCredentials(id, name string, credentials CloudCredentials, outputType,
return fmt.Errorf("%s is an invalid output type", outputType)
}
}

func accountsInteractivePrompt(accounts iter.Seq[Account], selected *Account) (Account, error) {
var opts []huh.Option[Account]
for account := range accounts {
opts = append(opts, huh.Option[Account]{
Key: account.Alias,
Value: account,
})
}

slices.SortStableFunc(opts, func(a huh.Option[Account], b huh.Option[Account]) int {
return strings.Compare(a.Key, b.Key)
})

ctrl := huh.NewSelect[Account]().
Options(opts...).
Title("account").
Description("Choose an account using your arrow keys or by typing the account name and pressing return to confirm your selection.")

if selected != nil {
ctrl = ctrl.Value(selected)
}

err := huh.Run(ctrl)
if err != nil {
return Account{}, err
}
return ctrl.GetValue().(Account), nil
}

func rolesInteractivePrompt(roles []string, mostRecent string) (string, error) {
opts := huh.NewOptions(roles...)
ctrl := huh.NewSelect[string]().
Options(opts...).
Value(&mostRecent).
Description("Choose a role using your arrow keys and press the return key to confirm.")

err := huh.Run(ctrl)
if err != nil {
return "", err
}
return ctrl.GetValue().(string), nil
}
24 changes: 23 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ require (
github.com/aws/aws-sdk-go-v2/config v1.28.3
github.com/aws/aws-sdk-go-v2/service/sts v1.32.4
github.com/aws/smithy-go v1.22.0
github.com/charmbracelet/huh v0.6.0
github.com/coreos/go-oidc v2.2.1+incompatible
github.com/go-ini/ini v1.61.0
github.com/hashicorp/vault/api v1.15.0
Expand All @@ -24,6 +25,7 @@ require (

require (
al.essio.dev/pkg/shellescape v1.5.1 // indirect
github.com/atotto/clipboard v0.1.4 // indirect
github.com/aws/aws-sdk-go-v2/credentials v1.17.44 // indirect
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.19 // indirect
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.23 // indirect
Expand All @@ -33,9 +35,19 @@ require (
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.4 // indirect
github.com/aws/aws-sdk-go-v2/service/sso v1.24.5 // indirect
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.4 // indirect
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
github.com/catppuccin/go v0.2.0 // indirect
github.com/cenkalti/backoff/v4 v4.3.0 // indirect
github.com/charmbracelet/bubbles v0.20.0 // indirect
github.com/charmbracelet/bubbletea v1.1.0 // indirect
github.com/charmbracelet/lipgloss v0.13.0 // indirect
github.com/charmbracelet/x/ansi v0.2.3 // indirect
github.com/charmbracelet/x/exp/strings v0.0.0-20240722160745-212f7b056ed0 // indirect
github.com/charmbracelet/x/term v0.2.0 // indirect
github.com/danieljoos/wincred v1.2.2 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/dustin/go-humanize v1.0.1 // indirect
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect
github.com/go-jose/go-jose/v4 v4.0.1 // indirect
github.com/godbus/dbus/v5 v5.1.0 // indirect
github.com/hashicorp/errwrap v1.1.0 // indirect
Expand All @@ -51,17 +63,27 @@ require (
github.com/kardianos/osext v0.0.0-20190222173326-2bc1f35cddc0 // indirect
github.com/kelseyhightower/envconfig v1.4.0 // indirect
github.com/kr/text v0.2.0 // indirect
github.com/lucasb-eyer/go-colorful v1.2.0 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mattn/go-localereader v0.0.1 // indirect
github.com/mattn/go-runewidth v0.0.16 // indirect
github.com/mitchellh/hashstructure/v2 v2.0.2 // indirect
github.com/mitchellh/mapstructure v1.5.0 // indirect
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect
github.com/muesli/cancelreader v0.2.2 // indirect
github.com/muesli/termenv v0.15.3-0.20240618155329-98d742f6907a // indirect
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e // indirect
github.com/nu7hatch/gouuid v0.0.0-20131221200532-179d4d0c4d8d // indirect
github.com/patrickmn/go-cache v0.0.0-20180815053127-5633e0862627 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/pquerna/cachecontrol v0.2.0 // indirect
github.com/rivo/uniseg v0.4.7 // indirect
github.com/ryanuber/go-glob v1.0.0 // indirect
github.com/smartystreets/goconvey v0.0.0-20190330032615-68dc04aab96a // indirect
golang.org/x/crypto v0.23.0 // indirect
golang.org/x/sync v0.8.0 // indirect
golang.org/x/sys v0.26.0 // indirect
golang.org/x/text v0.15.0 // indirect
golang.org/x/text v0.18.0 // indirect
golang.org/x/time v0.0.0-20220922220347-f3bd1da661af // indirect
gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f // indirect
gopkg.in/ini.v1 v1.42.0 // indirect
Expand Down
Loading
Loading