Skip to content

Commit 68c6d9e

Browse files
authored
Merge pull request cli#13046 from cli/wm/gh-api-agent
Ensure `api` and `auth` commands record agentic invocations
2 parents b626711 + 2bf528c commit 68c6d9e

File tree

7 files changed

+94
-22
lines changed

7 files changed

+94
-22
lines changed

AGENTS.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,12 @@ go test -tags acceptance ./acceptance # Acceptance tests
1313
make lint # golangci-lint (same as CI)
1414
```
1515

16+
**Before committing, ensure both tests and linter pass:**
17+
```bash
18+
go test ./...
19+
make lint
20+
```
21+
1622
## Architecture
1723

1824
Entry point: `cmd/gh/main.go``internal/ghcmd.Main()``pkg/cmd/root.NewCmdRoot()`.

internal/authflow/flow.go

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ func AuthFlow(httpClient *http.Client, oauthHost string, IO *iostreams.IOStreams
9797
return "", "", err
9898
}
9999

100-
userLogin, err := getViewer(oauthHost, token.Token, IO.ErrOut)
100+
userLogin, err := getViewer(httpClient, oauthHost, token.Token)
101101
if err != nil {
102102
return "", "", err
103103
}
@@ -123,16 +123,10 @@ func (c cfg) ActiveToken(hostname string) (string, string) {
123123
return c.token, "oauth_token"
124124
}
125125

126-
func getViewer(hostname, token string, logWriter io.Writer) (string, error) {
127-
opts := api.HTTPClientOptions{
128-
Config: cfg{token: token},
129-
Log: logWriter,
130-
}
131-
client, err := api.NewHTTPClient(opts)
132-
if err != nil {
133-
return "", err
134-
}
135-
return api.CurrentLoginName(api.NewClientFromHTTP(client), hostname)
126+
func getViewer(httpClient *http.Client, hostname, token string) (string, error) {
127+
authedClient := *httpClient
128+
authedClient.Transport = api.AddAuthTokenHeader(httpClient.Transport, cfg{token: token})
129+
return api.CurrentLoginName(api.NewClientFromHTTP(&authedClient), hostname)
136130
}
137131

138132
func waitForEnter(r io.Reader) error {

internal/authflow/flow_test.go

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,48 @@
11
package authflow
22

33
import (
4+
"bytes"
5+
"io"
6+
"net/http"
47
"testing"
58

69
"github.com/stretchr/testify/assert"
10+
"github.com/stretchr/testify/require"
711
)
812

13+
func Test_getViewer_leavesUserAgent(t *testing.T) {
14+
var receivedUA string
15+
var receivedAuth string
16+
17+
plainClient := &http.Client{
18+
Transport: &roundTripper{roundTrip: func(req *http.Request) (*http.Response, error) {
19+
receivedUA = req.Header.Get("User-Agent")
20+
receivedAuth = req.Header.Get("Authorization")
21+
22+
return &http.Response{
23+
StatusCode: 200,
24+
Header: http.Header{"Content-Type": []string{"application/json"}},
25+
Body: io.NopCloser(bytes.NewBufferString(`{"data":{"viewer":{"login":"monalisa"}}}`)),
26+
Request: req,
27+
}, nil
28+
}},
29+
}
30+
31+
login, err := getViewer(plainClient, "github.com", "test-token")
32+
require.NoError(t, err)
33+
assert.Equal(t, "monalisa", login)
34+
assert.Empty(t, receivedUA, "User-Agent header should be left unset so that downstream transports can set it")
35+
assert.Equal(t, "token test-token", receivedAuth)
36+
}
37+
38+
type roundTripper struct {
39+
roundTrip func(*http.Request) (*http.Response, error)
40+
}
41+
42+
func (t *roundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
43+
return t.roundTrip(req)
44+
}
45+
946
func Test_getCallbackURI(t *testing.T) {
1047
tests := []struct {
1148
name string

pkg/cmd/api/api.go

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,13 @@ const (
3434
)
3535

3636
type ApiOptions struct {
37-
AppVersion string
38-
BaseRepo func() (ghrepo.Interface, error)
39-
Branch func() (string, error)
40-
Config func() (gh.Config, error)
41-
HttpClient func() (*http.Client, error)
42-
IO *iostreams.IOStreams
37+
AppVersion string
38+
InvokingAgent string
39+
BaseRepo func() (ghrepo.Interface, error)
40+
Branch func() (string, error)
41+
Config func() (gh.Config, error)
42+
HttpClient func() (*http.Client, error)
43+
IO *iostreams.IOStreams
4344

4445
Hostname string
4546
RequestMethod string
@@ -62,11 +63,12 @@ type ApiOptions struct {
6263

6364
func NewCmdApi(f *cmdutil.Factory, runF func(*ApiOptions) error) *cobra.Command {
6465
opts := ApiOptions{
65-
AppVersion: f.AppVersion,
66-
BaseRepo: f.BaseRepo,
67-
Branch: f.Branch,
68-
Config: f.Config,
69-
IO: f.IOStreams,
66+
AppVersion: f.AppVersion,
67+
InvokingAgent: f.InvokingAgent,
68+
BaseRepo: f.BaseRepo,
69+
Branch: f.Branch,
70+
Config: f.Config,
71+
IO: f.IOStreams,
7072
}
7173

7274
cmd := &cobra.Command{
@@ -385,6 +387,7 @@ func apiRun(opts *ApiOptions) error {
385387
}
386388
opts := api.HTTPClientOptions{
387389
AppVersion: opts.AppVersion,
390+
InvokingAgent: opts.InvokingAgent,
388391
CacheTTL: opts.CacheTTL,
389392
Config: cfg.Authentication(),
390393
EnableCache: opts.CacheTTL > 0,

pkg/cmd/api/api_test.go

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1388,6 +1388,36 @@ func Test_apiRun_cache(t *testing.T) {
13881388
assert.Equal(t, "", stderr.String(), "stderr")
13891389
}
13901390

1391+
func Test_apiRun_invokingAgent(t *testing.T) {
1392+
var receivedUA string
1393+
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1394+
receivedUA = r.Header.Get("User-Agent")
1395+
w.WriteHeader(http.StatusNoContent)
1396+
}))
1397+
t.Cleanup(s.Close)
1398+
1399+
ios, _, _, _ := iostreams.Test()
1400+
options := ApiOptions{
1401+
IO: ios,
1402+
AppVersion: "1.2.3",
1403+
InvokingAgent: "copilot-cli",
1404+
Config: func() (gh.Config, error) {
1405+
return &ghmock.ConfigMock{
1406+
AuthenticationFunc: func() gh.AuthConfig {
1407+
cfg := &config.AuthConfig{}
1408+
cfg.SetActiveToken("token", "stub")
1409+
return cfg
1410+
},
1411+
}, nil
1412+
},
1413+
RequestPath: s.URL,
1414+
}
1415+
1416+
require.NoError(t, apiRun(&options))
1417+
assert.Contains(t, receivedUA, "GitHub CLI 1.2.3")
1418+
assert.Contains(t, receivedUA, "Agent/copilot-cli")
1419+
}
1420+
13911421
func Test_openUserFile(t *testing.T) {
13921422
f, err := os.CreateTemp(t.TempDir(), "gh-test")
13931423
if err != nil {

pkg/cmd/factory/default.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ var ssoURLRE = regexp.MustCompile(`\burl=([^;]+)`)
2929
func New(appVersion string, invokingAgent string) *cmdutil.Factory {
3030
f := &cmdutil.Factory{
3131
AppVersion: appVersion,
32+
InvokingAgent: invokingAgent,
3233
Config: configFunc(), // No factory dependencies
3334
ExecutableName: "gh",
3435
}

pkg/cmdutil/factory.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import (
1919
type Factory struct {
2020
AppVersion string
2121
ExecutableName string
22+
InvokingAgent string
2223

2324
Browser browser.Browser
2425
ExtensionManager extensions.ExtensionManager

0 commit comments

Comments
 (0)