diff --git a/cmd/main.go b/cmd/main.go index 3aa98928..5d93521f 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -131,17 +131,17 @@ func preRun(pluginName string, cmd *cobra.Command, args []string) error { return err } - engine, err := engine.Init(engineConfigVar) + engineInstance, err := engine.Init(engineConfigVar) if err != nil { return err } - if err := engine.AddRegexRules(customRegexRuleVar); err != nil { + if err := engineInstance.AddRegexRules(customRegexRuleVar); err != nil { return err } Channels.WaitGroup.Add(1) - go ProcessItems(engine, pluginName) + go ProcessItems(engineInstance, pluginName) Channels.WaitGroup.Add(1) go ProcessSecrets() @@ -151,10 +151,10 @@ func preRun(pluginName string, cmd *cobra.Command, args []string) error { if validateVar { Channels.WaitGroup.Add(1) - go ProcessValidationAndScoreWithValidation(engine) + go ProcessValidationAndScoreWithValidation(engineInstance) } else { Channels.WaitGroup.Add(1) - go ProcessScoreWithoutValidation(engine) + go ProcessScoreWithoutValidation(engineInstance) } return nil diff --git a/cmd/workers.go b/cmd/workers.go index 958f8327..c0b7a6d8 100644 --- a/cmd/workers.go +++ b/cmd/workers.go @@ -1,21 +1,37 @@ package cmd import ( + "context" "github.com/checkmarx/2ms/engine" "github.com/checkmarx/2ms/engine/extra" "github.com/checkmarx/2ms/lib/secrets" + "golang.org/x/sync/errgroup" "sync" ) -func ProcessItems(engine *engine.Engine, pluginName string) { +func ProcessItems(engineInstance engine.IEngine, pluginName string) { defer Channels.WaitGroup.Done() - wgItems := &sync.WaitGroup{} + + g, ctx := errgroup.WithContext(context.Background()) for item := range Channels.Items { Report.TotalItemsScanned++ - wgItems.Add(1) - go engine.Detect(item, SecretsChan, wgItems, pluginName, Channels.Errors) + item := item + + switch pluginName { + case "filesystem": + g.Go(func() error { + return engineInstance.DetectFile(ctx, item, SecretsChan) + }) + default: + g.Go(func() error { + return engineInstance.DetectFragment(item, SecretsChan, pluginName) + }) + } + } + + if err := g.Wait(); err != nil { + Channels.Errors <- err } - wgItems.Wait() close(SecretsChan) } @@ -48,7 +64,7 @@ func ProcessSecretsExtras() { wgExtras.Wait() } -func ProcessValidationAndScoreWithValidation(engine *engine.Engine) { +func ProcessValidationAndScoreWithValidation(engine engine.IEngine) { defer Channels.WaitGroup.Done() wgValidation := &sync.WaitGroup{} @@ -64,7 +80,7 @@ func ProcessValidationAndScoreWithValidation(engine *engine.Engine) { engine.Validate() } -func ProcessScoreWithoutValidation(engine *engine.Engine) { +func ProcessScoreWithoutValidation(engine engine.IEngine) { defer Channels.WaitGroup.Done() wgScore := &sync.WaitGroup{} diff --git a/engine/chunk/chunk.go b/engine/chunk/chunk.go new file mode 100644 index 00000000..e3e8d23b --- /dev/null +++ b/engine/chunk/chunk.go @@ -0,0 +1,206 @@ +package chunk + +//go:generate mockgen -source=$GOFILE -destination=${GOPACKAGE}_mock.go -package=${GOPACKAGE} + +import ( + "bufio" + "bytes" + "errors" + "fmt" + "sync" + "unicode" + + "github.com/h2non/filetype" +) + +const ( + defaultSize = 100 * 1024 // 100Kib + defaultMaxPeekSize = 25 * 1024 // 25Kib + defaultFileThreshold = 1 * 1024 * 1024 // 1MiB +) + +var ErrUnsupportedFileType = errors.New("unsupported file type") + +type Option func(*Chunk) + +// WithSize sets the chunk size +func WithSize(size int) Option { + return func(args *Chunk) { + args.size = size + } +} + +// WithMaxPeekSize sets the max size of look-ahead bytes +func WithMaxPeekSize(maxPeekSize int) Option { + return func(args *Chunk) { + args.maxPeekSize = maxPeekSize + } +} + +// WithSmallFileThreshold sets the threshold for small files +func WithSmallFileThreshold(smallFileThreshold int64) Option { + return func(args *Chunk) { + args.smallFileThreshold = smallFileThreshold + } +} + +// Chunk holds two pools and sizing parameters needed for reading chunks of data with look-ahead +type Chunk struct { + bufPool *sync.Pool // *bytes.Buffer with cap Size + MaxPeekSize + peekedBufPool *sync.Pool // *[]byte slices of length Size + MaxPeekSize + size int // base chunk size + maxPeekSize int // max size of look-ahead bytes + smallFileThreshold int64 // files smaller than this skip chunking +} + +type IChunk interface { + GetSize() int + GetMaxPeekSize() int + GetFileThreshold() int64 + ReadChunk(reader *bufio.Reader, totalLines int) (string, error) +} + +func New(opts ...Option) *Chunk { + // set default options + c := &Chunk{ + size: defaultSize, + maxPeekSize: defaultMaxPeekSize, + smallFileThreshold: defaultFileThreshold, + } + // apply overrides + for _, opt := range opts { + opt(c) + } + c.bufPool = &sync.Pool{ + New: func() interface{} { + // pre-allocate dynamic-size buffer for reading chunks (up to chunk size + peek size) + return bytes.NewBuffer(make([]byte, 0, c.size+c.maxPeekSize)) + }, + } + c.peekedBufPool = &sync.Pool{ + New: func() interface{} { + // pre-allocate fixed-size block for loading chunks + b := make([]byte, c.size+c.maxPeekSize) + return &b + }, + } + return c +} + +// GetBuf returns a bytes.Buffer from the pool, seeded with the data +func (c *Chunk) GetBuf(data []byte) (*bytes.Buffer, bool) { + window, ok := c.bufPool.Get().(*bytes.Buffer) + if !ok { + return nil, false + } + window.Write(data) // seed the buffer with the data + return window, ok +} + +// PutBuf returns the bytes.Buffer to the pool +func (c *Chunk) PutBuf(window *bytes.Buffer) { + window.Reset() + c.bufPool.Put(window) +} + +// GetPeekedBuf returns a fixed-size []byte from the pool +func (c *Chunk) GetPeekedBuf() (*[]byte, bool) { + b, ok := c.peekedBufPool.Get().(*[]byte) + return b, ok +} + +// PutPeekedBuf returns the fixed-size []byte to the pool +func (c *Chunk) PutPeekedBuf(b *[]byte) { + *b = (*b)[:0] // reset the slice to zero length + c.peekedBufPool.Put(b) +} + +func (c *Chunk) GetSize() int { + return c.size +} + +func (c *Chunk) GetMaxPeekSize() int { + return c.maxPeekSize +} + +func (c *Chunk) GetFileThreshold() int64 { + return c.smallFileThreshold +} + +// ReadChunk reads the next chunk of data from file +func (c *Chunk) ReadChunk(reader *bufio.Reader, totalLines int) (string, error) { + // borrow a []bytes from the pool and seed it with raw data from file (up to chunk size + peek size) + rawData, ok := c.GetPeekedBuf() + if !ok { + return "", fmt.Errorf("expected *bytes.Buffer, got %T", rawData) + } + defer c.PutPeekedBuf(rawData) + n, err := reader.Read(*rawData) + + var chunkStr string + // "Callers should always process the n > 0 bytes returned before considering the error err." + // https://pkg.go.dev/io#Reader + if n > 0 { + // only check the filetype at the start of file + if totalLines == 0 && ShouldSkipFile((*rawData)[:n]) { + return "", fmt.Errorf("skipping file: %w", ErrUnsupportedFileType) + } + + chunkStr, err = c.generateChunk((*rawData)[:n]) + } + if err != nil { + return "", err + } + return chunkStr, nil +} + +// generateChunk processes block of raw data and generates chunk to be scanned +func (c *Chunk) generateChunk(rawData []byte) (string, error) { + // Borrow a buffer from the pool and seed it with raw data (up to chunk size) + initialChunkLen := min(len(rawData), c.size) + chunkData, ok := c.GetBuf(rawData[:initialChunkLen]) + if !ok { + return "", fmt.Errorf("expected *bytes.Buffer, got %T", chunkData) + } + defer c.PutBuf(chunkData) + + // keep seeding chunk until detecting the “\n...\n” (i.e. safe boundary) + // or reaching the max limit of chunk size (i.e. chunk size + peek size) + for i := chunkData.Len(); i < len(rawData); i++ { + if endsWithTwoNewlines(rawData[:i]) { + break + } + chunkData.WriteByte(rawData[i]) + } + + return chunkData.String(), nil +} + +// endsWithTwoNewlines returns true if b ends in at least two '\n's (ignoring any number of ' ', '\r', or '\t' between them) +func endsWithTwoNewlines(b []byte) bool { + count := 0 + for i := len(b) - 1; i >= 0; i-- { + if b[i] == '\n' { + count++ + if count >= 2 { + return true + } + } else if unicode.IsSpace(rune(b[i])) { + // the presence of other whitespace characters (`\r`, ` `, `\t`) shouldn't reset the count + continue + } else { + return false + } + } + return false +} + +// ShouldSkipFile checks if the file should be skipped based on its content type +func ShouldSkipFile(data []byte) bool { + // TODO: could other optimizations be introduced here? + mimetype, err := filetype.Match(data) + if err != nil { + return true // could not determine file type + } + return mimetype.MIME.Type == "application" // skip binary files +} diff --git a/engine/chunk/chunk_mock.go b/engine/chunk/chunk_mock.go new file mode 100644 index 00000000..75844d0f --- /dev/null +++ b/engine/chunk/chunk_mock.go @@ -0,0 +1,98 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: chunk.go +// +// Generated by this command: +// +// mockgen -source=chunk.go -destination=chunk_mock.go -package=chunk +// + +// Package chunk is a generated GoMock package. +package chunk + +import ( + bufio "bufio" + reflect "reflect" + + gomock "go.uber.org/mock/gomock" +) + +// MockIChunk is a mock of IChunk interface. +type MockIChunk struct { + ctrl *gomock.Controller + recorder *MockIChunkMockRecorder + isgomock struct{} +} + +// MockIChunkMockRecorder is the mock recorder for MockIChunk. +type MockIChunkMockRecorder struct { + mock *MockIChunk +} + +// NewMockIChunk creates a new mock instance. +func NewMockIChunk(ctrl *gomock.Controller) *MockIChunk { + mock := &MockIChunk{ctrl: ctrl} + mock.recorder = &MockIChunkMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockIChunk) EXPECT() *MockIChunkMockRecorder { + return m.recorder +} + +// GetFileThreshold mocks base method. +func (m *MockIChunk) GetFileThreshold() int64 { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetFileThreshold") + ret0, _ := ret[0].(int64) + return ret0 +} + +// GetFileThreshold indicates an expected call of GetFileThreshold. +func (mr *MockIChunkMockRecorder) GetFileThreshold() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetFileThreshold", reflect.TypeOf((*MockIChunk)(nil).GetFileThreshold)) +} + +// GetMaxPeekSize mocks base method. +func (m *MockIChunk) GetMaxPeekSize() int { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetMaxPeekSize") + ret0, _ := ret[0].(int) + return ret0 +} + +// GetMaxPeekSize indicates an expected call of GetMaxPeekSize. +func (mr *MockIChunkMockRecorder) GetMaxPeekSize() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMaxPeekSize", reflect.TypeOf((*MockIChunk)(nil).GetMaxPeekSize)) +} + +// GetSize mocks base method. +func (m *MockIChunk) GetSize() int { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetSize") + ret0, _ := ret[0].(int) + return ret0 +} + +// GetSize indicates an expected call of GetSize. +func (mr *MockIChunkMockRecorder) GetSize() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSize", reflect.TypeOf((*MockIChunk)(nil).GetSize)) +} + +// ReadChunk mocks base method. +func (m *MockIChunk) ReadChunk(reader *bufio.Reader, totalLines int) (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReadChunk", reader, totalLines) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ReadChunk indicates an expected call of ReadChunk. +func (mr *MockIChunkMockRecorder) ReadChunk(reader, totalLines any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadChunk", reflect.TypeOf((*MockIChunk)(nil).ReadChunk), reader, totalLines) +} diff --git a/engine/chunk/chunk_test.go b/engine/chunk/chunk_test.go new file mode 100644 index 00000000..aaa17e5a --- /dev/null +++ b/engine/chunk/chunk_test.go @@ -0,0 +1,162 @@ +package chunk + +import ( + "bufio" + "bytes" + "github.com/stretchr/testify/require" + "io" + "strings" + "testing" +) + +const ( + chunkSize = 10 + maxPeekSize = 5 + smallFileThreshold = int64(20) +) + +func TestGetAndPutBuf(t *testing.T) { + c := New() + data := []byte("test") + buf, ok := c.GetBuf(data) + defer c.PutBuf(buf) + + require.True(t, ok) + require.Equal(t, defaultSize+defaultMaxPeekSize, buf.Cap()) + require.Equal(t, string(data), buf.String()) +} + +func TestGetAndPutPeekedBuf(t *testing.T) { + c := New() + window, ok := c.GetPeekedBuf() + defer c.PutPeekedBuf(window) + + require.True(t, ok) + require.Equal(t, defaultSize+defaultMaxPeekSize, len(*window)) +} + +func TestGetSize(t *testing.T) { + c := New() + require.Equal(t, defaultSize, c.GetSize()) +} + +func TestGetMaxPeekSize(t *testing.T) { + c := New() + require.Equal(t, defaultMaxPeekSize, c.GetMaxPeekSize()) +} + +func TestReadChunk(t *testing.T) { + // Arrange + type testCase struct { + name string + reader io.Reader + expected string + expectedError error + } + testCases := []testCase{ + { + name: "empty", + reader: strings.NewReader(""), + expectedError: io.EOF, + }, + { + name: "unsupported file type", + reader: bytes.NewReader([]byte{'P', 'K', 0x03, 0x04}), + expectedError: ErrUnsupportedFileType, + }, + { + name: "successful read", + reader: strings.NewReader("abc\n"), + expected: "abc\n", + }, + { + name: "successful read - peek size exceeded", + reader: strings.NewReader("abc\ndef\nghi\njkl\nmno\npqr\nstu\nvwx\nyz"), + expected: "abc\ndef\nghi\njkl", + }, + { + name: "successful read - multiple lines with consecutives new lines", + reader: strings.NewReader("abc\ndef\n\n\n\n\nghi\njkl"), + expected: "abc\ndef\n\n\n", + }, + { + name: "multiple lines without consecutives new lines", + reader: strings.NewReader("abc\ndef\nghi\n"), + expected: "abc\ndef\nghi\n", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c := New(WithSize(chunkSize), WithMaxPeekSize(maxPeekSize), WithSmallFileThreshold(smallFileThreshold)) + reader := bufio.NewReaderSize(tc.reader, chunkSize+maxPeekSize) + + // Act + result, err := c.ReadChunk(reader, 0) + require.ErrorIs(t, err, tc.expectedError) + + // Assert + require.Equal(t, tc.expected, result) + }) + } +} + +func TestGenerateChunk(t *testing.T) { + // Arrange + testCases := []struct { + name string + rawData []byte + expected string + }{ + // Current split is fine, exit early. + { + name: "safe original split - LF", + rawData: []byte("abc\ndef\n\n\nghijklmnop\n\nqrstuvwxyz"), + expected: "abc\ndef\n\n\n", + }, + { + name: "safe original split - CRLF", + rawData: []byte("abcdef\r\n\r\nghijklmnop\n"), + expected: "abcdef\r\n\r\n", + }, + // Current split is bad, look for a better one + { + name: "safe split - LF", + rawData: []byte("abcdef\nghi\n\njklmnop\n\nqrstuvwxyz"), + expected: "abcdef\nghi\n\n", + }, + { + name: "safe split - CRLF", + rawData: []byte("abcdef\r\nghi\r\n\r\njklmnopqrstuvwxyz"), + expected: "abcdef\r\nghi\r\n\r\n", + }, + { + name: "safe split - blank line", + rawData: []byte("abcdefghi\n\t \t\njklmnopqrstuvwxyz"), + expected: "abcdefghi\n\t \t\n", + }, + // Current split is bad, exhaust options + { + name: "no safe split", + rawData: []byte("abcdefg\nhijklmnopqrstuvwxyz"), + expected: "abcdefg\nhijklmn", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c := New(WithSize(chunkSize), WithMaxPeekSize(maxPeekSize), WithSmallFileThreshold(smallFileThreshold)) + reader := bufio.NewReaderSize(bytes.NewReader(tc.rawData), c.size+c.maxPeekSize) + peekedBuf := make([]byte, c.size+c.maxPeekSize) + _, err := reader.Read(peekedBuf) + require.NoError(t, err) + + // Act + chunkStr, err := c.generateChunk(peekedBuf) + require.NoError(t, err) + + // Assert + require.Equal(t, tc.expected, chunkStr) + }) + } +} diff --git a/engine/engine.go b/engine/engine.go index 1d2c80d2..5904351c 100644 --- a/engine/engine.go +++ b/engine/engine.go @@ -1,18 +1,24 @@ package engine +//go:generate mockgen -source=$GOFILE -destination=${GOPACKAGE}_mock.go -package=${GOPACKAGE} + import ( + "bufio" + "context" "crypto/sha1" "fmt" + "io" "os" "regexp" "strings" "sync" "text/tabwriter" + "github.com/checkmarx/2ms/engine/chunk" "github.com/checkmarx/2ms/engine/linecontent" - "github.com/checkmarx/2ms/engine/score" - "github.com/checkmarx/2ms/engine/rules" + "github.com/checkmarx/2ms/engine/score" + "github.com/checkmarx/2ms/engine/semaphore" "github.com/checkmarx/2ms/engine/validation" "github.com/checkmarx/2ms/lib/secrets" "github.com/checkmarx/2ms/plugins" @@ -28,12 +34,27 @@ type Engine struct { rulesBaseRiskScore map[string]float64 detector detect.Detector validator validation.Validator + semaphore semaphore.ISemaphore + chunk chunk.IChunk ignoredIds []string allowedValues []string } -const customRegexRuleIdFormat = "custom-regex-%d" +type IEngine interface { + DetectFragment(item plugins.ISourceItem, secretsChannel chan *secrets.Secret, pluginName string) error + DetectFile(ctx context.Context, item plugins.ISourceItem, secretsChannel chan *secrets.Secret) error + AddRegexRules(patterns []string) error + RegisterForValidation(secret *secrets.Secret, wg *sync.WaitGroup) + Score(secret *secrets.Secret, validateFlag bool, wg *sync.WaitGroup) + Validate() + GetRuleBaseRiskScore(ruleId string) float64 +} + +const ( + customRegexRuleIdFormat = "custom-regex-%d" + CxFileEndMarker = ";cx-file-end" +) type EngineConfig struct { SelectedList []string @@ -46,7 +67,7 @@ type EngineConfig struct { AllowedValues []string } -func Init(engineConfig EngineConfig) (*Engine, error) { +func Init(engineConfig EngineConfig) (IEngine, error) { selectedRules := rules.FilterRules(engineConfig.SelectedList, engineConfig.IgnoreList, engineConfig.SpecialList) if len(*selectedRules) == 0 { return nil, fmt.Errorf("no rules were selected") @@ -73,69 +94,139 @@ func Init(engineConfig EngineConfig) (*Engine, error) { rulesBaseRiskScore: rulesBaseRiskScore, detector: *detector, validator: *validation.NewValidator(), + semaphore: semaphore.NewSemaphore(), + chunk: chunk.New(), ignoredIds: engineConfig.IgnoredIds, allowedValues: engineConfig.AllowedValues, }, nil } -func (e *Engine) Detect(item plugins.ISourceItem, secretsChannel chan *secrets.Secret, wg *sync.WaitGroup, pluginName string, errors chan error) { - defer wg.Done() - const CxFileEndMarker = ";cx-file-end" - +// DetectFragment detects secrets in the given fragment +func (e *Engine) DetectFragment(item plugins.ISourceItem, secretsChannel chan *secrets.Secret, pluginName string) error { fragment := detect.Fragment{ Raw: *item.GetContent(), FilePath: item.GetSource(), } - fragment.Raw += CxFileEndMarker + "\n" - gitInfo := item.GetGitInfo() + return e.detectSecrets(item, fragment, secretsChannel, pluginName) +} - values := e.detector.Detect(fragment) +// DetectFile reads the given file and detects secrets in it +func (e *Engine) DetectFile(ctx context.Context, item plugins.ISourceItem, secretsChannel chan *secrets.Secret) error { + fi, err := os.Stat(item.GetSource()) + if err != nil { + return fmt.Errorf("failed to stat %q: %w", item.GetSource(), err) + } - for _, value := range values { - itemId := getFindingId(item, value) - var startLine, endLine int - var err error - if pluginName == "filesystem" { - startLine = value.StartLine + 1 - endLine = value.EndLine + 1 - } else if pluginName == "git" { - startLine, endLine, err = plugins.GetGitStartAndEndLine(gitInfo, value.StartLine, value.EndLine) - if err != nil { - errors <- fmt.Errorf("failed to get git lines for source %s: %w", item.GetSource(), err) - return - } - } else { - startLine = value.StartLine - endLine = value.EndLine + fileSize := fi.Size() + if e.isFileSizeExceedingLimit(fileSize) { + log.Debug().Int64("size", fileSize/1000000).Msg("Skipping file: exceeds --max-target-megabytes") + return nil + } + + // Check if file size exceeds the file threshold, if so, use chunking, if not, read the whole file + if fileSize > e.chunk.GetFileThreshold() { + // ChunkSize * 2 -> raw read buffer + bufio.Reader’s internal slice + // + (ChunkSize+MaxPeekSize) -> peekBuf backing slice + // + (ChunkSize+MaxPeekSize) -> chunkStr copy + weight := int64(e.chunk.GetSize()*4 + e.chunk.GetMaxPeekSize()*2) + err = e.semaphore.AcquireMemoryWeight(ctx, weight) + if err != nil { + return fmt.Errorf("failed to acquire memory: %w", err) } + defer e.semaphore.ReleaseMemoryWeight(weight) + + return e.detectChunks(item, secretsChannel) + } + // fileSize * 2 -> data file bytes and its conversion to string + weight := fileSize * 2 + err = e.semaphore.AcquireMemoryWeight(ctx, weight) + if err != nil { + return fmt.Errorf("failed to acquire memory: %w", err) + } + defer e.semaphore.ReleaseMemoryWeight(weight) + + data, err := os.ReadFile(item.GetSource()) + if err != nil { + return fmt.Errorf("read small file %q: %w", item.GetSource(), err) + } + fragment := detect.Fragment{ + Raw: string(data), + FilePath: item.GetSource(), + } + + return e.detectSecrets(item, fragment, secretsChannel, "filesystem") +} + +// detectChunks reads the given file in chunks and detects secrets in each chunk +func (e *Engine) detectChunks(item plugins.ISourceItem, secretsChannel chan *secrets.Secret) error { + f, err := os.Open(item.GetSource()) + if err != nil { + return fmt.Errorf("failed to open file %s: %w", item.GetSource(), err) + } + defer func() { + _ = f.Close() + }() - value.Line = strings.TrimSuffix(value.Line, CxFileEndMarker) + reader := bufio.NewReaderSize(f, e.chunk.GetSize()+e.chunk.GetMaxPeekSize()) + totalLines := 0 - lineContent, err := linecontent.GetLineContent(value.Line, value.Secret) + // Read the file in chunks until EOF + for { + chunkStr, err := e.chunk.ReadChunk(reader, totalLines) if err != nil { - errors <- fmt.Errorf("failed to get line content for source %s: %w", item.GetSource(), err) - return + if err.Error() == "skipping file: unsupported file type" { + log.Debug().Msgf("Skipping file %s: unsupported file type", item.GetSource()) + return nil + } + if err == io.EOF { + return nil + } + return fmt.Errorf("failed to read file %s: %w", item.GetSource(), err) + } + // Count the number of newlines in this chunk + linesInChunk := strings.Count(chunkStr, "\n") + totalLines += linesInChunk + + // Detect secrets in the chunk + fragment := detect.Fragment{ + Raw: chunkStr, + FilePath: item.GetSource(), } - secret := &secrets.Secret{ - ID: itemId, - Source: item.GetSource(), - RuleID: value.RuleID, - StartLine: startLine, - StartColumn: value.StartColumn, - EndLine: endLine, - EndColumn: value.EndColumn, - Value: value.Secret, - LineContent: lineContent, - RuleDescription: value.Description, + if detectErr := e.detectSecrets(item, fragment, secretsChannel, "filesystem"); detectErr != nil { + return fmt.Errorf("failed to detect secrets: %w", detectErr) + } + } +} + +// detectSecrets detects secrets and sends them to the secrets channel +func (e *Engine) detectSecrets(item plugins.ISourceItem, fragment detect.Fragment, secrets chan *secrets.Secret, + pluginName string) error { + fragment.Raw += CxFileEndMarker + "\n" + + values := e.detector.Detect(fragment) + for _, value := range values { + secret, buildErr := buildSecret(item, value, pluginName) + if buildErr != nil { + return fmt.Errorf("failed to build secret: %w", buildErr) } if !isSecretIgnored(secret, &e.ignoredIds, &e.allowedValues) { - secretsChannel <- secret + secrets <- secret } else { log.Debug().Msgf("Secret %s was ignored", secret.ID) } } + return nil +} + +// isFileSizeExceedingLimit checks if the file size exceeds the max target megabytes limit +func (e *Engine) isFileSizeExceedingLimit(fileSize int64) bool { + if e.detector.MaxTargetMegaBytes > 0 { + rawLength := fileSize / 1000000 // convert to MB + return rawLength > int64(e.detector.MaxTargetMegaBytes) + } + return false } func (e *Engine) AddRegexRules(patterns []string) error { @@ -155,42 +246,26 @@ func (e *Engine) AddRegexRules(patterns []string) error { return nil } -func (s *Engine) RegisterForValidation(secret *secrets.Secret, wg *sync.WaitGroup) { +func (e *Engine) RegisterForValidation(secret *secrets.Secret, wg *sync.WaitGroup) { defer wg.Done() - s.validator.RegisterForValidation(secret) + e.validator.RegisterForValidation(secret) } -func (s *Engine) Score(secret *secrets.Secret, validateFlag bool, wg *sync.WaitGroup) { +func (e *Engine) Score(secret *secrets.Secret, validateFlag bool, wg *sync.WaitGroup) { defer wg.Done() validationStatus := secrets.UnknownResult // default validity if validateFlag { validationStatus = secret.ValidationStatus } - secret.CvssScore = score.GetCvssScore(s.GetRuleBaseRiskScore(secret.RuleID), validationStatus) -} - -func (s *Engine) Validate() { - s.validator.Validate() + secret.CvssScore = score.GetCvssScore(e.GetRuleBaseRiskScore(secret.RuleID), validationStatus) } -func getFindingId(item plugins.ISourceItem, finding report.Finding) string { - idParts := []string{item.GetID(), finding.RuleID, finding.Secret} - sha := sha1.Sum([]byte(strings.Join(idParts, "-"))) - return fmt.Sprintf("%x", sha) +func (e *Engine) Validate() { + e.validator.Validate() } -func isSecretIgnored(secret *secrets.Secret, ignoredIds, allowedValues *[]string) bool { - for _, allowedValue := range *allowedValues { - if secret.Value == allowedValue { - return true - } - } - for _, ignoredId := range *ignoredIds { - if secret.ID == ignoredId { - return true - } - } - return false +func (e *Engine) GetRuleBaseRiskScore(ruleId string) float64 { + return e.rulesBaseRiskScore[ruleId] } func GetRulesCommand(engineConfig *EngineConfig) *cobra.Command { @@ -230,6 +305,73 @@ func GetRulesCommand(engineConfig *EngineConfig) *cobra.Command { } } -func (s *Engine) GetRuleBaseRiskScore(ruleId string) float64 { - return s.rulesBaseRiskScore[ruleId] +// buildSecret creates a secret object from the given source item and finding +func buildSecret(item plugins.ISourceItem, value report.Finding, pluginName string) (*secrets.Secret, error) { + gitInfo := item.GetGitInfo() + itemId := getFindingId(item, value) + startLine, endLine, err := getStartAndEndLines(pluginName, gitInfo, value) + if err != nil { + return nil, fmt.Errorf("failed to get start and end lines for source %s: %w", item.GetSource(), err) + } + + value.Line = strings.TrimSuffix(value.Line, CxFileEndMarker) + + lineContent, err := linecontent.GetLineContent(value.Line, value.Secret) + if err != nil { + return nil, fmt.Errorf("failed to get line content for source %s: %w", item.GetSource(), err) + } + + secret := &secrets.Secret{ + ID: itemId, + Source: item.GetSource(), + RuleID: value.RuleID, + StartLine: startLine, + StartColumn: value.StartColumn, + EndLine: endLine, + EndColumn: value.EndColumn, + Value: value.Secret, + LineContent: lineContent, + RuleDescription: value.Description, + } + return secret, nil +} + +func getFindingId(item plugins.ISourceItem, finding report.Finding) string { + idParts := []string{item.GetID(), finding.RuleID, finding.Secret} + sha := sha1.Sum([]byte(strings.Join(idParts, "-"))) + return fmt.Sprintf("%x", sha) +} + +func getStartAndEndLines(pluginName string, gitInfo *plugins.GitInfo, value report.Finding) (int, int, error) { + var startLine, endLine int + var err error + + if pluginName == "filesystem" { + startLine = value.StartLine + 1 + endLine = value.EndLine + 1 + } else if pluginName == "git" { + startLine, endLine, err = plugins.GetGitStartAndEndLine(gitInfo, value.StartLine, value.EndLine) + if err != nil { + return 0, 0, err + } + } else { + startLine = value.StartLine + endLine = value.EndLine + } + + return startLine, endLine, nil +} + +func isSecretIgnored(secret *secrets.Secret, ignoredIds, allowedValues *[]string) bool { + for _, allowedValue := range *allowedValues { + if secret.Value == allowedValue { + return true + } + } + for _, ignoredId := range *ignoredIds { + if secret.ID == ignoredId { + return true + } + } + return false } diff --git a/engine/engine_mock.go b/engine/engine_mock.go new file mode 100644 index 00000000..049fbed5 --- /dev/null +++ b/engine/engine_mock.go @@ -0,0 +1,136 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: engine.go +// +// Generated by this command: +// +// mockgen -source=engine.go -destination=engine_mock.go -package=engine +// + +// Package engine is a generated GoMock package. +package engine + +import ( + context "context" + reflect "reflect" + sync "sync" + + secrets "github.com/checkmarx/2ms/lib/secrets" + plugins "github.com/checkmarx/2ms/plugins" + gomock "go.uber.org/mock/gomock" +) + +// MockIEngine is a mock of IEngine interface. +type MockIEngine struct { + ctrl *gomock.Controller + recorder *MockIEngineMockRecorder + isgomock struct{} +} + +// MockIEngineMockRecorder is the mock recorder for MockIEngine. +type MockIEngineMockRecorder struct { + mock *MockIEngine +} + +// NewMockIEngine creates a new mock instance. +func NewMockIEngine(ctrl *gomock.Controller) *MockIEngine { + mock := &MockIEngine{ctrl: ctrl} + mock.recorder = &MockIEngineMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockIEngine) EXPECT() *MockIEngineMockRecorder { + return m.recorder +} + +// AddRegexRules mocks base method. +func (m *MockIEngine) AddRegexRules(patterns []string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AddRegexRules", patterns) + ret0, _ := ret[0].(error) + return ret0 +} + +// AddRegexRules indicates an expected call of AddRegexRules. +func (mr *MockIEngineMockRecorder) AddRegexRules(patterns any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddRegexRules", reflect.TypeOf((*MockIEngine)(nil).AddRegexRules), patterns) +} + +// DetectFile mocks base method. +func (m *MockIEngine) DetectFile(ctx context.Context, item plugins.ISourceItem, secretsChannel chan *secrets.Secret) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DetectFile", ctx, item, secretsChannel) + ret0, _ := ret[0].(error) + return ret0 +} + +// DetectFile indicates an expected call of DetectFile. +func (mr *MockIEngineMockRecorder) DetectFile(ctx, item, secretsChannel any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DetectFile", reflect.TypeOf((*MockIEngine)(nil).DetectFile), ctx, item, secretsChannel) +} + +// DetectFragment mocks base method. +func (m *MockIEngine) DetectFragment(item plugins.ISourceItem, secretsChannel chan *secrets.Secret, pluginName string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DetectFragment", item, secretsChannel, pluginName) + ret0, _ := ret[0].(error) + return ret0 +} + +// DetectFragment indicates an expected call of DetectFragment. +func (mr *MockIEngineMockRecorder) DetectFragment(item, secretsChannel, pluginName any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DetectFragment", reflect.TypeOf((*MockIEngine)(nil).DetectFragment), item, secretsChannel, pluginName) +} + +// GetRuleBaseRiskScore mocks base method. +func (m *MockIEngine) GetRuleBaseRiskScore(ruleId string) float64 { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetRuleBaseRiskScore", ruleId) + ret0, _ := ret[0].(float64) + return ret0 +} + +// GetRuleBaseRiskScore indicates an expected call of GetRuleBaseRiskScore. +func (mr *MockIEngineMockRecorder) GetRuleBaseRiskScore(ruleId any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRuleBaseRiskScore", reflect.TypeOf((*MockIEngine)(nil).GetRuleBaseRiskScore), ruleId) +} + +// RegisterForValidation mocks base method. +func (m *MockIEngine) RegisterForValidation(secret *secrets.Secret, wg *sync.WaitGroup) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "RegisterForValidation", secret, wg) +} + +// RegisterForValidation indicates an expected call of RegisterForValidation. +func (mr *MockIEngineMockRecorder) RegisterForValidation(secret, wg any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterForValidation", reflect.TypeOf((*MockIEngine)(nil).RegisterForValidation), secret, wg) +} + +// Score mocks base method. +func (m *MockIEngine) Score(secret *secrets.Secret, validateFlag bool, wg *sync.WaitGroup) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Score", secret, validateFlag, wg) +} + +// Score indicates an expected call of Score. +func (mr *MockIEngineMockRecorder) Score(secret, validateFlag, wg any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Score", reflect.TypeOf((*MockIEngine)(nil).Score), secret, validateFlag, wg) +} + +// Validate mocks base method. +func (m *MockIEngine) Validate() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Validate") +} + +// Validate indicates an expected call of Validate. +func (mr *MockIEngineMockRecorder) Validate() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Validate", reflect.TypeOf((*MockIEngine)(nil).Validate)) +} diff --git a/engine/engine_test.go b/engine/engine_test.go index 8fb67a96..d76bf02c 100644 --- a/engine/engine_test.go +++ b/engine/engine_test.go @@ -1,18 +1,42 @@ package engine import ( + "bytes" + "context" "fmt" - "github.com/stretchr/testify/assert" - "sync" + "go.uber.org/mock/gomock" + "io" + "os" + "path/filepath" "testing" + "github.com/checkmarx/2ms/engine/chunk" "github.com/checkmarx/2ms/engine/rules" + "github.com/checkmarx/2ms/engine/semaphore" "github.com/checkmarx/2ms/lib/secrets" "github.com/checkmarx/2ms/plugins" + "github.com/rs/zerolog" + "github.com/rs/zerolog/log" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/zricethezav/gitleaks/v8/config" + "github.com/zricethezav/gitleaks/v8/detect" ) var fsPlugin = &plugins.FileSystemPlugin{} +type mock struct { + semaphore *semaphore.MockISemaphore + chunk *chunk.MockIChunk +} + +func newMock(ctrl *gomock.Controller) *mock { + return &mock{ + semaphore: semaphore.NewMockISemaphore(ctrl), + chunk: chunk.NewMockIChunk(ctrl), + } +} + func Test_Init(t *testing.T) { allRules := *rules.FilterRules([]string{}, []string{}, []string{}) specialRule := rules.HardcodedPassword() @@ -78,10 +102,10 @@ func TestDetector(t *testing.T) { } secretsChan := make(chan *secrets.Secret, 1) - errorsChan := make(chan error, 1) - wg := &sync.WaitGroup{} - wg.Add(1) - detector.Detect(i, secretsChan, wg, fsPlugin.GetName(), errorsChan) + err = detector.DetectFragment(i, secretsChan, fsPlugin.GetName()) + if err != nil { + return + } close(secretsChan) s := <-secretsChan @@ -154,12 +178,11 @@ func TestSecrets(t *testing.T) { t.Run(name, func(t *testing.T) { fmt.Printf("Start test %s", name) secretsChan := make(chan *secrets.Secret, 1) - errorsChan := make(chan error, 1) - wg := &sync.WaitGroup{} - wg.Add(1) - detector.Detect(item{content: &secret.Content}, secretsChan, wg, fsPlugin.GetName(), errorsChan) + err = detector.DetectFragment(item{content: &secret.Content}, secretsChan, fsPlugin.GetName()) + if err != nil { + return + } close(secretsChan) - close(errorsChan) s := <-secretsChan @@ -172,6 +195,247 @@ func TestSecrets(t *testing.T) { } } +func TestDetectFile(t *testing.T) { + fileSize := 10 + sizeThreshold := int64(20) + chunkSize := 5 + maxPeekSize := 10 + chunkWeight := int64(4*chunkSize + 2*maxPeekSize) // 40 bytes + + testCases := []struct { + name string + makeFile func(tmp string) string + mockFunc func(m *mock) + maxMegabytes int + memoryBudget int64 + expectedLog string + expectedErr error + }{ + { + name: "non existent file", + makeFile: func(tmp string) string { return filepath.Join(tmp, "does-not-exist") }, + mockFunc: func(m *mock) {}, + memoryBudget: 1_000, + expectedErr: fmt.Errorf("failed to stat"), + }, + { + name: "exceed max megabytes", + makeFile: func(tmp string) string { return writeTempFile(t, tmp, 2000000, nil) /* 2MB */ }, + mockFunc: func(m *mock) {}, + maxMegabytes: 1, + memoryBudget: 1_000, + }, + { + name: "small file - acquire error", + makeFile: func(tmp string) string { return writeTempFile(t, tmp, fileSize, nil) }, + mockFunc: func(m *mock) { + weight := int64(fileSize * 2) + m.chunk.EXPECT().GetFileThreshold().Return(sizeThreshold) + m.semaphore.EXPECT().AcquireMemoryWeight(gomock.Any(), weight).Return(assert.AnError) + }, + memoryBudget: int64(fileSize*2) - 1, // 19 bytes < 2*filesize = 20 bytes + expectedErr: assert.AnError, + }, + { + name: "small file - success & release", + makeFile: func(tmp string) string { return writeTempFile(t, tmp, fileSize, nil) }, + mockFunc: func(m *mock) { + weight := int64(fileSize * 2) + m.chunk.EXPECT().GetFileThreshold().Return(sizeThreshold) + m.semaphore.EXPECT().AcquireMemoryWeight(gomock.Any(), weight).Return(nil) + m.semaphore.EXPECT().ReleaseMemoryWeight(weight) + }, + memoryBudget: 1_000, + }, + { + name: "large file - acquire error", + makeFile: func(tmp string) string { return writeTempFile(t, tmp, fileSize*2+1, nil) }, + mockFunc: func(m *mock) { + m.chunk.EXPECT().GetFileThreshold().Return(sizeThreshold) + m.chunk.EXPECT().GetSize().Return(chunkSize) + m.chunk.EXPECT().GetMaxPeekSize().Return(maxPeekSize) + m.semaphore.EXPECT().AcquireMemoryWeight(gomock.Any(), chunkWeight).Return(assert.AnError) + }, + memoryBudget: chunkWeight - 1, // 40 - 1 byte < 40 bytes + expectedErr: assert.AnError, + }, + { + name: "large file - read chunk error", + makeFile: func(tmp string) string { return writeTempFile(t, tmp, fileSize*2+1, nil) }, + mockFunc: func(m *mock) { + m.chunk.EXPECT().GetFileThreshold().Return(sizeThreshold) + m.chunk.EXPECT().GetSize().Return(chunkSize) + m.chunk.EXPECT().GetMaxPeekSize().Return(maxPeekSize) + m.semaphore.EXPECT().AcquireMemoryWeight(gomock.Any(), chunkWeight).Return(nil) + m.chunk.EXPECT().GetSize().Return(chunkSize) + m.chunk.EXPECT().GetMaxPeekSize().Return(maxPeekSize) + m.chunk.EXPECT().ReadChunk(gomock.Any(), gomock.Any()).Return("", assert.AnError) + m.semaphore.EXPECT().ReleaseMemoryWeight(chunkWeight) + }, + memoryBudget: 1_000, + expectedErr: assert.AnError, + }, + { + name: "large file - success & release", + makeFile: func(tmp string) string { + return writeTempFile(t, tmp, 0, []byte("abc\ndef\nghi\njkl\nmno\npqr\nstu\nvwx\nyz")) + }, + mockFunc: func(m *mock) { + m.chunk.EXPECT().GetFileThreshold().Return(sizeThreshold) + m.chunk.EXPECT().GetSize().Return(chunkSize) + m.chunk.EXPECT().GetMaxPeekSize().Return(maxPeekSize) + m.semaphore.EXPECT().AcquireMemoryWeight(gomock.Any(), chunkWeight).Return(nil) + m.chunk.EXPECT().GetSize().Return(chunkSize) + m.chunk.EXPECT().GetMaxPeekSize().Return(maxPeekSize) + m.chunk.EXPECT().ReadChunk(gomock.Any(), gomock.Any()).Return("abc\ndef\nghi\njkl\nmno\npqr\nstu\nvw", nil) + m.chunk.EXPECT().ReadChunk(gomock.Any(), gomock.Any()).Return("x\nyz", nil) + m.chunk.EXPECT().ReadChunk(gomock.Any(), gomock.Any()).Return("", io.EOF) + m.semaphore.EXPECT().ReleaseMemoryWeight(chunkWeight) + }, + memoryBudget: 1_000, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + var logsBuffer bytes.Buffer + log.Logger = log.Output(zerolog.ConsoleWriter{ + Out: &logsBuffer, + NoColor: true, + TimeFormat: "", + }).Level(zerolog.DebugLevel) + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + m := newMock(ctrl) + tc.mockFunc(m) + + cfg.Rules = make(map[string]config.Rule) + cfg.Keywords = []string{} + detector := detect.NewDetector(cfg) + detector.MaxTargetMegaBytes = tc.maxMegabytes + engine := &Engine{ + rules: nil, + + semaphore: m.semaphore, + chunk: m.chunk, + detector: *detector, + } + + tmp := t.TempDir() + src := tc.makeFile(tmp) + ctx := context.Background() + err := engine.DetectFile(ctx, &item{source: src}, make(chan *secrets.Secret, 1)) + loggedMessage := logsBuffer.String() + if tc.expectedErr != nil { + require.ErrorContains(t, err, tc.expectedErr.Error()) + } + if tc.expectedLog != "" { + expectedLog := fmt.Sprintf(tc.expectedLog, src) + require.Contains(t, loggedMessage, expectedLog) + } + }) + } +} + +func TestDetectChunks(t *testing.T) { + chunkSize := 5 + maxPeekSize := 20 + + testCases := []struct { + name string + makeFile func(tmp string) string + mockFunc func(m *mock) + expectedLog string + expectedErr error + }{ + { + name: "successful detection", + makeFile: func(tmp string) string { return writeTempFile(t, tmp, 0, []byte("password=supersecret\n")) }, + mockFunc: func(m *mock) { + m.chunk.EXPECT().GetSize().Return(chunkSize) + m.chunk.EXPECT().GetMaxPeekSize().Return(maxPeekSize) + m.chunk.EXPECT().ReadChunk(gomock.Any(), 0).Return("password=supersecret", nil) + m.chunk.EXPECT().ReadChunk(gomock.Any(), 0).Return("", io.EOF) + }, + }, + { + name: "non existent file", + makeFile: func(tmp string) string { return filepath.Join(tmp, "does-not-exist") }, + mockFunc: func(m *mock) {}, + expectedErr: fmt.Errorf("failed to open file"), + }, + { + name: "unsupported file type", + makeFile: func(tmp string) string { return writeTempFile(t, tmp, 0, []byte{'P', 'K', 0x03, 0x04}) }, + mockFunc: func(m *mock) { + m.chunk.EXPECT().GetSize().Return(chunkSize) + m.chunk.EXPECT().GetMaxPeekSize().Return(maxPeekSize) + m.chunk.EXPECT().ReadChunk(gomock.Any(), 0).Return("", fmt.Errorf("skipping file: unsupported file type")) + }, + expectedLog: "Skipping file %s: unsupported file type", + }, + { + name: "end of file error", + makeFile: func(tmp string) string { return writeTempFile(t, tmp, 0, []byte("password=supersecret\n")) }, + mockFunc: func(m *mock) { + m.chunk.EXPECT().GetSize().Return(chunkSize) + m.chunk.EXPECT().GetMaxPeekSize().Return(maxPeekSize) + m.chunk.EXPECT().ReadChunk(gomock.Any(), 0).Return("", io.EOF) + }, + }, + { + name: "chunk read error", + makeFile: func(tmp string) string { return writeTempFile(t, tmp, 0, []byte("password=supersecret\n")) }, + mockFunc: func(m *mock) { + m.chunk.EXPECT().GetSize().Return(chunkSize) + m.chunk.EXPECT().GetMaxPeekSize().Return(maxPeekSize) + m.chunk.EXPECT().ReadChunk(gomock.Any(), 0).Return("", assert.AnError) + }, + expectedErr: assert.AnError, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + var logsBuffer bytes.Buffer + log.Logger = log.Output(zerolog.ConsoleWriter{ + Out: &logsBuffer, + NoColor: true, + TimeFormat: "", + }).Level(zerolog.DebugLevel) + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + m := newMock(ctrl) + tc.mockFunc(m) + + cfg.Rules = make(map[string]config.Rule) + cfg.Keywords = []string{} + detector := detect.NewDetector(cfg) + engine := &Engine{ + rules: nil, + + semaphore: m.semaphore, + chunk: m.chunk, + detector: *detector, + } + tmp := t.TempDir() + src := tc.makeFile(tmp) + + err := engine.detectChunks(&item{source: src}, make(chan *secrets.Secret, 1)) + loggedMessage := logsBuffer.String() + if tc.expectedErr != nil { + require.ErrorContains(t, err, tc.expectedErr.Error()) + } + if tc.expectedLog != "" { + expectedLog := fmt.Sprintf(tc.expectedLog, src) + require.Contains(t, loggedMessage, expectedLog) + } + }) + } +} + type item struct { content *string id string @@ -199,3 +463,27 @@ func (i item) GetSource() string { func (i item) GetGitInfo() *plugins.GitInfo { return nil } + +// writeTempFile writes either the provided content or a buffer of 'size' bytes +func writeTempFile(t *testing.T, dir string, size int, content []byte) string { + t.Helper() + + f, err := os.CreateTemp(dir, "testfile-*.tmp") + require.NoError(t, err, "create temp file") + defer f.Close() + + var data []byte + if content != nil { + data = content + } else { + data = make([]byte, size) + for i := range data { + data[i] = 'a' + } + } + + _, err = f.Write(data) + require.NoError(t, err, "write temp file") + + return f.Name() +} diff --git a/engine/semaphore/semaphore.go b/engine/semaphore/semaphore.go new file mode 100644 index 00000000..0417cb12 --- /dev/null +++ b/engine/semaphore/semaphore.go @@ -0,0 +1,113 @@ +package semaphore + +//go:generate mockgen -source=$GOFILE -destination=${GOPACKAGE}_mock.go -package=${GOPACKAGE} + +import ( + "context" + "fmt" + "github.com/shirou/gopsutil/mem" + "golang.org/x/sync/semaphore" + "os" + "strconv" + "strings" +) + +type Semaphore struct { + memoryBudget int64 + sem *semaphore.Weighted +} + +type ISemaphore interface { + AcquireMemoryWeight(ctx context.Context, weight int64) error + ReleaseMemoryWeight(weight int64) +} + +func NewSemaphore() *Semaphore { + b := chooseMemoryBudget() + return NewSemaphoreWithBudget(b) +} + +func NewSemaphoreWithBudget(b int64) *Semaphore { + return &Semaphore{ + memoryBudget: b, + sem: semaphore.NewWeighted(b), + } +} + +// AcquireMemoryWeight acquires semaphore with a specified weight +func (s *Semaphore) AcquireMemoryWeight(ctx context.Context, weight int64) error { + if weight > s.memoryBudget { + return fmt.Errorf("buffer size %d exceeds memory budget %d", weight, s.memoryBudget) + } + if err := s.sem.Acquire(ctx, weight); err != nil { + return fmt.Errorf("failed to acquire semaphore: %w", err) + } + return nil +} + +// ReleaseMemoryWeight releases semaphore with a specified weight +func (s *Semaphore) ReleaseMemoryWeight(weight int64) { + s.sem.Release(weight) +} + +// getCgroupMemoryLimit returns the memory cap imposed by cgroups in bytes +func getCgroupMemoryLimit() uint64 { + // Try cgroup v2: unified hierarchy + if data, err := os.ReadFile("/sys/fs/cgroup/memory.max"); err == nil { + s := strings.TrimSpace(string(data)) + if s != "max" { + if v, err := strconv.ParseUint(s, 10, 64); err == nil { + return v + } + } + } + // Fallback cgroup v1 + if data, err := os.ReadFile("/sys/fs/cgroup/memory/memory.limit_in_bytes"); err == nil { + if v, err := strconv.ParseUint(strings.TrimSpace(string(data)), 10, 64); err == nil { + return v + } + } + // No limit detected + return ^uint64(0) // max uint64 +} + +// getTotalMemory returns the total physical RAM in bytes +func getTotalMemory() uint64 { + if vm, err := mem.VirtualMemory(); err == nil { + return vm.Total + } + return ^uint64(0) // max uint64 +} + +// computeMemoryBudget computes the memory budget based on the host memory and cgroup limits +func computeMemoryBudget(totalHost, cgroupLimit uint64) int64 { + // Effective total = min(host, cgroup) + var effectiveTotal uint64 + if totalHost < cgroupLimit { + effectiveTotal = totalHost + } else { + effectiveTotal = cgroupLimit + } + + // use 50% but cap to [256 MiB -> total − safety margin] + safetyMargin := uint64(200 * 1024 * 1024) // reserve 200 MiB for OS/other processes + avail := effectiveTotal + if effectiveTotal > safetyMargin { + avail = effectiveTotal - safetyMargin + } + budget := int64(avail / 2) // use half of what remains + if budget < 256*1024*1024 { + budget = 256 * 1024 * 1024 + } + return budget +} + +// chooseMemoryBudget picks 50% of total RAM (but at least 256 MiB) +func chooseMemoryBudget() int64 { + // Physical RAM + totalHost := getTotalMemory() + // Cgroup limit + cgroupLimit := getCgroupMemoryLimit() + + return computeMemoryBudget(totalHost, cgroupLimit) +} diff --git a/engine/semaphore/semaphore_mock.go b/engine/semaphore/semaphore_mock.go new file mode 100644 index 00000000..53c50076 --- /dev/null +++ b/engine/semaphore/semaphore_mock.go @@ -0,0 +1,67 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: semaphore.go +// +// Generated by this command: +// +// mockgen -source=semaphore.go -destination=semaphore_mock.go -package=semaphore +// + +// Package semaphore is a generated GoMock package. +package semaphore + +import ( + context "context" + reflect "reflect" + + gomock "go.uber.org/mock/gomock" +) + +// MockISemaphore is a mock of ISemaphore interface. +type MockISemaphore struct { + ctrl *gomock.Controller + recorder *MockISemaphoreMockRecorder + isgomock struct{} +} + +// MockISemaphoreMockRecorder is the mock recorder for MockISemaphore. +type MockISemaphoreMockRecorder struct { + mock *MockISemaphore +} + +// NewMockISemaphore creates a new mock instance. +func NewMockISemaphore(ctrl *gomock.Controller) *MockISemaphore { + mock := &MockISemaphore{ctrl: ctrl} + mock.recorder = &MockISemaphoreMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockISemaphore) EXPECT() *MockISemaphoreMockRecorder { + return m.recorder +} + +// AcquireMemoryWeight mocks base method. +func (m *MockISemaphore) AcquireMemoryWeight(ctx context.Context, weight int64) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AcquireMemoryWeight", ctx, weight) + ret0, _ := ret[0].(error) + return ret0 +} + +// AcquireMemoryWeight indicates an expected call of AcquireMemoryWeight. +func (mr *MockISemaphoreMockRecorder) AcquireMemoryWeight(ctx, weight any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcquireMemoryWeight", reflect.TypeOf((*MockISemaphore)(nil).AcquireMemoryWeight), ctx, weight) +} + +// ReleaseMemoryWeight mocks base method. +func (m *MockISemaphore) ReleaseMemoryWeight(weight int64) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "ReleaseMemoryWeight", weight) +} + +// ReleaseMemoryWeight indicates an expected call of ReleaseMemoryWeight. +func (mr *MockISemaphoreMockRecorder) ReleaseMemoryWeight(weight any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReleaseMemoryWeight", reflect.TypeOf((*MockISemaphore)(nil).ReleaseMemoryWeight), weight) +} diff --git a/engine/semaphore/semaphore_test.go b/engine/semaphore/semaphore_test.go new file mode 100644 index 00000000..83c3a485 --- /dev/null +++ b/engine/semaphore/semaphore_test.go @@ -0,0 +1,81 @@ +package semaphore + +import ( + "context" + "fmt" + "github.com/stretchr/testify/assert" + "testing" +) + +func TestComputeMemoryBudget(t *testing.T) { + mib := 1024 * 1024 // 1MiB + safety := 200 * mib // 200MiB + + type testCase struct { + name string + hostMemory uint64 + cgroupLimit uint64 + expectedBudget int64 + } + testCases := []testCase{ + { + name: "host mememory only", + hostMemory: uint64(4 * 1024 * mib), // 4GiB + cgroupLimit: ^uint64(0), + expectedBudget: int64((4*1024*mib - safety) / 2), // 2GiB - 200MiB + }, + { + name: "cgroup tighter than host", + hostMemory: uint64(4 * 1024 * mib), // 4GiB + cgroupLimit: uint64(2 * 1024 * mib), // 2GiB + expectedBudget: int64((2*1024*mib - safety) / 2), // 1GiB - 200MiB + }, + { + name: "floor budget to 256MiB", + hostMemory: uint64(300 * mib), // 300MiB + cgroupLimit: 0, + expectedBudget: int64(256 * mib), // 256MiB + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + budget := computeMemoryBudget(tc.hostMemory, tc.cgroupLimit) + assert.Equal(t, tc.expectedBudget, budget, "Expected budget does not match actual budget") + }) + } +} + +func TestAcquireReleaseMemoryWeight(t *testing.T) { + weight := int64(1024) // 1KiB + defaultMemoryBudget := int64(1024 * 1024) // 1MiB + type testCase struct { + name string + memoryBudget int64 + expectedError error + } + + testCases := []testCase{ + { + name: "successful acquisition and release", + memoryBudget: defaultMemoryBudget, + }, + { + name: "failed acquisition - over budget", + memoryBudget: weight - 1, + expectedError: fmt.Errorf("buffer size %d exceeds memory budget %d", weight, weight-1), + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + sem := NewSemaphoreWithBudget(tc.memoryBudget) + + err := sem.AcquireMemoryWeight(context.Background(), weight) + if err == nil { + sem.ReleaseMemoryWeight(weight) + } else { + assert.Equal(t, tc.expectedError, err) + } + }) + } +} diff --git a/go.mod b/go.mod index 046f0ef1..1ed47316 100644 --- a/go.mod +++ b/go.mod @@ -5,13 +5,17 @@ go 1.23.6 require ( github.com/bwmarrin/discordgo v0.27.1 github.com/gitleaks/go-gitdiff v0.9.0 + github.com/h2non/filetype v1.1.3 github.com/rs/zerolog v1.32.0 + github.com/shirou/gopsutil v3.21.11+incompatible github.com/slack-go/slack v0.12.2 github.com/spf13/cobra v1.8.0 github.com/spf13/pflag v1.0.6 github.com/spf13/viper v1.20.1 github.com/stretchr/testify v1.10.0 github.com/zricethezav/gitleaks/v8 v8.18.2 + go.uber.org/mock v0.5.2 + golang.org/x/sync v0.12.0 golang.org/x/time v0.5.0 gopkg.in/yaml.v3 v3.0.1 ) @@ -23,9 +27,9 @@ require ( github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/fatih/semgroup v1.2.0 // indirect github.com/fsnotify/fsnotify v1.8.0 // indirect + github.com/go-ole/go-ole v1.2.6 // indirect github.com/go-viper/mapstructure/v2 v2.2.1 // indirect github.com/gorilla/websocket v1.5.0 // indirect - github.com/h2non/filetype v1.1.3 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/lucasb-eyer/go-colorful v1.2.0 // indirect github.com/lucasjones/reggen v0.0.0-20200904144131-37ba4fa293bb // indirect @@ -42,10 +46,10 @@ require ( github.com/spf13/afero v1.14.0 // indirect github.com/spf13/cast v1.7.1 // indirect github.com/subosito/gotenv v1.6.0 // indirect + github.com/yusufpapurcu/wmi v1.2.4 // indirect go.uber.org/multierr v1.11.0 // indirect golang.org/x/crypto v0.35.0 // indirect - golang.org/x/sync v0.12.0 // indirect - golang.org/x/sys v0.30.0 // indirect + golang.org/x/sys v0.31.0 // indirect golang.org/x/text v0.23.0 // indirect ) diff --git a/go.sum b/go.sum index 8e77827a..35e15b3c 100644 --- a/go.sum +++ b/go.sum @@ -19,6 +19,8 @@ github.com/fsnotify/fsnotify v1.8.0 h1:dAwr6QBTBZIkG8roQaJjGof0pp0EeF+tNV7YBP3F/ github.com/fsnotify/fsnotify v1.8.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0= github.com/gitleaks/go-gitdiff v0.9.0 h1:SHAU2l0ZBEo8g82EeFewhVy81sb7JCxW76oSPtR/Nqg= github.com/gitleaks/go-gitdiff v0.9.0/go.mod h1:pKz0X4YzCKZs30BL+weqBIG7mx0jl4tF1uXV9ZyNvrA= +github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY= +github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0= github.com/go-test/deep v1.0.4 h1:u2CU3YKy9I2pmu9pX0eq50wCgjfGIt539SqR7FbHiho= github.com/go-test/deep v1.0.4/go.mod h1:wGDj63lr65AM2AQyKZd/NYHGb0R+1RLqB8NKt3aSFNA= github.com/go-viper/mapstructure/v2 v2.2.1 h1:ZAaOCxANMuZx5RCeg0mBdEZk7DZasvvZIxtHqx8aGss= @@ -73,6 +75,8 @@ github.com/rs/zerolog v1.32.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWR github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/sagikazarmark/locafero v0.9.0 h1:GbgQGNtTrEmddYDSAH9QLRyfAHY12md+8YFTqyMTC9k= github.com/sagikazarmark/locafero v0.9.0/go.mod h1:UBUyz37V+EdMS3hDF3QWIiVr/2dPrx49OMO0Bn0hJqk= +github.com/shirou/gopsutil v3.21.11+incompatible h1:+1+c1VGhc88SSonWP6foOcLhvnKlUeu/erjjvaPEYiI= +github.com/shirou/gopsutil v3.21.11+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA= github.com/slack-go/slack v0.12.2 h1:x3OppyMyGIbbiyFhsBmpf9pwkUzMhthJMRNmNlA4LaQ= github.com/slack-go/slack v0.12.2/go.mod h1:hlGi5oXA+Gt+yWTPP0plCdRKmjsDxecdHxYQdlMQKOw= github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo= @@ -94,8 +98,12 @@ github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8= github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0= +github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0= github.com/zricethezav/gitleaks/v8 v8.18.2 h1:slo/sMmgs3qA+6Vv6iqVhsCv+gsl3RekQXqDN0M4g5M= github.com/zricethezav/gitleaks/v8 v8.18.2/go.mod h1:8F5GrdCpEtyN5R+0MKPubbOPqIHptNckH3F7bYrhT+Y= +go.uber.org/mock v0.5.2 h1:LbtPTcP8A5k9WPXj54PPPbjcI4Y6lhyOZXn+VS7wNko= +go.uber.org/mock v0.5.2/go.mod h1:wLlUxC2vVTPTaE3UD51E0BGOAElKrILxhVSDYQLld5o= go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= @@ -116,6 +124,7 @@ golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw= golang.org/x/sync v0.12.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= +golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -125,8 +134,8 @@ golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc= -golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik= +golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/telemetry v0.0.0-20240228155512-f48c80bd79b2/go.mod h1:TeRTkGYfJXctD9OcfyVLyj2J3IxLnKwHJR8f4D8a3YE= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.28.0/go.mod h1:Sw/lC2IAUZ92udQNf3WodGtn4k/XoLyZoh8v/8uiwek= diff --git a/plugins/filesystem.go b/plugins/filesystem.go index 91d1e539..2a973447 100644 --- a/plugins/filesystem.go +++ b/plugins/filesystem.go @@ -118,17 +118,11 @@ func (p *FileSystemPlugin) GetItems(items chan ISourceItem, errs chan error, wg } func (p *FileSystemPlugin) getItem(filePath string) (*item, error) { - log.Debug().Str("file", filePath).Msg("reading file") - b, err := os.ReadFile(filePath) - if err != nil { - return nil, err - } + log.Debug().Str("file", filePath).Msg("sending file item") - content := string(b) item := &item{ - Content: &content, - ID: fmt.Sprintf("%s-%s-%s", p.GetName(), p.ProjectName, filePath), - Source: filePath, + ID: fmt.Sprintf("%s-%s-%s", p.GetName(), p.ProjectName, filePath), + Source: filePath, } return item, nil } diff --git a/plugins/filesystem_test.go b/plugins/filesystem_test.go index 5c374431..a5ab71e0 100644 --- a/plugins/filesystem_test.go +++ b/plugins/filesystem_test.go @@ -17,10 +17,6 @@ func TestGetItem(t *testing.T) { assert.NoError(t, err, "failed to remove temp file") }(tmpFile.Name()) - expectedContent := "mock expected content" - _, err = tmpFile.WriteString(expectedContent) - assert.NoError(t, err, "failed to write to temp file") - err = tmpFile.Close() assert.NoError(t, err, "failed to close temp file") @@ -31,8 +27,6 @@ func TestGetItem(t *testing.T) { it, err := plugin.getItem(tmpFile.Name()) assert.NoError(t, err, "getItem returned an error") - assert.Equal(t, expectedContent, *it.Content, "content should match the written content") - expectedID := fmt.Sprintf("%s-%s-%s", plugin.GetName(), plugin.ProjectName, tmpFile.Name()) assert.Equal(t, expectedID, it.ID, "ID should match the expected format") } @@ -53,8 +47,7 @@ func TestGetItems(t *testing.T) { assert.NoError(t, err, "failed to close temporary file") validFile := tmpFile.Name() - invalidFile := "nonexistent_file.txt" - fileList := []string{validFile, invalidFile} + fileList := []string{validFile} itemsChan := make(chan ISourceItem, len(fileList)) errsChan := make(chan error, len(fileList)) @@ -75,19 +68,10 @@ func TestGetItems(t *testing.T) { for itm := range itemsChan { items = append(items, itm) } - var errs []error - for e := range errsChan { - errs = append(errs, e) - } assert.Equal(t, 1, len(items), "should have one valid item") - assert.Equal(t, 1, len(errs), "should have one error") - - validItem, ok := items[0].(item) + _, ok := items[0].(item) assert.True(t, ok, "item should be of type item") - assert.Equal(t, validContent, *validItem.Content, "content mismatch for valid item") - - assert.Error(t, errs[0], "expected an error for invalid file") } func TestGetFiles(t *testing.T) {