diff --git a/.github/workflows/codecov.yaml b/.github/workflows/codecov.yaml index c833fad4..7765a469 100644 --- a/.github/workflows/codecov.yaml +++ b/.github/workflows/codecov.yaml @@ -11,8 +11,6 @@ on: jobs: run: runs-on: ubuntu-latest - env: - go-version: 'stable' steps: - name: Checkout code @@ -21,11 +19,11 @@ jobs: - name: Set up Go uses: actions/setup-go@0c52d547c9bc32b1aa3301fd7a9cb496313a4491 # v5.0.0 with: - go-version: ${{ env.go-version }} + go-version-file: go.mod env: GOPROXY: direct GONOSUMDB: "*" - GOPRIVATE: https://github.com/CheckmarxDev/ # Add your private organization url here + GOPRIVATE: https://github.com/CheckmarxDev/ - name: Install dependencies run: go install golang.org/x/tools/cmd/cover@latest diff --git a/.github/workflows/new-rules.yml b/.github/workflows/new-rules.yml index 61b5043a..34ee7188 100644 --- a/.github/workflows/new-rules.yml +++ b/.github/workflows/new-rules.yml @@ -12,6 +12,6 @@ jobs: - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # v4.1.1 - uses: actions/setup-go@0c52d547c9bc32b1aa3301fd7a9cb496313a4491 # v5.0.0 with: - go-version: "^1.22" + go-version-file: go.mod - name: Check Gitleaks new rules run: go run .ci/check_new_rules.go diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 20b596b9..245adcbb 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -59,7 +59,7 @@ jobs: - uses: actions/setup-go@0c52d547c9bc32b1aa3301fd7a9cb496313a4491 # v5.0.0 with: - go-version: "^1.22" + go-version-file: go.mod - name: Go Mod Tidy run: go mod tidy diff --git a/.github/workflows/validate-readme.yml b/.github/workflows/validate-readme.yml index dfd36871..5e3a86fe 100644 --- a/.github/workflows/validate-readme.yml +++ b/.github/workflows/validate-readme.yml @@ -14,7 +14,7 @@ jobs: - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # v4.1.1 - uses: actions/setup-go@0c52d547c9bc32b1aa3301fd7a9cb496313a4491 # v5.0.0 with: - go-version: "^1.22" + go-version-file: go.mod - name: update README run: ./.ci/update-readme.sh diff --git a/Dockerfile b/Dockerfile index 9865c245..1d9acc07 100644 --- a/Dockerfile +++ b/Dockerfile @@ -3,7 +3,7 @@ # and "Missing User Instruction" since 2ms container is stopped after scan # Builder image -FROM checkmarx/go:1.24.4-r0-ae7309142bb6bd@sha256:ae7309142bb6bd82e0272c3624ec53c0c68d855f6b63e985c5caaff5c1705644 AS builder +FROM checkmarx/go:1.25.3-r0-b47cbbc1194cd0@sha256:b47cbbc1194cd0d801fe7739fca12091d610117b0d30c32b52fc900217a0821a AS builder WORKDIR /app diff --git a/README.md b/README.md index caca8617..d3f7d684 100644 --- a/README.md +++ b/README.md @@ -176,7 +176,7 @@ Usage: 2ms [command] Scan Commands - confluence Scan Confluence server + confluence Scan Confluence Cloud discord Scan Discord server filesystem Scan local folder git Scan local Git repository @@ -274,30 +274,56 @@ This command is used to scan a [Confluence](https://www.atlassian.com/software/c 2ms confluence [flags] ``` -| Flag | Type | Default | Description | -| ------------ | ----- | ------------------------------ | -------------------------------------------------------------------------------- | -| `` | string | - | Confluence instance URL in the following format: `https://.atlassian.net/wiki` | -| `--history` | - | Doesn't scan history revisions | Scans pages history revisions | -| `--spaces` | string | all spaces | The names or IDs of the Confluence spaces to scan | -| `--token` | string | - | The Confluence API token for authentication | -| `--username` | string | - | Confluence user name or email for authentication | +| Flag | Type | Default | Description | +| --------------- | ----------- | ------- |----------------------------------------------------------------------------------| +| `--space-keys` | string list | (all) | Comma-separated list of space **keys** to scan. | +| `--space-ids` | string list | (all) | Comma-separated list of space **IDs** to scan. | +| `--page-ids` | string list | (all) | Comma-separated list of **page IDs** to scan. | +| `--history` | bool | `false` | Also scan **all versions** of each page (page history). | +| `--username` | string | | Confluence username/email (used for HTTP Basic Auth). | +| `--token-type` | string | | Token type for Confluence API. Accepted values: `api-token`, `scoped-api-token`. | +| `--token-value` | string | | The API token value. **Required** when `--token-type` is set. | -For example: +#### Authentication +- To scan **private spaces**, provide `--username`, `--token-type` and `--token-value` (API token). +- How to create a Confluence API token: https://support.atlassian.com/atlassian-account/docs/manage-api-tokens-for-your-atlassian-account/ + +#### Examples + +- Scan **all public pages** (no auth): + ```bash + 2ms confluence https://.atlassian.net/wiki + ``` -- To scan public spaces: +- Scan **private pages with an api token** (requires auth): + ```bash + 2ms confluence https://.atlassian.net/wiki --username --token-type api-token --token-value + ``` +- Scan **private pages with a scoped api token** (requires auth): ```bash - 2ms confluence https://checkmarx.atlassian.net/wiki --spaces secrets + 2ms confluence https://.atlassian.net/wiki --username --token-type scoped-api-token --token-value ``` - 💡 [The `secrets` Confluence site](https://checkmarx.atlassian.net/wiki/spaces/secrets) purposely created with plain example secrets as a test subject for this demo -- To scan private spaces, authentication is required +- Scan specific **spaces by key**: ```bash - 2ms confluence --username --token --spaces + 2ms confluence https://.atlassian.net/wiki --space-keys Key1,Key2 ``` - [How to get a Confluence API token](https://support.atlassian.com/atlassian-account/docs/manage-api-tokens-for-your-atlassian-account/). -[![asciicast](https://asciinema.org/a/607179.svg)](https://asciinema.org/a/607179) +- Scan specific **spaces by ID**: + ```bash + 2ms confluence https://.atlassian.net/wiki --space-ids 1234567890,9876543210 + ``` + +- Scan specific **pages by ID**: + ```bash + 2ms confluence https://.atlassian.net/wiki --page-ids 11223344556,99887766554 + ``` + +- Include **page history** (all revisions): + ```bash + 2ms confluence https://.atlassian.net/wiki --history + ``` ### Paligo diff --git a/cmd/main.go b/cmd/main.go index 9cef8ff8..70342515 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -49,7 +49,7 @@ var configFilePath string var vConfig = viper.New() var allPlugins = []plugins.IPlugin{ - &plugins.ConfluencePlugin{}, + plugins.NewConfluencePlugin(), &plugins.DiscordPlugin{}, &plugins.FileSystemPlugin{}, &plugins.SlackPlugin{}, @@ -112,9 +112,17 @@ func Execute() (int, error) { } subCommand.GroupID = group + pluginPreRun := subCommand.PreRunE // Capture plugin name for closure pluginName := plugin.GetName() subCommand.PreRunE = func(cmd *cobra.Command, args []string) error { + // run plugin's own PreRunE (if any) + if pluginPreRun != nil { + if err := pluginPreRun(cmd, args); err != nil { + return err + } + } + // run engine-level PreRunE return preRun(pluginName, engineInstance, cmd, args) } subCommand.PostRunE = func(cmd *cobra.Command, args []string) error { diff --git a/engine/engine.go b/engine/engine.go index 29bc48a0..62ba302c 100644 --- a/engine/engine.go +++ b/engine/engine.go @@ -342,7 +342,7 @@ func (e *Engine) detectSecrets( if buildErr != nil { return fmt.Errorf("failed to build secret: %w", buildErr) } - if !isSecretIgnored(secret, e.ignoredIds, e.allowedValues) { + if !isSecretIgnored(secret, e.ignoredIds, e.allowedValues, value.Line, value.Match, pluginName) { secrets <- secret } else { log.Debug().Msgf("Secret %s was ignored", secret.ID) @@ -575,13 +575,15 @@ func getStartAndEndLines( return startLine, endLine, nil } -func isSecretIgnored(secret *secrets.Secret, ignoredIds, allowedValues *[]string) bool { +func isSecretIgnored(secret *secrets.Secret, ignoredIds, allowedValues *[]string, secretLine, secretMatch, pluginName string) bool { for _, allowedValue := range *allowedValues { if secret.Value == allowedValue { return true } } - + if pluginName == "confluence" && isSecretFromConfluenceResourceIdentifier(secret.RuleID, secretLine, secretMatch) { + return true + } return slices.Contains(*ignoredIds, secret.ID) } @@ -740,3 +742,19 @@ func (e *Engine) Scan(pluginName string) { func (e *Engine) Wait() { e.wg.Wait() } + +// isSecretFromConfluenceResourceIdentifier reports whether a regex match found in a line +// actually belongs to Confluence Storage Format metadata (the `ri:` namespace) rather than +// real user content. This lets us ignore false-positives that cannot be suppressed via the +// generic-api-key rule allow-list. +func isSecretFromConfluenceResourceIdentifier(secretRuleID, secretLine, secretMatch string) bool { + if secretRuleID != rules.GenericApiKeyID || secretLine == "" || secretMatch == "" { + return false + } + + q := regexp.QuoteMeta(secretMatch) + + pat := `<[^>]*\sri:` + q + `[^>]*>` + re := regexp.MustCompile(pat) + return re.MatchString(secretLine) +} diff --git a/engine/engine_test.go b/engine/engine_test.go index ce7adf61..b4dfed6c 100644 --- a/engine/engine_test.go +++ b/engine/engine_test.go @@ -727,6 +727,120 @@ func TestGetFindingId(t *testing.T) { }) } +func TestIsSecretFromConfluenceResourceIdentifier(t *testing.T) { + tests := []struct { + name string + ruleID string + line string + match string + want bool + }{ + { + name: "matches ri:secret attribute with quoted value", + ruleID: rules.GenericApiKeyID, + line: ``, + match: `secret="12345"`, + want: true, + }, + { + name: "matches with extra whitespace and self-closing tag", + ruleID: rules.GenericApiKeyID, + line: ``, + match: `secret="12345"`, + want: true, + }, + { + name: "no match when value format differs (expects exact literal)", + ruleID: rules.GenericApiKeyID, + line: ``, + match: `secret=12345`, + want: false, + }, + { + name: "no match when value appears in a different attribute", + ruleID: rules.GenericApiKeyID, + line: ``, + match: `secret=12345`, + want: false, + }, + { + name: "no match when ri: prefixes the element name (not an attribute)", + ruleID: rules.GenericApiKeyID, + line: ``, + match: `secret`, + want: false, + }, + { + name: "no match when text is outside any tag", + ruleID: rules.GenericApiKeyID, + line: `ri:secret=12345`, + match: `secret=12345`, + want: false, + }, + { + name: "no match for xri: prefixed attribute", + ruleID: rules.GenericApiKeyID, + line: ``, + match: `secret="12345"`, + want: false, + }, + { + name: "no match when rule ID is not generic-api-key does not apply", + ruleID: "some-other-rule", + line: ``, + match: `secret="12345"`, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := isSecretFromConfluenceResourceIdentifier(tt.ruleID, tt.line, tt.match) + assert.Equal(t, tt.want, got, "ruleID=%q, line=%q, match=%q", tt.ruleID, tt.line, tt.match) + }) + } +} + +// if any of these tests fails, we should review isSecretFromConfluenceResourceIdentifier and/or generic-api-key rule +func TestDetectWithConfluenceMetadata(t *testing.T) { + secretsCases := []struct { + Content string + Name string + ShouldFind bool + }{ + { + Content: "", + Name: "should not detect from confluence userkey metadata", + ShouldFind: false, + }, + } + + detector, err := Init(&EngineConfig{}) + if err != nil { + t.Fatal(err) + } + + for _, secret := range secretsCases { + t.Run(secret.Name, func(t *testing.T) { + secretsChan := make(chan *secrets.Secret, 1) + c := plugins.ConfluencePlugin{} + err = detector.DetectFragment(item{content: &secret.Content}, secretsChan, c.GetName()) + if err != nil { + return + } + close(secretsChan) + + s := <-secretsChan + + if secret.ShouldFind { + assert.Equal(t, s.LineContent, secret.Content) + } else { + assert.Nil(t, s) + } + }) + } +} + type item struct { content *string id string diff --git a/engine/rules/generic-key.go b/engine/rules/generic-key.go index 640cb00f..5b84ce16 100644 --- a/engine/rules/generic-key.go +++ b/engine/rules/generic-key.go @@ -7,6 +7,8 @@ import ( "github.com/zricethezav/gitleaks/v8/config" ) +const GenericApiKeyID = "generic-api-key" + func GenericCredential() *config.Rule { regex := generateSemiGenericRegexIncludingXml([]string{ "access", @@ -21,7 +23,7 @@ func GenericCredential() *config.Rule { }, `[\w.=-]{10,150}|[a-z0-9][a-z0-9+/]{11,}={0,3}`, true) return &config.Rule{ - RuleID: "generic-api-key", + RuleID: GenericApiKeyID, Description: "Detected a Generic API Key, potentially exposing access to various services and sensitive operations.", Regex: regex, Keywords: []string{ diff --git a/plugins/confluence.go b/plugins/confluence.go index d64768fd..1c982e21 100644 --- a/plugins/confluence.go +++ b/plugins/confluence.go @@ -1,357 +1,429 @@ package plugins import ( - "encoding/json" + "bufio" + "context" + "errors" "fmt" - "net/http" + "io" + "net/url" + "strconv" "strings" - "sync" - "github.com/checkmarx/2ms/v4/lib/utils" + "github.com/checkmarx/2ms/v4/engine/chunk" "github.com/rs/zerolog/log" "github.com/spf13/cobra" +) - "net/url" +var ( + ErrHTTPSRequired = errors.New("must use https") ) +// CLI flags for Confluence. const ( - argUrl = "url" - argSpaces = "spaces" - argUsername = "username" - argToken = "token" - argHistory = "history" - confluenceDefaultWindow = 25 - confluenceMaxRequests = 500 + flagSpaceIDs = "space-ids" + flagSpaceKeys = "space-keys" + flagPageIDs = "page-ids" + flagUsername = "username" + flagTokenType = "token-type" // "api-token" or "scoped-api-token" + flagTokenValue = "token-value" // required when token-type is set + flagHistory = "history" ) -var ( - username string - token string +// Confluence Cloud REST API v2 per-request limits (server caps by endpoint/param). +const ( + // maxPageIDsPerRequest is the per-request server cap for the number of page IDs + // accepted by GET /pages via the ids= query parameter. + maxPageIDsPerRequest = 250 + + // maxSpaceIDsPerRequest is the per-request server cap for the number of space IDs + // accepted by GET /pages via the space-id= query parameter. + maxSpaceIDsPerRequest = 100 + + // maxSpaceKeysPerRequest is the per-request server cap for the number of space keys + // accepted by GET /spaces via the keys= query parameter. + maxSpaceKeysPerRequest = 250 + + // maxPageSize is the requested number of items per page in paginated responses. + // Confluence v2 accepts 1–250; we use 250 to minimize requests so we're less likely to hit rate limits. + maxPageSize = 250 +) + +type TokenType string + +const ( + ApiToken TokenType = "api-token" + ScopedApiToken TokenType = "scoped-api-token" //nolint:gosec // constant label, not a credential ) type ConfluencePlugin struct { Plugin - Spaces []string - History bool - client IConfluenceClient + + SpaceIDs []string + SpaceKeys []string + PageIDs []string + History bool itemsChan chan ISourceItem errorsChan chan error -} -func (p *ConfluencePlugin) GetName() string { - return "confluence" + client ConfluenceClient + chunker chunk.IChunk } -func isValidURL(cmd *cobra.Command, args []string) error { - urlStr := args[0] - parsedURL, err := url.Parse(urlStr) - if err != nil && parsedURL.Scheme != "https" { - return fmt.Errorf("invalid URL format") +func NewConfluencePlugin() IPlugin { + return &ConfluencePlugin{ + chunker: chunk.New(), } - return nil } -func (p *ConfluencePlugin) DefineCommand(items chan ISourceItem, errors chan error) (*cobra.Command, error) { +func (p *ConfluencePlugin) GetName() string { return "confluence" } + +func (p *ConfluencePlugin) DefineCommand(items chan ISourceItem, errs chan error) (*cobra.Command, error) { p.itemsChan = items - p.errorsChan = errors + p.errorsChan = errs + + var username string + var tokenType TokenType + var tokenValue string - var confluenceCmd = &cobra.Command{ + cmd := &cobra.Command{ Use: fmt.Sprintf("%s ", p.GetName()), - Short: "Scan Confluence server", - Long: "Scan Confluence server for sensitive information", + Short: "Scan Confluence Cloud", + Long: "Scan Confluence Cloud for sensitive information", Example: fmt.Sprintf(" 2ms %s https://checkmarx.atlassian.net/wiki", p.GetName()), Args: cobra.MatchAll(cobra.ExactArgs(1), isValidURL), + PreRunE: func(cmd *cobra.Command, args []string) error { + tokenType = TokenType(strings.ToLower(string(tokenType))) + if tokenValue != "" && tokenType == "" { + return fmt.Errorf("--%s must be set when --%s is provided", flagTokenType, flagTokenValue) + } + if !isValidTokenType(tokenType) { + return fmt.Errorf("invalid --%s %q; valid values are %q or %q", + flagTokenType, tokenType, ApiToken, ScopedApiToken) + } + if tokenType != "" && tokenValue == "" { + return fmt.Errorf("--%s requires --%s", flagTokenType, flagTokenValue) + } + if err := p.initialize(args[0], username, tokenType, tokenValue); err != nil { + return err + } + if username == "" || tokenValue == "" { + log.Warn().Msg("Confluence credentials not provided. The scan will run anonymously (public pages only).") + } + return nil + }, Run: func(cmd *cobra.Command, args []string) { - p.initialize(args[0]) - wg := &sync.WaitGroup{} - p.scanConfluence(wg) - wg.Wait() + log.Info().Msg("Confluence plugin started") + if err := p.walkAndEmitPages(context.Background()); err != nil { + p.errorsChan <- err + return + } close(items) }, } - flags := confluenceCmd.Flags() - flags.StringSliceVar(&p.Spaces, argSpaces, []string{}, "Confluence spaces: The names or IDs of the spaces to scan") - flags.StringVar(&username, argUsername, "", "Confluence user name or email for authentication") - flags.StringVar(&token, argToken, "", "The Confluence API token for authentication") - flags.BoolVar(&p.History, argHistory, false, "Scan pages history") + flags := cmd.Flags() + flags.StringSliceVar(&p.SpaceIDs, flagSpaceIDs, []string{}, "Comma-separated list of Confluence space IDs to scan.") + flags.StringSliceVar(&p.SpaceKeys, flagSpaceKeys, []string{}, "Comma-separated list of Confluence space keys to scan.") + flags.StringSliceVar(&p.PageIDs, flagPageIDs, []string{}, "Comma-separated list of Confluence page IDs to scan.") + flags.StringVar(&username, flagUsername, "", "Confluence user name or email for authentication.") + flags.StringVar((*string)(&tokenType), flagTokenType, "", `Token type: "api-token" or "scoped-api-token".`) + flags.StringVar(&tokenValue, flagTokenValue, "", "Token value.") + flags.BoolVar(&p.History, flagHistory, false, "Also scan all page revisions (all versions).") - return confluenceCmd, nil + return cmd, nil } -func (p *ConfluencePlugin) initialize(urlArg string) { - url := strings.TrimRight(urlArg, "/") - - if username == "" || token == "" { - log.Warn().Msg("confluence credentials were not provided. The scan will be made anonymously only for the public pages") - } - p.client = newConfluenceClient(url, token, username) - - p.Limit = make(chan struct{}, confluenceMaxRequests) -} - -func (p *ConfluencePlugin) scanConfluence(wg *sync.WaitGroup) { - spaces, err := p.getSpaces() +// isValidURL validates the single CLI argument as an HTTPS URL. +func isValidURL(_ *cobra.Command, args []string) error { + inputURL := strings.TrimSpace(args[0]) + parsedURL, err := url.Parse(inputURL) if err != nil { - p.errorsChan <- err + return fmt.Errorf("invalid URL: %w", err) + } + if parsedURL.Scheme != "https" { + return fmt.Errorf("invalid URL: %w", ErrHTTPSRequired) } + return nil +} - for _, space := range spaces { - wg.Add(1) - go p.scanConfluenceSpace(wg, space) +// isValidTokenType reports whether the provided tokenType is supported. +// Valid values are the empty string (no auth), "api-token", and "scoped-api-token". +func isValidTokenType(tokenType TokenType) bool { + switch tokenType { + case "", ApiToken, ScopedApiToken: + return true + default: + return false } } -func (p *ConfluencePlugin) scanConfluenceSpace(wg *sync.WaitGroup, space ConfluenceSpaceResult) { - defer wg.Done() +// initialize stores the base wiki URL and constructs the Confluence client. +func (p *ConfluencePlugin) initialize(base, username string, tokenType TokenType, tokenValue string) error { + baseWikiURL := strings.TrimRight(base, "/") - pages, err := p.getPages(space) + client, err := NewConfluenceClient(baseWikiURL, username, tokenType, tokenValue) if err != nil { - p.errorsChan <- err - return + return err } + p.client = client - for _, page := range pages.Pages { - wg.Add(1) - p.Limit <- struct{}{} - go func(page ConfluencePage) { - p.scanPageAllVersions(wg, page, space) - <-p.Limit - }(page) - } + return nil } -func (p *ConfluencePlugin) scanPageAllVersions(wg *sync.WaitGroup, page ConfluencePage, space ConfluenceSpaceResult) { - defer wg.Done() - - previousVersion := p.scanPageVersion(page, space, 0) - if !p.History { - return - } +// walkAndEmitPages discovers pages by the provided selectors (space IDs, space keys, page IDs). +// If no selector is provided, it walks all accessible pages. Pages are de-duplicated by ID. +func (p *ConfluencePlugin) walkAndEmitPages(ctx context.Context) error { + seenPageIDs := make(map[string]struct{}, len(p.PageIDs)) + seenSpaceIDs := make(map[string]struct{}, len(p.SpaceIDs)) - for previousVersion > 0 { - previousVersion = p.scanPageVersion(page, space, previousVersion) + if len(p.SpaceIDs) > 0 { + if err := p.scanBySpaceIDs(ctx, seenPageIDs, seenSpaceIDs); err != nil { + return err + } } -} -func (p *ConfluencePlugin) scanPageVersion(page ConfluencePage, space ConfluenceSpaceResult, version int) int { - pageContent, err := p.client.getPageContentRequest(page, version) - if err != nil { - p.errorsChan <- err - return 0 + if len(p.SpaceKeys) > 0 { + if err := p.scanBySpaceKeys(ctx, seenPageIDs, seenSpaceIDs); err != nil { + return err + } } - itemID := fmt.Sprintf("%s-%s-%s", p.GetName(), space.Key, page.ID) - p.itemsChan <- convertPageToItem(pageContent, itemID) - - return pageContent.History.PreviousVersion.Number -} -func convertPageToItem(pageContent *ConfluencePageContent, itemID string) ISourceItem { - return &item{ - Content: &pageContent.Body.Storage.Value, - ID: itemID, - Source: pageContent.Links["base"] + pageContent.Links["webui"], + if len(p.PageIDs) > 0 { + if err := p.scanByPageIDs(ctx, seenPageIDs); err != nil { + return err + } } -} -func (p *ConfluencePlugin) getSpaces() ([]ConfluenceSpaceResult, error) { - totalSpaces, err := p.client.getSpacesRequest(0) - if err != nil { - return nil, err + if len(p.SpaceIDs) == 0 && len(p.SpaceKeys) == 0 && len(p.PageIDs) == 0 { + if err := p.client.WalkAllPages(ctx, maxPageSize, func(page *Page) error { + return p.emitUniquePage(ctx, page, seenPageIDs) + }); err != nil { + return err + } } - actualSize := totalSpaces.Size + return nil +} - for actualSize == confluenceDefaultWindow { - moreSpaces, err := p.client.getSpacesRequest(totalSpaces.Size) - if err != nil { - return nil, err +// scanBySpaceIDs walks pages in the explicitly provided space IDs (p.SpaceIDs). +// It deduplicates space IDs with seenSpaceIDs, batches requests (maxSpaceIDsPerRequest), +// and emits pages via emitUniquePage while tracking seenPageIDs. +func (p *ConfluencePlugin) scanBySpaceIDs(ctx context.Context, seenPageIDs, seenSpaceIDs map[string]struct{}) error { + var uniqueSpaceIDs []string + for _, spaceID := range p.SpaceIDs { + if _, alreadySeen := seenSpaceIDs[spaceID]; alreadySeen { + continue } - - totalSpaces.Results = append(totalSpaces.Results, moreSpaces.Results...) - totalSpaces.Size += moreSpaces.Size - actualSize = moreSpaces.Size + seenSpaceIDs[spaceID] = struct{}{} + uniqueSpaceIDs = append(uniqueSpaceIDs, spaceID) } - if len(p.Spaces) == 0 { - log.Info().Msgf(" Total of all %d Spaces detected", len(totalSpaces.Results)) - return totalSpaces.Results, nil - } + return p.walkPagesByIDBatches( + ctx, + uniqueSpaceIDs, + maxSpaceIDsPerRequest, + seenPageIDs, + p.client.WalkPagesBySpaceIDs, + ) +} - filteredSpaces := make([]ConfluenceSpaceResult, 0) - if len(p.Spaces) > 0 { - for _, space := range totalSpaces.Results { - for _, spaceToScan := range p.Spaces { - if space.Key == spaceToScan || space.Name == spaceToScan || fmt.Sprintf("%d", space.ID) == spaceToScan { - filteredSpaces = append(filteredSpaces, space) - } +// scanBySpaceKeys resolves space keys (p.SpaceKeys) to space IDs, deduplicates with +// seenSpaceIDs, then walks pages by those IDs in batches. Each page is emitted via +// emitUniquePage, updating seenPageIDs. +func (p *ConfluencePlugin) scanBySpaceKeys(ctx context.Context, seenPageIDs, seenSpaceIDs map[string]struct{}) error { + for _, spaceKeyBatch := range chunkStrings(p.SpaceKeys, maxSpaceKeysPerRequest) { + var newlyResolvedSpaceIDs []string + if err := p.client.WalkSpacesByKeys(ctx, spaceKeyBatch, maxPageSize, func(space *Space) error { + if _, alreadySeen := seenSpaceIDs[space.ID]; alreadySeen { + return nil } + seenSpaceIDs[space.ID] = struct{}{} + newlyResolvedSpaceIDs = append(newlyResolvedSpaceIDs, space.ID) + return nil + }); err != nil { + return err + } + + if err := p.walkPagesByIDBatches( + ctx, + newlyResolvedSpaceIDs, + maxSpaceIDsPerRequest, + seenPageIDs, + p.client.WalkPagesBySpaceIDs, + ); err != nil { + return err } } + return nil +} - log.Info().Msgf(" Total of filtered %d Spaces detected", len(filteredSpaces)) - return filteredSpaces, nil +// scanByPageIDs walks the specific page IDs in p.PageIDs, batching requests (maxPageIDsPerRequest), +// and emits each page via emitUniquePage while tracking seenPageIDs to avoid duplicates. +func (p *ConfluencePlugin) scanByPageIDs(ctx context.Context, seenPageIDs map[string]struct{}) error { + return p.walkPagesByIDBatches( + ctx, + p.PageIDs, + maxPageIDsPerRequest, + seenPageIDs, + p.client.WalkPagesByIDs, + ) } -func (p *ConfluencePlugin) getPages(space ConfluenceSpaceResult) (*ConfluencePageResult, error) { - totalPages, err := p.client.getPagesRequest(space, 0) +// emitInChunks emits page content as one or many items. +func (p *ConfluencePlugin) emitInChunks(page *Page) error { + if page.Body.Storage == nil { + return nil + } - if err != nil { - return nil, fmt.Errorf("unexpected error creating an http request %w", err) + if len(page.Body.Storage.Value) < int(p.chunker.GetFileThreshold()) { + p.itemsChan <- p.convertPageToItem(page) + return nil } - actualSize := len(totalPages.Pages) + reader := bufio.NewReaderSize( + strings.NewReader(page.Body.Storage.Value), p.chunker.GetSize()+p.chunker.GetMaxPeekSize(), + ) - for actualSize == confluenceDefaultWindow { - morePages, err := p.client.getPagesRequest(space, len(totalPages.Pages)) + // We don't care about line-count logic here + totalLines := -1 + for { + chunkStr, err := p.chunker.ReadChunk(reader, totalLines) if err != nil { - return nil, fmt.Errorf("unexpected error creating an http request %w", err) + if err == io.EOF { + return nil + } + return fmt.Errorf("failed to read chunk for page %s: %w", page.ID, err) } + tmp := *page + tmp.Body.Storage = &struct { + Value string `json:"value"` + }{Value: chunkStr} - totalPages.Pages = append(totalPages.Pages, morePages.Pages...) - actualSize = len(morePages.Pages) + p.itemsChan <- p.convertPageToItem(&tmp) } - - log.Info().Msgf(" Space - %s have %d pages", space.Name, len(totalPages.Pages)) - - return totalPages, nil } -/* - * Confluence client - */ - -type IConfluenceClient interface { - getSpacesRequest(start int) (*ConfluenceSpaceResponse, error) - getPagesRequest(space ConfluenceSpaceResult, start int) (*ConfluencePageResult, error) - getPageContentRequest(page ConfluencePage, version int) (*ConfluencePageContent, error) -} +// emitUniquePage emits the current version of a page (and, if enabled, its historical versions) +// ensuring each page ID is emitted only once. +func (p *ConfluencePlugin) emitUniquePage(ctx context.Context, page *Page, seenPageIDs map[string]struct{}) error { + if _, alreadySeen := seenPageIDs[page.ID]; alreadySeen { + return nil + } + seenPageIDs[page.ID] = struct{}{} -type confluenceClient struct { - baseURL string - token string - username string -} + // current version + if err := p.emitInChunks(page); err != nil { + return err + } -func newConfluenceClient(baseURL, token, username string) IConfluenceClient { - return &confluenceClient{ - baseURL: baseURL, - token: token, - username: username, + if p.History { + if err := p.emitHistory(ctx, page); err != nil { + return err + } } + return nil } -func (c *confluenceClient) GetCredentials() (string, string) { - return c.username, c.token +// emitHistory enumerates all versions of a page and emits each version +// except the current one (which is already emitted by emitUniquePage). +func (p *ConfluencePlugin) emitHistory(ctx context.Context, page *Page) error { + current := page.Version.Number + return p.client.WalkPageVersions(ctx, page.ID, maxPageSize, func(versionNumber int) error { + if versionNumber == current { + return nil // already emitted current version + } + versionedPage, err := p.client.FetchPageAtVersion(ctx, page.ID, versionNumber) + if err != nil { + return err + } + return p.emitInChunks(versionedPage) + }) } -func (c *confluenceClient) GetAuthorizationHeader() string { - if c.username == "" || c.token == "" { - return "" +// chunkStrings splits a slice into chunks of at most chunkSize elements. +func chunkStrings(input []string, chunkSize int) [][]string { + if chunkSize <= 0 || len(input) == 0 { + return nil + } + var chunks [][]string + for startIndex := 0; startIndex < len(input); startIndex += chunkSize { + endIndex := min(startIndex+chunkSize, len(input)) + chunks = append(chunks, input[startIndex:endIndex]) } - return utils.CreateBasicAuthCredentials(c) + return chunks } -func (c *confluenceClient) getSpacesRequest(start int) (*ConfluenceSpaceResponse, error) { - url := fmt.Sprintf("%s/rest/api/space?start=%d", c.baseURL, start) - data, resp, err := utils.HttpRequest(http.MethodGet, url, c, utils.RetrySettings{}) - if err != nil { - return nil, fmt.Errorf("unexpected error creating an http request %w", err) - } - defer resp.Body.Close() +// convertPageToItem converts a Confluence Page into an ISourceItem +func (p *ConfluencePlugin) convertPageToItem(page *Page) ISourceItem { + itemID := fmt.Sprintf("%s-%s-%s", p.GetName(), page.ID, strconv.Itoa(page.Version.Number)) - response := &ConfluenceSpaceResponse{} - jsonErr := json.Unmarshal(data, response) - if jsonErr != nil { - return nil, fmt.Errorf("could not unmarshal response %w", err) + sourceURL := "" + if resolvedURL, ok := p.resolveConfluenceSourceURL(page, page.Version.Number); ok { + sourceURL = resolvedURL } - return response, nil -} - -func (c *confluenceClient) getPagesRequest(space ConfluenceSpaceResult, start int) (*ConfluencePageResult, error) { - url := fmt.Sprintf("%s/rest/api/space/%s/content?start=%d", c.baseURL, space.Key, start) - data, resp, err := utils.HttpRequest(http.MethodGet, url, c, utils.RetrySettings{}) - if err != nil { - return nil, fmt.Errorf("unexpected error creating an http request %w", err) + var content *string + if page.Body.Storage != nil { + content = &page.Body.Storage.Value } - defer resp.Body.Close() - response := ConfluencePageResponse{} - jsonErr := json.Unmarshal(data, &response) - if jsonErr != nil { - return nil, fmt.Errorf("could not unmarshal response %w", err) + return &item{ + ID: itemID, + Source: sourceURL, + Content: content, } - - return &response.Results, nil } -func (c *confluenceClient) getPageContentRequest(page ConfluencePage, version int) (*ConfluencePageContent, error) { - var url string - - // If no version given get the latest, else get the specified version - if version == 0 { - url = fmt.Sprintf("%s/rest/api/content/%s?expand=body.storage,version,history.previousVersion", c.baseURL, page.ID) - } else { - url = fmt.Sprintf("%s/rest/api/content/%s?status=historical&version=%d&expand=body.storage,version,history.previousVersion", - c.baseURL, page.ID, version) +// resolveConfluenceSourceURL resolves a URL for a page. +// It prefers the "_links.webui" path and appends pageVersion. +// Falls back to "_links.base" when webui is unavailable. +func (p *ConfluencePlugin) resolveConfluenceSourceURL(page *Page, versionNumber int) (string, bool) { + if page.Links == nil { + return "", false } - request, resp, err := utils.HttpRequest(http.MethodGet, url, c, utils.RetrySettings{MaxRetries: 3, ErrorCodes: []int{500}}) - if err != nil { - return nil, fmt.Errorf("unexpected error creating an http request %w", err) - } - defer resp.Body.Close() - pageContent := ConfluencePageContent{} - jsonErr := json.Unmarshal(request, &pageContent) - if jsonErr != nil { - return nil, jsonErr + // Prefer "webui" + if webUIPath, ok := page.Links["webui"]; ok && webUIPath != "" { + baseURL, err := url.Parse(strings.TrimRight(p.client.WikiBaseURL(), "/") + "/") // e.g., https://tenant.atlassian.net/wiki/ + if err != nil { + return "", false + } + relativeURL, err := url.Parse(strings.TrimPrefix(webUIPath, "/")) // "pages/viewpage.action?..." + if err != nil { + return "", false + } + resolvedURL := baseURL.ResolveReference(relativeURL) // preserves /wiki + queryValues := resolvedURL.Query() + queryValues.Set("pageVersion", strconv.Itoa(versionNumber)) + resolvedURL.RawQuery = queryValues.Encode() + return resolvedURL.String(), true } - return &pageContent, nil -} - -type ConfluenceSpaceResult struct { - ID int `json:"id"` - Key string `json:"key"` - Name string `json:"Name"` - Links map[string]string `json:"_links"` -} - -type ConfluenceSpaceResponse struct { - Results []ConfluenceSpaceResult `json:"results"` - Size int `json:"size"` -} - -type ConfluencePageContent struct { - Body struct { - Storage struct { - Value string `json:"value"` - } `json:"storage"` - } `json:"body"` - History struct { - PreviousVersion struct { - Number int - } `json:"previousVersion"` - } `json:"history"` - Version struct { - Number int `json:"number"` - } `json:"version"` - Links map[string]string `json:"_links"` -} - -type ConfluencePage struct { - ID string `json:"id"` - Type string `json:"type"` - Title string `json:"title"` -} + // Fallback: "_links.base" + if baseLink, ok := page.Links["base"]; ok && baseLink != "" { + return baseLink, true + } -type ConfluencePageResult struct { - Pages []ConfluencePage `json:"results"` + return "", false } -type ConfluencePageResponse struct { - Results ConfluencePageResult `json:"page"` +// walkPagesByIDBatches batch IDs and emit pages using the provided walker. +func (p *ConfluencePlugin) walkPagesByIDBatches( + ctx context.Context, + ids []string, + perBatch int, + seenPageIDs map[string]struct{}, + walker func(context.Context, []string, int, func(*Page) error) error, +) error { + for _, idBatch := range chunkStrings(ids, perBatch) { + if err := walker(ctx, idBatch, maxPageSize, func(page *Page) error { + return p.emitUniquePage(ctx, page, seenPageIDs) + }); err != nil { + return err + } + } + return nil } diff --git a/plugins/confluence_client.go b/plugins/confluence_client.go new file mode 100644 index 00000000..da6aaf42 --- /dev/null +++ b/plugins/confluence_client.go @@ -0,0 +1,637 @@ +package plugins + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "path" + "strconv" + "strings" + "time" +) + +const ( + httpTimeout = 60 * time.Second +) + +var ( + ErrUnsupportedTokenType = errors.New("unsupported token type") + ErrEmptyCloudID = errors.New("empty cloudID") + ErrUnexpectedHTTPStatus = errors.New("unexpected http status") +) + +// ConfluenceClient defines the operations required by the Confluence plugin. +// Methods stream results via visitor callbacks and handle pagination internally. +type ConfluenceClient interface { + WalkAllPages(ctx context.Context, limit int, visit func(*Page) error) error + WalkPagesByIDs(ctx context.Context, pageIDs []string, limit int, visit func(*Page) error) error + WalkPagesBySpaceIDs(ctx context.Context, spaceIDs []string, limit int, visit func(*Page) error) error + WalkPageVersions(ctx context.Context, pageID string, limit int, visit func(int) error) error + FetchPageAtVersion(ctx context.Context, pageID string, version int) (*Page, error) + WalkSpacesByKeys(ctx context.Context, spaceKeys []string, limit int, visit func(*Space) error) error + WikiBaseURL() string +} + +// httpConfluenceClient is a ConfluenceClient implementation backed by net/http. +// It supports optional Basic Auth using a Confluence email/username and API token. +type httpConfluenceClient struct { + baseWikiURL string + httpClient *http.Client + username string + token string + apiBase string +} + +// NewConfluenceClient constructs a ConfluenceClient for the given base wiki URL +// (e.g., https://.atlassian.net/wiki). If username and token are +// non-empty, requests use HTTP Basic Auth. +func NewConfluenceClient(baseWikiURL, username string, tokenType TokenType, tokenValue string) (ConfluenceClient, error) { + c := &httpConfluenceClient{ + baseWikiURL: strings.TrimRight(baseWikiURL, "/"), + httpClient: &http.Client{Timeout: httpTimeout}, + username: username, + token: tokenValue, + } + apiBase, err := c.buildAPIBase(context.Background(), tokenType) + if err != nil { + return nil, err + } + c.apiBase = apiBase + return c, nil +} + +// WikiBaseURL returns the base Confluence wiki URL configured for this client. +func (c *httpConfluenceClient) WikiBaseURL() string { return c.baseWikiURL } + +// buildAPIBase returns the REST v2 base URL to use for this client. +// For api-token (or no) tokens it builds "/api/v2". +// For scoped tokens it discovers the site's cloudId and builds +// "https://api.atlassian.com/ex/confluence/{cloudId}/wiki/api/v2". +func (c *httpConfluenceClient) buildAPIBase(ctx context.Context, tokenType TokenType) (string, error) { + switch tokenType { + case "", ApiToken: + u, err := url.Parse(c.baseWikiURL) + if err != nil { + return "", fmt.Errorf("parse base wiki url: %w", err) + } + u.Path = path.Join(u.Path, "api", "v2") + return strings.TrimRight(u.String(), "/"), nil + + case ScopedApiToken: + cloudID, err := c.discoverCloudID(ctx) + if err != nil { + return "", err + } + u, _ := url.Parse("https://api.atlassian.com") + u.Path = path.Join("/ex/confluence", cloudID, "wiki", "api", "v2") + return strings.TrimRight(u.String(), "/"), nil + + default: + return "", fmt.Errorf("%w %q", ErrUnsupportedTokenType, tokenType) + } +} + +// discoverCloudID resolves the Atlassian cloudId for baseWikiURL by calling +// "https:///_edge/tenant_info" and decoding {"cloudId": "..."}. +// Used when constructing the v2 API base for scoped api tokens. +func (c *httpConfluenceClient) discoverCloudID(ctx context.Context) (string, error) { + site, err := url.Parse(c.baseWikiURL) + if err != nil { + return "", fmt.Errorf("parse base url: %w", err) + } + site.RawQuery, site.Fragment = "", "" + site.Scheme = "https" + site.Path = "/_edge/tenant_info" + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, site.String(), http.NoBody) + if err != nil { + return "", fmt.Errorf("build tenant_info request: %w", err) + } + resp, err := c.httpClient.Do(req) + if err != nil { + return "", fmt.Errorf("tenant_info request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + b, _ := io.ReadAll(io.LimitReader(resp.Body, 2048)) + return "", fmt.Errorf("tenant_info: %w %d: %s", ErrUnexpectedHTTPStatus, resp.StatusCode, strings.TrimSpace(string(b))) + } + + var tmp struct { + CloudID string `json:"cloudId"` + } + if err := json.NewDecoder(resp.Body).Decode(&tmp); err != nil { + return "", fmt.Errorf("decode tenant_info: %w", err) + } + if tmp.CloudID == "" { + return "", fmt.Errorf("tenant_info: %w", ErrEmptyCloudID) + } + return tmp.CloudID, nil +} + +// WalkAllPages iterates all accessible pages and calls visit for each Page. +func (c *httpConfluenceClient) WalkAllPages(ctx context.Context, limit int, visit func(*Page) error) error { + apiURL := c.apiURL("/pages") + q := apiURL.Query() + q.Set("limit", strconv.Itoa(limit)) + q.Set("body-format", "storage") + apiURL.RawQuery = q.Encode() + return c.walkPagesPaginated(ctx, apiURL.String(), visit) +} + +// WalkPagesByIDs iterates the given page IDs and calls visit for each Page. +func (c *httpConfluenceClient) WalkPagesByIDs(ctx context.Context, pageIDs []string, limit int, visit func(*Page) error) error { + apiURL := c.apiURL("/pages") + q := apiURL.Query() + q.Set("limit", strconv.Itoa(limit)) + q.Set("body-format", "storage") + q.Set("id", strings.Join(pageIDs, ",")) + apiURL.RawQuery = q.Encode() + return c.walkPagesPaginated(ctx, apiURL.String(), visit) +} + +// WalkPagesBySpaceIDs iterates pages across the provided space IDs and calls visit. +func (c *httpConfluenceClient) WalkPagesBySpaceIDs(ctx context.Context, spaceIDs []string, limit int, visit func(*Page) error) error { + apiURL := c.apiURL("/pages") + q := apiURL.Query() + q.Set("limit", strconv.Itoa(limit)) + q.Set("body-format", "storage") + q.Set("space-id", strings.Join(spaceIDs, ",")) + apiURL.RawQuery = q.Encode() + return c.walkPagesPaginated(ctx, apiURL.String(), visit) +} + +// WalkPageVersions lists version numbers for a page and calls visit for each. +func (c *httpConfluenceClient) WalkPageVersions(ctx context.Context, pageID string, limit int, visit func(int) error) error { + apiURL := c.apiURL(fmt.Sprintf("/pages/%s/versions", url.PathEscape(pageID))) + q := apiURL.Query() + q.Set("limit", strconv.Itoa(limit)) + apiURL.RawQuery = q.Encode() + + return walkPaginated[int]( + ctx, + apiURL, + c.getJSON, + parseVersionsResponse, + visit, + ) +} + +// FetchPageAtVersion fetches a page at a specific version. +func (c *httpConfluenceClient) FetchPageAtVersion(ctx context.Context, pageID string, version int) (*Page, error) { + apiURL := c.apiURL(fmt.Sprintf("/pages/%s", url.PathEscape(pageID))) + q := apiURL.Query() + q.Set("version", strconv.Itoa(version)) + q.Set("body-format", "storage") + apiURL.RawQuery = q.Encode() + + bodyBytes, _, err := c.getJSON(ctx, apiURL.String()) + if err != nil { + return nil, err + } + var page Page + if err := json.Unmarshal(bodyBytes, &page); err != nil { + return nil, fmt.Errorf("decode page version: %w", err) + } + return &page, nil +} + +// WalkSpacesByKeys lists spaces by their keys and calls visit for each Space. +func (c *httpConfluenceClient) WalkSpacesByKeys(ctx context.Context, spaceKeys []string, limit int, visit func(*Space) error) error { + apiURL := c.apiURL("/spaces") + q := apiURL.Query() + q.Set("limit", strconv.Itoa(limit)) + q.Set("keys", strings.Join(spaceKeys, ",")) + apiURL.RawQuery = q.Encode() + + return walkPaginated[*Space]( + ctx, + apiURL, + c.getJSON, + parseSpacesResponse, + visit, + ) +} + +// Generic pager +// Fetches items from initialURL, applies parse, calls visit for each item, +// and advances using resolveNext until there is no next page. +func walkPaginated[T any]( + ctx context.Context, + apiURL *url.URL, + get func(context.Context, string) ([]byte, http.Header, error), + parse func(http.Header, []byte) ([]T, string, string, error), + visit func(T) error, +) error { + base := baseWithoutCursor(apiURL) + nextURL := apiURL.String() + + for { + body, headers, err := get(ctx, nextURL) + if err != nil { + return err + } + + items, linkNext, bodyNext, err := parse(headers, body) + if err != nil { + return err + } + + for _, it := range items { + if err := visit(it); err != nil { + return err + } + } + + rawNext := linkNext + if rawNext == "" { + rawNext = bodyNext + } + if rawNext == "" { + return nil + } + + cur := cursorFromURL(rawNext) + if cur == "" { + return nil + } + nextURL = withCursor(base, cur) + } +} + +// walkPagesPaginated iterates pages starting from initialURL (streaming decode of results array). +func (c *httpConfluenceClient) walkPagesPaginated( + ctx context.Context, initialURL string, visit func(*Page) error, +) error { + // Build a base URL without any cursor, then append the next cursor each time. + start, err := url.Parse(initialURL) + if err != nil { + return fmt.Errorf("parse initial pages url: %w", err) + } + base := baseWithoutCursor(start) + + nextURL := initialURL + for { + rc, headers, err := c.getJSONStream(ctx, nextURL) + if err != nil { + return err + } + + // Prefer Link header; body may also include _links.next. + linkNext := nextURLFromLinkHeader(headers) + bodyNext, decodeErr := streamPagesFromBody(rc, visit) + closeErr := rc.Close() + if decodeErr != nil { + return decodeErr + } + if closeErr != nil { + return closeErr + } + + // Extract only the cursor and rebuild the next URL from our base. + cur := firstNonEmptyString(cursorFromURL(linkNext), cursorFromURL(bodyNext)) + if cur == "" { + return nil + } + nextURL = withCursor(base, cur) + } +} + +// apiURL joins the relative API path to the base wiki URL or the platform host, +// producing a URL rooted at .../api/v2/. +func (c *httpConfluenceClient) apiURL(relativePath string) *url.URL { + parsedURL, _ := url.Parse(c.apiBase) + parsedURL.Path = path.Join(parsedURL.Path, strings.TrimPrefix(relativePath, "/")) + return parsedURL +} + +// getJSON performs a GET request and returns the response body and headers. +// It sets Accept: application/json and uses Basic Auth when credentials were +// provided. Non-2xx responses return an error with a short body snippet. +// HTTP 429 includes a human-friendly message derived from Retry-After. +func (c *httpConfluenceClient) getJSON(ctx context.Context, reqURL string) ([]byte, http.Header, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, reqURL, http.NoBody) + if err != nil { + return nil, nil, fmt.Errorf("build request: %w", err) + } + if c.username != "" && c.token != "" { + req.SetBasicAuth(c.username, c.token) + } + req.Header.Set("Accept", "application/json") + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, nil, fmt.Errorf("http get: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode == http.StatusTooManyRequests { + return nil, nil, fmt.Errorf("%s", rateLimitMessage(resp.Header)) + } + if resp.StatusCode < 200 || resp.StatusCode > 299 { + snippet, _ := io.ReadAll(io.LimitReader(resp.Body, 8192)) + return nil, nil, fmt.Errorf("%w %d: %s", ErrUnexpectedHTTPStatus, resp.StatusCode, strings.TrimSpace(string(snippet))) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, nil, fmt.Errorf("read body: %w", err) + } + return body, resp.Header.Clone(), nil +} + +// getJSONStream performs a GET request and returns the response Body (caller must Close) +// and headers, allowing streaming decode without buffering the entire payload. +// HTTP errors are handled similarly to getJSON. +func (c *httpConfluenceClient) getJSONStream(ctx context.Context, reqURL string) (io.ReadCloser, http.Header, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, reqURL, http.NoBody) + if err != nil { + return nil, nil, fmt.Errorf("build request: %w", err) + } + if c.username != "" && c.token != "" { + req.SetBasicAuth(c.username, c.token) + } + req.Header.Set("Accept", "application/json") + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, nil, fmt.Errorf("http get: %w", err) + } + + if resp.StatusCode == http.StatusTooManyRequests { + defer resp.Body.Close() + return nil, nil, fmt.Errorf("%s", rateLimitMessage(resp.Header)) + } + if resp.StatusCode < 200 || resp.StatusCode > 299 { + defer resp.Body.Close() + snippet, _ := io.ReadAll(io.LimitReader(resp.Body, 8192)) + return nil, nil, fmt.Errorf("%w %d: %s", ErrUnexpectedHTTPStatus, resp.StatusCode, strings.TrimSpace(string(snippet))) + } + + // Caller must Close. + return resp.Body, resp.Header.Clone(), nil +} + +// rateLimitMessage formats a user-friendly message for HTTP 429 responses, +// using the Retry-After header when available (seconds). +func rateLimitMessage(h http.Header) string { + retryAfter := strings.TrimSpace(h.Get("Retry-After")) + if retryAfter == "" { + return "rate limited (429)" + } + secs, err := strconv.Atoi(retryAfter) // seconds + if err != nil || secs < 0 { + return "rate limited (429)" + } + minutes := secs / 60 + seconds := secs % 60 + return fmt.Sprintf("rate limited (429) — retry after %d minute(s) %d second(s)", minutes, seconds) +} + +// PageVersion models the "version" object returned by Confluence. +type PageVersion struct { + Number int `json:"number"` +} + +// PageBody contains the Storage-Format body of a page. +type PageBody struct { + Storage *struct { + Value string `json:"value"` + } `json:"storage,omitempty"` +} + +// Page represents a Confluence page +type Page struct { + ID string `json:"id"` + Status string `json:"status"` + Title string `json:"title"` + SpaceID string `json:"spaceId"` + Type string `json:"type"` + Body PageBody `json:"body"` + Links map[string]string `json:"_links"` + Version PageVersion `json:"version"` +} + +// Space represents a Confluence space +type Space struct { + ID string `json:"id"` + Key string `json:"key"` + Name string `json:"name"` + Links map[string]string `json:"_links"` +} + +// listSpacesResponse models the JSON response returned by /spaces queries. +type listSpacesResponse struct { + Results []*Space `json:"results"` + Links map[string]string `json:"_links"` +} + +type versionEntry struct { + Number int `json:"number"` +} + +// listVersionsResponse models the JSON response returned by /pages/{id}/versions. +type listVersionsResponse struct { + Results []versionEntry `json:"results"` + Links map[string]string `json:"_links"` +} + +// parseSpacesResponse decodes a spaces response and returns the spaces plus any +// "next" URL found in either the Link header or the body _links.next. +func parseSpacesResponse(headers http.Header, body []byte) ([]*Space, string, string, error) { + var payload listSpacesResponse + if err := json.Unmarshal(body, &payload); err != nil { + return nil, "", "", fmt.Errorf("decode spaces: %w", err) + } + linkNext := nextURLFromLinkHeader(headers) + bodyNext := "" + if payload.Links != nil { + bodyNext = payload.Links["next"] + } + return payload.Results, linkNext, bodyNext, nil +} + +// parseVersionsResponse decodes a versions response and returns a slice of +// version numbers plus any "next" URL (Link header or body _links.next). +func parseVersionsResponse(headers http.Header, body []byte) ([]int, string, string, error) { + var payload listVersionsResponse + if err := json.Unmarshal(body, &payload); err != nil { + return nil, "", "", fmt.Errorf("decode versions: %w", err) + } + + versionNumbers := make([]int, 0, len(payload.Results)) + for _, entry := range payload.Results { + versionNumbers = append(versionNumbers, entry.Number) + } + + linkNext := nextURLFromLinkHeader(headers) + bodyNext := "" + if payload.Links != nil { + bodyNext = payload.Links["next"] + } + return versionNumbers, linkNext, bodyNext, nil +} + +// nextURLFromLinkHeader extracts the rel="next" URL from the Link header. +// It returns an empty string when no such relation is present. +func nextURLFromLinkHeader(h http.Header) string { + link := h.Get("Link") + if link == "" { + return "" + } + // Example: Link: ; rel="next", ; rel="base" + for part := range strings.SplitSeq(link, ",") { + part = strings.TrimSpace(part) + if !strings.Contains(part, `rel="next"`) { + continue + } + if i := strings.IndexByte(part, '<'); i >= 0 { + part = part[i+1:] + if j := strings.IndexByte(part, '>'); j >= 0 { + return part[:j] + } + } + } + return "" +} + +// streamPagesFromBody streams the Confluence /pages response, +// calling visit(*Page) for each element in results, and returns body _links.next if present. +func streamPagesFromBody(r io.Reader, visit func(*Page) error) (string, error) { + dec := json.NewDecoder(r) + + tok, err := dec.Token() + if err != nil { + return "", fmt.Errorf("decode: top-level token: %w", err) + } + delim, ok := tok.(json.Delim) + if !ok || delim != '{' { + return "", fmt.Errorf("decode: expected '{' at top-level") + } + + var bodyLinksNext string + + for { + t, err := dec.Token() + if err != nil { + return "", fmt.Errorf("decode: key token: %w", err) + } + + if d, ok := t.(json.Delim); ok && d == '}' { + return bodyLinksNext, nil + } + + key, ok := t.(string) + if !ok { + return "", fmt.Errorf("decode: expected object key") + } + + switch key { + case "results": + if err := decodeResultsArray(dec, visit); err != nil { + return "", err + } + case "_links": + next, err := decodeLinksNext(dec) + if err != nil { + return "", err + } + bodyLinksNext = next + default: + var skip any + if err := dec.Decode(&skip); err != nil { + return "", fmt.Errorf("decode: skip: %w", err) + } + } + } +} + +// decodeResultsArray consumes the next token, which must be '[' for a "results" +// array, then stream-decodes each element into a Page and calls visit for it. +// It stops at the closing ']' and returns any decoding or visitor error. +func decodeResultsArray(dec *json.Decoder, visit func(*Page) error) error { + tok, err := dec.Token() + if err != nil { + return fmt.Errorf("decode: results '[': %w", err) + } + delim, ok := tok.(json.Delim) + if !ok || delim != '[' { + return fmt.Errorf("decode: expected '[' for results") + } + + for dec.More() { + var p Page + if err := dec.Decode(&p); err != nil { + return fmt.Errorf("decode: page: %w", err) + } + if err := visit(&p); err != nil { + return err + } + } + + // read closing ']' + if tok, err = dec.Token(); err != nil { + return fmt.Errorf("decode: results ']': %w", err) + } + if d, ok := tok.(json.Delim); !ok || d != ']' { + return fmt.Errorf("decode: expected closing ']' for results") + } + return nil +} + +// decodeLinksNext decodes the JSON object that follows the "_links" key and +// returns its "next" value (empty string if absent). It consumes the entire +// object and wraps any decoding error with context. +func decodeLinksNext(dec *json.Decoder) (string, error) { + var ln map[string]string + if err := dec.Decode(&ln); err != nil { + return "", fmt.Errorf("decode: _links: %w", err) + } + return ln["next"], nil +} + +// baseWithoutCursor returns a shallow copy of inputURL with the "cursor" query +// parameter removed. The original URL is not modified. +func baseWithoutCursor(inputURL *url.URL) *url.URL { + cloneURL := *inputURL + queryParams := cloneURL.Query() + queryParams.Del("cursor") + cloneURL.RawQuery = queryParams.Encode() + return &cloneURL +} + +// cursorFromURL parses rawURL (absolute or relative) and returns the "cursor" +// query parameter value if present; otherwise returns an empty string. +func cursorFromURL(rawURL string) string { + if strings.TrimSpace(rawURL) == "" { + return "" + } + parsedURL, err := url.Parse(rawURL) + if err != nil { + return "" + } + return parsedURL.Query().Get("cursor") +} + +// withCursor returns the string form of baseURL with its "cursor" query +// parameter set to cursorValue (overwriting any existing one). +func withCursor(baseURL *url.URL, cursorValue string) string { + updatedURL := *baseURL + queryParams := updatedURL.Query() + queryParams.Set("cursor", cursorValue) + updatedURL.RawQuery = queryParams.Encode() + return updatedURL.String() +} + +// firstNonEmptyString returns primary if it is non-empty; otherwise fallback. +func firstNonEmptyString(primary, fallback string) string { + if primary != "" { + return primary + } + return fallback +} diff --git a/plugins/confluence_client_mock_test.go b/plugins/confluence_client_mock_test.go new file mode 100644 index 00000000..1d3bd500 --- /dev/null +++ b/plugins/confluence_client_mock_test.go @@ -0,0 +1,140 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/checkmarx/2ms/v4/plugins (interfaces: ConfluenceClient) +// +// Generated by this command: +// +// mockgen -destination=confluence_client_mock_test.go -package=plugins github.com/checkmarx/2ms/v4/plugins ConfluenceClient +// + +// Package plugins is a generated GoMock package. +package plugins + +import ( + context "context" + reflect "reflect" + + gomock "go.uber.org/mock/gomock" +) + +// MockConfluenceClient is a mock of ConfluenceClient interface. +type MockConfluenceClient struct { + ctrl *gomock.Controller + recorder *MockConfluenceClientMockRecorder + isgomock struct{} +} + +// MockConfluenceClientMockRecorder is the mock recorder for MockConfluenceClient. +type MockConfluenceClientMockRecorder struct { + mock *MockConfluenceClient +} + +// NewMockConfluenceClient creates a new mock instance. +func NewMockConfluenceClient(ctrl *gomock.Controller) *MockConfluenceClient { + mock := &MockConfluenceClient{ctrl: ctrl} + mock.recorder = &MockConfluenceClientMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockConfluenceClient) EXPECT() *MockConfluenceClientMockRecorder { + return m.recorder +} + +// FetchPageAtVersion mocks base method. +func (m *MockConfluenceClient) FetchPageAtVersion(ctx context.Context, pageID string, version int) (*Page, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "FetchPageAtVersion", ctx, pageID, version) + ret0, _ := ret[0].(*Page) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// FetchPageAtVersion indicates an expected call of FetchPageAtVersion. +func (mr *MockConfluenceClientMockRecorder) FetchPageAtVersion(ctx, pageID, version any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FetchPageAtVersion", reflect.TypeOf((*MockConfluenceClient)(nil).FetchPageAtVersion), ctx, pageID, version) +} + +// WalkAllPages mocks base method. +func (m *MockConfluenceClient) WalkAllPages(ctx context.Context, limit int, visit func(*Page) error) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "WalkAllPages", ctx, limit, visit) + ret0, _ := ret[0].(error) + return ret0 +} + +// WalkAllPages indicates an expected call of WalkAllPages. +func (mr *MockConfluenceClientMockRecorder) WalkAllPages(ctx, limit, visit any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WalkAllPages", reflect.TypeOf((*MockConfluenceClient)(nil).WalkAllPages), ctx, limit, visit) +} + +// WalkPageVersions mocks base method. +func (m *MockConfluenceClient) WalkPageVersions(ctx context.Context, pageID string, limit int, visit func(int) error) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "WalkPageVersions", ctx, pageID, limit, visit) + ret0, _ := ret[0].(error) + return ret0 +} + +// WalkPageVersions indicates an expected call of WalkPageVersions. +func (mr *MockConfluenceClientMockRecorder) WalkPageVersions(ctx, pageID, limit, visit any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WalkPageVersions", reflect.TypeOf((*MockConfluenceClient)(nil).WalkPageVersions), ctx, pageID, limit, visit) +} + +// WalkPagesByIDs mocks base method. +func (m *MockConfluenceClient) WalkPagesByIDs(ctx context.Context, pageIDs []string, limit int, visit func(*Page) error) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "WalkPagesByIDs", ctx, pageIDs, limit, visit) + ret0, _ := ret[0].(error) + return ret0 +} + +// WalkPagesByIDs indicates an expected call of WalkPagesByIDs. +func (mr *MockConfluenceClientMockRecorder) WalkPagesByIDs(ctx, pageIDs, limit, visit any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WalkPagesByIDs", reflect.TypeOf((*MockConfluenceClient)(nil).WalkPagesByIDs), ctx, pageIDs, limit, visit) +} + +// WalkPagesBySpaceIDs mocks base method. +func (m *MockConfluenceClient) WalkPagesBySpaceIDs(ctx context.Context, spaceIDs []string, limit int, visit func(*Page) error) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "WalkPagesBySpaceIDs", ctx, spaceIDs, limit, visit) + ret0, _ := ret[0].(error) + return ret0 +} + +// WalkPagesBySpaceIDs indicates an expected call of WalkPagesBySpaceIDs. +func (mr *MockConfluenceClientMockRecorder) WalkPagesBySpaceIDs(ctx, spaceIDs, limit, visit any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WalkPagesBySpaceIDs", reflect.TypeOf((*MockConfluenceClient)(nil).WalkPagesBySpaceIDs), ctx, spaceIDs, limit, visit) +} + +// WalkSpacesByKeys mocks base method. +func (m *MockConfluenceClient) WalkSpacesByKeys(ctx context.Context, spaceKeys []string, limit int, visit func(*Space) error) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "WalkSpacesByKeys", ctx, spaceKeys, limit, visit) + ret0, _ := ret[0].(error) + return ret0 +} + +// WalkSpacesByKeys indicates an expected call of WalkSpacesByKeys. +func (mr *MockConfluenceClientMockRecorder) WalkSpacesByKeys(ctx, spaceKeys, limit, visit any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WalkSpacesByKeys", reflect.TypeOf((*MockConfluenceClient)(nil).WalkSpacesByKeys), ctx, spaceKeys, limit, visit) +} + +// WikiBaseURL mocks base method. +func (m *MockConfluenceClient) WikiBaseURL() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "WikiBaseURL") + ret0, _ := ret[0].(string) + return ret0 +} + +// WikiBaseURL indicates an expected call of WikiBaseURL. +func (mr *MockConfluenceClientMockRecorder) WikiBaseURL() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WikiBaseURL", reflect.TypeOf((*MockConfluenceClient)(nil).WikiBaseURL)) +} diff --git a/plugins/confluence_client_test.go b/plugins/confluence_client_test.go new file mode 100644 index 00000000..880e7fb9 --- /dev/null +++ b/plugins/confluence_client_test.go @@ -0,0 +1,1332 @@ +package plugins + +import ( + "context" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestBuildAPIBase(t *testing.T) { + tlsTenant := func(cloudID string) *httptest.Server { + return httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/_edge/tenant_info" { + _, _ = w.Write([]byte(`{"cloudId":"` + cloudID + `"}`)) + return + } + w.WriteHeader(http.StatusNotFound) + })) + } + + tests := []struct { + name string + setup func() *httpConfluenceClient + tokenType TokenType + expectedBase string + expectedErr error + }{ + { + name: "api-token", + setup: func() *httpConfluenceClient { + return &httpConfluenceClient{baseWikiURL: "https://tenant.atlassian.net/wiki"} + }, + tokenType: ApiToken, + expectedBase: "https://tenant.atlassian.net/wiki/api/v2", + expectedErr: nil, + }, + { + name: "scoped-api-token (discovers cloudId)", + setup: func() *httpConfluenceClient { + ts := tlsTenant("abc-123") + t.Cleanup(ts.Close) + base, _ := url.Parse(ts.URL) + base.Path = "/wiki" + return &httpConfluenceClient{ + baseWikiURL: base.String(), + httpClient: ts.Client(), // trust the TLS test server + } + }, + tokenType: ScopedApiToken, + expectedBase: "https://api.atlassian.com/ex/confluence/abc-123/wiki/api/v2", + expectedErr: nil, + }, + { + name: "scoped (discoverCloudID error)", + setup: func() *httpConfluenceClient { + // TLS server that returns 500 for /_edge/tenant_info + ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/_edge/tenant_info" { + http.Error(w, "boom", http.StatusInternalServerError) + return + } + w.WriteHeader(http.StatusNotFound) + })) + t.Cleanup(ts.Close) + + base, _ := url.Parse(ts.URL) + base.Path = "/wiki" + return &httpConfluenceClient{ + baseWikiURL: base.String(), + httpClient: ts.Client(), // trust test server + } + }, + tokenType: ScopedApiToken, + expectedBase: "", + expectedErr: ErrUnexpectedHTTPStatus, + }, + { + name: "unsupported", + setup: func() *httpConfluenceClient { + return &httpConfluenceClient{baseWikiURL: "https://example.test/wiki"} + }, + tokenType: TokenType("bad"), + expectedErr: ErrUnsupportedTokenType, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + c := tc.setup() + actualBase, err := c.buildAPIBase(context.Background(), tc.tokenType) + assert.ErrorIs(t, err, tc.expectedErr) + assert.Equal(t, tc.expectedBase, actualBase) + }) + } +} + +func TestDiscoverCloudID(t *testing.T) { + tests := []struct { + name string + ctx context.Context + setup func(t *testing.T) (*httpConfluenceClient, func()) + expectedID string + expectedErr error + }{ + { + name: "success", + setup: func(t *testing.T) (*httpConfluenceClient, func()) { + ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "/_edge/tenant_info", r.URL.Path) + _, _ = io.WriteString(w, `{"cloudId":"abc-123"}`) + })) + // base has /wiki just like real-world usage + base, _ := url.Parse(ts.URL) + base.Path = "/wiki" + c := &httpConfluenceClient{ + baseWikiURL: base.String(), + httpClient: ts.Client(), + } + return c, ts.Close + }, + expectedID: "abc-123", + expectedErr: nil, + }, + { + name: "parse base url error", + setup: func(t *testing.T) (*httpConfluenceClient, func()) { + c := &httpConfluenceClient{ + baseWikiURL: "http://[::1", // invalid + httpClient: &http.Client{Timeout: 5 * time.Second}, + } + return c, func() {} + }, + expectedID: "", + expectedErr: fmt.Errorf("parse \"http://[::1\": missing ']' in host"), + }, + { + name: "client do error", + setup: func(t *testing.T) (*httpConfluenceClient, func()) { + c := &httpConfluenceClient{ + baseWikiURL: "https://127.0.0.1:1/wiki", + httpClient: &http.Client{Timeout: 200 * time.Millisecond}, + } + return c, func() {} + }, + expectedID: "", + expectedErr: fmt.Errorf("tenant_info request"), + }, + { + name: "non-200 http with snippet", + setup: func(t *testing.T) (*httpConfluenceClient, func()) { + ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "/_edge/tenant_info", r.URL.Path) + w.WriteHeader(http.StatusInternalServerError) + _, _ = io.WriteString(w, "fail") + })) + base, _ := url.Parse(ts.URL) + base.Path = "/wiki" + c := &httpConfluenceClient{ + baseWikiURL: base.String(), + httpClient: ts.Client(), + } + return c, ts.Close + }, + expectedID: "", + expectedErr: ErrUnexpectedHTTPStatus, + }, + { + name: "decode error", + setup: func(t *testing.T) (*httpConfluenceClient, func()) { + ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "/_edge/tenant_info", r.URL.Path) + _, _ = io.WriteString(w, "{") // invalid JSON + })) + base, _ := url.Parse(ts.URL) + base.Path = "/wiki" + c := &httpConfluenceClient{ + baseWikiURL: base.String(), + httpClient: ts.Client(), + } + return c, ts.Close + }, + expectedID: "", + expectedErr: io.ErrUnexpectedEOF, + }, + { + name: "empty cloudId", + setup: func(t *testing.T) (*httpConfluenceClient, func()) { + ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "/_edge/tenant_info", r.URL.Path) + _, _ = io.WriteString(w, `{"cloudId":""}`) + })) + base, _ := url.Parse(ts.URL) + base.Path = "/wiki" + c := &httpConfluenceClient{ + baseWikiURL: base.String(), + httpClient: ts.Client(), + } + return c, ts.Close + }, + expectedID: "", + expectedErr: ErrEmptyCloudID, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + c, cleanup := tc.setup(t) + defer cleanup() + + actualID, err := c.discoverCloudID(context.Background()) + assert.Equal(t, tc.expectedID, actualID) + if tc.name == "parse base url error" || tc.name == "client do error" { + assert.Contains(t, err.Error(), tc.expectedErr.Error()) + } else { + assert.ErrorIs(t, err, tc.expectedErr) + } + }) + } +} + +func TestAPIURL(t *testing.T) { + tests := []struct { + name string + apiBase string + inPath string + expectedFull string + }{ + { + name: "leading slash", + apiBase: "https://example.test/wiki/api/v2", + inPath: "/pages", + expectedFull: "https://example.test/wiki/api/v2/pages", + }, + { + name: "missing leading slash", + apiBase: "https://example.test/wiki/api/v2", + inPath: "pages", + expectedFull: "https://example.test/wiki/api/v2/pages", + }, + { + name: "trailing slash base ok", + apiBase: "https://example.test/wiki/api/v2/", + inPath: "/pages", + expectedFull: "https://example.test/wiki/api/v2/pages", + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + c := &httpConfluenceClient{apiBase: tc.apiBase} + actual := c.apiURL(tc.inPath).String() + assert.Equal(t, tc.expectedFull, actual) + }) + } +} + +func TestNextURLFromLinkHeader(t *testing.T) { + tests := []struct { + name string + link string + expectedNext string + }{ + { + name: `has rel="next"`, + link: `; rel="base", ; rel="next"`, + expectedNext: "/wiki/api/v2/pages?cursor=bar", + }, + { + name: `has rel="next" but it is empty`, + link: `; rel="base", ; rel="next"`, + expectedNext: "", + }, + { + name: "empty header", + link: "", + expectedNext: "", + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + h := http.Header{} + if tc.link != "" { + h.Set("Link", tc.link) + } + actual := nextURLFromLinkHeader(h) + assert.Equal(t, tc.expectedNext, actual) + }) + } +} + +func TestRateLimitMessage(t *testing.T) { + tests := []struct { + name string + retryAfter string + expectedText string + }{ + { + name: "seconds to minutes+seconds", + retryAfter: "75", + expectedText: "rate limited (429) — retry after 1 minute(s) 15 second(s)", + }, + { + name: "invalid value", + retryAfter: "NotANumber", + expectedText: "rate limited (429)", + }, + { + name: "no header", + retryAfter: "", + expectedText: "rate limited (429)", + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + h := http.Header{} + if tc.retryAfter != "" { + h.Set("Retry-After", tc.retryAfter) + } + actual := rateLimitMessage(h) + assert.Equal(t, tc.expectedText, actual) + }) + } +} + +func TestBaseWithoutCursor(t *testing.T) { + tests := []struct { + name string + raw string + expected string + }{ + { + name: "remove only cursor", + raw: "https://x.test/wiki/api/v2/pages?cursor=abc&limit=10", + expected: "https://x.test/wiki/api/v2/pages?limit=10", + }, + { + name: "no cursor present", + raw: "https://x.test/wiki/api/v2/pages?limit=10", + expected: "https://x.test/wiki/api/v2/pages?limit=10", + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + u, _ := url.Parse(tc.raw) + actual := baseWithoutCursor(u).String() + assert.Equal(t, tc.expected, actual) + }) + } +} + +func TestCursorFromURL(t *testing.T) { + tests := []struct { + name string + rawURL string + expected string + }{ + { + name: "relative with cursor", + rawURL: "/wiki/api/v2/pages?cursor=abc", + expected: "abc", + }, + { + name: "empty", + rawURL: "", + expected: "", + }, + { + name: "invalid url", + rawURL: "%", + expected: "", + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + actual := cursorFromURL(tc.rawURL) + assert.Equal(t, tc.expected, actual) + }) + } +} + +func TestWithCursor(t *testing.T) { + tests := []struct { + name string + raw string + cur string + expectedFull string + }{ + { + name: "add cursor", + raw: "https://x.test/wiki/api/v2/pages?limit=10", + cur: "next", + expectedFull: "https://x.test/wiki/api/v2/pages?limit=10&cursor=next", + }, + { + name: "overwrite cursor", + raw: "https://x.test/wiki/api/v2/pages?cursor=old", + cur: "new", + expectedFull: "https://x.test/wiki/api/v2/pages?cursor=new", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + u, _ := url.Parse(tc.raw) + actual := withCursor(u, tc.cur) + + actualURL, err := url.Parse(actual) + assert.NoError(t, err) + + expectedURL, err := url.Parse(tc.expectedFull) + assert.NoError(t, err) + + assert.Equal(t, expectedURL.Scheme, actualURL.Scheme) + assert.Equal(t, expectedURL.Host, actualURL.Host) + assert.Equal(t, expectedURL.Path, actualURL.Path) + assert.Equal(t, expectedURL.Query(), actualURL.Query()) + }) + } +} + +func TestFirstNonEmptyString(t *testing.T) { + tests := []struct { + name string + first string + second string + expected string + }{ + { + name: "primary", + first: "a", + second: "b", + expected: "a", + }, + { + name: "fallback", + first: "", + second: "b", + expected: "b", + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + actual := firstNonEmptyString(tc.first, tc.second) + assert.Equal(t, tc.expected, actual) + }) + } +} + +func TestWalkPaginated(t *testing.T) { + type step struct { + getErr error + parseErr error + items []int + linkNext string + bodyNext string + } + + tests := []struct { + name string + steps []step + visitErrOnVal *int + expected []int + expectedErr error + }{ + { + name: "uses bodyNext then ends", + steps: []step{ + {items: []int{1, 2}, bodyNext: "/api?cursor=two"}, + {items: []int{3}}, + }, + expected: []int{1, 2, 3}, + }, + { + name: "get error on first call", + steps: []step{ + {getErr: assert.AnError}, + }, + expectedErr: assert.AnError, + }, + { + name: "parse error on first call", + steps: []step{ + {parseErr: assert.AnError}, + }, + expectedErr: assert.AnError, + }, + { + name: "visit error on first item", + steps: []step{ + {items: []int{42}}, + }, + visitErrOnVal: func() *int { v := 42; return &v }(), + expectedErr: assert.AnError, + }, + { + name: "no next stops after first page", + steps: []step{ + {items: []int{1, 2}}, + }, + expected: []int{1, 2}, + }, + { + name: "next present but without cursor then stops", + steps: []step{ + {items: []int{1}, linkNext: "/api?foo=bar"}, + }, + expected: []int{1}, + }, + { + name: "prefers linkNext over bodyNext", + steps: []step{ + {items: []int{1}, linkNext: "/api?cursor=two", bodyNext: "/api?cursor=ignored"}, + {items: []int{2}}, + }, + expected: []int{1, 2}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + i := 0 + + get := func(ctx context.Context, _ string) ([]byte, http.Header, error) { + st := tc.steps[i] + if st.getErr != nil { + return nil, nil, st.getErr + } + return nil, http.Header{}, nil + } + + parse := func(_ http.Header, _ []byte) ([]int, string, string, error) { + st := tc.steps[i] + i++ + if st.parseErr != nil { + return nil, "", "", st.parseErr + } + return st.items, st.linkNext, st.bodyNext, nil + } + + var actualVersions []int + visit := func(n int) error { + if tc.visitErrOnVal != nil && n == *tc.visitErrOnVal { + return assert.AnError + } + actualVersions = append(actualVersions, n) + return nil + } + + start, _ := url.Parse("https://example.test/api") + err := walkPaginated[int](context.Background(), start, get, parse, visit) + + assert.ErrorIs(t, err, tc.expectedErr) + assert.Equal(t, tc.expected, actualVersions) + }) + } +} + +func TestStreamPagesFromBody(t *testing.T) { + tests := []struct { + name string + jsonInput string + visit func(*Page) error + expectedErr error + expectedVisited []string + expectedNext string + useContainsForErr bool + }{ + { + name: "results + _links.next", + jsonInput: `{ + "results": [ + {"id":"1","title":"A","body":{"storage":{"value":"x"}},"_links":{"self":"/s1"}}, + {"id":"2","title":"B","body":{"storage":{"value":"y"}},"_links":{"self":"/s2"}} + ], + "_links": {"next": "/wiki/api/v2/pages?cursor=next"} + }`, + visit: func(*Page) error { return nil }, + expectedErr: nil, + expectedVisited: []string{"1", "2"}, + expectedNext: "/wiki/api/v2/pages?cursor=next", + }, + { + name: "top-level ReadToken() fails", + jsonInput: ``, + visit: func(*Page) error { return nil }, + expectedErr: io.EOF, + }, + { + name: "top-level token not '{'", + jsonInput: `[]`, + visit: func(*Page) error { return nil }, + expectedErr: fmt.Errorf("decode: expected '{' at top-level"), + useContainsForErr: true, + }, + { + name: "unexpected token kind after '{'", + jsonInput: `{ 1 }`, + visit: func(*Page) error { return nil }, + expectedErr: fmt.Errorf("decode: key token"), + useContainsForErr: true, + }, + { + name: "key token ReadToken() fails", + jsonInput: `{"`, + visit: func(*Page) error { return nil }, + expectedErr: fmt.Errorf("decode: key token"), + useContainsForErr: true, + }, + { + name: "results decode error (not an array)", + jsonInput: `{ + "results": {}, + "_links": {"next": "/wiki/api/v2/pages?cursor=n"} + }`, + visit: func(*Page) error { return nil }, + expectedErr: fmt.Errorf("decode: expected '[' for results"), + useContainsForErr: true, + }, + { + name: "_links decode error (invalid type)", + jsonInput: `{ + "results": [], + "_links": 123 + }`, + visit: func(*Page) error { return nil }, + expectedErr: fmt.Errorf("decode: _links"), + useContainsForErr: true, + }, + { + name: "visitor returns error (propagated)", + jsonInput: `{ + "results": [{"id":"42","title":"Only"}] + }`, + visit: func(*Page) error { + return assert.AnError + }, + expectedErr: assert.AnError, + expectedVisited: nil, + }, + { + name: "unknown key is skipped", + jsonInput: `{ + "ignore_me": {"nested":{"deep": [1,2,3]}}, + "results": [{"id":"99","title":"X"}] + }`, + visit: func(*Page) error { return nil }, + expectedErr: nil, + expectedVisited: []string{"99"}, + expectedNext: "", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + var actualVisited []string + + next, err := streamPagesFromBody(strings.NewReader(tc.jsonInput), func(p *Page) error { + if tc.visit != nil { + if e := tc.visit(p); e != nil { + return e + } + } + actualVisited = append(actualVisited, p.ID) + return nil + }) + + if tc.useContainsForErr { + assert.Contains(t, err.Error(), tc.expectedErr.Error()) + } else { + assert.ErrorIs(t, err, tc.expectedErr) + } + assert.Equal(t, tc.expectedVisited, actualVisited) + assert.Equal(t, tc.expectedNext, next) + }) + } +} + +func TestParseSpacesResponse(t *testing.T) { + tests := []struct { + name string + headers http.Header + body string + expectedIDs []string + expectedLinkNext string + expectedBodyNext string + expectedErr error + useContainsForErr bool + }{ + { + name: "link header and body _links.next", + headers: func() http.Header { + h := http.Header{} + h.Set("Link", `; rel="next"`) + return h + }(), + body: `{"results":[{"id":"S1","key":"K1"}],"_links":{"next":"/wiki/api/v2/spaces?cursor=2"}}`, + expectedIDs: []string{"S1"}, + expectedLinkNext: "/wiki/api/v2/spaces?cursor=1", + expectedBodyNext: "/wiki/api/v2/spaces?cursor=2", + }, + { + name: "no link header or _links in body", + headers: http.Header{}, + body: `{"results":[{"id":"S9","key":"K9"}]}`, + expectedIDs: []string{"S9"}, + }, + { + name: "decode spaces error", + headers: http.Header{}, + body: `{`, + expectedErr: fmt.Errorf(""), + useContainsForErr: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + spaces, linkNext, bodyNext, err := parseSpacesResponse(tc.headers, []byte(tc.body)) + + if tc.useContainsForErr { + assert.Contains(t, err.Error(), tc.expectedErr.Error()) + } else { + assert.ErrorIs(t, err, tc.expectedErr) + } + + var actualIDs []string + for _, s := range spaces { + actualIDs = append(actualIDs, s.ID) + } + assert.Equal(t, tc.expectedIDs, actualIDs) + assert.Equal(t, tc.expectedLinkNext, linkNext) + assert.Equal(t, tc.expectedBodyNext, bodyNext) + }) + } +} + +func TestParseVersionsResponse(t *testing.T) { + tests := []struct { + name string + headers http.Header + body string + expectedVersions []int + expectedLinkNext string + expectedBodyNext string + expectedErr error + useContainsForErr bool + }{ + { + name: "link header and body _links.next", + headers: func() http.Header { + h := http.Header{} + h.Set("Link", `; rel="next"`) + return h + }(), + body: `{"results":[{"number":3},{"number":2},{"number":1}],"_links":{"next":"/wiki/api/v2/pages/123/versions?cursor=2"}}`, + expectedVersions: []int{3, 2, 1}, + expectedLinkNext: "/wiki/api/v2/pages/123/versions?cursor=1", + expectedBodyNext: "/wiki/api/v2/pages/123/versions?cursor=2", + }, + { + name: "no link header or _links in body", + headers: http.Header{}, + body: `{"results":[{"number":7},{"number":6}]}`, + expectedVersions: []int{7, 6}, + }, + { + name: "decode versions error ", + headers: http.Header{}, + body: `{`, + expectedErr: fmt.Errorf("unexpected end of JSON input"), + useContainsForErr: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + versions, linkNext, bodyNext, err := parseVersionsResponse(tc.headers, []byte(tc.body)) + + if tc.useContainsForErr { + assert.Contains(t, err.Error(), tc.expectedErr.Error()) + } else { + assert.ErrorIs(t, err, tc.expectedErr) + } + + assert.Equal(t, tc.expectedVersions, versions) + assert.Equal(t, tc.expectedLinkNext, linkNext) + assert.Equal(t, tc.expectedBodyNext, bodyNext) + }) + } +} +func TestGetJSON(t *testing.T) { + tests := []struct { + name string + username string + token string + setupServer func(t *testing.T) *httptest.Server + expectedErr error + expectedHeader string + }{ + { + name: "success with auth and Accept header", + username: "user", + token: "token", + setupServer: func(t *testing.T) *httptest.Server { + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Validate headers + expAuth := "Basic " + base64.StdEncoding.EncodeToString([]byte("user:token")) + assert.Equal(t, expAuth, r.Header.Get("Authorization")) + assert.Equal(t, "application/json", r.Header.Get("Accept")) + w.Header().Set("Link", "mockLink") + _, _ = w.Write([]byte(`{"ok":true}`)) + })) + }, + expectedErr: nil, + expectedHeader: "mockLink", + }, + { + name: "429 returns friendly error", + setupServer: func(t *testing.T) *httptest.Server { + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Retry-After", "120") + w.WriteHeader(http.StatusTooManyRequests) + })) + }, + expectedErr: fmt.Errorf("rate limited (429) — retry after 2 minute(s) 0 second(s)"), + expectedHeader: "", + }, + { + name: "non-2xx returns snippet", + setupServer: func(t *testing.T) *httptest.Server { + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "something went wrong", http.StatusInternalServerError) + })) + }, + expectedErr: ErrUnexpectedHTTPStatus, + expectedHeader: "", + }, + { + name: "no auth header when username/token empty", + setupServer: func(t *testing.T) *httptest.Server { + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "", r.Header.Get("Authorization")) + _, _ = w.Write([]byte(`{}`)) + })) + }, + expectedErr: nil, + expectedHeader: "", + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + ts := tc.setupServer(t) + defer ts.Close() + + client := &httpConfluenceClient{ + baseWikiURL: ts.URL + "/wiki", + httpClient: &http.Client{Timeout: 5 * time.Second}, + username: tc.username, + token: tc.token, + } + + _, headers, err := client.getJSON(context.Background(), ts.URL) + if tc.name == "429 returns friendly error" { + assert.Contains(t, err.Error(), tc.expectedErr.Error()) + } else { + assert.ErrorIs(t, err, tc.expectedErr) + } + + var actualHeader string + if headers != nil { + actualHeader = headers.Get("Link") + } + assert.Equal(t, tc.expectedHeader, actualHeader) + }) + } +} + +func TestGetJSONStream(t *testing.T) { + tests := []struct { + name string + username string + token string + setupServer func(t *testing.T) *httptest.Server + expectedErr error + expectedHeader string + expectedBody string + }{ + { + name: "success returns ReadCloser and headers", + username: "u", + token: "p", + setupServer: func(t *testing.T) *httptest.Server { + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + expAuth := "Basic " + base64.StdEncoding.EncodeToString([]byte("u:p")) + assert.Equal(t, expAuth, r.Header.Get("Authorization")) + assert.Equal(t, "application/json", r.Header.Get("Accept")) + w.Header().Set("Link", "mockLink") + _, _ = w.Write([]byte(`{"ok":true}`)) + })) + }, + expectedErr: nil, + expectedHeader: "mockLink", + expectedBody: `{"ok":true}`, + }, + { + name: "429 returns friendly error", + setupServer: func(t *testing.T) *httptest.Server { + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Retry-After", "5") + w.WriteHeader(http.StatusTooManyRequests) + })) + }, + expectedErr: fmt.Errorf("rate limited (429) — retry after 0 minute(s) 5 second(s)"), + expectedHeader: "", + expectedBody: "", + }, + { + name: "non-2xx returns snippet", + setupServer: func(t *testing.T) *httptest.Server { + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "simulated error", http.StatusBadRequest) + })) + }, + expectedErr: ErrUnexpectedHTTPStatus, + expectedHeader: "", + expectedBody: "", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + ts := tc.setupServer(t) + defer ts.Close() + + c := &httpConfluenceClient{ + httpClient: &http.Client{Timeout: 5 * time.Second}, + username: tc.username, + token: tc.token, + } + rc, headers, err := c.getJSONStream(context.Background(), ts.URL) + if tc.name == "429 returns friendly error" { + assert.Contains(t, err.Error(), tc.expectedErr.Error()) + } else { + assert.ErrorIs(t, err, tc.expectedErr) + } + + var actualHeader string + if headers != nil { + actualHeader = headers.Get("Link") + } + assert.Equal(t, tc.expectedHeader, actualHeader) + + var actualBody string + if rc != nil { + defer rc.Close() + b, _ := io.ReadAll(rc) + actualBody = strings.TrimSpace(string(b)) + } + assert.Equal(t, strings.TrimSpace(tc.expectedBody), actualBody) + }) + } +} + +func TestWalkPagesPaginated(t *testing.T) { + var calls int + var ts *httptest.Server + ts = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + calls++ + switch calls { + case 1: + // First page + Link header with a next cursor + w.Header().Set("Link", fmt.Sprintf("<%s/wiki/api/v2/pages?cursor=next>; rel=\"next\"", ts.URL)) + _, _ = io.WriteString(w, `{"results":[{"id":"1","title":"A"},{"id":"2","title":"B"}]}`) + default: + _, _ = io.WriteString(w, `{"results":[{"id":"3","title":"C"}]}`) + } + })) + defer ts.Close() + + c := &httpConfluenceClient{ + baseWikiURL: ts.URL + "/wiki", + apiBase: ts.URL + "/wiki/api/v2", + httpClient: &http.Client{Timeout: 5 * time.Second}, + } + + var actualIDs []string + actualErr := c.walkPagesPaginated(context.Background(), ts.URL, func(p *Page) error { + actualIDs = append(actualIDs, p.ID) + return nil + }) + assert.Equal(t, nil, actualErr) + assert.Equal(t, []string{"1", "2", "3"}, actualIDs) +} + +func TestWalkAllPages(t *testing.T) { + expectPath := "/wiki/api/v2/pages" + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, expectPath, r.URL.Path) + assert.Equal(t, "250", r.URL.Query().Get("limit")) + assert.Equal(t, "storage", r.URL.Query().Get("body-format")) + _, _ = io.WriteString(w, `{"results":[{"id":"1","title":"A"},{"id":"2","title":"B"}]}`) + })) + defer ts.Close() + + client := &httpConfluenceClient{ + baseWikiURL: ts.URL + "/wiki", + apiBase: ts.URL + "/wiki/api/v2", + httpClient: &http.Client{Timeout: 5 * time.Second}, + } + var actual []string + actualErr := client.WalkAllPages(context.Background(), 250, func(p *Page) error { + actual = append(actual, p.ID) + return nil + }) + assert.Equal(t, nil, actualErr) + assert.Equal(t, []string{"1", "2"}, actual) +} + +func TestWalkPagesByIDs(t *testing.T) { + expectPath := "/wiki/api/v2/pages" + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, expectPath, r.URL.Path) + assert.Equal(t, "2", r.URL.Query().Get("limit")) + assert.Equal(t, "10,20", r.URL.Query().Get("id")) + assert.Equal(t, "storage", r.URL.Query().Get("body-format")) + _, _ = io.WriteString(w, `{"results":[{"id":"10"},{"id":"20"}]}`) + })) + defer ts.Close() + + client := &httpConfluenceClient{ + baseWikiURL: ts.URL + "/wiki", + apiBase: ts.URL + "/wiki/api/v2", + httpClient: &http.Client{Timeout: 5 * time.Second}, + } + var actual []string + actualErr := client.WalkPagesByIDs(context.Background(), []string{"10", "20"}, 2, func(p *Page) error { + actual = append(actual, p.ID) + return nil + }) + assert.Equal(t, nil, actualErr) + assert.Equal(t, []string{"10", "20"}, actual) +} + +func TestWalkPagesBySpaceIDs(t *testing.T) { + expectPath := "/wiki/api/v2/pages" + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, expectPath, r.URL.Path) + assert.Equal(t, "2", r.URL.Query().Get("limit")) + assert.Equal(t, "S1,S2", r.URL.Query().Get("space-id")) + assert.Equal(t, "storage", r.URL.Query().Get("body-format")) + _, _ = io.WriteString(w, `{"results":[{"id":"100"},{"id":"200"}]}`) + })) + defer ts.Close() + + client := &httpConfluenceClient{ + baseWikiURL: ts.URL + "/wiki", + apiBase: ts.URL + "/wiki/api/v2", + httpClient: &http.Client{Timeout: 5 * time.Second}, + } + var actual []string + actualErr := client.WalkPagesBySpaceIDs(context.Background(), []string{"S1", "S2"}, 2, func(p *Page) error { + actual = append(actual, p.ID) + return nil + }) + assert.Equal(t, nil, actualErr) + assert.Equal(t, []string{"100", "200"}, actual) +} + +func TestWalkPageVersions(t *testing.T) { + expectPath := "/wiki/api/v2/pages/123/versions" + resp := listVersionsResponse{Results: []versionEntry{{Number: 2}, {Number: 1}}} + + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, expectPath, r.URL.Path) + assert.Equal(t, "50", r.URL.Query().Get("limit")) + _ = json.NewEncoder(w).Encode(resp) + })) + defer testServer.Close() + + client := &httpConfluenceClient{ + baseWikiURL: testServer.URL + "/wiki", + apiBase: testServer.URL + "/wiki/api/v2", + httpClient: &http.Client{Timeout: 5 * time.Second}, + } + var actual []int + actualErr := client.WalkPageVersions(context.Background(), "123", 50, func(n int) error { + actual = append(actual, n) + return nil + }) + assert.Equal(t, nil, actualErr) + assert.Equal(t, []int{2, 1}, actual) +} + +func TestFetchPageAtVersion(t *testing.T) { + expectPath := "/wiki/api/v2/pages/123" + + expected := mkPage("123", 7) + + tests := []struct { + name string + handler http.HandlerFunc + expectedErr error + useContainsForErr bool + expectedPage *Page + }{ + { + name: "success", + handler: func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, expectPath, r.URL.Path) + assert.Equal(t, "7", r.URL.Query().Get("version")) + assert.Equal(t, "storage", r.URL.Query().Get("body-format")) + _ = json.NewEncoder(w).Encode(expected) + }, + expectedErr: nil, + expectedPage: expected, + }, + { + name: "getJSON error (non-2xx)", + handler: func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, expectPath, r.URL.Path) + assert.Equal(t, "7", r.URL.Query().Get("version")) + assert.Equal(t, "storage", r.URL.Query().Get("body-format")) + http.Error(w, "boom", http.StatusInternalServerError) + }, + expectedErr: ErrUnexpectedHTTPStatus, + expectedPage: nil, + }, + { + name: "decode error (invalid JSON)", + handler: func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, expectPath, r.URL.Path) + assert.Equal(t, "7", r.URL.Query().Get("version")) + assert.Equal(t, "storage", r.URL.Query().Get("body-format")) + _, _ = io.WriteString(w, "{") + }, + expectedErr: fmt.Errorf("unexpected end of JSON input"), + useContainsForErr: true, + expectedPage: nil, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + ts := httptest.NewServer(tc.handler) + defer ts.Close() + + client := &httpConfluenceClient{ + baseWikiURL: ts.URL + "/wiki", + apiBase: ts.URL + "/wiki/api/v2", + httpClient: &http.Client{Timeout: 5 * time.Second}, + } + + actualPage, err := client.FetchPageAtVersion(context.Background(), "123", 7) + + if tc.useContainsForErr { + assert.Contains(t, err.Error(), tc.expectedErr.Error()) + } else { + assert.ErrorIs(t, err, tc.expectedErr) + } + assert.Equal(t, tc.expectedPage, actualPage) + }) + } +} + +func TestWalkSpacesByKeys(t *testing.T) { + expectPath := "/wiki/api/v2/spaces" + resp := listSpacesResponse{ + Results: []*Space{ + {ID: "S1", Key: "KEY1"}, + {ID: "S2", Key: "KEY2"}, + }, + } + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, expectPath, r.URL.Path) + assert.Equal(t, "2", r.URL.Query().Get("limit")) + assert.Equal(t, "KEY1,KEY2", r.URL.Query().Get("keys")) + _ = json.NewEncoder(w).Encode(resp) + })) + defer ts.Close() + + client := &httpConfluenceClient{ + baseWikiURL: ts.URL + "/wiki", + apiBase: ts.URL + "/wiki/api/v2", + httpClient: &http.Client{Timeout: 5 * time.Second}, + } + + var actual []string + actualErr := client.WalkSpacesByKeys(context.Background(), []string{"KEY1", "KEY2"}, 2, func(s *Space) error { + actual = append(actual, s.ID) + return nil + }) + + assert.Equal(t, nil, actualErr) + expected := []string{"S1", "S2"} + assert.Equal(t, expected, actual) +} + +func TestWikiBaseURL(t *testing.T) { + client := &httpConfluenceClient{ + baseWikiURL: "wikiURL", + } + + actual := client.WikiBaseURL() + assert.Equal(t, "wikiURL", actual) +} + +func TestDecodeResultsArray(t *testing.T) { + makeDec := func(s string) *json.Decoder { + return json.NewDecoder(strings.NewReader(s)) + } + + tests := []struct { + name string + jsonInput string + visit func(*Page) error + expectedErr error + expectedVisited []string + useContainsForErr bool + }{ + { + name: "readToken error at start", + jsonInput: "", + visit: func(*Page) error { return nil }, + expectedErr: io.EOF, + }, + { + name: "first token not '['", + jsonInput: `{}`, // object instead of array + visit: func(*Page) error { return nil }, + expectedErr: fmt.Errorf("decode: expected '[' for results"), + useContainsForErr: true, + }, + { + name: "element decode error", + jsonInput: `["not-an-object"]`, + visit: func(*Page) error { return nil }, + expectedErr: fmt.Errorf("decode: page"), + useContainsForErr: true, + }, + { + name: "visitor returns error", + jsonInput: `[{"id":"1","title":"A"}]`, + visit: func(*Page) error { + return assert.AnError + }, + expectedErr: assert.AnError, + }, + { + name: "success empty array", + jsonInput: `[]`, + visit: func(*Page) error { return nil }, + expectedVisited: nil, + expectedErr: nil, + }, + { + name: "success with two pages", + jsonInput: `[{"id":"1","title":"A"},{"id":"2","title":"B"}]`, + visit: func(p *Page) error { + return nil + }, + expectedVisited: []string{"1", "2"}, + expectedErr: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var actualVisited []string + dec := makeDec(tt.jsonInput) + + visit := func(p *Page) error { + if tt.visit != nil { + if err := tt.visit(p); err != nil { + return err + } + } + actualVisited = append(actualVisited, p.ID) + return nil + } + + err := decodeResultsArray(dec, visit) + + if tt.useContainsForErr { + assert.Contains(t, err.Error(), tt.expectedErr.Error()) + } else { + assert.ErrorIs(t, err, tt.expectedErr) + } + assert.Equal(t, tt.expectedVisited, actualVisited) + }) + } +} + +func TestDecodeLinksNext(t *testing.T) { + tests := []struct { + name string + jsonObj string + expectedNext string + expectedErr error + }{ + { + name: "decode next", + jsonObj: `{"next":"/wiki/api/v2/pages?cursor=abc"}`, + expectedNext: "/wiki/api/v2/pages?cursor=abc", + expectedErr: nil, + }, + { + name: "error decoding", + jsonObj: `{`, + expectedErr: io.ErrUnexpectedEOF, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + dec := json.NewDecoder(strings.NewReader(tc.jsonObj)) + + actualNext, err := decodeLinksNext(dec) + + assert.ErrorIs(t, err, tc.expectedErr) + assert.Equal(t, tc.expectedNext, actualNext) + }) + } +} diff --git a/plugins/confluence_test.go b/plugins/confluence_test.go index 44701282..2a24f703 100644 --- a/plugins/confluence_test.go +++ b/plugins/confluence_test.go @@ -1,829 +1,1494 @@ package plugins import ( - "bytes" + "bufio" + "context" "fmt" - "sort" + "io" + "net/http" "strconv" "strings" - "sync" "testing" - "github.com/rs/zerolog" - "github.com/rs/zerolog/log" + "github.com/checkmarx/2ms/v4/engine/chunk" + "github.com/spf13/cobra" "github.com/stretchr/testify/assert" + "go.uber.org/mock/gomock" ) -type mockConfluenceClient struct { - pageContentResponse []*ConfluencePageContent - pageContentError error - numberOfPages int - firstPagesRequestError error - secondPagesRequestError error - numberOfSpaces int - firstSpacesRequestError error - secondSpacesRequestError error +//go:generate mockgen -destination=confluence_client_mock_test.go -package=plugins github.com/checkmarx/2ms/v4/plugins ConfluenceClient + +const mockGetFileThresholdReturn = 1_000_000 + +func TestGetName(t *testing.T) { + p := &ConfluencePlugin{} + assert.Equal(t, "confluence", p.GetName()) } -func (m *mockConfluenceClient) getSpacesRequest(start int) (*ConfluenceSpaceResponse, error) { - if m.firstSpacesRequestError != nil && start == 0 { - return nil, m.firstSpacesRequestError +func TestIsValidURL(t *testing.T) { + tests := []struct { + name string + input string + expectedErr error + }{ + { + name: "valid https", + input: "https://checkmarx.atlassian.net/wiki", + expectedErr: nil, + }, + { + name: "invalid scheme", + input: "http://checkmarx.atlassian.net/wiki", + expectedErr: ErrHTTPSRequired, + }, + { + name: "not a url", + input: "%", + expectedErr: fmt.Errorf("invalid URL escape"), + }, } - - if m.secondSpacesRequestError != nil && start != 0 { - return nil, m.secondSpacesRequestError + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + args := []string{tc.input} + err := isValidURL(&cobra.Command{}, args) + if tc.name == "not a url" { + assert.Contains(t, err.Error(), tc.expectedErr.Error()) + } else { + assert.ErrorIs(t, err, tc.expectedErr) + } + }) } +} - var spaces []ConfluenceSpaceResult - for i := start; i < m.numberOfSpaces && i-start < confluenceDefaultWindow; i++ { - spaces = append(spaces, ConfluenceSpaceResult{ID: i, Key: strconv.Itoa(i)}) +func TestIsValidTokenType(t *testing.T) { + tests := []struct { + name string + tokenType TokenType + expectedValid bool + }{ + { + name: "empty", + tokenType: "", + expectedValid: true, + }, + { + name: "api-token", + tokenType: ApiToken, + expectedValid: true, + }, + { + name: "scoped-api-token", + tokenType: ScopedApiToken, + expectedValid: true, + }, + { + name: "invalid", + tokenType: TokenType("weird"), + expectedValid: false, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + assert.Equal(t, tc.expectedValid, isValidTokenType(tc.tokenType)) + }) } - return &ConfluenceSpaceResponse{ - Results: spaces, - Size: len(spaces), - }, nil } -func (m *mockConfluenceClient) getPagesRequest(space ConfluenceSpaceResult, start int) (*ConfluencePageResult, error) { - if m.firstPagesRequestError != nil && start == 0 { - return nil, m.firstPagesRequestError - } +func TestInitialize(t *testing.T) { + const ( + baseURL = "https://tenant.atlassian.net/wiki" + expectedAPIBase = "https://tenant.atlassian.net/wiki/api/v2" + username = "user@example.com" + tokenValue = "token123" + ) - if m.secondPagesRequestError != nil && start != 0 { - return nil, m.secondPagesRequestError + tests := []struct { + name string + base string + tokenType TokenType + username string + tokenValue string + expectedErr error + expectedClient ConfluenceClient + }{ + { + name: "valid initialization (api-token)", + base: baseURL, + tokenType: ApiToken, + username: username, + tokenValue: tokenValue, + expectedErr: nil, + expectedClient: &httpConfluenceClient{ + baseWikiURL: baseURL, + httpClient: &http.Client{Timeout: httpTimeout}, + username: username, + token: tokenValue, + apiBase: expectedAPIBase, + }, + }, + { + name: "invalid initialization (unsupported token type)", + base: baseURL, + tokenType: TokenType("bad"), + username: username, + tokenValue: tokenValue, + expectedErr: ErrUnsupportedTokenType, + expectedClient: nil, + }, } - var pages []ConfluencePage - for i := start; i < m.numberOfPages && i-start < confluenceDefaultWindow; i++ { - pages = append(pages, ConfluencePage{ID: strconv.Itoa(i)}) - } - return &ConfluencePageResult{Pages: pages}, nil -} + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + p := NewConfluencePlugin().(*ConfluencePlugin) + err := p.initialize(tc.base, tc.username, tc.tokenType, tc.tokenValue) -func (m *mockConfluenceClient) getPageContentRequest(page ConfluencePage, version int) (*ConfluencePageContent, error) { - if m.pageContentError != nil { - return nil, m.pageContentError + assert.ErrorIs(t, err, tc.expectedErr) + assert.Equal(t, tc.expectedClient, p.client) + }) } - return m.pageContentResponse[version], nil } -func TestGetPages(t *testing.T) { +func TestChunkStrings(t *testing.T) { tests := []struct { - name string - numberOfPages int - firstPagesRequestError error - secondPagesRequestError error - expectedError error + name string + in []string + chunkSize int + chunkSpans [][2]int // [start,end) ranges expected por chunk }{ { - name: "Error while getting pages before pagination is required", - numberOfPages: confluenceDefaultWindow - 2, - firstPagesRequestError: fmt.Errorf("some error before pagination is required"), - expectedError: fmt.Errorf( - "unexpected error creating an http request %w", - fmt.Errorf("some error before pagination is required"), - ), + name: "exact multiple", + in: makeRangeStrings(1, 300), // 300 items + chunkSize: 100, + chunkSpans: [][2]int{{0, 100}, {100, 200}, {200, 300}}, }, { - name: "error while getting pages after pagination is required", - numberOfPages: confluenceDefaultWindow + 2, - secondPagesRequestError: fmt.Errorf("some error after pagination required"), - expectedError: fmt.Errorf( - "unexpected error creating an http request %w", - fmt.Errorf("some error after pagination required"), - ), + name: "not an exact multiple", + in: makeRangeStrings(1, 305), // 305 items + chunkSize: 100, + chunkSpans: [][2]int{{0, 100}, {100, 200}, {200, 300}, {300, 305}}, }, { - name: "pages less than confluenceDefaultWindow", - numberOfPages: confluenceDefaultWindow - 2, - expectedError: nil, + name: "small input", + in: []string{"a", "b"}, + chunkSize: 250, + chunkSpans: [][2]int{{0, 2}}, }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + chunks := chunkStrings(tc.in, tc.chunkSize) + assert.Equal(t, len(tc.chunkSpans), len(chunks)) + for i, span := range tc.chunkSpans { + assert.Equal(t, tc.in[span[0]:span[1]], chunks[i]) + } + }) + } +} + +func TestConvertPageToItem(t *testing.T) { + const base = "https://checkmarx.atlassian.net/wiki" + + tests := []struct { + name string + page *Page + expectCalls int // WikiBaseURL calls + expectedID string + expectedSrc string + expectedContent *string + }{ { - name: "exactly confluenceDefaultWindow pages", - numberOfPages: confluenceDefaultWindow, - expectedError: nil, + name: "webui + version", + page: &Page{ + ID: "123", + Title: "Page Title", + Body: PageBody{Storage: &struct { + Value string `json:"value"` + }{Value: "

