Skip to content

Commit e9aa155

Browse files
authored
PAPI: auto enable on upgrade (#3659)
1 parent b3a4d8e commit e9aa155

File tree

10 files changed

+173
-58
lines changed

10 files changed

+173
-58
lines changed

cmd/crowdsec-cli/clicapi/capi.go

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ package clicapi
22

33
import (
44
"context"
5-
"errors"
65
"fmt"
76
"io"
87
"net/url"
@@ -155,34 +154,36 @@ func (cli *cliCapi) newRegisterCmd() *cobra.Command {
155154
return cmd
156155
}
157156

157+
type capiStatus struct {
158+
authenticated bool
159+
enrolled bool
160+
subscriptionType string
161+
}
162+
158163
// queryCAPIStatus checks if the Central API is reachable, and if the credentials are correct. It then checks if the instance is enrolled in the console.
159-
func queryCAPIStatus(ctx context.Context, db *database.Client, hub *cwhub.Hub, credURL string, login string, password string) (bool, bool, error) {
164+
func queryCAPIStatus(ctx context.Context, db *database.Client, hub *cwhub.Hub, credURL string, login string, password string) (capiStatus, error) {
160165
apiURL, err := url.Parse(credURL)
161166
if err != nil {
162-
return false, false, err
167+
return capiStatus{}, err
163168
}
164169

165170
itemsForAPI := hub.GetInstalledListForAPI()
166171

167-
if len(itemsForAPI) == 0 {
168-
return false, false, errors.New("no scenarios or appsec-rules installed, abort")
169-
}
170-
171172
passwd := strfmt.Password(password)
172173

173174
client, err := apiclient.NewClient(&apiclient.Config{
174175
MachineID: login,
175176
Password: passwd,
176177
URL: apiURL,
177-
// I don't believe papi is neede to check enrollement
178+
// I don't believe papi is needed to check enrollement
178179
// PapiURL: papiURL,
179180
VersionPrefix: "v3",
180181
UpdateScenario: func(_ context.Context) ([]string, error) {
181182
return itemsForAPI, nil
182183
},
183184
})
184185
if err != nil {
185-
return false, false, err
186+
return capiStatus{}, err
186187
}
187188

188189
pw := strfmt.Password(password)
@@ -195,20 +196,20 @@ func queryCAPIStatus(ctx context.Context, db *database.Client, hub *cwhub.Hub, c
195196

196197
authResp, _, err := client.Auth.AuthenticateWatcher(ctx, t)
197198
if err != nil {
198-
return false, false, err
199+
return capiStatus{}, err
199200
}
200201

201202
if err := db.SaveAPICToken(ctx, apiclient.TokenDBField, authResp.Token); err != nil {
202-
return false, false, err
203+
return capiStatus{}, err
203204
}
204205

205206
client.GetClient().Transport.(*apiclient.JWTTransport).Token = authResp.Token
206207

207208
if client.IsEnrolled() {
208-
return true, true, nil
209+
return capiStatus{true, true, client.GetSubscriptionType()}, nil
209210
}
210211

211-
return true, false, nil
212+
return capiStatus{true, false, ""}, nil
212213
}
213214

214215
func (cli *cliCapi) Status(ctx context.Context, db *database.Client, out io.Writer, hub *cwhub.Hub) error {
@@ -223,17 +224,18 @@ func (cli *cliCapi) Status(ctx context.Context, db *database.Client, out io.Writ
223224
fmt.Fprintf(out, "Loaded credentials from %s\n", cfg.API.Server.OnlineClient.CredentialsFilePath)
224225
fmt.Fprintf(out, "Trying to authenticate with username %s on %s\n", cred.Login, cred.URL)
225226

226-
auth, enrolled, err := queryCAPIStatus(ctx, db, hub, cred.URL, cred.Login, cred.Password)
227+
status, err := queryCAPIStatus(ctx, db, hub, cred.URL, cred.Login, cred.Password)
227228
if err != nil {
228229
return fmt.Errorf("failed to authenticate to Central API (CAPI): %w", err)
229230
}
230231

231-
if auth {
232+
if status.authenticated {
232233
fmt.Fprint(out, "You can successfully interact with Central API (CAPI)\n")
233234
}
234235

235-
if enrolled {
236+
if status.enrolled {
236237
fmt.Fprint(out, "Your instance is enrolled in the console\n")
238+
fmt.Fprintf(out, "Subscription type: %s\n", status.subscriptionType)
237239
}
238240

239241
switch *cfg.API.Server.OnlineClient.Sharing {

pkg/apiclient/auth_jwt.go

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,15 +30,19 @@ type JWTTransport struct {
3030
RetryConfig *RetryConfig
3131
// Transport is the underlying HTTP transport to use when making requests.
3232
// It will default to http.DefaultTransport if nil.
33-
Transport http.RoundTripper
34-
UpdateScenario func(context.Context) ([]string, error)
33+
Transport http.RoundTripper
34+
UpdateScenario func(context.Context) ([]string, error)
35+
TokenRefreshChan chan struct{} // will write to this channel when the token is refreshed
36+
3537
refreshTokenMutex sync.Mutex
3638
TokenSave TokenSave
3739
}
3840

3941
func (t *JWTTransport) refreshJwtToken(ctx context.Context) error {
4042
var err error
4143

44+
log.Debugf("refreshing jwt token for '%s'", *t.MachineID)
45+
4246
if t.UpdateScenario != nil {
4347
t.Scenarios, err = t.UpdateScenario(ctx)
4448
if err != nil {
@@ -142,6 +146,12 @@ func (t *JWTTransport) refreshJwtToken(ctx context.Context) error {
142146

143147
log.Debugf("token %s will expire on %s", t.Token, t.Expiration.String())
144148

149+
select {
150+
case t.TokenRefreshChan <- struct{}{}:
151+
default:
152+
// Do not block if no one is waiting for the token refresh (ie, PAPI fully disabled)
153+
}
154+
145155
return nil
146156
}
147157

pkg/apiclient/client.go

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,29 @@ func (c *ApiClient) IsEnrolled() bool {
6969
return ok
7070
}
7171

72+
func (c *ApiClient) GetSubscriptionType() string {
73+
jwtTransport := c.client.Transport.(*JWTTransport)
74+
tokenStr := jwtTransport.Token
75+
76+
token, _ := jwt.Parse(tokenStr, nil)
77+
if token == nil {
78+
return ""
79+
}
80+
81+
claims := token.Claims.(jwt.MapClaims)
82+
subscriptionType, ok := claims["subscription_type"].(string)
83+
if ok {
84+
return subscriptionType
85+
}
86+
87+
return ""
88+
}
89+
90+
func (c *ApiClient) GetTokenRefreshChan() chan struct{} {
91+
jwtTransport := c.client.Transport.(*JWTTransport)
92+
return jwtTransport.TokenRefreshChan
93+
}
94+
7295
type service struct {
7396
client *ApiClient
7497
}
@@ -149,7 +172,8 @@ func NewClient(config *Config) (*ApiClient, error) {
149172
WithStatusCodeConfig(http.StatusServiceUnavailable, 5, true, false),
150173
WithStatusCodeConfig(http.StatusGatewayTimeout, 5, true, false),
151174
),
152-
TokenSave: config.TokenSave,
175+
TokenSave: config.TokenSave,
176+
TokenRefreshChan: make(chan struct{}),
153177
}
154178

155179
transport, baseURL := createTransport(config.URL)
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
package apiclient
2+
3+
const (
4+
SubscriptionTypeEnterprise = "ENTERPRISE"
5+
SubscriptionTypeSecOps = "SECOPS"
6+
SubscriptionTypeCommunity = "COMMUNITY"
7+
)

pkg/apiserver/apiserver.go

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -260,19 +260,17 @@ func NewServer(ctx context.Context, config *csconfig.LocalApiServerCfg) (*APISer
260260

261261
controller.AlertsAddChan = apiClient.AlertsAddChan
262262

263-
if config.ConsoleConfig.IsPAPIEnabled() && config.OnlineClient.Credentials.PapiURL != "" {
264-
if apiClient.apiClient.IsEnrolled() {
265-
log.Info("Machine is enrolled in the console, Loading PAPI Client")
263+
if apiClient.apiClient.IsEnrolled() {
264+
log.Info("Machine is enrolled in the console, Loading PAPI Client")
266265

267-
papiClient, err = NewPAPI(apiClient, dbClient, config.ConsoleConfig, *config.PapiLogLevel)
268-
if err != nil {
269-
return nil, err
270-
}
271-
272-
controller.DecisionDeleteChan = papiClient.Channels.DeleteDecisionChannel
273-
} else {
274-
log.Error("Machine is not enrolled in the console, can't synchronize with the console")
266+
papiClient, err = NewPAPI(apiClient, dbClient, config.ConsoleConfig, *config.PapiLogLevel)
267+
if err != nil {
268+
return nil, err
275269
}
270+
271+
controller.DecisionDeleteChan = papiClient.Channels.DeleteDecisionChannel
272+
} else {
273+
log.Error("Machine is not enrolled in the console, can't synchronize with the console")
276274
}
277275
}
278276

@@ -343,9 +341,8 @@ func (s *APIServer) initAPIC(ctx context.Context) {
343341
s.apic.pushTomb.Go(func() error { return s.apicPush(ctx) })
344342
s.apic.pullTomb.Go(func() error { return s.apicPull(ctx) })
345343

346-
// csConfig.API.Server.ConsoleConfig.ShareCustomScenarios
347344
if s.apic.apiClient.IsEnrolled() {
348-
if s.consoleConfig.IsPAPIEnabled() && s.papi != nil {
345+
if s.papi != nil {
349346
if s.papi.URL != "" {
350347
log.Info("Starting PAPI decision receiver")
351348
s.papi.pullTomb.Go(func() error { return s.papiPull(ctx) })

pkg/apiserver/papi.go

Lines changed: 54 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -258,32 +258,66 @@ func (p *Papi) Pull(ctx context.Context) error {
258258
}
259259
}
260260

261-
p.Logger.Infof("Starting PAPI pull (since:%s)", lastTimestamp)
261+
tokenRefreshChan := p.apiClient.GetTokenRefreshChan()
262+
var papiChan chan longpollclient.Event // Chan is nil by default to block until PAPI actually establishes the connection
263+
papiCtx, cancel := context.WithCancel(ctx)
262264

263-
for event := range p.Client.Start(ctx, lastTimestamp) {
264-
logger := p.Logger.WithField("request-id", event.RequestId)
265-
// update last timestamp in database
266-
newTime := time.Now().UTC()
265+
currentSubscriptionType := p.apiClient.GetSubscriptionType()
267266

268-
binTime, err := newTime.MarshalText()
269-
if err != nil {
270-
return fmt.Errorf("failed to serialize last timestamp: %w", err)
271-
}
267+
p.Logger.Debugf("current subscription type is %s", currentSubscriptionType)
272268

273-
err = p.handleEvent(event, false)
274-
if err != nil {
275-
logger.Errorf("failed to handle event: %s", err)
276-
continue
277-
}
269+
if currentSubscriptionType == apiclient.SubscriptionTypeEnterprise || currentSubscriptionType == apiclient.SubscriptionTypeSecOps {
270+
// If allowed to use PAPI, start it
271+
// Otherwise it will be started when the token is refreshed with an ent subscription
272+
p.Logger.Infof("Starting PAPI pull (since:%s)", lastTimestamp)
273+
papiChan = p.Client.Start(papiCtx, lastTimestamp)
274+
}
278275

279-
if err := p.DBClient.SetConfigItem(ctx, PapiPullKey, string(binTime)); err != nil {
280-
return fmt.Errorf("failed to update last timestamp: %w", err)
281-
}
276+
for {
277+
select {
278+
case <-tokenRefreshChan:
279+
subType := p.apiClient.GetSubscriptionType()
280+
if subType == currentSubscriptionType {
281+
continue
282+
}
283+
currentSubscriptionType = subType
284+
p.Logger.Infof("Subscription type changed to %s", subType)
285+
switch subType {
286+
case apiclient.SubscriptionTypeEnterprise, apiclient.SubscriptionTypeSecOps:
287+
p.Logger.Infof("Starting PAPI pull (since:%s)", lastTimestamp)
288+
papiChan = p.Client.Start(papiCtx, lastTimestamp)
289+
default:
290+
// PAPI got started but the user downgraded (or removed the engine from the console)
291+
p.Logger.Info("Stopping PAPI because of plan downgrade or engine removal")
292+
cancel() // This will stop any ongoing PAPI pull
293+
p.Client.Stop()
294+
papiCtx, cancel = context.WithCancel(ctx) //nolint:fatcontext // Recreate the context if the pull is restarted
295+
papiChan = nil
296+
p.Logger.Debug("done stopping PAPI pull")
297+
}
298+
case event := <-papiChan:
299+
logger := p.Logger.WithField("request-id", event.RequestId)
300+
// update last timestamp in database
301+
newTime := time.Now().UTC()
282302

283-
logger.Debugf("set last timestamp to %s", newTime)
284-
}
303+
binTime, _ := newTime.MarshalText() // No need to check the error, time.Now().UTC() always returns a valid time
285304

286-
return nil
305+
lastTimestamp = newTime
306+
307+
err = p.handleEvent(event, false)
308+
if err != nil {
309+
logger.Errorf("failed to handle event: %s", err)
310+
continue
311+
}
312+
313+
if err := p.DBClient.SetConfigItem(ctx, PapiPullKey, string(binTime)); err != nil {
314+
// Killing PAPI is overkill if we cannot update the last timestamp
315+
logger.Errorf("failed to update last timestamp in database: %s", err)
316+
}
317+
318+
logger.Debugf("set last timestamp to %s", newTime)
319+
}
320+
}
287321
}
288322

289323
func (p *Papi) SyncDecisions(ctx context.Context) error {

pkg/csconfig/api.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import (
2020
"github.com/crowdsecurity/go-cs-lib/yamlpatch"
2121

2222
"github.com/crowdsecurity/crowdsec/pkg/apiclient"
23+
"github.com/crowdsecurity/crowdsec/pkg/types"
2324
)
2425

2526
type APICfg struct {
@@ -123,6 +124,10 @@ func (o *OnlineApiClientCfg) Load() error {
123124
o.Credentials = nil
124125
}
125126

127+
if o.Credentials != nil && o.Credentials.PapiURL == "" {
128+
o.Credentials.PapiURL = types.PAPIBaseURL
129+
}
130+
126131
return nil
127132
}
128133

pkg/csconfig/api_test.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ import (
1313

1414
"github.com/crowdsecurity/go-cs-lib/cstest"
1515
"github.com/crowdsecurity/go-cs-lib/ptr"
16+
17+
"github.com/crowdsecurity/crowdsec/pkg/types"
1618
)
1719

1820
func TestLoadLocalApiClientCfg(t *testing.T) {
@@ -93,6 +95,7 @@ func TestLoadOnlineApiClientCfg(t *testing.T) {
9395
URL: "http://crowdsec.api",
9496
Login: "test",
9597
Password: "testpassword",
98+
PapiURL: types.PAPIBaseURL,
9699
},
97100
},
98101
{
@@ -211,6 +214,7 @@ func TestLoadAPIServer(t *testing.T) {
211214
URL: "http://crowdsec.api",
212215
Login: "test",
213216
Password: "testpassword",
217+
PapiURL: types.PAPIBaseURL,
214218
},
215219
Sharing: ptr.Of(true),
216220
PullConfig: CapiPullConfig{

0 commit comments

Comments
 (0)