|
1 | 1 | package remoteoauth |
2 | 2 |
|
3 | 3 | import ( |
4 | | - "context" |
5 | | - "errors" |
6 | | - "fmt" |
7 | | - "net/http" |
8 | | - "os" |
9 | | - |
10 | | - cloudquery_api "github.com/cloudquery/cloudquery-api-go" |
11 | | - "github.com/google/uuid" |
12 | 4 | "golang.org/x/oauth2" |
13 | 5 | ) |
14 | 6 |
|
| 7 | +// NewTokenSource creates a new token source. |
| 8 | +// Deprecated: Use oauth2.StaticTokenSource directly instead. |
15 | 9 | func NewTokenSource(opts ...TokenSourceOption) (oauth2.TokenSource, error) { |
16 | 10 | t := &tokenSource{} |
17 | 11 | for _, opt := range opts { |
18 | 12 | opt(t) |
19 | 13 | } |
20 | | - |
21 | | - if _, cloudEnabled := os.LookupEnv("CQ_CLOUD"); !cloudEnabled { |
22 | | - return oauth2.StaticTokenSource(&t.currentToken), nil |
23 | | - } |
24 | | - |
25 | | - cloudToken, err := newCloudTokenSource(t.defaultContext) |
26 | | - if err != nil { |
27 | | - return nil, err |
28 | | - } |
29 | | - if t.noWrap { |
30 | | - return cloudToken, nil |
31 | | - } |
32 | | - |
33 | | - return oauth2.ReuseTokenSource(nil, cloudToken), nil |
| 14 | + return oauth2.StaticTokenSource(&t.currentToken), nil |
34 | 15 | } |
35 | 16 |
|
36 | 17 | type tokenSource struct { |
37 | | - defaultContext context.Context |
38 | | - currentToken oauth2.Token |
39 | | - noWrap bool |
40 | | -} |
41 | | - |
42 | | -type cloudTokenSource struct { |
43 | | - defaultContext context.Context |
44 | | - apiClient *cloudquery_api.ClientWithResponses |
45 | | - |
46 | | - apiURL string |
47 | | - apiToken string |
48 | | - teamName string |
49 | | - syncName string |
50 | | - testConnUUID uuid.UUID |
51 | | - syncRunUUID uuid.UUID |
52 | | - connectorUUID uuid.UUID |
53 | | - isTestConnection bool |
54 | | -} |
55 | | - |
56 | | -var _ oauth2.TokenSource = (*cloudTokenSource)(nil) |
57 | | - |
58 | | -func newCloudTokenSource(defaultContext context.Context) (oauth2.TokenSource, error) { |
59 | | - t := &cloudTokenSource{ |
60 | | - defaultContext: defaultContext, |
61 | | - } |
62 | | - if t.defaultContext == nil { |
63 | | - t.defaultContext = context.Background() |
64 | | - } |
65 | | - |
66 | | - err := t.initCloudOpts() |
67 | | - if err != nil { |
68 | | - return nil, err |
69 | | - } |
70 | | - |
71 | | - t.apiClient, err = cloudquery_api.NewClientWithResponses(t.apiURL, |
72 | | - cloudquery_api.WithRequestEditorFn(func(_ context.Context, req *http.Request) error { |
73 | | - req.Header.Set("Authorization", "Bearer "+t.apiToken) |
74 | | - return nil |
75 | | - })) |
76 | | - if err != nil { |
77 | | - return nil, fmt.Errorf("failed to create api client: %w", err) |
78 | | - } |
79 | | - |
80 | | - return t, nil |
81 | | -} |
82 | | - |
83 | | -// Token returns a new token from the remote source using the default context. |
84 | | -func (t *cloudTokenSource) Token() (*oauth2.Token, error) { |
85 | | - return t.retrieveToken(t.defaultContext) |
86 | | -} |
87 | | - |
88 | | -func (t *cloudTokenSource) retrieveToken(ctx context.Context) (*oauth2.Token, error) { |
89 | | - var oauthResp *cloudquery_api.ConnectorCredentialsResponseOAuth |
90 | | - if !t.isTestConnection { |
91 | | - resp, err := t.apiClient.GetSyncRunConnectorCredentialsWithResponse(ctx, t.teamName, t.syncName, t.syncRunUUID, t.connectorUUID) |
92 | | - if err != nil { |
93 | | - return nil, fmt.Errorf("failed to get sync run connector credentials: %w", err) |
94 | | - } |
95 | | - if resp.StatusCode() != http.StatusOK { |
96 | | - if resp.JSON422 != nil { |
97 | | - return nil, fmt.Errorf("failed to get sync run connector credentials: %s", resp.JSON422.Message) |
98 | | - } |
99 | | - return nil, fmt.Errorf("failed to get sync run connector credentials: %s", resp.Status()) |
100 | | - } |
101 | | - oauthResp = resp.JSON200.Oauth |
102 | | - } else { |
103 | | - resp, err := t.apiClient.GetTestConnectionConnectorCredentialsWithResponse(ctx, t.teamName, t.testConnUUID, t.connectorUUID) |
104 | | - if err != nil { |
105 | | - return nil, fmt.Errorf("failed to get test connection connector credentials: %w", err) |
106 | | - } |
107 | | - if resp.StatusCode() != http.StatusOK { |
108 | | - if resp.JSON422 != nil { |
109 | | - return nil, fmt.Errorf("failed to get test connection connector credentials: %s", resp.JSON422.Message) |
110 | | - } |
111 | | - return nil, fmt.Errorf("failed to get test connection connector credentials: %s", resp.Status()) |
112 | | - } |
113 | | - oauthResp = resp.JSON200.Oauth |
114 | | - } |
115 | | - |
116 | | - if oauthResp == nil { |
117 | | - return nil, errors.New("missing oauth credentials in response") |
118 | | - } |
119 | | - |
120 | | - tok := &oauth2.Token{ |
121 | | - AccessToken: oauthResp.AccessToken, |
122 | | - } |
123 | | - if oauthResp.Expires != nil { |
124 | | - tok.Expiry = *oauthResp.Expires |
125 | | - } |
126 | | - return tok, nil |
127 | | -} |
128 | | - |
129 | | -func (t *cloudTokenSource) initCloudOpts() error { |
130 | | - var allErr error |
131 | | - |
132 | | - t.apiToken = os.Getenv("CLOUDQUERY_API_KEY") |
133 | | - if t.apiToken == "" { |
134 | | - allErr = errors.Join(allErr, errors.New("CLOUDQUERY_API_KEY missing")) |
135 | | - } |
136 | | - t.apiURL = os.Getenv("CLOUDQUERY_API_URL") |
137 | | - if t.apiURL == "" { |
138 | | - t.apiURL = "https://api.cloudquery.io" |
139 | | - } |
140 | | - |
141 | | - t.teamName = os.Getenv("_CQ_TEAM_NAME") |
142 | | - if t.teamName == "" { |
143 | | - allErr = errors.Join(allErr, errors.New("_CQ_TEAM_NAME missing")) |
144 | | - } |
145 | | - t.syncName = os.Getenv("_CQ_SYNC_NAME") |
146 | | - syncRunID := os.Getenv("_CQ_SYNC_RUN_ID") |
147 | | - testConnID := os.Getenv("_CQ_SYNC_TEST_CONNECTION_ID") |
148 | | - if testConnID == "" && syncRunID == "" { |
149 | | - allErr = errors.Join(allErr, errors.New("_CQ_SYNC_TEST_CONNECTION_ID or _CQ_SYNC_RUN_ID missing")) |
150 | | - } else if testConnID != "" && syncRunID != "" { |
151 | | - allErr = errors.Join(allErr, errors.New("_CQ_SYNC_TEST_CONNECTION_ID and _CQ_SYNC_RUN_ID are mutually exclusive")) |
152 | | - } |
153 | | - |
154 | | - var err error |
155 | | - if syncRunID != "" { |
156 | | - if t.syncName == "" { |
157 | | - allErr = errors.Join(allErr, errors.New("_CQ_SYNC_NAME missing")) |
158 | | - } |
159 | | - |
160 | | - t.syncRunUUID, err = uuid.Parse(syncRunID) |
161 | | - if err != nil { |
162 | | - allErr = errors.Join(allErr, fmt.Errorf("_CQ_SYNC_RUN_ID is not a valid UUID: %w", err)) |
163 | | - } |
164 | | - } |
165 | | - if testConnID != "" { |
166 | | - if t.syncName != "" { |
167 | | - allErr = errors.Join(allErr, errors.New("_CQ_SYNC_NAME should be empty")) |
168 | | - } |
169 | | - |
170 | | - t.testConnUUID, err = uuid.Parse(testConnID) |
171 | | - if err != nil { |
172 | | - allErr = errors.Join(allErr, fmt.Errorf("_CQ_SYNC_TEST_CONNECTION_ID is not a valid UUID: %w", err)) |
173 | | - } |
174 | | - t.isTestConnection = true |
175 | | - } |
176 | | - |
177 | | - connectorID := os.Getenv("_CQ_CONNECTOR_ID") |
178 | | - if connectorID == "" { |
179 | | - allErr = errors.Join(allErr, errors.New("_CQ_CONNECTOR_ID missing")) |
180 | | - } else { |
181 | | - t.connectorUUID, err = uuid.Parse(connectorID) |
182 | | - if err != nil { |
183 | | - allErr = errors.Join(allErr, fmt.Errorf("_CQ_CONNECTOR_ID is not a valid UUID: %w", err)) |
184 | | - } |
185 | | - } |
186 | | - return allErr |
| 18 | + currentToken oauth2.Token |
187 | 19 | } |
0 commit comments