content

"}}, + Links: map[string]string{"webui": "/pages/viewpage.action?pageId=123"}, + Version: PageVersion{Number: 4}, + }, + expectCalls: 1, + expectedID: "confluence-123-4", + expectedSrc: base + "/pages/viewpage.action?pageId=123&pageVersion=4", + expectedContent: ptr("

content

"), }, { - name: "fetching more pages after confluenceDefaultWindow", - numberOfPages: confluenceDefaultWindow + 2, - expectedError: nil, + name: "fallback base link", + page: &Page{ + ID: "456", + Links: map[string]string{"base": base}, + Version: PageVersion{Number: 1}, + }, + expectCalls: 0, + expectedID: "confluence-456-1", + expectedSrc: base, + expectedContent: nil, }, } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - mockClient := mockConfluenceClient{ - numberOfPages: tt.numberOfPages, - firstPagesRequestError: tt.firstPagesRequestError, - secondPagesRequestError: tt.secondPagesRequestError, - } - space := ConfluenceSpaceResult{Name: "Test Space"} - plugin := &ConfluencePlugin{client: &mockClient} - result, err := plugin.getPages(space) - assert.Equal(t, tt.expectedError, err) - if tt.expectedError == nil { - var expectedResult ConfluencePageResult - for i := 0; i < tt.numberOfPages; i++ { - expectedResult.Pages = append(expectedResult.Pages, ConfluencePage{ID: strconv.Itoa(i)}) - } - assert.Equal(t, &expectedResult, result) - } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockClient := NewMockConfluenceClient(ctrl) + mockClient.EXPECT().WikiBaseURL().Return(base).Times(tc.expectCalls) + + p := &ConfluencePlugin{client: mockClient} + item := p.convertPageToItem(tc.page) + + assert.Equal(t, tc.expectedID, item.GetID()) + assert.Equal(t, tc.expectedSrc, item.GetSource()) + assert.Equal(t, tc.expectedContent, item.GetContent()) }) } } -func TestGetSpaces(t *testing.T) { +func TestResolveConfluenceSourceURL(t *testing.T) { + const base = "https://checkmarx.atlassian.net/wiki" + tests := []struct { - name string - numberOfSpaces int - firstSpacesRequestError error - secondSpacesRequestError error - expectedError error - filteredSpaces []string + name string + links map[string]string + version int + expectCalls int // WikiBaseURL calls + canResolve bool + expectedURL string + wikiURL string }{ { - name: "Error while getting spaces before pagination is required", - numberOfSpaces: confluenceDefaultWindow - 2, - firstSpacesRequestError: fmt.Errorf("some error before pagination is required"), - expectedError: fmt.Errorf("some error before pagination is required"), + name: "webui relative + version", + links: map[string]string{"webui": "/pages/viewpage.action?pageId=123"}, + version: 4, + expectCalls: 1, + canResolve: true, + expectedURL: base + "/pages/viewpage.action?pageId=123&pageVersion=4", + wikiURL: base, }, { - name: "error while getting spaces after pagination is required", - numberOfSpaces: confluenceDefaultWindow + 2, - secondSpacesRequestError: fmt.Errorf("some error after pagination required"), - expectedError: fmt.Errorf("some error after pagination required"), + name: "webui absolute + version", + links: map[string]string{"webui": base + "/pages/viewpage.action?pageId=456"}, + version: 2, + expectCalls: 1, + canResolve: true, + expectedURL: base + "/pages/viewpage.action?pageId=456&pageVersion=2", + wikiURL: base, }, { - name: "zero spaces", - numberOfSpaces: 0, - expectedError: nil, + name: "webui present and valid but wikiURL invalid", + links: map[string]string{"webui": base + "/pages/viewpage.action?pageId=456"}, + version: 2, + expectCalls: 1, + canResolve: false, + expectedURL: "", + wikiURL: "%", }, { - name: "spaces less than confluenceDefaultWindow", - numberOfSpaces: confluenceDefaultWindow - 2, - expectedError: nil, + name: "fallback base", + links: map[string]string{"base": base}, + version: 1, + expectCalls: 0, + canResolve: true, + expectedURL: base, + wikiURL: base, }, { - name: "exactly confluenceDefaultWindow spaces", - numberOfSpaces: confluenceDefaultWindow, - expectedError: nil, + name: "links nil", + links: nil, + version: 1, + expectCalls: 0, + canResolve: false, + expectedURL: "", }, { - name: "fetching more spaces after confluenceDefaultWindow", - numberOfSpaces: confluenceDefaultWindow + 2, - expectedError: nil, + name: "missing one of the required links", + links: map[string]string{"something": "mock"}, + version: 1, + expectCalls: 0, + canResolve: false, + expectedURL: "", }, { - name: "fetching spaces with filtered spaces", - numberOfSpaces: 5, - filteredSpaces: []string{"2"}, - expectedError: nil, + name: "invalid webui", + links: map[string]string{"webui": "%"}, + version: 1, + expectCalls: 1, + canResolve: false, + expectedURL: "", }, } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - mockClient := mockConfluenceClient{ - numberOfSpaces: tt.numberOfSpaces, - firstSpacesRequestError: tt.firstSpacesRequestError, - secondSpacesRequestError: tt.secondSpacesRequestError, - } - plugin := &ConfluencePlugin{ - client: &mockClient, - Spaces: tt.filteredSpaces, - } - result, err := plugin.getSpaces() - assert.Equal(t, tt.expectedError, err) - if tt.expectedError == nil { - var expectedResult []ConfluenceSpaceResult - if len(tt.filteredSpaces) == 0 { - for i := 0; i < tt.numberOfSpaces; i++ { - expectedResult = append(expectedResult, ConfluenceSpaceResult{ID: i, Key: strconv.Itoa(i)}) - } - } else { - for i := 0; i < len(tt.filteredSpaces); i++ { - id, errConvert := strconv.Atoi(tt.filteredSpaces[i]) - key := tt.filteredSpaces[i] - assert.NoError(t, errConvert) - expectedResult = append(expectedResult, ConfluenceSpaceResult{ID: id, Key: key}) - } - } - assert.Equal(t, expectedResult, result) - } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockClient := NewMockConfluenceClient(ctrl) + mockClient.EXPECT().WikiBaseURL().Return(tc.wikiURL).Times(tc.expectCalls) + + p := &ConfluencePlugin{client: mockClient} + page := &Page{Links: tc.links} + actualURL, ok := p.resolveConfluenceSourceURL(page, tc.version) + assert.Equal(t, tc.canResolve, ok) + assert.Equal(t, tc.expectedURL, actualURL) }) } } -func TestScanPageVersion(t *testing.T) { +func TestWalkPagesByIDBatches(t *testing.T) { tests := []struct { - name string - mockPageContent *ConfluencePageContent - mockError error - expectError bool - expectItem bool - expectedVersionNum int + name string + allIDs []string + perBatch int + setupWalker func() func(context.Context, []string, int, func(*Page) error) error + expectedBatches [][]string + expectedEmitCount int + expectedErr error }{ { - name: "Successful page scan with previous version", - mockPageContent: &ConfluencePageContent{ - Body: struct { - Storage struct { - Value string `json:"value"` - } `json:"storage"` - }(struct { - Storage struct { - Value string - } - }{ - Storage: struct{ Value string }{Value: "Page content"}, - }), - History: struct { - PreviousVersion struct{ Number int } `json:"previousVersion"` - }(struct { - PreviousVersion struct { - Number int + name: "walks in chunks and emits via walker", + allIDs: []string{"a", "b", "c", "d", "e"}, + perBatch: 2, + setupWalker: func() func(context.Context, []string, int, func(*Page) error) error { + return func(_ context.Context, ids []string, _ int, visit func(*Page) error) error { + for _, id := range ids { + _ = visit(mkPage(id, 1)) } - }{PreviousVersion: struct{ Number int }{Number: 1}}), - Links: map[string]string{ - "base": "https://example.com", - "webui": "/wiki/page", - }, + return nil + } }, - expectItem: true, - expectedVersionNum: 1, + expectedBatches: [][]string{{"a", "b"}, {"c", "d"}, {"e"}}, + expectedEmitCount: 5, + expectedErr: nil, }, { - name: "Error fetching page content", - mockError: fmt.Errorf("fetch error"), - expectError: true, - expectItem: false, - expectedVersionNum: 0, + name: "propagates walker error", + allIDs: []string{"1", "2"}, + perBatch: 10, + setupWalker: func() func(context.Context, []string, int, func(*Page) error) error { + return func(_ context.Context, _ []string, _ int, _ func(*Page) error) error { + return assert.AnError + } + }, + expectedBatches: [][]string{{"1", "2"}}, + expectedEmitCount: 0, + expectedErr: assert.AnError, }, } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - mockClient := &mockConfluenceClient{ - pageContentResponse: []*ConfluencePageContent{tt.mockPageContent}, - pageContentError: tt.mockError, - } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() - errorsChan := make(chan error, 1) - itemsChan := make(chan ISourceItem, 1) + mockChunk := chunk.NewMockIChunk(ctrl) + mockChunk.EXPECT().GetFileThreshold().Return(int64(mockGetFileThresholdReturn)).Times(tc.expectedEmitCount) - plugin := &ConfluencePlugin{ - client: mockClient, - errorsChan: errorsChan, - itemsChan: itemsChan, - } + mockClient := NewMockConfluenceClient(ctrl) + mockClient.EXPECT().WikiBaseURL().Return("https://tenant.atlassian.net/wiki").Times(tc.expectedEmitCount) - page := ConfluencePage{ID: "pageID"} - space := ConfluenceSpaceResult{Key: "spaceKey"} - - result := plugin.scanPageVersion(page, space, 0) - - assert.Equal(t, tt.expectedVersionNum, result) - - if tt.expectError { - assert.NotEmpty(t, errorsChan) - err := <-errorsChan - assert.Equal(t, tt.mockError, err) - } else { - assert.Empty(t, errorsChan) + p := &ConfluencePlugin{ + itemsChan: make(chan ISourceItem, 100), + chunker: mockChunk, + client: mockClient, } - if tt.expectItem { - assert.NotEmpty(t, itemsChan) - actualItem := <-itemsChan - expectedItem := item{ - Content: ptrToString("Page content"), - ID: "confluence-spaceKey-pageID", - Source: "https://example.com/wiki/page", - } - assert.Equal(t, &expectedItem, actualItem) - } else { - assert.Empty(t, itemsChan) + seen := map[string]struct{}{} + var seenBatches [][]string + walker := func(ctx context.Context, ids []string, lim int, v func(*Page) error) error { + seenBatches = append(seenBatches, append([]string(nil), ids...)) + return tc.setupWalker()(ctx, ids, lim, v) } - close(itemsChan) - close(errorsChan) + err := p.walkPagesByIDBatches(context.Background(), tc.allIDs, tc.perBatch, seen, walker) + assert.ErrorIs(t, err, tc.expectedErr) + assert.Equal(t, tc.expectedBatches, seenBatches) + assert.Len(t, collectEmittedItems(p.itemsChan), tc.expectedEmitCount) }) } } - -func TestScanPageAllVersions(t *testing.T) { +func TestEmitUniquePage(t *testing.T) { tests := []struct { - name string - mockPageContents []*ConfluencePageContent - expectedErrors []error - expectedItems []item - historyEnabled bool + name string + seenPages map[string]struct{} + page *Page + history bool + setupMocks func(mc *MockConfluenceClient, ch *chunk.MockIChunk, p *Page) + expectedErr error + expectedEmitCount int }{ { - name: "scan with multiple versions and history enabled", - mockPageContents: []*ConfluencePageContent{ - { - Body: struct { - Storage struct { - Value string `json:"value"` - } `json:"storage"` - }(struct { - Storage struct { - Value string - } - }{ - Storage: struct{ Value string }{Value: "Page content 1"}, - }), - History: struct { - PreviousVersion struct{ Number int } `json:"previousVersion"` - }(struct{ PreviousVersion struct{ Number int } }{PreviousVersion: struct{ Number int }{Number: 2}}), - Links: map[string]string{ - "base": "https://example.com", - "webui": "/wiki/page", - }, - }, - { - Body: struct { - Storage struct { - Value string `json:"value"` - } `json:"storage"` - }(struct { - Storage struct { - Value string - } - }{ - Storage: struct{ Value string }{Value: "Page content 2"}, - }), - History: struct { - PreviousVersion struct{ Number int } `json:"previousVersion"` - }(struct{ PreviousVersion struct{ Number int } }{PreviousVersion: struct{ Number int }{Number: 0}}), - Links: map[string]string{ - "base": "https://example.com", - "webui": "/wiki/page", - }, - }, - { - Body: struct { - Storage struct { - Value string `json:"value"` - } `json:"storage"` - }(struct { - Storage struct { - Value string - } - }{ - Storage: struct{ Value string }{Value: "Page content 3"}, - }), - History: struct { - PreviousVersion struct{ Number int } `json:"previousVersion"` - }(struct{ PreviousVersion struct{ Number int } }{PreviousVersion: struct{ Number int }{Number: 1}}), - Links: map[string]string{ - "base": "https://example.com", - "webui": "/wiki/page", - }, - }, - }, - historyEnabled: true, - expectedErrors: nil, - expectedItems: []item{ - { - Content: ptrToString("Page content 1"), - ID: "confluence-spaceKey-pageID", - Source: "https://example.com/wiki/page", - }, - { - Content: ptrToString("Page content 3"), - ID: "confluence-spaceKey-pageID", - Source: "https://example.com/wiki/page", - }, - { - Content: ptrToString("Page content 2"), - ID: "confluence-spaceKey-pageID", - Source: "https://example.com/wiki/page", - }, + name: "first time emits", + seenPages: map[string]struct{}{}, + page: mkPage("42", 3), + history: false, + setupMocks: func(mc *MockConfluenceClient, ch *chunk.MockIChunk, p *Page) { + ch.EXPECT().GetFileThreshold().Return(int64(mockGetFileThresholdReturn)).Times(1) + mc.EXPECT().WikiBaseURL().Return("https://tenant.atlassian.net/wiki").Times(1) }, + expectedErr: nil, + expectedEmitCount: 1, }, { - name: "scan with multiple versions and history disabled", - mockPageContents: []*ConfluencePageContent{ - { - Body: struct { - Storage struct { - Value string `json:"value"` - } `json:"storage"` - }(struct { - Storage struct { - Value string - } - }{ - Storage: struct{ Value string }{Value: "Page content 1"}, - }), - History: struct { - PreviousVersion struct{ Number int } `json:"previousVersion"` - }(struct{ PreviousVersion struct{ Number int } }{PreviousVersion: struct{ Number int }{Number: 2}}), - Links: map[string]string{ - "base": "https://example.com", - "webui": "/wiki/page", - }, - }, - { - Body: struct { - Storage struct { - Value string `json:"value"` - } `json:"storage"` - }(struct { - Storage struct { - Value string - } - }{ - Storage: struct{ Value string }{Value: "Page content 2"}, - }), - History: struct { - PreviousVersion struct{ Number int } `json:"previousVersion"` - }(struct{ PreviousVersion struct{ Number int } }{PreviousVersion: struct{ Number int }{Number: 0}}), - Links: map[string]string{ - "base": "https://example.com", - "webui": "/wiki/page", - }, - }, - { - Body: struct { - Storage struct { - Value string `json:"value"` - } `json:"storage"` - }(struct { - Storage struct { - Value string - } - }{ - Storage: struct{ Value string }{Value: "Page content 3"}, - }), - History: struct { - PreviousVersion struct{ Number int } `json:"previousVersion"` - }(struct{ PreviousVersion struct{ Number int } }{PreviousVersion: struct{ Number int }{Number: 1}}), - Links: map[string]string{ - "base": "https://example.com", - "webui": "/wiki/page", - }, - }, + name: "already seen", + seenPages: map[string]struct{}{"42": {}}, + page: mkPage("42", 3), + history: false, + setupMocks: func(_ *MockConfluenceClient, _ *chunk.MockIChunk, _ *Page) {}, + expectedErr: nil, + expectedEmitCount: 0, + }, + { + name: "emitInChunks error", + seenPages: map[string]struct{}{}, + page: func() *Page { + pg := mkPage("99", 1) + pg.Body.Storage = &struct { + Value string `json:"value"` + }{Value: strings.Repeat("X", 64)} + return pg + }(), + history: false, + setupMocks: func(mc *MockConfluenceClient, ch *chunk.MockIChunk, p *Page) { + ch.EXPECT().GetFileThreshold().Return(int64(1)).Times(1) + ch.EXPECT().GetSize().Return(64).Times(1) + ch.EXPECT().GetMaxPeekSize().Return(0).Times(1) + ch.EXPECT().ReadChunk(gomock.Any(), -1).Return("", assert.AnError).Times(1) }, - historyEnabled: false, - expectedErrors: nil, - expectedItems: []item{ - { - Content: ptrToString("Page content 1"), - ID: "confluence-spaceKey-pageID", - Source: "https://example.com/wiki/page", - }, + expectedErr: assert.AnError, + expectedEmitCount: 0, + }, + { + name: "history enabled and emitHistory returns error (after emitting current)", + seenPages: map[string]struct{}{}, + page: mkPage("77", 5), + history: true, + setupMocks: func(mc *MockConfluenceClient, ch *chunk.MockIChunk, p *Page) { + ch.EXPECT().GetFileThreshold().Return(int64(mockGetFileThresholdReturn)).Times(1) + mc.EXPECT().WikiBaseURL().Return("https://tenant.atlassian.net/wiki").Times(1) + + mc.EXPECT(). + WalkPageVersions(gomock.Any(), p.ID, maxPageSize, gomock.Any()). + Return(assert.AnError).Times(1) }, + expectedErr: assert.AnError, + expectedEmitCount: 1, // current was emitted before history failed }, } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - mockClient := &mockConfluenceClient{ - pageContentResponse: tt.mockPageContents, - } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() - errorsChan := make(chan error, 3) - itemsChan := make(chan ISourceItem, 3) + mockChunk := chunk.NewMockIChunk(ctrl) + mockClient := NewMockConfluenceClient(ctrl) - plugin := &ConfluencePlugin{ - client: mockClient, - errorsChan: errorsChan, - itemsChan: itemsChan, - History: tt.historyEnabled, + p := &ConfluencePlugin{ + itemsChan: make(chan ISourceItem, 10), + chunker: mockChunk, + client: mockClient, + History: tc.history, } - page := ConfluencePage{ID: "pageID"} - space := ConfluenceSpaceResult{Key: "spaceKey"} + if tc.setupMocks != nil { + tc.setupMocks(mockClient, mockChunk, tc.page) + } - var wg sync.WaitGroup - wg.Add(1) - go plugin.scanPageAllVersions(&wg, page, space) - wg.Wait() + err := p.emitUniquePage(context.Background(), tc.page, tc.seenPages) + assert.ErrorIs(t, err, tc.expectedErr) - if len(tt.expectedErrors) == 0 { - assert.Empty(t, errorsChan) - } + emitted := collectEmittedItems(p.itemsChan) + assert.Len(t, emitted, tc.expectedEmitCount) + }) + } +} + +func TestEmitHistory(t *testing.T) { + const base = "https://tenant.atlassian.net/wiki" + + tests := []struct { + name string + pageID string + currentVersion int + versionsWalked []int + errorsFetchPageAt []error + wikiBase string + expectWikiCalls int + expectThresholdCalls int + expectedIDs []string + expectedErr error + }{ + { + name: "happy path emits historical v1..v4", + pageID: "200", + currentVersion: 5, + versionsWalked: []int{1, 2, 3, 4, 5}, + errorsFetchPageAt: []error{nil, nil, nil, nil, nil}, + wikiBase: base, + expectWikiCalls: 4, // v1..v4 + expectThresholdCalls: 4, // v1..v4 + expectedIDs: []string{"confluence-200-1", "confluence-200-2", "confluence-200-3", "confluence-200-4"}, + expectedErr: nil, + }, + { + name: "error fetching page at v1", + pageID: "200", + currentVersion: 2, + versionsWalked: []int{1, 2}, + errorsFetchPageAt: []error{assert.AnError, nil}, + wikiBase: base, + expectWikiCalls: 0, // fail before any emit + expectThresholdCalls: 0, + expectedIDs: []string{}, + expectedErr: assert.AnError, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockClient := NewMockConfluenceClient(ctrl) + mockChunk := chunk.NewMockIChunk(ctrl) - assert.Equal(t, len(tt.expectedErrors), len(errorsChan)) - for _, expectedError := range tt.expectedErrors { - actualError := <-errorsChan - assert.Equal(t, expectedError, actualError) + mockChunk.EXPECT().GetFileThreshold().Return(int64(mockGetFileThresholdReturn)).Times(tc.expectThresholdCalls) + mockClient.EXPECT().WikiBaseURL().Return(tc.wikiBase).Times(tc.expectWikiCalls) + + for i, v := range tc.versionsWalked { + if v == tc.currentVersion { + continue + } + if tc.errorsFetchPageAt[i] != nil { + mockClient. + EXPECT(). + FetchPageAtVersion(gomock.Any(), tc.pageID, v). + Return(nil, tc.errorsFetchPageAt[i]). + Times(1) + // after first error, the walker stops + break + } + mockClient. + EXPECT(). + FetchPageAtVersion(gomock.Any(), tc.pageID, v). + Return(mkPage(tc.pageID, v), nil). + Times(1) } - assert.Equal(t, len(tt.expectedItems), len(itemsChan)) - for _, expectedItem := range tt.expectedItems { - actualItem := <-itemsChan - assert.Equal(t, &expectedItem, actualItem) + mockClient. + EXPECT(). + WalkPageVersions(gomock.Any(), tc.pageID, maxPageSize, gomock.Any()). + DoAndReturn(func(_ context.Context, _ string, _ int, visit func(int) error) error { + for _, v := range tc.versionsWalked { + if err := visit(v); err != nil { + return err + } + } + return nil + }). + Times(1) + + p := &ConfluencePlugin{ + client: mockClient, + itemsChan: make(chan ISourceItem, 16), + chunker: mockChunk, } - close(errorsChan) - close(itemsChan) + cur := mkPage(tc.pageID, tc.currentVersion) + err := p.emitHistory(context.Background(), cur) + + assert.ErrorIs(t, err, tc.expectedErr) + + items := collectEmittedItems(p.itemsChan) + got := make([]string, len(items)) + for i := range items { + got[i] = items[i].ID + } + assert.ElementsMatch(t, tc.expectedIDs, got) }) } } -func TestScanConfluenceSpace(t *testing.T) { +func TestScanBySpaceIDs(t *testing.T) { + const base = "https://tenant.atlassian.net/wiki" + tests := []struct { - name string - firstPagesRequestError error - expectedError error - numberOfPages int - mockPageContent *ConfluencePageContent + name string + spaceIDs []string + pagesErr error + expectWikiCalls int + expectThresholdCalls int + expectedIDs []string + expectedErr error }{ { - name: "getPages returns error", - firstPagesRequestError: fmt.Errorf("some error before pagination is required"), - expectedError: fmt.Errorf( - "unexpected error creating an http request %w", - fmt.Errorf("some error before pagination is required"), - ), - numberOfPages: 1, - }, - { - name: "scan confluence space with multiple pages", - firstPagesRequestError: nil, - expectedError: nil, - numberOfPages: 3, - mockPageContent: &ConfluencePageContent{ - Body: struct { - Storage struct { - Value string `json:"value"` - } `json:"storage"` - }(struct { - Storage struct { - Value string - } - }{ - Storage: struct{ Value string }{Value: "Page content"}, - }), - History: struct { - PreviousVersion struct{ Number int } `json:"previousVersion"` - }(struct { - PreviousVersion struct { - Number int - } - }{PreviousVersion: struct{ Number int }{Number: 1}}), - Links: map[string]string{ - "base": "https://example.com", - "webui": "/wiki/page", - }, - }, + name: "dedupe and emit pages", + spaceIDs: []string{"S1", "S2", "S1"}, + pagesErr: nil, + expectWikiCalls: 2, + expectThresholdCalls: 2, + expectedIDs: []string{"confluence-P1-1", "confluence-P2-1"}, + expectedErr: nil, + }, + { + name: "error from WalkPagesBySpaceIDs", + spaceIDs: []string{"S1", "S2"}, + pagesErr: assert.AnError, + expectWikiCalls: 0, + expectThresholdCalls: 0, + expectedIDs: []string{}, + expectedErr: assert.AnError, }, } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - mockClient := &mockConfluenceClient{ - firstPagesRequestError: tt.firstPagesRequestError, - numberOfPages: tt.numberOfPages, - pageContentResponse: []*ConfluencePageContent{tt.mockPageContent}, - } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() - errorsChan := make(chan error, 1) - itemsChan := make(chan ISourceItem, 3) + mockChunk := chunk.NewMockIChunk(ctrl) + mockClient := NewMockConfluenceClient(ctrl) - plugin := Plugin{ - Limit: make(chan struct{}, confluenceMaxRequests), + mockChunk.EXPECT().GetFileThreshold().Return(int64(mockGetFileThresholdReturn)).Times(tc.expectThresholdCalls) + mockClient.EXPECT().WikiBaseURL().Return(base).Times(tc.expectWikiCalls) + + p := &ConfluencePlugin{ + itemsChan: make(chan ISourceItem, 10), + client: mockClient, + chunker: mockChunk, + SpaceIDs: tc.spaceIDs, } - confluencePlugin := &ConfluencePlugin{ - Plugin: plugin, - client: mockClient, - errorsChan: errorsChan, - itemsChan: itemsChan, + seenPages := map[string]struct{}{} + seenSpaces := map[string]struct{}{} + + mockClient. + EXPECT(). + WalkPagesBySpaceIDs(gomock.Any(), gomock.Any(), maxPageSize, gomock.Any()). + DoAndReturn(func(_ context.Context, ids []string, _ int, visit func(*Page) error) error { + if tc.pagesErr != nil { + return tc.pagesErr + } + // dedup should yield S1,S2 in any order + assert.ElementsMatch(t, []string{"S1", "S2"}, ids) + _ = visit(mkPage("P1", 1)) + _ = visit(mkPage("P2", 1)) + return nil + }).Times(1) + + err := p.scanBySpaceIDs(context.Background(), seenPages, seenSpaces) + assert.ErrorIs(t, err, tc.expectedErr) + + items := collectEmittedItems(p.itemsChan) + actualIDs := make([]string, 0, len(items)) + for _, it := range items { + actualIDs = append(actualIDs, it.ID) } + assert.ElementsMatch(t, tc.expectedIDs, actualIDs) + }) + } +} - space := ConfluenceSpaceResult{Key: "spaceKey"} - var wg sync.WaitGroup - wg.Add(1) +func TestScanBySpaceKeys(t *testing.T) { + const base = "https://tenant.atlassian.net/wiki" - go confluencePlugin.scanConfluenceSpace(&wg, space) + tests := []struct { + name string + spaceKeys []string + pagesErr error + expectWikiCalls int + expectThresholdCalls int + expectedIDs []string + expectedErr error + }{ + { + name: "dedup and emit pages", + spaceKeys: []string{"K1", "K2", "K1"}, + pagesErr: nil, + expectWikiCalls: 2, + expectThresholdCalls: 2, + expectedIDs: []string{"confluence-P-S1-1", "confluence-P-S2-1"}, + expectedErr: nil, + }, + { + name: "error from WalkPagesBySpaceKeys", + spaceKeys: []string{"K1", "K2"}, + pagesErr: assert.AnError, + expectWikiCalls: 0, + expectThresholdCalls: 0, + expectedIDs: []string{}, + expectedErr: assert.AnError, + }, + } - wg.Wait() + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() - close(errorsChan) - close(itemsChan) + mockChunk := chunk.NewMockIChunk(ctrl) + mockChunk.EXPECT().GetFileThreshold().Return(int64(mockGetFileThresholdReturn)).Times(tc.expectThresholdCalls) - if tt.expectedError != nil { - actualError := <-errorsChan - assert.Equal(t, tt.expectedError, actualError) - } else { - assert.Empty(t, errorsChan) - var actualItems []ISourceItem - for i := 0; i < tt.numberOfPages; i++ { - actualItem := <-itemsChan - actualItems = append(actualItems, actualItem) - } - sort.Slice(actualItems, func(i, j int) bool { - return actualItems[i].GetID() < actualItems[j].GetID() - }) - for i := 0; i < tt.numberOfPages; i++ { - expectedItem := item{ - Content: ptrToString("Page content"), - ID: fmt.Sprintf("confluence-spaceKey-%d", i), - Source: "https://example.com/wiki/page", + mockClient := NewMockConfluenceClient(ctrl) + mockClient.EXPECT().WikiBaseURL().Return(base).Times(tc.expectWikiCalls) + + p := &ConfluencePlugin{ + itemsChan: make(chan ISourceItem, 10), + client: mockClient, + chunker: mockChunk, + SpaceKeys: tc.spaceKeys, + } + + seenPages := map[string]struct{}{} + seenSpaces := map[string]struct{}{} + + // Resolve spaces by keys -> S1 for K1, S2 for K2 (dedup happens in code under test) + mockClient. + EXPECT(). + WalkSpacesByKeys(gomock.Any(), gomock.Any(), maxPageSize, gomock.Any()). + DoAndReturn(func(_ context.Context, keys []string, _ int, visit func(*Space) error) error { + for _, k := range keys { + switch k { + case "K1": + _ = visit(&Space{ID: "S1", Key: "K1"}) + case "K2": + _ = visit(&Space{ID: "S2", Key: "K2"}) + } } - assert.Equal(t, &expectedItem, actualItems[i]) - } + return nil + }). + Times(1) + + // Then walk pages by resolved space IDs in batches + mockClient. + EXPECT(). + WalkPagesBySpaceIDs(gomock.Any(), gomock.Any(), maxPageSize, gomock.Any()). + DoAndReturn(func(_ context.Context, ids []string, _ int, visit func(*Page) error) error { + if tc.pagesErr != nil { + return tc.pagesErr + } + assert.ElementsMatch(t, []string{"S1", "S2"}, ids) + for _, id := range ids { + _ = visit(mkPage("P-"+id, 1)) + } + return nil + }). + Times(1) + + err := p.scanBySpaceKeys(context.Background(), seenPages, seenSpaces) + assert.ErrorIs(t, err, tc.expectedErr) + + items := collectEmittedItems(p.itemsChan) + var actualIDs []string + for _, it := range items { + actualIDs = append(actualIDs, it.ID) } + assert.ElementsMatch(t, tc.expectedIDs, actualIDs) }) } } -func TestScanConfluence(t *testing.T) { +func TestScanByPageIDs(t *testing.T) { + const base = "https://tenant.atlassian.net/wiki" + tests := []struct { - name string - firstSpacesRequestError error - expectedError error - numberOfSpaces int - numberOfPages int - mockPageContent *ConfluencePageContent + name string + pageIDs []string + pagesErr error + expectWikiCalls int + expectThresholdCalls int + expectedIDs []string + expectedErr error }{ { - name: "getSpaces returns error", - firstSpacesRequestError: fmt.Errorf("some error before pagination is required"), - expectedError: fmt.Errorf("some error before pagination is required"), - numberOfPages: 1, - }, - { - name: "scan confluence with multiple spaces and pages", - firstSpacesRequestError: nil, - expectedError: nil, - numberOfSpaces: 3, - numberOfPages: 3, - mockPageContent: &ConfluencePageContent{ - Body: struct { - Storage struct { - Value string `json:"value"` - } `json:"storage"` - }(struct { - Storage struct { - Value string + name: "emit pages", + pageIDs: []string{"1", "2", "3"}, + pagesErr: nil, + expectWikiCalls: 3, + expectThresholdCalls: 3, + expectedIDs: []string{"confluence-1-1", "confluence-2-1", "confluence-3-1"}, + expectedErr: nil, + }, + { + name: "error from WalkPagesByIDs", + pageIDs: []string{"1", "2", "3"}, + pagesErr: assert.AnError, + expectWikiCalls: 0, + expectThresholdCalls: 0, + expectedIDs: []string{}, + expectedErr: assert.AnError, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockChunk := chunk.NewMockIChunk(ctrl) + mockClient := NewMockConfluenceClient(ctrl) + + mockChunk.EXPECT().GetFileThreshold().Return(int64(mockGetFileThresholdReturn)).Times(tc.expectThresholdCalls) + mockClient.EXPECT().WikiBaseURL().Return(base).Times(tc.expectWikiCalls) + + p := &ConfluencePlugin{ + itemsChan: make(chan ISourceItem, 10), + client: mockClient, + chunker: mockChunk, + PageIDs: tc.pageIDs, + } + + seenPages := map[string]struct{}{} + + mockClient. + EXPECT(). + WalkPagesByIDs(gomock.Any(), gomock.Any(), maxPageSize, gomock.Any()). + DoAndReturn(func(_ context.Context, ids []string, _ int, visit func(*Page) error) error { + if tc.pagesErr != nil { + return tc.pagesErr } - }{ - Storage: struct{ Value string }{Value: "Page content"}, - }), - History: struct { - PreviousVersion struct{ Number int } `json:"previousVersion"` - }(struct { - PreviousVersion struct { - Number int + assert.ElementsMatch(t, tc.pageIDs, ids) + for _, id := range ids { + _ = visit(mkPage(id, 1)) } - }{PreviousVersion: struct{ Number int }{Number: 1}}), - Links: map[string]string{ - "base": "https://example.com", - "webui": "/wiki/page", - }, + return nil + }).Times(1) + + err := p.scanByPageIDs(context.Background(), seenPages) + assert.ErrorIs(t, err, tc.expectedErr) + + items := collectEmittedItems(p.itemsChan) + actualIDs := make([]string, len(items)) + for i := range items { + actualIDs[i] = items[i].ID + } + assert.ElementsMatch(t, tc.expectedIDs, actualIDs) + }) + } +} + +func TestWalkAndEmitPages(t *testing.T) { + tests := []struct { + name string + setupMocks func(p *ConfluencePlugin, m *MockConfluenceClient, mc *chunk.MockIChunk) + setupPlugin func(p *ConfluencePlugin) + expectedErr error + expectedIDs []string + expectedSources []string + expectedBodies []string + fileThresholdCalls int + wikiBaseURLCalls int + }{ + { + name: "no filters, history off", + setupMocks: func(p *ConfluencePlugin, m *MockConfluenceClient, mc *chunk.MockIChunk) { + page := mkPage("100", 3) + m.EXPECT(). + WalkAllPages(gomock.Any(), maxPageSize, gomock.Any()). + DoAndReturn(func(_ context.Context, _ int, visit func(*Page) error) error { + return visit(page) + }).Times(1) + }, + setupPlugin: func(p *ConfluencePlugin) { p.History = false }, + expectedErr: nil, + expectedIDs: []string{"confluence-100-3"}, + expectedSources: []string{"https://tenant.atlassian.net/wiki/pages/viewpage.action?pageId=100&pageVersion=3"}, + expectedBodies: []string{"content 100"}, + fileThresholdCalls: 1, + wikiBaseURLCalls: 1, + }, + { + name: "no filters, history on (current + older versions)", + setupMocks: func(p *ConfluencePlugin, m *MockConfluenceClient, mc *chunk.MockIChunk) { + cur := mkPage("200", 5) + m.EXPECT(). + WalkAllPages(gomock.Any(), maxPageSize, gomock.Any()). + DoAndReturn(func(_ context.Context, _ int, visit func(*Page) error) error { + return visit(cur) + }).Times(1) + + m.EXPECT(). + WalkPageVersions(gomock.Any(), "200", maxPageSize, gomock.Any()). + DoAndReturn(func(_ context.Context, _ string, _ int, visit func(int) error) error { + for _, v := range []int{1, 2, 3, 4, 5} { + _ = visit(v) + } + return nil + }).Times(1) + + for _, v := range []int{1, 2, 3, 4} { + m.EXPECT(). + FetchPageAtVersion(gomock.Any(), "200", v). + DoAndReturn(func(_ context.Context, _ string, _ int) (*Page, error) { + return mkPage("200", v), nil + }).Times(1) + } + }, + setupPlugin: func(p *ConfluencePlugin) { p.History = true }, + expectedErr: nil, + expectedIDs: []string{"confluence-200-5", "confluence-200-1", "confluence-200-2", "confluence-200-3", "confluence-200-4"}, + expectedSources: []string{ + "https://tenant.atlassian.net/wiki/pages/viewpage.action?pageId=200&pageVersion=5", + "https://tenant.atlassian.net/wiki/pages/viewpage.action?pageId=200&pageVersion=1", + "https://tenant.atlassian.net/wiki/pages/viewpage.action?pageId=200&pageVersion=2", + "https://tenant.atlassian.net/wiki/pages/viewpage.action?pageId=200&pageVersion=3", + "https://tenant.atlassian.net/wiki/pages/viewpage.action?pageId=200&pageVersion=4", + }, + expectedBodies: []string{"content 200", "content 200", "content 200", "content 200", "content 200"}, + fileThresholdCalls: 5, + wikiBaseURLCalls: 5, + }, + { + name: "SpaceIDs only (dedupe pages)", + setupMocks: func(p *ConfluencePlugin, m *MockConfluenceClient, mc *chunk.MockIChunk) { + m.EXPECT(). + WalkPagesBySpaceIDs(gomock.Any(), []string{"S1", "S2"}, maxPageSize, gomock.Any()). + DoAndReturn(func(_ context.Context, _ []string, _ int, visit func(*Page) error) error { + _ = visit(mkPage("P1", 2)) + _ = visit(mkPage("P1", 2)) + _ = visit(mkPage("P2", 1)) + return nil + }).Times(1) + }, + setupPlugin: func(p *ConfluencePlugin) { p.SpaceIDs = []string{"S1", "S2"} }, + expectedErr: nil, + expectedIDs: []string{"confluence-P1-2", "confluence-P2-1"}, + expectedSources: []string{ + "https://tenant.atlassian.net/wiki/pages/viewpage.action?pageId=P1&pageVersion=2", + "https://tenant.atlassian.net/wiki/pages/viewpage.action?pageId=P2&pageVersion=1", + }, + expectedBodies: []string{"content P1", "content P2"}, + fileThresholdCalls: 2, + wikiBaseURLCalls: 2, + }, + { + name: "PageIDs only (dedupe)", + setupMocks: func(p *ConfluencePlugin, m *MockConfluenceClient, mc *chunk.MockIChunk) { + m.EXPECT(). + WalkPagesByIDs(gomock.Any(), []string{"10", "20", "10"}, maxPageSize, gomock.Any()). + DoAndReturn(func(_ context.Context, ids []string, _ int, visit func(*Page) error) error { + for _, id := range ids { + _ = visit(mkPage(id, 1)) + } + return nil + }).Times(1) + }, + setupPlugin: func(p *ConfluencePlugin) { p.PageIDs = []string{"10", "20", "10"} }, + expectedErr: nil, + expectedIDs: []string{"confluence-10-1", "confluence-20-1"}, + expectedSources: []string{ + "https://tenant.atlassian.net/wiki/pages/viewpage.action?pageId=10&pageVersion=1", + "https://tenant.atlassian.net/wiki/pages/viewpage.action?pageId=20&pageVersion=1", + }, + expectedBodies: []string{"content 10", "content 20"}, + fileThresholdCalls: 2, + wikiBaseURLCalls: 2, + }, + { + name: "filters collide (unique by page ID)", + setupMocks: func(p *ConfluencePlugin, m *MockConfluenceClient, mc *chunk.MockIChunk) { + m.EXPECT(). + WalkPagesBySpaceIDs(gomock.Any(), []string{"S1"}, maxPageSize, gomock.Any()). + DoAndReturn(func(_ context.Context, _ []string, _ int, visit func(*Page) error) error { + _ = visit(mkPage("P1", 3)) + _ = visit(mkPage("P2", 1)) + return nil + }).Times(1) + + m.EXPECT(). + WalkSpacesByKeys(gomock.Any(), gomock.Any(), maxPageSize, gomock.Any()). + DoAndReturn(func(_ context.Context, _ []string, _ int, visit func(*Space) error) error { + _ = visit(&Space{ID: "S1", Key: "Key1"}) + return nil + }).Times(1) + + m.EXPECT(). + WalkPagesByIDs(gomock.Any(), gomock.Any(), maxPageSize, gomock.Any()). + DoAndReturn(func(_ context.Context, ids []string, _ int, visit func(*Page) error) error { + for _, id := range ids { + _ = visit(mkPage(id, 1)) + } + return nil + }).Times(1) + }, + setupPlugin: func(p *ConfluencePlugin) { + p.SpaceIDs = []string{"S1"} + p.SpaceKeys = []string{"Key1"} + p.PageIDs = []string{"P1", "P3"} + }, + expectedErr: nil, + expectedIDs: []string{"confluence-P1-3", "confluence-P2-1", "confluence-P3-1"}, + expectedSources: []string{ + "https://tenant.atlassian.net/wiki/pages/viewpage.action?pageId=P1&pageVersion=3", + "https://tenant.atlassian.net/wiki/pages/viewpage.action?pageId=P2&pageVersion=1", + "https://tenant.atlassian.net/wiki/pages/viewpage.action?pageId=P3&pageVersion=1", + }, + expectedBodies: []string{"content P1", "content P2", "content P3"}, + fileThresholdCalls: 3, + wikiBaseURLCalls: 3, + }, + { + name: "error in WalkPagesBySpaceIDs", + setupMocks: func(p *ConfluencePlugin, m *MockConfluenceClient, mc *chunk.MockIChunk) { + m.EXPECT(). + WalkPagesBySpaceIDs(gomock.Any(), gomock.Any(), maxPageSize, gomock.Any()). + DoAndReturn(func(_ context.Context, _ []string, _ int, _ func(*Page) error) error { + return assert.AnError + }).Times(1) + }, + setupPlugin: func(p *ConfluencePlugin) { p.SpaceIDs = []string{"1", "2"} }, + expectedErr: assert.AnError, + expectedIDs: []string{}, + expectedSources: []string{}, + expectedBodies: []string{}, + fileThresholdCalls: 0, + wikiBaseURLCalls: 0, + }, + { + name: "error in WalkSpacesByKeys", + setupMocks: func(p *ConfluencePlugin, m *MockConfluenceClient, mc *chunk.MockIChunk) { + m.EXPECT(). + WalkSpacesByKeys(gomock.Any(), gomock.Any(), maxPageSize, gomock.Any()). + DoAndReturn(func(_ context.Context, _ []string, _ int, _ func(*Space) error) error { + return assert.AnError + }).Times(1) + }, + setupPlugin: func(p *ConfluencePlugin) { p.SpaceKeys = []string{"Key1", "Key2"} }, + expectedErr: assert.AnError, + expectedIDs: []string{}, + expectedSources: []string{}, + expectedBodies: []string{}, + fileThresholdCalls: 0, + wikiBaseURLCalls: 0, + }, + { + name: "error in WalkPagesByIDs", + setupMocks: func(p *ConfluencePlugin, m *MockConfluenceClient, mc *chunk.MockIChunk) { + m.EXPECT(). + WalkPagesByIDs(gomock.Any(), gomock.Any(), maxPageSize, gomock.Any()). + DoAndReturn(func(_ context.Context, _ []string, _ int, _ func(*Page) error) error { + return assert.AnError + }).Times(1) }, + setupPlugin: func(p *ConfluencePlugin) { p.PageIDs = []string{"1", "2"} }, + expectedErr: assert.AnError, + expectedIDs: []string{}, + expectedSources: []string{}, + expectedBodies: []string{}, + fileThresholdCalls: 0, + wikiBaseURLCalls: 0, + }, + { + name: "error in WalkAllPages", + setupMocks: func(p *ConfluencePlugin, m *MockConfluenceClient, mc *chunk.MockIChunk) { + m.EXPECT(). + WalkAllPages(gomock.Any(), maxPageSize, gomock.Any()). + DoAndReturn(func(_ context.Context, _ int, _ func(*Page) error) error { + return assert.AnError + }).Times(1) + }, + setupPlugin: func(p *ConfluencePlugin) {}, + expectedErr: assert.AnError, + expectedIDs: []string{}, + expectedSources: []string{}, + expectedBodies: []string{}, + fileThresholdCalls: 0, + wikiBaseURLCalls: 0, }, } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - mockClient := &mockConfluenceClient{ - firstSpacesRequestError: tt.firstSpacesRequestError, - numberOfPages: tt.numberOfPages, - numberOfSpaces: tt.numberOfSpaces, - pageContentResponse: []*ConfluencePageContent{tt.mockPageContent}, + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + p, ctrl, mockClient, mockChunk := newPluginWithMock(t) + defer ctrl.Finish() + + p.itemsChan = make(chan ISourceItem, 200) + if tc.setupPlugin != nil { + tc.setupPlugin(p) } - errorsChan := make(chan error, 1) - itemsChan := make(chan ISourceItem, 3) + mockChunk.EXPECT().GetFileThreshold().Return(int64(mockGetFileThresholdReturn)).Times(tc.fileThresholdCalls) + mockClient.EXPECT().WikiBaseURL().Return("https://tenant.atlassian.net/wiki").Times(tc.wikiBaseURLCalls) - plugin := Plugin{ - Limit: make(chan struct{}, confluenceMaxRequests), + tc.setupMocks(p, mockClient, mockChunk) + + err := p.walkAndEmitPages(context.Background()) + assert.ErrorIs(t, err, tc.expectedErr) + + items := collectEmittedItems(p.itemsChan) + actualIDs := make([]string, 0, len(items)) + actualSources := make([]string, 0, len(items)) + actualBodies := make([]string, 0, len(items)) + for _, it := range items { + actualIDs = append(actualIDs, it.ID) + actualSources = append(actualSources, it.Source) + actualBodies = append(actualBodies, it.Content) } + assert.ElementsMatch(t, tc.expectedIDs, actualIDs) + assert.ElementsMatch(t, tc.expectedSources, actualSources) + assert.ElementsMatch(t, tc.expectedBodies, actualBodies) + }) + } +} - confluencePlugin := &ConfluencePlugin{ - Plugin: plugin, - client: mockClient, - errorsChan: errorsChan, - itemsChan: itemsChan, +func TestDefineCommand(t *testing.T) { + t.Run("RunE validation", func(t *testing.T) { + tests := []struct { + name string + expectedErr error + }{ + { + name: "normal execution", + expectedErr: nil, + }, + { + name: "error during execution", + expectedErr: assert.AnError, + }, + } + + recvErr := func(ch <-chan error) error { + select { + case e := <-ch: + return e + default: + return nil } + } - wg := &sync.WaitGroup{} + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() - go confluencePlugin.scanConfluence(wg) + items := make(chan ISourceItem, 8) + errs := make(chan error, 1) - wg.Wait() + mockClient := NewMockConfluenceClient(ctrl) - if tt.expectedError != nil { - actualError := <-errorsChan - assert.Equal(t, tt.expectedError, actualError) - } else { - assert.Empty(t, errorsChan) - var actualItems []ISourceItem - for i := 0; i < tt.numberOfSpaces; i++ { - for j := 0; j < tt.numberOfPages; j++ { - actualItem := <-itemsChan - actualItems = append(actualItems, actualItem) - } + p := &ConfluencePlugin{ + itemsChan: items, + errorsChan: errs, + client: mockClient, } - sort.Slice(actualItems, func(i, j int) bool { - splitID := func(id string) (string, string) { - parts := strings.Split(id, "-") - return parts[1], parts[2] - } - spaceKey1, pageID1 := splitID(actualItems[i].GetID()) - spaceKey2, pageID2 := splitID(actualItems[j].GetID()) - - if spaceKey1 != spaceKey2 { - return spaceKey1 < spaceKey2 - } - return pageID1 < pageID2 - }) - for i := 0; i < tt.numberOfSpaces; i++ { - for j := 0; j < tt.numberOfPages; j++ { - expectedItem := item{ - Content: ptrToString("Page content"), - ID: fmt.Sprintf("confluence-%d-%d", i, j), - Source: "https://example.com/wiki/page", - } - assert.Equal(t, &expectedItem, actualItems[i*tt.numberOfPages+j]) - } + cmd, err := p.DefineCommand(items, errs) + assert.NoError(t, err) + + mockClient.EXPECT(). + WalkAllPages(gomock.Any(), maxPageSize, gomock.Any()). + Return(tc.expectedErr). + Times(1) + + cmd.Run(cmd, []string{"https://tenant.atlassian.net/wiki"}) + + err = recvErr(errs) + assert.ErrorIs(t, err, tc.expectedErr) + }) + } + }) + + t.Run("PreRunE validation", func(t *testing.T) { + tests := []struct { + name string + setFlags func(cmd *cobra.Command) + args []string + expectedErr error + }{ + { + name: "token value but no token type", + setFlags: func(cmd *cobra.Command) { + _ = cmd.Flags().Set("token-value", "value") + // token-type intentionally not set + }, + args: []string{"https://tenant.atlassian.net/wiki"}, + expectedErr: fmt.Errorf("--token-type must be set when --token-value is provided"), + }, + { + name: "invalid token type", + setFlags: func(cmd *cobra.Command) { + _ = cmd.Flags().Set("token-type", "bad") + }, + args: []string{"https://tenant.atlassian.net/wiki"}, + expectedErr: fmt.Errorf("invalid --token-type \"bad\"; valid values are \"api-token\" or \"scoped-api-token\""), + }, + { + name: "token type api-token but without token value", + setFlags: func(cmd *cobra.Command) { + _ = cmd.Flags().Set("token-type", string(ApiToken)) + // no token-value + }, + args: []string{"https://tenant.atlassian.net/wiki"}, + expectedErr: fmt.Errorf("--token-type requires --token-value"), + }, + { + name: "without credentials", + setFlags: func(cmd *cobra.Command) { + // nothing set + }, + args: []string{"https://tenant.atlassian.net/wiki"}, + expectedErr: nil, + }, + { + name: "token type api-token and token value provided", + setFlags: func(cmd *cobra.Command) { + _ = cmd.Flags().Set("username", "user@example.com") + _ = cmd.Flags().Set("token-type", string(ApiToken)) + _ = cmd.Flags().Set("token-value", "tok") + }, + args: []string{"https://tenant.atlassian.net/wiki"}, + expectedErr: nil, + }, + { + name: "initialize fails with invalid base URL", + setFlags: func(cmd *cobra.Command) {}, + args: []string{"%"}, + expectedErr: fmt.Errorf("invalid URL escape"), + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + p := &ConfluencePlugin{} + cmd, err := p.DefineCommand(make(chan ISourceItem, 1), make(chan error, 1)) + assert.NoError(t, err) + + if tc.setFlags != nil { + tc.setFlags(cmd) } - } - }) - } + err = cmd.PreRunE(cmd, tc.args) + if tc.name == "initialize fails with invalid base URL" { + assert.Contains(t, err.Error(), tc.expectedErr.Error()) + } else { + assert.Equal(t, tc.expectedErr, err) + } + }) + } + }) } -func TestInitializeConfluence(t *testing.T) { +func TestEmitInChunks(t *testing.T) { + const base = "https://tenant.atlassian.net/wiki" + tests := []struct { - name string - urlArg string - username string - token string - expectURL string - expectLimit int - expectWarn bool + name string + buildPage func() *Page + setupMock func(m *chunk.MockIChunk) + expectedErr error + expectedIDs []string + expectedSources []string + expectedBodies []string }{ { - name: "Valid credentials", - urlArg: "https://example.com/", - username: "user", - token: "token", - expectURL: "https://example.com", - expectLimit: confluenceMaxRequests, - expectWarn: false, + name: "storage nil: no emission (no chunk calls)", + buildPage: func() *Page { + return &Page{ + ID: "42", + Links: map[string]string{"webui": "/pages/viewpage.action?pageId=42"}, + Body: PageBody{Storage: nil}, + Version: PageVersion{Number: 7}, + } + }, + setupMock: func(_ *chunk.MockIChunk) {}, // still no calls expected + expectedErr: nil, + expectedIDs: []string{}, + expectedSources: []string{}, + expectedBodies: []string{}, }, { - name: "No credentials provided", - urlArg: "https://example.com/", - username: "", - token: "", - expectURL: "https://example.com", - expectLimit: confluenceMaxRequests, - expectWarn: true, + name: "below threshold: single item (full body)", + buildPage: func() *Page { + return &Page{ + ID: "100", + Links: map[string]string{"webui": "/pages/viewpage.action?pageId=100"}, + Body: PageBody{ + Storage: &struct { + Value string `json:"value"` + }{Value: "AAAEND"}, + }, + Version: PageVersion{Number: 3}, + } + }, + setupMock: func(m *chunk.MockIChunk) { + m.EXPECT().GetFileThreshold().Return(int64(100)).Times(1) + }, + expectedErr: nil, + expectedIDs: []string{"confluence-100-3"}, + expectedSources: []string{base + "/pages/viewpage.action?pageId=100&pageVersion=3"}, + expectedBodies: []string{"AAAEND"}, }, { - name: "URL without trailing slash", - urlArg: "https://example.com", - username: "user", - token: "token", - expectURL: "https://example.com", - expectLimit: confluenceMaxRequests, - expectWarn: false, + name: "above threshold: two chunks then EOF", + buildPage: func() *Page { + return &Page{ + ID: "999", + Links: map[string]string{"webui": "/pages/viewpage.action?pageId=999"}, + Body: PageBody{ + Storage: &struct { + Value string `json:"value"` + }{Value: strings.Repeat("x", 50)}, + }, + Version: PageVersion{Number: 8}, + } + }, + setupMock: func(m *chunk.MockIChunk) { + // Force chunking branch + m.EXPECT().GetFileThreshold().Return(int64(1)).Times(1) + m.EXPECT().GetSize().Return(8).Times(1) + m.EXPECT().GetMaxPeekSize().Return(4).Times(1) + + gomock.InOrder( + m.EXPECT(). + ReadChunk(gomock.Any(), -1). + DoAndReturn(func(_ *bufio.Reader, _ int) (string, error) { return "CHUNK-1\n", nil }), + m.EXPECT(). + ReadChunk(gomock.Any(), -1). + DoAndReturn(func(_ *bufio.Reader, _ int) (string, error) { return "CHUNK-2\n", nil }), + m.EXPECT(). + ReadChunk(gomock.Any(), -1). + DoAndReturn(func(_ *bufio.Reader, _ int) (string, error) { return "", io.EOF }), + ) + }, + expectedErr: nil, + expectedIDs: []string{"confluence-999-8", "confluence-999-8"}, + expectedSources: []string{ + base + "/pages/viewpage.action?pageId=999&pageVersion=8", + base + "/pages/viewpage.action?pageId=999&pageVersion=8", + }, + expectedBodies: []string{"CHUNK-1\n", "CHUNK-2\n"}, + }, + { + name: "ReadChunk error: wrapped and returned, no items", + buildPage: func() *Page { + return &Page{ + ID: "500", + Links: map[string]string{"webui": "/pages/viewpage.action?pageId=500"}, + Body: PageBody{ + Storage: &struct { + Value string `json:"value"` + }{Value: "trigger chunking"}, + }, + Version: PageVersion{Number: 1}, + } + }, + setupMock: func(m *chunk.MockIChunk) { + m.EXPECT().GetFileThreshold().Return(int64(1)).Times(1) + m.EXPECT().GetSize().Return(8).Times(1) + m.EXPECT().GetMaxPeekSize().Return(4).Times(1) + m.EXPECT(). + ReadChunk(gomock.Any(), -1). + Return("", assert.AnError). + Times(1) + }, + expectedErr: assert.AnError, + expectedIDs: []string{}, + expectedSources: []string{}, + expectedBodies: []string{}, }, } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - var logBuf bytes.Buffer - log.Logger = zerolog.New(&logBuf) + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() - username = tt.username - token = tt.token + mockChunk := chunk.NewMockIChunk(ctrl) + mockClient := NewMockConfluenceClient(ctrl) + mockClient.EXPECT().WikiBaseURL().Return(base).AnyTimes() - p := &ConfluencePlugin{} + p := &ConfluencePlugin{ + itemsChan: make(chan ISourceItem, 16), + chunker: mockChunk, + client: mockClient, + } + + tc.setupMock(mockChunk) + + err := p.emitInChunks(tc.buildPage()) + assert.ErrorIs(t, err, tc.expectedErr) + + // Drain emitted items + n := len(p.itemsChan) + actualIDs := make([]string, 0, n) + actualSources := make([]string, 0, n) + actualBodies := make([]string, 0, n) + for i := 0; i < n; i++ { + it := <-p.itemsChan + actualIDs = append(actualIDs, it.GetID()) + actualSources = append(actualSources, it.GetSource()) + if it.GetContent() != nil { + actualBodies = append(actualBodies, *it.GetContent()) + } else { + actualBodies = append(actualBodies, "") + } + } + + assert.ElementsMatch(t, tc.expectedIDs, actualIDs) + assert.ElementsMatch(t, tc.expectedSources, actualSources) + // preserve order for bodies (chunks are emitted in sequence) + assert.Equal(t, tc.expectedBodies, actualBodies) + }) + } +} - p.initialize(tt.urlArg) +func newPluginWithMock(t *testing.T) (*ConfluencePlugin, *gomock.Controller, *MockConfluenceClient, *chunk.MockIChunk) { + t.Helper() + ctrl := gomock.NewController(t) - assert.NotNil(t, p.client) - client, ok := p.client.(*confluenceClient) - assert.True(t, ok, "Client should be of type *confluenceClient") + mockClient := NewMockConfluenceClient(ctrl) + mockChunk := chunk.NewMockIChunk(ctrl) - assert.Equal(t, tt.expectURL, client.baseURL) + p := &ConfluencePlugin{ + itemsChan: make(chan ISourceItem, 1000), + client: mockClient, + chunker: mockChunk, + } + return p, ctrl, mockClient, mockChunk +} - assert.Equal(t, tt.username, client.username) - assert.Equal(t, tt.token, client.token) +func mkPage(id string, ver int) *Page { + return &Page{ + ID: id, + Title: "T-" + id, + Links: map[string]string{"webui": "/pages/viewpage.action?pageId=" + id}, + Body: PageBody{ + Storage: &struct { + Value string `json:"value"` + }{Value: "content " + id}, + }, + Version: PageVersion{Number: ver}, + } +} - assert.NotNil(t, p.Limit) - assert.Equal(t, tt.expectLimit, cap(p.Limit)) +type emittedItem struct { + ID string + Source string + Content string +} - logOutput := logBuf.String() - if tt.expectWarn { - assert.Contains(t, logOutput, "confluence credentials were not provided", "Expected warning log missing") - } else { - assert.NotContains(t, logOutput, "confluence credentials were not provided", "Unexpected warning log found") - } +func collectEmittedItems(ch chan ISourceItem) []emittedItem { + n := len(ch) + items := make([]emittedItem, 0, n) + for range n { + it := <-ch + content := "" + if it.GetContent() != nil { + content = *it.GetContent() + } + items = append(items, emittedItem{ + ID: it.GetID(), + Source: it.GetSource(), + Content: content, }) } + return items } -func ptrToString(s string) *string { - return &s +func makeRangeStrings(start, end int) []string { + out := make([]string, 0, end-start+1) + for i := start; i <= end; i++ { + out = append(out, strconv.Itoa(i)) + } + return out }