diff --git a/.github/workflows/close-issues.yml b/.github/workflows/close-issues.yml index 6d00bc25b..a6a6d4363 100644 --- a/.github/workflows/close-issues.yml +++ b/.github/workflows/close-issues.yml @@ -3,6 +3,8 @@ on: schedule: - cron: "30 1 * * *" +permissions: {} + jobs: close-issues: runs-on: ubuntu-latest diff --git a/.github/workflows/codeql-analysis.yml b/.github/workflows/codeql-analysis.yml index e98e67f9d..535107a26 100644 --- a/.github/workflows/codeql-analysis.yml +++ b/.github/workflows/codeql-analysis.yml @@ -3,10 +3,20 @@ on: pull_request: schedule: - cron: '0 6 * * 6' + +permissions: {} + +concurrency: + group: codeql-${{ github.ref }} + cancel-in-progress: ${{ github.event_name == 'pull_request' }} + jobs: analyze: name: Analyze runs-on: ubuntu-latest + permissions: + security-events: write + contents: read steps: - name: Checkout repository diff --git a/.github/workflows/fuzz.yml b/.github/workflows/fuzz.yml index a601a92a6..faca0e3d8 100644 --- a/.github/workflows/fuzz.yml +++ b/.github/workflows/fuzz.yml @@ -6,15 +6,20 @@ on: - cron: "05 14 * * *" workflow_dispatch: +permissions: {} + jobs: fuzz: name: Fuzz tests runs-on: ubuntu-latest + permissions: + contents: read + issues: write steps: - uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6 - uses: actions/setup-go@44694675825211faa026b3c33043df3e48a5fa00 # v6 with: - go-version: ">=1.24.0" + go-version: ">=1.25.0" - run: go run mage.go fuzz - run: | gh issue create --title "$GITHUB_WORKFLOW #$GITHUB_RUN_NUMBER failed" \ diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 6e373b8a7..eb6a69972 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -14,9 +14,17 @@ on: - "**/*.md" - "LICENSE" +permissions: {} + +concurrency: + group: lint-${{ github.ref }} + cancel-in-progress: ${{ github.event_name == 'pull_request' }} + jobs: lint: runs-on: ubuntu-latest + permissions: + contents: read steps: - uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6 - name: Install Go diff --git a/.github/workflows/regression.yml b/.github/workflows/regression.yml index 0c8974611..d8b4e1acb 100644 --- a/.github/workflows/regression.yml +++ b/.github/workflows/regression.yml @@ -8,37 +8,51 @@ on: - "**/*.md" - "LICENSE" pull_request: + branches: + - main paths-ignore: - "**/*.md" - "LICENSE" -jobs: - # Generate matrix of tags for all permutations of the tests - generate-matrix: - runs-on: ubuntu-latest - outputs: - tags: ${{ steps.generate.outputs.tags }} - steps: - - name: Checkout code - uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6 +permissions: {} + +concurrency: + group: regression-${{ github.ref }} + cancel-in-progress: ${{ github.event_name == 'pull_request' }} - - name: Generate tag combinations - id: generate - run: | - go run mage.go tagsmatrix > tags.json - echo "tags=$(cat tags.json)" >> "$GITHUB_OUTPUT" - shell: bash +jobs: test: - needs: generate-matrix strategy: fail-fast: false matrix: - go-version: [1.24.x, 1.25.x] + go-version: [1.25.x] os: [ubuntu-latest] - build-flag: ${{ fromJson(needs.generate-matrix.outputs.tags) }} + build-flag: + - "" + - "coraza.rule.mandatory_rule_id_check" + - "coraza.rule.case_sensitive_args_keys" + - "coraza.rule.no_regex_multiline" + - "coraza.no_memoize" + - "coraza.rule.multiphase_evaluation" + - "no_fs_access" + - "coraza.rule.multiphase_evaluation,coraza.rule.mandatory_rule_id_check" + - "coraza.rule.multiphase_evaluation,coraza.rule.case_sensitive_args_keys" + - "coraza.rule.multiphase_evaluation,coraza.rule.no_regex_multiline" + - "no_fs_access,coraza.no_memoize" + - "coraza.rule.mandatory_rule_id_check,coraza.rule.case_sensitive_args_keys,coraza.rule.no_regex_multiline" + - "coraza.rule.multiphase_evaluation,coraza.rule.mandatory_rule_id_check,coraza.rule.case_sensitive_args_keys,coraza.rule.no_regex_multiline,coraza.no_memoize,no_fs_access" + include: + - go-version: 1.26.x + os: ubuntu-latest + build-flag: "" + - go-version: 1.26.x + os: ubuntu-latest + build-flag: "coraza.rule.multiphase_evaluation,coraza.rule.mandatory_rule_id_check,coraza.rule.case_sensitive_args_keys,coraza.rule.no_regex_multiline,coraza.no_memoize,no_fs_access" runs-on: ${{ matrix.os }} + permissions: + contents: read env: - GOLANG_BASE_VERSION: "1.24.x" + GOLANG_BASE_VERSION: "1.25.x" steps: - name: Checkout code uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6 @@ -48,9 +62,9 @@ jobs: go-version: ${{ matrix.go-version }} cache: true - name: Tests and coverage - run: | - export BUILD_TAGS=${{ matrix.build-flag }} - go run mage.go coverage + env: + BUILD_TAGS: ${{ matrix.build-flag }} + run: go run mage.go coverage - name: "Codecov: General" uses: codecov/codecov-action@5a1091511ad55cbe89839c7260b706298ca349f7 # v5 if: ${{ matrix.go-version == env.GOLANG_BASE_VERSION }} @@ -84,12 +98,13 @@ jobs: runs-on: ubuntu-latest permissions: checks: read + contents: read steps: - name: GitHub Checks uses: poseidon/wait-for-status-checks@899c768d191b56eef585c18f8558da19e1f3e707 # v0.6.0 with: token: ${{ secrets.GITHUB_TOKEN }} - delay: 120s # give some time to matrix jobs + delay: 30s interval: 10s # default value timeout: 3600s # default value ignore: "codecov/patch,codecov/project" diff --git a/.github/workflows/tinygo.yml b/.github/workflows/tinygo.yml index 00d2de703..9e15119f7 100644 --- a/.github/workflows/tinygo.yml +++ b/.github/workflows/tinygo.yml @@ -14,15 +14,23 @@ on: - "**/*.md" - "LICENSE" +permissions: {} + +concurrency: + group: tinygo-${{ github.ref }} + cancel-in-progress: ${{ github.event_name == 'pull_request' }} + jobs: test: strategy: matrix: - go-version: [1.24.x] + go-version: [1.25.x] # tinygo-version is meant to stay aligned with the one used in corazawaf/coraza-proxy-wasm tinygo-version: [0.40.1] os: [ubuntu-latest] runs-on: ${{ matrix.os }} + permissions: + contents: read steps: - name: Checkout code uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6 @@ -53,5 +61,5 @@ jobs: - name: Tests run: tinygo test -v -short ./internal/... - - name: Tests memoize - run: tinygo test -v -short -tags=memoize_builders ./internal/... + - name: Tests no_memoize + run: tinygo test -v -short -tags=coraza.no_memoize ./internal/... diff --git a/README.md b/README.md index f284bd945..eea6ad613 100644 --- a/README.md +++ b/README.md @@ -100,9 +100,11 @@ have compatibility guarantees across minor versions - use with care. the operator with `plugins.RegisterOperator` to reduce binary size / startup overhead. * `coraza.rule.multiphase_evaluation` - enables evaluation of rule variables in the phases that they are ready, not only the phase the rule is defined for. -* `memoize_builders` - enables memoization of builders for regex and aho-corasick -dictionaries to reduce memory consumption in deployments that launch several coraza -instances. For more context check [this issue](https://github.com/corazawaf/coraza-caddy/issues/76) +* `coraza.no_memoize` - disables the default memoization of regex and aho-corasick builders. +Memoization is enabled by default and uses a global cache to reuse compiled patterns across WAF +instances, reducing memory consumption and startup overhead. In long-lived processes that perform +live reloads, use `WAF.Close()` (via `experimental.WAFCloser`) to release cached entries when a +WAF is destroyed, or use this tag to opt out of memoization entirely. * `no_fs_access` - indicates that the target environment has no access to FS in order to not leverage OS' filesystem related functionality e.g. file body buffers. * `coraza.rule.case_sensitive_args_keys` - enables case-sensitive matching for ARGS keys, aligning Coraza behavior with RFC 3986 specification. It will be enabled by default in the next major version. * `coraza.rule.no_regex_multiline` - disables enabling by default regexes multiline modifiers in `@rx` operator. It aligns with CRS expected behavior, reduces false positives and might improve performances. No multiline regexes by default will be enabled in the next major version. For more context check [this PR](https://github.com/corazawaf/coraza/pull/876) diff --git a/config.go b/config.go index ffb3d5894..6795c9591 100644 --- a/config.go +++ b/config.go @@ -94,6 +94,7 @@ type wafRule struct { // int is a signed integer type that is at least 32 bits in size (platform-dependent size). // We still basically assume 64-bit usage where int are big sizes. type wafConfig struct { + ruleObserver func(rule types.RuleMetadata) rules []wafRule auditLog *auditLogConfig requestBodyAccess bool @@ -119,6 +120,12 @@ func (c *wafConfig) WithRules(rules ...*corazawaf.Rule) WAFConfig { return ret } +func (c *wafConfig) WithRuleObserver(observer func(rule types.RuleMetadata)) WAFConfig { + ret := c.clone() + ret.ruleObserver = observer + return ret +} + func (c *wafConfig) WithDirectivesFromFile(path string) WAFConfig { ret := c.clone() ret.rules = append(ret.rules, wafRule{file: path}) diff --git a/coraza.conf-recommended b/coraza.conf-recommended index 8ef869eae..b0d3654bf 100644 --- a/coraza.conf-recommended +++ b/coraza.conf-recommended @@ -34,6 +34,23 @@ SecRule REQUEST_HEADERS:Content-Type "^application/json" \ SecRule REQUEST_HEADERS:Content-Type "^application/[a-z0-9.-]+[+]json" \ "id:'200006',phase:1,t:none,t:lowercase,pass,nolog,ctl:requestBodyProcessor=JSON" +# Configures the maximum JSON recursion depth limit Coraza will accept. +SecRequestBodyJsonDepthLimit 1024 + +# Enable JSON stream request body parser for NDJSON (Newline Delimited JSON) format. +# This processor handles streaming JSON where each line contains a complete JSON object. +# Commonly used for bulk data imports, log streaming, and batch API endpoints. +# Each JSON object is indexed by line number: json.0.field, json.1.field, etc. +# +#SecRule REQUEST_HEADERS:Content-Type "^application/x-ndjson" \ +# "id:'200007',phase:1,t:none,t:lowercase,pass,nolog,ctl:requestBodyProcessor=JSONSTREAM" +# +#SecRule REQUEST_HEADERS:Content-Type "^application/jsonlines" \ +# "id:'200008',phase:1,t:none,t:lowercase,pass,nolog,ctl:requestBodyProcessor=JSONSTREAM" +# +#SecRule REQUEST_HEADERS:Content-Type "^application/json-seq" \ +# "id:'200010',phase:1,t:none,t:lowercase,pass,nolog,ctl:requestBodyProcessor=JSONSTREAM" + # Maximum request body size we will accept for buffering. If you support # file uploads, this value must has to be as large as the largest file # you are willing to accept. @@ -89,6 +106,9 @@ SecResponseBodyAccess On # configuration below to catch documents but avoid static files # (e.g., images and archives). # +# Note: Add 'application/json' and 'application/x-ndjson' if you want to +# inspect JSON and NDJSON response bodies +# SecResponseBodyMimeType text/plain text/html text/xml # Buffer response bodies of up to 512 KB in length. @@ -118,11 +138,11 @@ SecDataDir /tmp/ # #SecUploadDir /opt/coraza/var/upload/ -# If On, the WAF will store the uploaded files in the SecUploadDir -# directory. -# Note: SecUploadKeepFiles is currently NOT supported by Coraza +# Controls whether intercepted uploaded files will be kept after +# transaction is processed. Possible values: On, Off, RelevantOnly. +# RelevantOnly will keep files only when a matching rule is logged (rules with 'nolog' do not qualify). # -#SecUploadKeepFiles Off +#SecUploadKeepFiles RelevantOnly # Uploaded files are by default created with permissions that do not allow # any other user to access them. You may need to relax that if you want to diff --git a/examples/http-server/go.mod b/examples/http-server/go.mod index f2c2e1835..d1e0de7a3 100644 --- a/examples/http-server/go.mod +++ b/examples/http-server/go.mod @@ -1,11 +1,11 @@ module github.com/corazawaf/coraza/v3/examples/http-server -go 1.24.0 +go 1.25.0 require github.com/corazawaf/coraza/v3 v3.3.3 require ( - github.com/corazawaf/libinjection-go v0.2.2 // indirect + github.com/corazawaf/libinjection-go v0.3.2 // indirect github.com/google/go-cmp v0.6.0 // indirect github.com/magefile/mage v1.15.1-0.20250615140142-78acbaf2e3ae // indirect github.com/petar-dambovaliev/aho-corasick v0.0.0-20250424160509-463d218d4745 // indirect @@ -13,9 +13,9 @@ require ( github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.1 // indirect github.com/valllabh/ocsf-schema-golang v1.0.3 // indirect - golang.org/x/net v0.43.0 // indirect - golang.org/x/sync v0.16.0 // indirect - golang.org/x/tools v0.35.0 // indirect + golang.org/x/net v0.52.0 // indirect + golang.org/x/sync v0.20.0 // indirect + golang.org/x/tools v0.42.0 // indirect google.golang.org/protobuf v1.35.1 // indirect rsc.io/binaryregexp v0.2.0 // indirect ) diff --git a/examples/http-server/go.sum b/examples/http-server/go.sum index df496a76e..3cba2f706 100644 --- a/examples/http-server/go.sum +++ b/examples/http-server/go.sum @@ -2,8 +2,8 @@ github.com/corazawaf/coraza-coreruleset v0.0.0-20240226094324-415b1017abdc h1:Ol github.com/corazawaf/coraza-coreruleset v0.0.0-20240226094324-415b1017abdc/go.mod h1:7rsocqNDkTCira5T0M7buoKR2ehh7YZiPkzxRuAgvVU= github.com/corazawaf/coraza/v3 v3.3.3 h1:kqjStHAgWqwP5dh7n0vhTOF0a3t+VikNS/EaMiG0Fhk= github.com/corazawaf/coraza/v3 v3.3.3/go.mod h1:xSaXWOhFMSbrV8qOOfBKAyw3aOqfwaSaOy5BgSF8XlA= -github.com/corazawaf/libinjection-go v0.2.2 h1:Chzodvb6+NXh6wew5/yhD0Ggioif9ACrQGR4qjTCs1g= -github.com/corazawaf/libinjection-go v0.2.2/go.mod h1:OP4TM7xdJ2skyXqNX1AN1wN5nNZEmJNuWbNPOItn7aw= +github.com/corazawaf/libinjection-go v0.3.2 h1:9rrKt0lpg4WvUXt+lwS06GywfqRXXsa/7JcOw5cQLwI= +github.com/corazawaf/libinjection-go v0.3.2/go.mod h1:Ik/+w3UmTWH9yn366RgS9D95K3y7Atb5m/H/gXzzPCk= github.com/foxcpp/go-mockdns v1.1.0 h1:jI0rD8M0wuYAxL7r/ynTrCQQq0BVqfB99Vgk7DlmewI= github.com/foxcpp/go-mockdns v1.1.0/go.mod h1:IhLeSFGed3mJIAXPH2aiRQB+kqz7oqu8ld2qVbOu7Wk= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= @@ -25,16 +25,16 @@ github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/valllabh/ocsf-schema-golang v1.0.3 h1:eR8k/3jP/OOqB8LRCtdJ4U+vlgd/gk5y3KMXoodrsrw= github.com/valllabh/ocsf-schema-golang v1.0.3/go.mod h1:sZ3as9xqm1SSK5feFWIR2CuGeGRhsM7TR1MbpBctzPk= -golang.org/x/mod v0.26.0 h1:EGMPT//Ezu+ylkCijjPc+f4Aih7sZvaAr+O3EHBxvZg= -golang.org/x/mod v0.26.0/go.mod h1:/j6NAhSk8iQ723BGAUyoAcn7SlD7s15Dp9Nd/SfeaFQ= -golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE= -golang.org/x/net v0.43.0/go.mod h1:vhO1fvI4dGsIjh73sWfUVjj3N7CA9WkKJNQm2svM6Jg= -golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw= -golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= -golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI= -golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= -golang.org/x/tools v0.35.0 h1:mBffYraMEf7aa0sB+NuKnuCy8qI/9Bughn8dC2Gu5r0= -golang.org/x/tools v0.35.0/go.mod h1:NKdj5HkL/73byiZSJjqJgKn3ep7KjFkBOkR/Hps3VPw= +golang.org/x/mod v0.33.0 h1:tHFzIWbBifEmbwtGz65eaWyGiGZatSrT9prnU8DbVL8= +golang.org/x/mod v0.33.0/go.mod h1:swjeQEj+6r7fODbD2cqrnje9PnziFuw4bmLbBZFrQ5w= +golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0= +golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw= +golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= +golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= +golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo= +golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= +golang.org/x/tools v0.42.0 h1:uNgphsn75Tdz5Ji2q36v/nsFSfR/9BRFvqhGBaJGd5k= +golang.org/x/tools v0.42.0/go.mod h1:Ma6lCIwGZvHK6XtgbswSoWroEkhugApmsXyrUmBhfr0= google.golang.org/protobuf v1.35.1 h1:m3LfL6/Ca+fqnjnlqQXNpFPABW1UD7mjh8KO2mKFytA= google.golang.org/protobuf v1.35.1/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE= rsc.io/binaryregexp v0.2.0 h1:HfqmD5MEmC0zvwBuF187nq9mdnXjXsSivRiXN7SmRkE= diff --git a/experimental/bodyprocessors/jsonstream.go b/experimental/bodyprocessors/jsonstream.go new file mode 100644 index 000000000..b1eaad125 --- /dev/null +++ b/experimental/bodyprocessors/jsonstream.go @@ -0,0 +1,390 @@ +// Copyright 2026 OWASP Coraza contributors +// SPDX-License-Identifier: Apache-2.0 + +package bodyprocessors + +import ( + "bufio" + "errors" + "fmt" + "io" + "strconv" + "strings" + + "github.com/tidwall/gjson" + + "github.com/corazawaf/coraza/v3/experimental/plugins" + "github.com/corazawaf/coraza/v3/experimental/plugins/plugintypes" +) + +const ( + // DefaultStreamRecursionLimit is the default recursion limit for streaming JSON processing + // This protects against deeply nested JSON objects in each line + DefaultStreamRecursionLimit = 1024 + + // recordSeparator is the ASCII RS character (0x1E) used in RFC 7464 JSON Sequences + recordSeparator = '\x1e' +) + +// indexedCollection is an interface for collections that support indexed key-value storage. +type indexedCollection interface { + SetIndex(string, int, string) +} + +// jsonStreamBodyProcessor handles streaming JSON formats. +// Each record/line in the input is expected to be a complete, valid JSON object. +// Empty lines are ignored. Each JSON object is flattened and indexed by record number. +// +// Supported formats: +// - NDJSON (application/x-ndjson): Each line is a complete JSON object +// - JSON Lines (application/jsonlines): Alias for NDJSON +// - JSON Sequence (application/json-seq): RFC 7464 format with RS (0x1E) record separator +// +// The processor auto-detects the format based on the presence of RS characters. +type jsonStreamBodyProcessor struct{} + +var _ plugintypes.StreamingBodyProcessor = &jsonStreamBodyProcessor{} + +func (js *jsonStreamBodyProcessor) ProcessRequest(reader io.Reader, v plugintypes.TransactionVariables, _ plugintypes.BodyProcessorOptions) error { + col := v.ArgsPost() + + // Store the raw body for TX variables. + // Note: This creates a memory copy of the entire body, similar to the regular JSON processor. + // This is necessary for operators like @validateSchema that need access to the raw content. + // Memory usage: 2x the body size (once in buffer, once in parsed variables) + var rawBody strings.Builder + + // Create a TeeReader to read the body and store it simultaneously + tee := io.TeeReader(reader, &rawBody) + + // Use default recursion limit for now + // TODO: Use RequestBodyRecursionLimit from BodyProcessorOptions when available + lineNum, err := processJSONStream(tee, col, DefaultStreamRecursionLimit) + if err != nil { + return err + } + + // Store the raw JSON stream in the TX variable for potential validation + if txVar := v.TX(); txVar != nil { + txVar.Set("jsonstream_request_body", []string{rawBody.String()}) + txVar.Set("jsonstream_request_line_count", []string{strconv.Itoa(lineNum)}) + } + + return nil +} + +func (js *jsonStreamBodyProcessor) ProcessResponse(reader io.Reader, v plugintypes.TransactionVariables, _ plugintypes.BodyProcessorOptions) error { + col := v.ResponseArgs() + + // Store the raw body for TX variables. + // Note: This creates a memory copy of the entire body, similar to the regular JSON processor. + // Memory usage: 2x the body size (once in buffer, once in parsed variables) + var rawBody strings.Builder + + // Create a TeeReader to read the body and store it simultaneously + tee := io.TeeReader(reader, &rawBody) + + // Use default recursion limit for response bodies too + // TODO: Consider using a different limit for responses when configurable + lineNum, err := processJSONStream(tee, col, DefaultStreamRecursionLimit) + if err != nil { + return err + } + + // Store the raw JSON stream in the TX variable for potential validation + if txVar := v.TX(); txVar != nil && v.ResponseBody() != nil { + txVar.Set("jsonstream_response_body", []string{rawBody.String()}) + txVar.Set("jsonstream_response_line_count", []string{strconv.Itoa(lineNum)}) + } + + return nil +} + +// processJSONStream processes a stream of JSON objects incrementally. +// Supports both NDJSON (newline-delimited) and RFC 7464 JSON Sequence (RS-delimited) formats. +// The format is auto-detected by peeking at the first chunk of data. +// Returns the number of records processed and any error encountered. +func processJSONStream(reader io.Reader, col indexedCollection, maxRecursion int) (int, error) { + return processJSONStreamWithCallback(reader, maxRecursion, func(_ int, fields map[string]string, _ string) error { + for key, value := range fields { + col.SetIndex(key, 0, value) + } + return nil + }) +} + +// processJSONStreamWithCallback processes a stream of JSON objects, calling fn for each record. +// The callback receives the record number, a map of pre-formatted fields (with record-prefixed keys +// like "json.0.name"), and the raw record text. If fn returns a non-nil error, processing stops. +// Returns the number of records processed and any error encountered. +func processJSONStreamWithCallback(reader io.Reader, maxRecursion int, + fn func(recordNum int, fields map[string]string, rawRecord string) error) (int, error) { + bufReader := bufio.NewReader(reader) + + // Peek at the first chunk to detect format without consuming the entire stream + // Use 4KB as a reasonable peek size - enough to detect RS in typical streams + peekSize := 4096 + peekBytes, err := bufReader.Peek(peekSize) + if err != nil && err != io.EOF && err != bufio.ErrBufferFull { + return 0, fmt.Errorf("error peeking stream: %w", err) + } + + // Check if we have any data at all + if len(peekBytes) == 0 { + return 0, errors.New("no valid JSON objects found in stream") + } + + // Auto-detect format: if peek contains RS characters, use JSON Sequence parsing + // Otherwise, use NDJSON (newline) parsing + if containsRS(peekBytes) { + return processJSONSequenceStreamWithCallback(bufReader, maxRecursion, fn) + } + return processNDJSONStreamWithCallback(bufReader, maxRecursion, fn) +} + +// containsRS checks if a byte slice contains the RS character +func containsRS(data []byte) bool { + for _, b := range data { + if b == recordSeparator { + return true + } + } + return false +} + +// newRecordScanner creates a bufio.Scanner with the standard buffer sizes used for JSON record scanning. +func newRecordScanner(reader io.Reader, split bufio.SplitFunc) *bufio.Scanner { + scanner := bufio.NewScanner(reader) + if split != nil { + scanner.Split(split) + } + + // Increase scanner buffer to handle large JSON objects (default is 64KB) + // Set max to 1MB to match typical JSON object sizes while preventing memory exhaustion + const maxScanTokenSize = 1024 * 1024 // 1MB + buf := make([]byte, 64*1024) + scanner.Buffer(buf, maxScanTokenSize) + + return scanner +} + +// scanRecords iterates over scanner records, parsing each as JSON and calling fn. +// The formatRecord function wraps each parsed record with the format-specific delimiters +// (e.g., trailing \n for NDJSON, or RS prefix + trailing \n for RFC 7464), +// so the rawRecord passed to fn can be relayed as-is to preserve the original format. +func scanRecords(scanner *bufio.Scanner, maxRecursion int, formatRecord func(string) string, + fn func(recordNum int, fields map[string]string, rawRecord string) error) (int, error) { + recordNum := 0 + + for scanner.Scan() { + record := strings.TrimSpace(scanner.Text()) + if record == "" { + continue + } + + fields, err := parseJSONRecord(record, recordNum, maxRecursion) + if err != nil { + return recordNum, err + } + + if err := fn(recordNum, fields, formatRecord(record)); err != nil { + return recordNum + 1, err + } + + recordNum++ + } + + if err := scanner.Err(); err != nil { + return recordNum, fmt.Errorf("error reading stream: %w", err) + } + + if recordNum == 0 { + return 0, errors.New("no valid JSON objects found in stream") + } + + return recordNum, nil +} + +func formatNDJSON(record string) string { + return record + "\n" +} + +func formatJSONSequence(record string) string { + return "\x1e" + record + "\n" +} + +// processNDJSONStreamWithCallback processes NDJSON format (newline-delimited JSON objects) from a reader, +// calling fn for each record. +func processNDJSONStreamWithCallback(reader io.Reader, maxRecursion int, + fn func(recordNum int, fields map[string]string, rawRecord string) error) (int, error) { + return scanRecords(newRecordScanner(reader, nil), maxRecursion, formatNDJSON, fn) +} + +// processJSONSequenceStreamWithCallback processes RFC 7464 JSON Sequence format (RS-delimited JSON objects) +// from a reader, calling fn for each record. +func processJSONSequenceStreamWithCallback(reader io.Reader, maxRecursion int, + fn func(recordNum int, fields map[string]string, rawRecord string) error) (int, error) { + return scanRecords(newRecordScanner(reader, splitOnRS), maxRecursion, formatJSONSequence, fn) +} + +// splitOnRS is a custom split function for bufio.Scanner that splits on RS (0x1E) characters. +// This enables streaming processing of RFC 7464 JSON Sequence format. +func splitOnRS(data []byte, atEOF bool) (advance int, token []byte, err error) { + // Skip leading RS characters + start := 0 + for start < len(data) && data[start] == recordSeparator { + start++ + } + + // If we've consumed all data and we're at EOF, we're done + if atEOF && start >= len(data) { + return len(data), nil, nil + } + + // Find the next RS character after start + for i := start; i < len(data); i++ { + if data[i] == recordSeparator { + // Found RS, return the record between start and i + return i + 1, data[start:i], nil + } + } + + // If we're at EOF, return remaining data as the last record + if atEOF && start < len(data) { + return len(data), data[start:], nil + } + + // Request more data + return 0, nil, nil +} + +// parseJSONRecord parses a single JSON record and returns a map of fields with record-prefixed keys. +// Keys are formatted as "json.{recordNum}.field.subfield". +func parseJSONRecord(jsonText string, recordNum int, maxRecursion int) (map[string]string, error) { + // Validate JSON before parsing + if !gjson.Valid(jsonText) { + // Use 1-based numbering for user-friendly error messages + return nil, fmt.Errorf("invalid JSON at record %d", recordNum+1) + } + + // Parse the JSON record + data, err := readJSONWithLimit(jsonText, maxRecursion) + if err != nil { + // Use 1-based numbering for user-friendly error messages + return nil, fmt.Errorf("error parsing JSON at record %d: %w", recordNum+1, err) + } + + // Build result with record-prefixed keys + // Example: json.0.field, json.1.field, etc. + fields := make(map[string]string, len(data)) + for key, value := range data { + // Replace the "json" prefix with "json.{recordNum}" + // Original key format: "json.field.subfield" + // New key format: "json.0.field.subfield" + if strings.HasPrefix(key, "json.") { + key = fmt.Sprintf("json.%d.%s", recordNum, key[5:]) // Skip "json." + } else if key == "json" { + key = fmt.Sprintf("json.%d", recordNum) + } + fields[key] = value + } + + return fields, nil +} + +// readJSONWithLimit is a helper that calls readJSON but with protection against deep nesting. +// This is a separate copy from internal/bodyprocessors/json.go due to package boundaries. +// TODO: Remove this when readItems supports maxRecursion parameter natively (see #1110) +func readJSONWithLimit(s string, maxRecursion int) (map[string]string, error) { + json := gjson.Parse(s) + res := make(map[string]string) + key := []byte("json") + err := readItemsWithLimit(json, key, maxRecursion, res) + return res, err +} + +// readItemsWithLimit is similar to readItems in internal/bodyprocessors/json.go but with recursion limit. +// This is a separate copy due to package boundaries (experimental vs internal). +// TODO: Remove this when readItems supports maxRecursion parameter natively (see #1110) +func readItemsWithLimit(json gjson.Result, objKey []byte, maxRecursion int, res map[string]string) error { + arrayLen := 0 + var iterationError error + + if maxRecursion == 0 { + return errors.New("max recursion reached while reading json object") + } + + json.ForEach(func(key, value gjson.Result) bool { + prevParentLength := len(objKey) + objKey = append(objKey, '.') + if key.Type == gjson.String { + objKey = append(objKey, key.Str...) + } else { + objKey = strconv.AppendInt(objKey, int64(key.Num), 10) + arrayLen++ + } + + var val string + switch value.Type { + case gjson.JSON: + iterationError = readItemsWithLimit(value, objKey, maxRecursion-1, res) + if iterationError != nil { + return false + } + objKey = objKey[:prevParentLength] + return true + case gjson.String: + val = value.Str + case gjson.Null: + val = "" + default: + val = value.Raw + } + + res[string(objKey)] = val + objKey = objKey[:prevParentLength] + + return true + }) + if arrayLen > 0 { + res[string(objKey)] = strconv.Itoa(arrayLen) + } + return iterationError +} + +func (js *jsonStreamBodyProcessor) ProcessRequestRecords(reader io.Reader, _ plugintypes.BodyProcessorOptions, + fn func(recordNum int, fields map[string]string, rawRecord string) error) error { + recordCount, err := processJSONStreamWithCallback(reader, DefaultStreamRecursionLimit, fn) + if err != nil { + return err + } + if recordCount == 0 { + return errors.New("no valid JSON objects found in stream") + } + return nil +} + +func (js *jsonStreamBodyProcessor) ProcessResponseRecords(reader io.Reader, _ plugintypes.BodyProcessorOptions, + fn func(recordNum int, fields map[string]string, rawRecord string) error) error { + recordCount, err := processJSONStreamWithCallback(reader, DefaultStreamRecursionLimit, fn) + if err != nil { + return err + } + if recordCount == 0 { + return errors.New("no valid JSON objects found in stream") + } + return nil +} + +func init() { + // Register the processor with multiple names for different content-types + plugins.RegisterBodyProcessor("jsonstream", func() plugintypes.BodyProcessor { + return &jsonStreamBodyProcessor{} + }) + plugins.RegisterBodyProcessor("ndjson", func() plugintypes.BodyProcessor { + return &jsonStreamBodyProcessor{} + }) + plugins.RegisterBodyProcessor("jsonlines", func() plugintypes.BodyProcessor { + return &jsonStreamBodyProcessor{} + }) +} diff --git a/experimental/bodyprocessors/jsonstream_test.go b/experimental/bodyprocessors/jsonstream_test.go new file mode 100644 index 000000000..848515b6f --- /dev/null +++ b/experimental/bodyprocessors/jsonstream_test.go @@ -0,0 +1,1113 @@ +// Copyright 2026 OWASP Coraza contributors +// SPDX-License-Identifier: Apache-2.0 + +package bodyprocessors_test + +import ( + "errors" + "fmt" + "strings" + "testing" + + _ "github.com/corazawaf/coraza/v3/experimental/bodyprocessors" + + "github.com/corazawaf/coraza/v3/experimental/plugins" + "github.com/corazawaf/coraza/v3/experimental/plugins/plugintypes" + "github.com/corazawaf/coraza/v3/internal/corazawaf" +) + +func jsonstreamProcessor(t *testing.T) plugintypes.BodyProcessor { + t.Helper() + jsp, err := plugins.GetBodyProcessor("jsonstream") + if err != nil { + t.Fatal(err) + } + return jsp +} + +func TestJSONStreamSingleLine(t *testing.T) { + input := `{"name": "John", "age": 30} +` + + jsp := jsonstreamProcessor(t) + v := corazawaf.NewTransactionVariables() + + err := jsp.ProcessRequest(strings.NewReader(input), v, plugintypes.BodyProcessorOptions{}) + + if err != nil { + t.Errorf("unexpected error: %v", err) + return + } + + argsPost := v.ArgsPost() + + // Check expected keys + if name := argsPost.Get("json.0.name"); len(name) == 0 || name[0] != "John" { + t.Errorf("json.0.name should be 'John', got: %v", name) + } + + if age := argsPost.Get("json.0.age"); len(age) == 0 || age[0] != "30" { + t.Errorf("json.0.age should be '30', got: %v", age) + } +} + +func TestJSONStreamMultipleLines(t *testing.T) { + input := `{"name": "John", "age": 30} +{"name": "Jane", "age": 25} +{"name": "Bob", "age": 35} +` + + jsp := jsonstreamProcessor(t) + v := corazawaf.NewTransactionVariables() + + err := jsp.ProcessRequest(strings.NewReader(input), v, plugintypes.BodyProcessorOptions{}) + + if err != nil { + t.Errorf("unexpected error: %v", err) + return + } + + argsPost := v.ArgsPost() + + // Check all three lines + tests := []struct { + line int + name string + age string + }{ + {0, "John", "30"}, + {1, "Jane", "25"}, + {2, "Bob", "35"}, + } + + for _, tt := range tests { + nameKey := fmt.Sprintf("json.%d.name", tt.line) + ageKey := fmt.Sprintf("json.%d.age", tt.line) + + if name := argsPost.Get(nameKey); len(name) == 0 || name[0] != tt.name { + t.Errorf("%s should be '%s', got: %v", nameKey, tt.name, name) + } + + if age := argsPost.Get(ageKey); len(age) == 0 || age[0] != tt.age { + t.Errorf("%s should be '%s', got: %v", ageKey, tt.age, age) + } + } +} + +func TestJSONStreamNestedObjects(t *testing.T) { + input := `{"user": {"name": "John", "id": 1}, "active": true} +{"user": {"name": "Jane", "id": 2}, "active": false} +` + + jsp := jsonstreamProcessor(t) + v := corazawaf.NewTransactionVariables() + + err := jsp.ProcessRequest(strings.NewReader(input), v, plugintypes.BodyProcessorOptions{}) + + if err != nil { + t.Errorf("unexpected error: %v", err) + return + } + + argsPost := v.ArgsPost() + + // Check nested fields + if name := argsPost.Get("json.0.user.name"); len(name) == 0 || name[0] != "John" { + t.Errorf("json.0.user.name should be 'John', got: %v", name) + } + + if id := argsPost.Get("json.0.user.id"); len(id) == 0 || id[0] != "1" { + t.Errorf("json.0.user.id should be '1', got: %v", id) + } + + if active := argsPost.Get("json.0.active"); len(active) == 0 || active[0] != "true" { + t.Errorf("json.0.active should be 'true', got: %v", active) + } + + if name := argsPost.Get("json.1.user.name"); len(name) == 0 || name[0] != "Jane" { + t.Errorf("json.1.user.name should be 'Jane', got: %v", name) + } + + if active := argsPost.Get("json.1.active"); len(active) == 0 || active[0] != "false" { + t.Errorf("json.1.active should be 'false', got: %v", active) + } +} + +func TestJSONStreamWithArrays(t *testing.T) { + input := `{"name": "John", "tags": ["admin", "user"]} +` + + jsp := jsonstreamProcessor(t) + v := corazawaf.NewTransactionVariables() + + err := jsp.ProcessRequest(strings.NewReader(input), v, plugintypes.BodyProcessorOptions{}) + + if err != nil { + t.Errorf("unexpected error: %v", err) + return + } + + argsPost := v.ArgsPost() + + // Check array fields + if tags := argsPost.Get("json.0.tags"); len(tags) == 0 || tags[0] != "2" { + t.Errorf("json.0.tags should be '2' (array length), got: %v", tags) + } + + if tag0 := argsPost.Get("json.0.tags.0"); len(tag0) == 0 || tag0[0] != "admin" { + t.Errorf("json.0.tags.0 should be 'admin', got: %v", tag0) + } + + if tag1 := argsPost.Get("json.0.tags.1"); len(tag1) == 0 || tag1[0] != "user" { + t.Errorf("json.0.tags.1 should be 'user', got: %v", tag1) + } +} + +func TestJSONStreamSkipEmptyLines(t *testing.T) { + input := `{"name": "John"} + +{"name": "Jane"} + +` + + jsp := jsonstreamProcessor(t) + v := corazawaf.NewTransactionVariables() + + err := jsp.ProcessRequest(strings.NewReader(input), v, plugintypes.BodyProcessorOptions{}) + + if err != nil { + t.Errorf("unexpected error: %v", err) + return + } + + argsPost := v.ArgsPost() + + // Empty lines should be skipped, so we should only have line 0 and 1 + if name := argsPost.Get("json.0.name"); len(name) == 0 || name[0] != "John" { + t.Errorf("json.0.name should be 'John', got: %v", name) + } + + if name := argsPost.Get("json.1.name"); len(name) == 0 || name[0] != "Jane" { + t.Errorf("json.1.name should be 'Jane', got: %v", name) + } + + // Line 2 should not exist + if name := argsPost.Get("json.2.name"); len(name) != 0 { + t.Errorf("json.2.name should not exist, got: %v", name) + } +} + +func TestJSONStreamArrayAsRoot(t *testing.T) { + input := `[1, 2, 3] +[4, 5, 6] +` + + jsp := jsonstreamProcessor(t) + v := corazawaf.NewTransactionVariables() + + err := jsp.ProcessRequest(strings.NewReader(input), v, plugintypes.BodyProcessorOptions{}) + + if err != nil { + t.Errorf("unexpected error: %v", err) + return + } + + argsPost := v.ArgsPost() + + // Check first array + if arr := argsPost.Get("json.0"); len(arr) == 0 || arr[0] != "3" { + t.Errorf("json.0 should be '3' (array length), got: %v", arr) + } + + if val := argsPost.Get("json.0.0"); len(val) == 0 || val[0] != "1" { + t.Errorf("json.0.0 should be '1', got: %v", val) + } + + // Check second array + if arr := argsPost.Get("json.1"); len(arr) == 0 || arr[0] != "3" { + t.Errorf("json.1 should be '3' (array length), got: %v", arr) + } + + if val := argsPost.Get("json.1.0"); len(val) == 0 || val[0] != "4" { + t.Errorf("json.1.0 should be '4', got: %v", val) + } +} + +func TestJSONStreamNullAndBooleans(t *testing.T) { + input := `{"null": null, "true": true, "false": false} +` + + jsp := jsonstreamProcessor(t) + v := corazawaf.NewTransactionVariables() + + err := jsp.ProcessRequest(strings.NewReader(input), v, plugintypes.BodyProcessorOptions{}) + + if err != nil { + t.Errorf("unexpected error: %v", err) + return + } + + argsPost := v.ArgsPost() + + // Check null value (should be empty string) + if null := argsPost.Get("json.0.null"); len(null) == 0 || null[0] != "" { + t.Errorf("json.0.null should be empty string, got: %v", null) + } + + // Check boolean values + if trueVal := argsPost.Get("json.0.true"); len(trueVal) == 0 || trueVal[0] != "true" { + t.Errorf("json.0.true should be 'true', got: %v", trueVal) + } + + if falseVal := argsPost.Get("json.0.false"); len(falseVal) == 0 || falseVal[0] != "false" { + t.Errorf("json.0.false should be 'false', got: %v", falseVal) + } +} + +func TestJSONStreamInvalidJSON(t *testing.T) { + input := `{"name": "John"} +{invalid json} +{"name": "Jane"} +` + + jsp := jsonstreamProcessor(t) + v := corazawaf.NewTransactionVariables() + + err := jsp.ProcessRequest(strings.NewReader(input), v, plugintypes.BodyProcessorOptions{}) + + if err == nil { + t.Errorf("expected error for invalid JSON, got none") + } + + if !strings.Contains(err.Error(), "invalid JSON") { + t.Errorf("expected 'invalid JSON' error, got: %v", err) + } +} + +func TestJSONStreamEmptyStream(t *testing.T) { + input := "" + + jsp := jsonstreamProcessor(t) + v := corazawaf.NewTransactionVariables() + + err := jsp.ProcessRequest(strings.NewReader(input), v, plugintypes.BodyProcessorOptions{}) + + if err == nil { + t.Errorf("expected error for empty stream, got none") + } + + if !strings.Contains(err.Error(), "no valid JSON objects") { + t.Errorf("expected 'no valid JSON objects' error, got: %v", err) + } +} + +func TestJSONStreamOnlyEmptyLines(t *testing.T) { + input := "\n\n\n" + + jsp := jsonstreamProcessor(t) + v := corazawaf.NewTransactionVariables() + + err := jsp.ProcessRequest(strings.NewReader(input), v, plugintypes.BodyProcessorOptions{}) + + if err == nil { + t.Errorf("expected error for only empty lines, got none") + } + + if !strings.Contains(err.Error(), "no valid JSON objects") { + t.Errorf("expected 'no valid JSON objects' error, got: %v", err) + } +} + +func TestJSONStreamRecursionLimit(t *testing.T) { + // Create a deeply nested JSON object that exceeds the default limit + // Default limit is 1024, so we create 1500 levels to trigger the error + deeplyNested := strings.Repeat(`{"a":`, 1500) + "1" + strings.Repeat(`}`, 1500) + input := deeplyNested + "\n" + + jsp := jsonstreamProcessor(t) + v := corazawaf.NewTransactionVariables() + + // This should fail because it exceeds DefaultStreamRecursionLimit (1024) + err := jsp.ProcessRequest(strings.NewReader(input), v, plugintypes.BodyProcessorOptions{}) + + if err == nil { + t.Errorf("expected error due to recursion limit, got none") + } + + if !strings.Contains(err.Error(), "max recursion") { + t.Errorf("expected recursion error, got: %v", err) + } +} + +func TestJSONStreamTXVariables(t *testing.T) { + input := `{"name": "John"} +{"name": "Jane"} +` + + jsp := jsonstreamProcessor(t) + v := corazawaf.NewTransactionVariables() + + err := jsp.ProcessRequest(strings.NewReader(input), v, plugintypes.BodyProcessorOptions{}) + + if err != nil { + t.Errorf("unexpected error: %v", err) + return + } + + // Check TX variables + txVars := v.TX() + + // Check raw body storage + rawBody := txVars.Get("jsonstream_request_body") + if len(rawBody) == 0 || rawBody[0] != input { + t.Errorf("jsonstream_request_body not stored correctly") + } + + // Check line count + lineCount := txVars.Get("jsonstream_request_line_count") + if len(lineCount) == 0 || lineCount[0] != "2" { + t.Errorf("jsonstream_request_line_count should be 2, got: %v", lineCount) + } +} + +func TestJSONStreamProcessResponse(t *testing.T) { + input := `{"status": "ok"} +{"status": "error"} +` + + jsp := jsonstreamProcessor(t) + v := corazawaf.NewTransactionVariables() + + err := jsp.ProcessResponse(strings.NewReader(input), v, plugintypes.BodyProcessorOptions{}) + + if err != nil { + t.Errorf("unexpected error: %v", err) + return + } + + // Check response args + responseArgs := v.ResponseArgs() + + if status0 := responseArgs.Get("json.0.status"); len(status0) == 0 || status0[0] != "ok" { + t.Errorf("json.0.status should be 'ok', got: %v", status0) + } + + if status1 := responseArgs.Get("json.1.status"); len(status1) == 0 || status1[0] != "error" { + t.Errorf("json.1.status should be 'error', got: %v", status1) + } +} + +func TestJSONSequenceRFC7464(t *testing.T) { + // RFC 7464 format uses ASCII RS (0x1E) as record separator + const RS = "\x1e" + + input := RS + `{"name": "John", "age": 30}` + "\n" + + RS + `{"name": "Jane", "age": 25}` + "\n" + + RS + `{"name": "Bob", "age": 35}` + "\n" + + jsp := jsonstreamProcessor(t) + v := corazawaf.NewTransactionVariables() + + err := jsp.ProcessRequest(strings.NewReader(input), v, plugintypes.BodyProcessorOptions{}) + + if err != nil { + t.Errorf("unexpected error: %v", err) + return + } + + argsPost := v.ArgsPost() + + // Check all three records + tests := []struct { + record int + name string + age string + }{ + {0, "John", "30"}, + {1, "Jane", "25"}, + {2, "Bob", "35"}, + } + + for _, tt := range tests { + nameKey := fmt.Sprintf("json.%d.name", tt.record) + ageKey := fmt.Sprintf("json.%d.age", tt.record) + + if name := argsPost.Get(nameKey); len(name) == 0 || name[0] != tt.name { + t.Errorf("%s should be '%s', got: %v", nameKey, tt.name, name) + } + + if age := argsPost.Get(ageKey); len(age) == 0 || age[0] != tt.age { + t.Errorf("%s should be '%s', got: %v", ageKey, tt.age, age) + } + } + + // Check line count + txVars := v.TX() + lineCount := txVars.Get("jsonstream_request_line_count") + if len(lineCount) == 0 || lineCount[0] != "3" { + t.Errorf("jsonstream_request_line_count should be 3, got: %v", lineCount) + } +} + +func TestJSONSequenceNestedObjects(t *testing.T) { + const RS = "\x1e" + + input := RS + `{"user": {"name": "John", "id": 1}, "active": true}` + "\n" + + RS + `{"user": {"name": "Jane", "id": 2}, "active": false}` + "\n" + + jsp := jsonstreamProcessor(t) + v := corazawaf.NewTransactionVariables() + + err := jsp.ProcessRequest(strings.NewReader(input), v, plugintypes.BodyProcessorOptions{}) + + if err != nil { + t.Errorf("unexpected error: %v", err) + return + } + + argsPost := v.ArgsPost() + + // Check nested fields + if name := argsPost.Get("json.0.user.name"); len(name) == 0 || name[0] != "John" { + t.Errorf("json.0.user.name should be 'John', got: %v", name) + } + + if id := argsPost.Get("json.0.user.id"); len(id) == 0 || id[0] != "1" { + t.Errorf("json.0.user.id should be '1', got: %v", id) + } + + if active := argsPost.Get("json.0.active"); len(active) == 0 || active[0] != "true" { + t.Errorf("json.0.active should be 'true', got: %v", active) + } +} + +func TestJSONSequenceWithoutTrailingNewlines(t *testing.T) { + const RS = "\x1e" + + // RFC 7464 says newlines are optional, test without them + input := RS + `{"name": "John"}` + RS + `{"name": "Jane"}` + RS + `{"name": "Bob"}` + + jsp := jsonstreamProcessor(t) + v := corazawaf.NewTransactionVariables() + + err := jsp.ProcessRequest(strings.NewReader(input), v, plugintypes.BodyProcessorOptions{}) + + if err != nil { + t.Errorf("unexpected error: %v", err) + return + } + + argsPost := v.ArgsPost() + + if name := argsPost.Get("json.0.name"); len(name) == 0 || name[0] != "John" { + t.Errorf("json.0.name should be 'John', got: %v", name) + } + + if name := argsPost.Get("json.1.name"); len(name) == 0 || name[0] != "Jane" { + t.Errorf("json.1.name should be 'Jane', got: %v", name) + } + + if name := argsPost.Get("json.2.name"); len(name) == 0 || name[0] != "Bob" { + t.Errorf("json.2.name should be 'Bob', got: %v", name) + } +} + +func TestJSONSequenceEmptyRecords(t *testing.T) { + const RS = "\x1e" + + // Empty records (multiple RS in a row) should be skipped + input := RS + RS + `{"name": "John"}` + "\n" + RS + "\n" + RS + `{"name": "Jane"}` + "\n" + + jsp := jsonstreamProcessor(t) + v := corazawaf.NewTransactionVariables() + + err := jsp.ProcessRequest(strings.NewReader(input), v, plugintypes.BodyProcessorOptions{}) + + if err != nil { + t.Errorf("unexpected error: %v", err) + return + } + + argsPost := v.ArgsPost() + + // Should only have 2 records (empty ones skipped) + if name := argsPost.Get("json.0.name"); len(name) == 0 || name[0] != "John" { + t.Errorf("json.0.name should be 'John', got: %v", name) + } + + if name := argsPost.Get("json.1.name"); len(name) == 0 || name[0] != "Jane" { + t.Errorf("json.1.name should be 'Jane', got: %v", name) + } + + // Third record should not exist + if name := argsPost.Get("json.2.name"); len(name) != 0 { + t.Errorf("json.2.name should not exist, got: %v", name) + } +} + +func TestJSONSequenceInvalidJSON(t *testing.T) { + const RS = "\x1e" + + input := RS + `{"name": "John"}` + "\n" + + RS + `{invalid json}` + "\n" + + RS + `{"name": "Jane"}` + "\n" + + jsp := jsonstreamProcessor(t) + v := corazawaf.NewTransactionVariables() + + err := jsp.ProcessRequest(strings.NewReader(input), v, plugintypes.BodyProcessorOptions{}) + + if err == nil { + t.Errorf("expected error for invalid JSON, got none") + } + + if !strings.Contains(err.Error(), "invalid JSON") { + t.Errorf("expected 'invalid JSON' error, got: %v", err) + } +} + +func TestFormatAutoDetection(t *testing.T) { + const RS = "\x1e" + + tests := []struct { + name string + input string + format string + }{ + { + name: "NDJSON without RS", + input: `{"name": "John"}` + "\n" + `{"name": "Jane"}` + "\n", + format: "NDJSON", + }, + { + name: "JSON Sequence with RS", + input: RS + `{"name": "John"}` + "\n" + RS + `{"name": "Jane"}` + "\n", + format: "JSON Sequence", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + jsp := jsonstreamProcessor(t) + v := corazawaf.NewTransactionVariables() + + err := jsp.ProcessRequest(strings.NewReader(tt.input), v, plugintypes.BodyProcessorOptions{}) + + if err != nil { + t.Errorf("unexpected error for %s: %v", tt.format, err) + return + } + + argsPost := v.ArgsPost() + + // Both formats should produce the same output + if name := argsPost.Get("json.0.name"); len(name) == 0 || name[0] != "John" { + t.Errorf("%s: json.0.name should be 'John', got: %v", tt.format, name) + } + + if name := argsPost.Get("json.1.name"); len(name) == 0 || name[0] != "Jane" { + t.Errorf("%s: json.1.name should be 'Jane', got: %v", tt.format, name) + } + }) + } +} + +// --- Benchmarks --- + +// buildNDJSONStream generates an NDJSON stream with the given number of records using the record template. +func buildNDJSONStream(numRecords int, record string) string { + var sb strings.Builder + sb.Grow(numRecords * (len(record) + 1)) + for i := 0; i < numRecords; i++ { + sb.WriteString(record) + sb.WriteByte('\n') + } + return sb.String() +} + +// buildRFC7464Stream generates an RFC 7464 JSON Sequence stream. +func buildRFC7464Stream(numRecords int, record string) string { + var sb strings.Builder + sb.Grow(numRecords * (len(record) + 2)) + for i := 0; i < numRecords; i++ { + sb.WriteByte('\x1e') + sb.WriteString(record) + sb.WriteByte('\n') + } + return sb.String() +} + +const ( + smallRecord = `{"id":1,"name":"Alice"}` + mediumRecord = `{"user_id":1234567890,"name":"User Name","email":"user@example.com","role":"admin","active":true,"tags":["tag1","tag2","tag3"]}` + nestedRecord = `{"user":{"name":"Alice","address":{"city":"NYC","zip":"10001"}},"scores":[95,87,92],"meta":{"created":"2026-01-01","active":true}}` +) + +func BenchmarkJSONStreamProcessor(b *testing.B) { + jsp, err := plugins.GetBodyProcessor("jsonstream") + if err != nil { + b.Fatal(err) + } + + benchmarks := []struct { + name string + numRecords int + record string + }{ + {"small/1", 1, smallRecord}, + {"small/10", 10, smallRecord}, + {"small/100", 100, smallRecord}, + {"small/1000", 1000, smallRecord}, + {"medium/1", 1, mediumRecord}, + {"medium/10", 10, mediumRecord}, + {"medium/100", 100, mediumRecord}, + {"medium/1000", 1000, mediumRecord}, + {"nested/1", 1, nestedRecord}, + {"nested/10", 10, nestedRecord}, + {"nested/100", 100, nestedRecord}, + {"nested/1000", 1000, nestedRecord}, + } + + for _, bm := range benchmarks { + input := buildNDJSONStream(bm.numRecords, bm.record) + b.Run("ProcessRequest/"+bm.name, func(b *testing.B) { + b.SetBytes(int64(len(input))) + b.ResetTimer() + for i := 0; i < b.N; i++ { + v := corazawaf.NewTransactionVariables() + if err := jsp.ProcessRequest(strings.NewReader(input), v, plugintypes.BodyProcessorOptions{}); err != nil { + b.Fatal(err) + } + } + }) + } +} + +func BenchmarkJSONStreamCallback(b *testing.B) { + bp, err := plugins.GetBodyProcessor("jsonstream") + if err != nil { + b.Fatal(err) + } + sp, ok := bp.(plugintypes.StreamingBodyProcessor) + if !ok { + b.Fatal("jsonstream processor does not implement StreamingBodyProcessor") + } + + benchmarks := []struct { + name string + numRecords int + record string + }{ + {"small/1", 1, smallRecord}, + {"small/10", 10, smallRecord}, + {"small/100", 100, smallRecord}, + {"small/1000", 1000, smallRecord}, + {"medium/1", 1, mediumRecord}, + {"medium/10", 10, mediumRecord}, + {"medium/100", 100, mediumRecord}, + {"medium/1000", 1000, mediumRecord}, + {"nested/1", 1, nestedRecord}, + {"nested/10", 10, nestedRecord}, + {"nested/100", 100, nestedRecord}, + {"nested/1000", 1000, nestedRecord}, + } + + noop := func(_ int, _ map[string]string, _ string) error { return nil } + + for _, bm := range benchmarks { + input := buildNDJSONStream(bm.numRecords, bm.record) + b.Run(bm.name, func(b *testing.B) { + b.SetBytes(int64(len(input))) + b.ResetTimer() + for i := 0; i < b.N; i++ { + if err := sp.ProcessRequestRecords(strings.NewReader(input), plugintypes.BodyProcessorOptions{}, noop); err != nil { + b.Fatal(err) + } + } + }) + } +} + +func BenchmarkJSONStreamRFC7464(b *testing.B) { + bp, err := plugins.GetBodyProcessor("jsonstream") + if err != nil { + b.Fatal(err) + } + sp, ok := bp.(plugintypes.StreamingBodyProcessor) + if !ok { + b.Fatal("jsonstream processor does not implement StreamingBodyProcessor") + } + + benchmarks := []struct { + name string + numRecords int + record string + }{ + {"small/10", 10, smallRecord}, + {"small/100", 100, smallRecord}, + {"medium/100", 100, mediumRecord}, + {"nested/100", 100, nestedRecord}, + } + + noop := func(_ int, _ map[string]string, _ string) error { return nil } + + for _, bm := range benchmarks { + input := buildRFC7464Stream(bm.numRecords, bm.record) + b.Run(bm.name, func(b *testing.B) { + b.SetBytes(int64(len(input))) + b.ResetTimer() + for i := 0; i < b.N; i++ { + if err := sp.ProcessRequestRecords(strings.NewReader(input), plugintypes.BodyProcessorOptions{}, noop); err != nil { + b.Fatal(err) + } + } + }) + } +} + +func TestProcessorRegistration(t *testing.T) { + // Test that all three aliases are registered + aliases := []string{"jsonstream", "ndjson", "jsonlines"} + + for _, alias := range aliases { + t.Run(alias, func(t *testing.T) { + processor, err := plugins.GetBodyProcessor(alias) + if err != nil { + t.Errorf("Failed to get processor '%s': %v", alias, err) + } + if processor == nil { + t.Errorf("Processor '%s' is nil", alias) + } + }) + } +} + +func TestJSONStreamLargeToken(t *testing.T) { + // Create a JSON object that exceeds 1MB to trigger scanner buffer error + largeValue := strings.Repeat("x", 2*1024*1024) // 2MB string + input := fmt.Sprintf(`{"large": "%s"}`, largeValue) + "\n" + + jsp := jsonstreamProcessor(t) + v := corazawaf.NewTransactionVariables() + + err := jsp.ProcessRequest(strings.NewReader(input), v, plugintypes.BodyProcessorOptions{}) + + // Should get a scanner error about token too long + if err == nil { + t.Errorf("expected error for token too large, got none") + } + + if !strings.Contains(err.Error(), "error reading stream") && !strings.Contains(err.Error(), "token too long") { + t.Logf("Got error (this is expected): %v", err) + } +} + +func TestJSONSequenceLargeToken(t *testing.T) { + const RS = "\x1e" + // Create a JSON object that exceeds 1MB to trigger scanner buffer error + largeValue := strings.Repeat("x", 2*1024*1024) // 2MB string + input := RS + fmt.Sprintf(`{"large": "%s"}`, largeValue) + "\n" + + jsp := jsonstreamProcessor(t) + v := corazawaf.NewTransactionVariables() + + err := jsp.ProcessRequest(strings.NewReader(input), v, plugintypes.BodyProcessorOptions{}) + + // Should get a scanner error about token too long + if err == nil { + t.Errorf("expected error for token too large, got none") + } + + if !strings.Contains(err.Error(), "error reading stream") && !strings.Contains(err.Error(), "token too long") { + t.Logf("Got error (this is expected): %v", err) + } +} + +func TestProcessResponseWithoutResponseBody(t *testing.T) { + input := `{"status": "ok"} +` + + jsp := jsonstreamProcessor(t) + v := corazawaf.NewTransactionVariables() + + // Process response without setting up ResponseBody + err := jsp.ProcessResponse(strings.NewReader(input), v, plugintypes.BodyProcessorOptions{}) + + if err != nil { + t.Errorf("unexpected error: %v", err) + return + } + + // Check response args were still populated + responseArgs := v.ResponseArgs() + if status := responseArgs.Get("json.0.status"); len(status) == 0 || status[0] != "ok" { + t.Errorf("json.0.status should be 'ok', got: %v", status) + } + + // TX variables should be set (but response body related ones may not be if ResponseBody() is nil) + txVars := v.TX() + lineCount := txVars.Get("jsonstream_response_line_count") + if len(lineCount) == 0 || lineCount[0] != "1" { + t.Logf("jsonstream_response_line_count: %v (may be empty if ResponseBody() is nil)", lineCount) + } +} + +// errorReader is a reader that always returns an error +type errorReader struct{} + +func (e errorReader) Read(p []byte) (n int, err error) { + return 0, errors.New("simulated read error") +} + +func TestJSONStreamPeekError(t *testing.T) { + jsp := jsonstreamProcessor(t) + v := corazawaf.NewTransactionVariables() + + // Use an error reader to trigger peek error + err := jsp.ProcessRequest(errorReader{}, v, plugintypes.BodyProcessorOptions{}) + + if err == nil { + t.Errorf("expected error from peek, got none") + } + + if !strings.Contains(err.Error(), "error peeking stream") { + t.Errorf("expected 'error peeking stream' error, got: %v", err) + } +} + +func TestJSONSequenceOnlyRS(t *testing.T) { + const RS = "\x1e" + + // Only RS characters, no actual JSON + input := RS + RS + RS + + jsp := jsonstreamProcessor(t) + v := corazawaf.NewTransactionVariables() + + err := jsp.ProcessRequest(strings.NewReader(input), v, plugintypes.BodyProcessorOptions{}) + + if err == nil { + t.Errorf("expected error for only RS characters, got none") + } + + if !strings.Contains(err.Error(), "no valid JSON objects") { + t.Errorf("expected 'no valid JSON objects' error, got: %v", err) + } +} + +// --- Streaming callback tests --- + +func jsonstreamStreamingProcessor(t *testing.T) plugintypes.StreamingBodyProcessor { + t.Helper() + bp := jsonstreamProcessor(t) + sp, ok := bp.(plugintypes.StreamingBodyProcessor) + if !ok { + t.Fatal("jsonstream processor does not implement StreamingBodyProcessor") + } + return sp +} + +func TestStreamingCallbackPerRecord(t *testing.T) { + input := `{"name": "Alice", "age": 30} +{"name": "Bob", "age": 25} +{"name": "Charlie", "age": 35} +` + sp := jsonstreamStreamingProcessor(t) + + var records []int + var rawRecords []string + + err := sp.ProcessRequestRecords(strings.NewReader(input), plugintypes.BodyProcessorOptions{}, + func(recordNum int, fields map[string]string, rawRecord string) error { + records = append(records, recordNum) + rawRecords = append(rawRecords, rawRecord) + return nil + }) + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(records) != 3 { + t.Fatalf("expected 3 records, got %d", len(records)) + } + + for i, r := range records { + if r != i { + t.Errorf("expected recordNum %d, got %d", i, r) + } + } + + if !strings.Contains(rawRecords[0], "Alice") { + t.Errorf("expected raw record 0 to contain Alice, got: %s", rawRecords[0]) + } + if !strings.Contains(rawRecords[1], "Bob") { + t.Errorf("expected raw record 1 to contain Bob, got: %s", rawRecords[1]) + } +} + +func TestStreamingFieldsHaveRecordPrefix(t *testing.T) { + input := `{"name": "Alice"} +{"name": "Bob"} +` + sp := jsonstreamStreamingProcessor(t) + + var allFields []map[string]string + + err := sp.ProcessRequestRecords(strings.NewReader(input), plugintypes.BodyProcessorOptions{}, + func(recordNum int, fields map[string]string, _ string) error { + // Copy fields since the map may be reused + copy := make(map[string]string, len(fields)) + for k, v := range fields { + copy[k] = v + } + allFields = append(allFields, copy) + return nil + }) + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(allFields) != 2 { + t.Fatalf("expected 2 records, got %d", len(allFields)) + } + + // Record 0 should have json.0.name + if v, ok := allFields[0]["json.0.name"]; !ok || v != "Alice" { + t.Errorf("expected json.0.name=Alice, got %q (ok=%v)", v, ok) + } + + // Record 1 should have json.1.name + if v, ok := allFields[1]["json.1.name"]; !ok || v != "Bob" { + t.Errorf("expected json.1.name=Bob, got %q (ok=%v)", v, ok) + } + + // Record 0 should NOT have json.1.* keys + for k := range allFields[0] { + if strings.HasPrefix(k, "json.1.") { + t.Errorf("record 0 should not have key %q", k) + } + } +} + +func TestStreamingInterruptionStops(t *testing.T) { + input := `{"name": "Alice"} +{"name": "Bob"} +{"name": "Charlie"} +{"name": "Dave"} +` + sp := jsonstreamStreamingProcessor(t) + errBlocked := errors.New("blocked") + processedRecords := 0 + + err := sp.ProcessRequestRecords(strings.NewReader(input), plugintypes.BodyProcessorOptions{}, + func(recordNum int, fields map[string]string, _ string) error { + processedRecords++ + if recordNum == 1 { + return errBlocked + } + return nil + }) + + if err != errBlocked { + t.Fatalf("expected errBlocked, got: %v", err) + } + + if processedRecords != 2 { + t.Errorf("expected 2 records processed (0 and 1), got %d", processedRecords) + } +} + +func TestStreamingEmptyStream(t *testing.T) { + sp := jsonstreamStreamingProcessor(t) + + err := sp.ProcessRequestRecords(strings.NewReader(""), plugintypes.BodyProcessorOptions{}, + func(_ int, _ map[string]string, _ string) error { + t.Error("callback should not be called for empty stream") + return nil + }) + + if err == nil { + t.Error("expected error for empty stream") + } +} + +func TestStreamingRFC7464WithCallback(t *testing.T) { + input := "\x1e{\"name\": \"Alice\"}\n\x1e{\"name\": \"Bob\"}\n" + + sp := jsonstreamStreamingProcessor(t) + + var records []int + var allFields []map[string]string + + err := sp.ProcessRequestRecords(strings.NewReader(input), plugintypes.BodyProcessorOptions{}, + func(recordNum int, fields map[string]string, _ string) error { + records = append(records, recordNum) + copy := make(map[string]string, len(fields)) + for k, v := range fields { + copy[k] = v + } + allFields = append(allFields, copy) + return nil + }) + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(records) != 2 { + t.Fatalf("expected 2 records, got %d", len(records)) + } + + if v := allFields[0]["json.0.name"]; v != "Alice" { + t.Errorf("expected json.0.name=Alice, got %q", v) + } + if v := allFields[1]["json.1.name"]; v != "Bob" { + t.Errorf("expected json.1.name=Bob, got %q", v) + } +} + +func TestStreamingResponseRecords(t *testing.T) { + input := `{"status": "ok"} +{"status": "error"} +` + sp := jsonstreamStreamingProcessor(t) + processedRecords := 0 + + err := sp.ProcessResponseRecords(strings.NewReader(input), plugintypes.BodyProcessorOptions{}, + func(recordNum int, fields map[string]string, _ string) error { + processedRecords++ + return nil + }) + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if processedRecords != 2 { + t.Errorf("expected 2 records, got %d", processedRecords) + } +} + +func TestStreamingBackwardCompat(t *testing.T) { + // Verify that ProcessRequest still works unchanged after refactoring + input := `{"name": "Alice"} +{"name": "Bob"} +` + jsp := jsonstreamProcessor(t) + v := corazawaf.NewTransactionVariables() + + err := jsp.ProcessRequest(strings.NewReader(input), v, plugintypes.BodyProcessorOptions{}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + argsPost := v.ArgsPost() + if vals := argsPost.Get("json.0.name"); len(vals) == 0 || vals[0] != "Alice" { + t.Errorf("expected json.0.name=Alice via ProcessRequest, got %v", vals) + } + if vals := argsPost.Get("json.1.name"); len(vals) == 0 || vals[0] != "Bob" { + t.Errorf("expected json.1.name=Bob via ProcessRequest, got %v", vals) + } +} diff --git a/experimental/plugins/bodyprocessors.go b/experimental/plugins/bodyprocessors.go index 10ef7b552..2b4f1ea02 100644 --- a/experimental/plugins/bodyprocessors.go +++ b/experimental/plugins/bodyprocessors.go @@ -14,3 +14,9 @@ import ( func RegisterBodyProcessor(name string, fn func() plugintypes.BodyProcessor) { bodyprocessors.RegisterBodyProcessor(name, fn) } + +// GetBodyProcessor returns a body processor by name. +// If the body processor is not found, it returns an error +func GetBodyProcessor(name string) (plugintypes.BodyProcessor, error) { + return bodyprocessors.GetBodyProcessor(name) +} diff --git a/experimental/plugins/plugintypes/bodyprocessor.go b/experimental/plugins/plugintypes/bodyprocessor.go index d3052a53c..ff9f3c04c 100644 --- a/experimental/plugins/plugintypes/bodyprocessor.go +++ b/experimental/plugins/plugintypes/bodyprocessor.go @@ -21,6 +21,8 @@ type BodyProcessorOptions struct { FileMode fs.FileMode // DirMode is the mode of the directory that will be created DirMode fs.FileMode + // RequestBodyRecursionLimit is the maximum recursion level accepted in a body processor + RequestBodyRecursionLimit int } // BodyProcessor interface is used to create @@ -32,3 +34,24 @@ type BodyProcessor interface { ProcessRequest(reader io.Reader, variables TransactionVariables, options BodyProcessorOptions) error ProcessResponse(reader io.Reader, variables TransactionVariables, options BodyProcessorOptions) error } + +// StreamingBodyProcessor extends BodyProcessor with per-record streaming support. +// Body processors that handle multi-record formats (NDJSON, JSON-Seq) can implement +// this interface to enable per-record rule evaluation instead of evaluating rules +// only after the entire body has been consumed. +// +// The callback receives pre-formatted field keys including the record number prefix +// (e.g., "json.0.name", "json.1.age") and the raw record text. Returning a non-nil +// error from the callback stops processing immediately. +type StreamingBodyProcessor interface { + BodyProcessor + + // ProcessRequestRecords reads records one at a time from the reader and calls fn + // for each record's parsed fields. Processing stops if fn returns a non-nil error. + ProcessRequestRecords(reader io.Reader, options BodyProcessorOptions, + fn func(recordNum int, fields map[string]string, rawRecord string) error) error + + // ProcessResponseRecords is the response equivalent of ProcessRequestRecords. + ProcessResponseRecords(reader io.Reader, options BodyProcessorOptions, + fn func(recordNum int, fields map[string]string, rawRecord string) error) error +} diff --git a/experimental/plugins/plugintypes/operator.go b/experimental/plugins/plugintypes/operator.go index 3d950aee3..0aa598d46 100644 --- a/experimental/plugins/plugintypes/operator.go +++ b/experimental/plugins/plugintypes/operator.go @@ -5,6 +5,12 @@ package plugintypes import "io/fs" +// Memoizer caches the result of expensive function calls by key. +// Implementations must be safe for concurrent use. +type Memoizer interface { + Do(key string, fn func() (any, error)) (any, error) +} + // OperatorOptions is used to store the options for a rule operator type OperatorOptions struct { // Arguments is used to store the operator args @@ -18,6 +24,9 @@ type OperatorOptions struct { // Datasets contains input datasets or dictionaries Datasets map[string][]string + + // Memoizer caches expensive compilations (regex, aho-corasick). + Memoizer Memoizer } // Operator interface is used to define rule @operators diff --git a/experimental/rule_observer.go b/experimental/rule_observer.go new file mode 100644 index 000000000..dea42fea4 --- /dev/null +++ b/experimental/rule_observer.go @@ -0,0 +1,25 @@ +// Copyright 2026 Juan Pablo Tosso and the OWASP Coraza contributors +// SPDX-License-Identifier: Apache-2.0 + +package experimental + +import ( + "github.com/corazawaf/coraza/v3" + "github.com/corazawaf/coraza/v3/types" +) + +// wafConfigWithRuleObserver is the private capability interface +type wafConfigWithRuleObserver interface { + WithRuleObserver(func(rule types.RuleMetadata)) coraza.WAFConfig +} + +// WAFConfigWithRuleObserver applies a rule observer if supported. +func WAFConfigWithRuleObserver( + cfg coraza.WAFConfig, + observer func(rule types.RuleMetadata), +) coraza.WAFConfig { + if c, ok := cfg.(wafConfigWithRuleObserver); ok { + return c.WithRuleObserver(observer) + } + return cfg +} diff --git a/experimental/rule_observer_test.go b/experimental/rule_observer_test.go new file mode 100644 index 000000000..ec70ab20f --- /dev/null +++ b/experimental/rule_observer_test.go @@ -0,0 +1,82 @@ +// Copyright 2026 Juan Pablo Tosso and the OWASP Coraza contributors +// SPDX-License-Identifier: Apache-2.0 + +package experimental_test + +import ( + "testing" + + "github.com/corazawaf/coraza/v3" + "github.com/corazawaf/coraza/v3/experimental" + "github.com/corazawaf/coraza/v3/types" +) + +func TestRuleObserver(t *testing.T) { + testCases := map[string]struct { + directives string + withObserver bool + expectRules int + }{ + "no observer configured": { + directives: ` + SecRule REQUEST_URI "@contains /test" "id:1000,phase:1,deny" + `, + withObserver: false, + expectRules: 0, + }, + "single rule observed": { + directives: ` + SecRule REQUEST_URI "@contains /test" "id:1001,phase:1,deny" + `, + withObserver: true, + expectRules: 1, + }, + "multiple rules observed": { + directives: ` + SecRule REQUEST_URI "@contains /a" "id:1002,phase:1,deny" + SecRule REQUEST_URI "@contains /b" "id:1003,phase:2,deny" + `, + withObserver: true, + expectRules: 2, + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + var observed []types.RuleMetadata + + cfg := coraza.NewWAFConfig(). + WithDirectives(tc.directives) + + if tc.withObserver { + cfg = experimental.WAFConfigWithRuleObserver(cfg, func(rule types.RuleMetadata) { + observed = append(observed, rule) + }) + } + + waf, err := coraza.NewWAF(cfg) + if err != nil { + t.Fatalf("unexpected error creating WAF: %v", err) + } + if waf == nil { + t.Fatal("waf is nil") + } + + if len(observed) != tc.expectRules { + t.Fatalf("expected %d observed rules, got %d", tc.expectRules, len(observed)) + } + + for _, rule := range observed { + if rule.ID() == 0 { + t.Fatal("expected rule ID to be set") + } + if rule.File() == "" { + t.Fatal("expected rule file to be set") + } + if rule.Line() == 0 { + t.Fatal("expected rule line to be set") + } + } + }) + } +} diff --git a/experimental/streaming.go b/experimental/streaming.go new file mode 100644 index 000000000..2d3028433 --- /dev/null +++ b/experimental/streaming.go @@ -0,0 +1,33 @@ +// Copyright 2026 OWASP Coraza contributors +// SPDX-License-Identifier: Apache-2.0 + +package experimental + +import ( + "io" + + "github.com/corazawaf/coraza/v3/types" +) + +// StreamingTransaction extends Transaction with streaming body processing capabilities. +// Transactions created by a WAF instance implement this interface when the WAF supports +// streaming body processors (e.g., NDJSON, JSON-Seq). +// +// Unlike the standard ProcessRequestBody/ProcessResponseBody methods which require the +// full body to be buffered first, streaming methods read records directly from input, +// evaluate rules per record, and write clean records to output for relay to the backend. +type StreamingTransaction interface { + types.Transaction + + // ProcessRequestBodyFromStream reads records from input, evaluates Phase 2 rules + // per record, and writes clean records to output. If a record triggers an interruption, + // processing stops and the interruption is returned. + // + // For non-streaming body processors, this falls back to buffering the input, + // processing it normally, and copying the result to output. + ProcessRequestBodyFromStream(input io.Reader, output io.Writer) (*types.Interruption, error) + + // ProcessResponseBodyFromStream reads records from input, evaluates Phase 4 rules + // per record, and writes clean records to output. + ProcessResponseBodyFromStream(input io.Reader, output io.Writer) (*types.Interruption, error) +} diff --git a/experimental/waf.go b/experimental/waf.go index bc167a2a9..c2fd967ca 100644 --- a/experimental/waf.go +++ b/experimental/waf.go @@ -4,6 +4,8 @@ package experimental import ( + "io" + "github.com/corazawaf/coraza/v3/internal/corazawaf" "github.com/corazawaf/coraza/v3/types" ) @@ -15,3 +17,19 @@ type Options = corazawaf.Options type WAFWithOptions interface { NewTransactionWithOptions(Options) types.Transaction } + +// WAFWithRules is an interface that allows to inspect the number of +// rules loaded in a WAF instance. This is useful for connectors that +// need to verify rule loading or implement configuration caching. +type WAFWithRules interface { + // RulesCount returns the number of rules in this WAF. + RulesCount() int +} + +// WAFCloser allows closing a WAF instance to release cached resources +// such as compiled regex patterns. Transactions in-flight are unaffected +// as they hold their own references to compiled objects. +// This will be promoted to the public WAF interface in v4. +type WAFCloser interface { + io.Closer +} diff --git a/go.mod b/go.mod index d58a70941..5098dac9b 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/corazawaf/coraza/v3 -go 1.24.0 +go 1.25.0 // Testing dependencies: // - go-mockdns @@ -19,7 +19,7 @@ go 1.24.0 require ( github.com/anuraaga/go-modsecurity v0.0.0-20220824035035-b9a4099778df github.com/corazawaf/coraza-coreruleset v0.0.0-20240226094324-415b1017abdc - github.com/corazawaf/libinjection-go v0.2.2 + github.com/corazawaf/libinjection-go v0.3.2 github.com/foxcpp/go-mockdns v1.1.0 github.com/jcchavezs/mergefs v0.1.0 github.com/kaptinlin/jsonschema v0.4.6 @@ -28,8 +28,8 @@ require ( github.com/petar-dambovaliev/aho-corasick v0.0.0-20250424160509-463d218d4745 github.com/tidwall/gjson v1.18.0 github.com/valllabh/ocsf-schema-golang v1.0.3 - golang.org/x/net v0.43.0 - golang.org/x/sync v0.16.0 + golang.org/x/net v0.52.0 + golang.org/x/sync v0.20.0 rsc.io/binaryregexp v0.2.0 ) @@ -45,10 +45,10 @@ require ( github.com/stretchr/testify v1.11.1 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.1 // indirect - golang.org/x/mod v0.26.0 // indirect - golang.org/x/sys v0.35.0 // indirect - golang.org/x/text v0.28.0 // indirect - golang.org/x/tools v0.35.0 // indirect + golang.org/x/mod v0.33.0 // indirect + golang.org/x/sys v0.42.0 // indirect + golang.org/x/text v0.35.0 // indirect + golang.org/x/tools v0.42.0 // indirect google.golang.org/protobuf v1.35.1 // indirect ) diff --git a/go.sum b/go.sum index 9ca14c92f..0a7340c52 100644 --- a/go.sum +++ b/go.sum @@ -2,8 +2,8 @@ github.com/anuraaga/go-modsecurity v0.0.0-20220824035035-b9a4099778df h1:YWiVl53 github.com/anuraaga/go-modsecurity v0.0.0-20220824035035-b9a4099778df/go.mod h1:7jguE759ADzy2EkxGRXigiC0ER1Yq2IFk2qNtwgzc7U= github.com/corazawaf/coraza-coreruleset v0.0.0-20240226094324-415b1017abdc h1:OlJhrgI3I+FLUCTI3JJW8MoqyM78WbqJjecqMnqG+wc= github.com/corazawaf/coraza-coreruleset v0.0.0-20240226094324-415b1017abdc/go.mod h1:7rsocqNDkTCira5T0M7buoKR2ehh7YZiPkzxRuAgvVU= -github.com/corazawaf/libinjection-go v0.2.2 h1:Chzodvb6+NXh6wew5/yhD0Ggioif9ACrQGR4qjTCs1g= -github.com/corazawaf/libinjection-go v0.2.2/go.mod h1:OP4TM7xdJ2skyXqNX1AN1wN5nNZEmJNuWbNPOItn7aw= +github.com/corazawaf/libinjection-go v0.3.2 h1:9rrKt0lpg4WvUXt+lwS06GywfqRXXsa/7JcOw5cQLwI= +github.com/corazawaf/libinjection-go v0.3.2/go.mod h1:Ik/+w3UmTWH9yn366RgS9D95K3y7Atb5m/H/gXzzPCk= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/foxcpp/go-mockdns v1.1.0 h1:jI0rD8M0wuYAxL7r/ynTrCQQq0BVqfB99Vgk7DlmewI= @@ -57,8 +57,8 @@ golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91 golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.14.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= -golang.org/x/mod v0.26.0 h1:EGMPT//Ezu+ylkCijjPc+f4Aih7sZvaAr+O3EHBxvZg= -golang.org/x/mod v0.26.0/go.mod h1:/j6NAhSk8iQ723BGAUyoAcn7SlD7s15Dp9Nd/SfeaFQ= +golang.org/x/mod v0.33.0 h1:tHFzIWbBifEmbwtGz65eaWyGiGZatSrT9prnU8DbVL8= +golang.org/x/mod v0.33.0/go.mod h1:swjeQEj+6r7fODbD2cqrnje9PnziFuw4bmLbBZFrQ5w= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= @@ -67,16 +67,16 @@ golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk= golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE= golang.org/x/net v0.18.0/go.mod h1:/czyP5RqHAH4odGYxBJ1qz0+CE5WZ+2j1YgoEo8F2jQ= -golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE= -golang.org/x/net v0.43.0/go.mod h1:vhO1fvI4dGsIjh73sWfUVjj3N7CA9WkKJNQm2svM6Jg= +golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0= +golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= golang.org/x/sync v0.4.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= golang.org/x/sync v0.5.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= -golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw= -golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= +golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= +golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -87,8 +87,8 @@ golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.14.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI= -golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo= +golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= @@ -103,16 +103,16 @@ golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= -golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng= -golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU= +golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8= +golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58= golang.org/x/tools v0.15.0/go.mod h1:hpksKq4dtpQWS1uQ61JkdqWM3LscIS6Slf+VVkm+wQk= -golang.org/x/tools v0.35.0 h1:mBffYraMEf7aa0sB+NuKnuCy8qI/9Bughn8dC2Gu5r0= -golang.org/x/tools v0.35.0/go.mod h1:NKdj5HkL/73byiZSJjqJgKn3ep7KjFkBOkR/Hps3VPw= +golang.org/x/tools v0.42.0 h1:uNgphsn75Tdz5Ji2q36v/nsFSfR/9BRFvqhGBaJGd5k= +golang.org/x/tools v0.42.0/go.mod h1:Ma6lCIwGZvHK6XtgbswSoWroEkhugApmsXyrUmBhfr0= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/protobuf v1.35.1 h1:m3LfL6/Ca+fqnjnlqQXNpFPABW1UD7mjh8KO2mKFytA= google.golang.org/protobuf v1.35.1/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE= diff --git a/go.work b/go.work index 3e2fdc963..c03028a31 100644 --- a/go.work +++ b/go.work @@ -1,4 +1,4 @@ -go 1.24.0 +go 1.25.0 use ( . diff --git a/http/interceptor_test.go b/http/interceptor_test.go index cab89b793..dcd0c9361 100644 --- a/http/interceptor_test.go +++ b/http/interceptor_test.go @@ -92,10 +92,10 @@ func TestWriteWithWriteHeader(t *testing.T) { res := httptest.NewRecorder() rw, responseProcessor := wrap(res, req, tx) - rw.WriteHeader(204) + rw.WriteHeader(201) // although we called WriteHeader, status code should be applied until // responseProcessor is called. - if unwanted, have := 204, res.Code; unwanted == have { + if unwanted, have := 201, res.Code; unwanted == have { t.Errorf("unexpected status code %d", have) } @@ -114,7 +114,7 @@ func TestWriteWithWriteHeader(t *testing.T) { t.Errorf("unexpected error: %v", err) } - if want, have := 204, res.Code; want != have { + if want, have := 201, res.Code; want != have { t.Errorf("unexpected status code, want %d, have %d", want, have) } } @@ -203,10 +203,10 @@ func TestReadFrom(t *testing.T) { } rw, responseProcessor := wrap(resWithReaderFrom, req, tx) - rw.WriteHeader(204) + rw.WriteHeader(201) // although we called WriteHeader, status code should be applied until // responseProcessor is called. - if unwanted, have := 204, res.Code; unwanted == have { + if unwanted, have := 201, res.Code; unwanted == have { t.Errorf("unexpected status code %d", have) } @@ -225,7 +225,7 @@ func TestReadFrom(t *testing.T) { t.Errorf("unexpected error: %v", err) } - if want, have := 204, res.Code; want != have { + if want, have := 201, res.Code; want != have { t.Errorf("unexpected status code, want %d, have %d", want, have) } } diff --git a/http/middleware.go b/http/middleware.go index 86e398df0..5ec73ce91 100644 --- a/http/middleware.go +++ b/http/middleware.go @@ -55,8 +55,10 @@ func processRequest(tx types.Transaction, req *http.Request) (*types.Interruptio // Transfer-Encoding header is removed by go/http // We manually add it to make rules relying on it work (E.g. CRS rule 920171) - if req.TransferEncoding != nil { - tx.AddRequestHeader("Transfer-Encoding", req.TransferEncoding[0]) + // All values must be added to allow the WAF to detect HTTP request smuggling + // attempts (e.g. TE.TE attacks). + for _, te := range req.TransferEncoding { + tx.AddRequestHeader("Transfer-Encoding", te) } in = tx.ProcessRequestHeaders() @@ -92,6 +94,9 @@ func processRequest(tx types.Transaction, req *http.Request) (*types.Interruptio } } + // ProcessRequestBody evaluates Phase 2 rules. For streaming body processors + // (e.g., NDJSON, JSON-Seq), rules are evaluated per record. For standard + // body processors, rules are evaluated once after the full body is processed. return tx.ProcessRequestBody() } diff --git a/http/middleware_test.go b/http/middleware_test.go index a1c958fdb..9d9450042 100644 --- a/http/middleware_test.go +++ b/http/middleware_test.go @@ -108,6 +108,32 @@ SecRule &REQUEST_HEADERS:Transfer-Encoding "!@eq 0" "id:1,phase:1,deny" } } +func TestProcessRequestMultipleTransferEncodings(t *testing.T) { + // Multiple Transfer-Encoding values are a classic HTTP request smuggling vector (TE.TE attacks). + // All values should be forwarded to the WAF. + waf, _ := coraza.NewWAF(coraza.NewWAFConfig(). + WithDirectives(` +SecRule REQUEST_HEADERS:Transfer-Encoding "@contains identity" "id:1,phase:1,deny" +`)) + tx := waf.NewTransaction() + + req, _ := http.NewRequest("GET", "https://www.coraza.io/test", nil) + req.TransferEncoding = []string{"chunked", "identity"} + + it, err := processRequest(tx, req) + if err != nil { + t.Fatal(err) + } + if it == nil { + t.Fatal("Expected interruption: second Transfer-Encoding value should be processed") + } else if it.RuleID != 1 { + t.Fatalf("Expected rule 1 to be triggered, got rule %d", it.RuleID) + } + if err := tx.Close(); err != nil { + t.Fatal(err) + } +} + func createMultipartRequest(t *testing.T) *http.Request { t.Helper() diff --git a/internal/actions/actions.go b/internal/actions/actions.go index ed8889fb9..df1b33cac 100644 --- a/internal/actions/actions.go +++ b/internal/actions/actions.go @@ -1,6 +1,67 @@ // Copyright 2022 Juan Pablo Tosso and the OWASP Coraza contributors // SPDX-License-Identifier: Apache-2.0 +// Package actions implements SecLang rule actions for processing and control flow. +// +// # Overview +// +// Actions define how the system handles HTTP requests when rule conditions match. +// Actions are defined as part of a SecRule or as parameters for SecAction or SecDefaultAction. +// A rule can have no or several actions which need to be separated by a comma. +// +// # Action Categories +// +// Actions are categorized into five types: +// +// 1. Disruptive Actions +// +// Trigger Coraza operations such as blocking or allowing transactions. +// Only one disruptive action per rule applies; if multiple are specified, +// the last one takes precedence. Disruptive actions will NOT be executed +// if SecRuleEngine is set to DetectionOnly. +// +// Examples: deny, drop, redirect, allow, block, pass +// +// 2. Non-disruptive Actions +// +// Perform operations without affecting rule flow, such as variable modifications, +// logging, or setting metadata. These actions execute regardless of SecRuleEngine mode. +// +// Examples: log, nolog, setvar, msg, logdata, severity, tag +// +// 3. Flow Actions +// +// Control rule processing and execution flow. These actions determine which rules +// are evaluated and in what order. +// +// Examples: chain, skip, skipAfter +// +// 4. Meta-data Actions +// +// Provide information about rules, such as identification, versioning, and classification. +// These actions do not affect transaction processing. +// +// Examples: id, rev, msg, tag, severity, maturity, ver +// +// 5. Data Actions +// +// Containers that hold data for use by other actions, such as status codes +// for blocking responses. +// +// Examples: status (used with deny/redirect) +// +// # Usage +// +// Actions are specified in SecRule directives as comma-separated values: +// +// SecRule ARGS "@rx attack" "id:100,deny,log,msg:'Attack detected'" +// +// # Important Notes +// +// When using the allow action for allowlisting, it's recommended to add +// ctl:ruleEngine=On to ensure the rule executes even in DetectionOnly mode. +// +// For the complete list of available actions, see: https://coraza.io/docs/seclang/actions/ package actions import ( diff --git a/internal/actions/allow.go b/internal/actions/allow.go index 139c2c827..1d3b7713b 100644 --- a/internal/actions/allow.go +++ b/internal/actions/allow.go @@ -56,7 +56,7 @@ func (a *allowFn) Init(_ plugintypes.RuleMetadata, data string) error { func (a *allowFn) Evaluate(r plugintypes.RuleMetadata, txS plugintypes.TransactionState) { tx := txS.(*corazawaf.Transaction) - tx.AllowType = a.allow + tx.Allow(a.allow) } func (a *allowFn) Type() plugintypes.ActionType { diff --git a/internal/actions/capture.go b/internal/actions/capture.go index bab58bd00..5a846aa7e 100644 --- a/internal/actions/capture.go +++ b/internal/actions/capture.go @@ -21,8 +21,10 @@ import ( // // Example: // ``` -// SecRule REQUEST_BODY "^username=(\w{25,})" phase:2,capture,t:none,chain,id:105 -// SecRule TX:1 "(?:(?:a(dmin|nonymous)))" +// +// SecRule REQUEST_BODY "^username=(\w{25,})" "phase:2,capture,t:none,chain,id:105" +// SecRule TX:1 "(?:(?:a(dmin|nonymous)))" +// // ``` type captureFn struct{} diff --git a/internal/actions/ctl.go b/internal/actions/ctl.go index b8a04e008..a091d7e6a 100644 --- a/internal/actions/ctl.go +++ b/internal/actions/ctl.go @@ -6,6 +6,7 @@ package actions import ( "errors" "fmt" + "regexp" "strconv" "strings" @@ -73,7 +74,13 @@ const ( // // Here are some notes about the options: // -// 1. Option `ruleRemoveTargetById`, `ruleRemoveTargetByMsg`, and `ruleRemoveTargetByTag`, users don't need to use the char ! before the target list. +// 1. Option `ruleRemoveTargetById`, `ruleRemoveTargetByMsg`, and `ruleRemoveTargetByTag` accept a collection key in two forms: +// - **Exact string**: `ARGS:user` — removes only the variable whose name is exactly `user`. +// - **Regular expression** (delimited by `/`): `ARGS:/^json\.\d+\.field$/` — removes all variables whose +// names match the pattern. The closing `/` must not be preceded by an odd number of backslashes +// (e.g. `/foo\/` is treated as the literal string `/foo\/`, not a regex). An empty pattern (`//`) is rejected. +// Pattern matching is always case-insensitive because variable names are lowercased before comparison. +// Users do not need to use the `!` character before the target list. // // 2. Option `ruleRemoveById` is triggered at run time and should be specified before the rule in which it is disabling. // @@ -99,17 +106,30 @@ const ( // SecRule REQUEST_URI "@beginsWith /index.php" "phase:1,t:none,pass,\ // nolog,ctl:ruleRemoveTargetById=981260;ARGS:user" // +// # white-list all JSON array fields matching a pattern for rule #932125 when the REQUEST_URI begins with /api/jobs +// +// SecRule REQUEST_URI "@beginsWith /api/jobs" "phase:1,t:none,pass,\ +// nolog,ctl:ruleRemoveTargetById=932125;ARGS:/^json\.\d+\.jobdescription$/" +// // ``` type ctlFn struct { action ctlFunctionType value string collection variables.RuleVariable colKey string + colKeyRx *regexp.Regexp } -func (a *ctlFn) Init(_ plugintypes.RuleMetadata, data string) error { +func (a *ctlFn) Init(m plugintypes.RuleMetadata, data string) error { + // Type-assert RuleMetadata to *corazawaf.Rule to access the rule's memoizer. + // When the assertion fails (e.g., in tests using a stub RuleMetadata), the + // memoizer remains nil and regex compilation proceeds without caching. + var memoizer plugintypes.Memoizer + if r, ok := m.(*corazawaf.Rule); ok { + memoizer = r.Memoizer() + } var err error - a.action, a.value, a.collection, a.colKey, err = parseCtl(data) + a.action, a.value, a.collection, a.colKey, a.colKeyRx, err = parseCtl(data, memoizer) return err } @@ -130,7 +150,7 @@ func (a *ctlFn) Evaluate(_ plugintypes.RuleMetadata, txS plugintypes.Transaction tx := txS.(*corazawaf.Transaction) switch a.action { case ctlRuleRemoveTargetByID: - ran, err := rangeToInts(tx.WAF.Rules.GetRules(), a.value) + start, end, err := parseIDOrRange(a.value) if err != nil { tx.DebugLogger().Error(). Str("ctl", "RuleRemoveTargetByID"). @@ -138,21 +158,23 @@ func (a *ctlFn) Evaluate(_ plugintypes.RuleMetadata, txS plugintypes.Transaction Msg("Invalid range") return } - for _, id := range ran { - tx.RemoveRuleTargetByID(id, a.collection, a.colKey) + for _, r := range tx.WAF.Rules.GetRules() { + if r.ID_ >= start && r.ID_ <= end { + tx.RemoveRuleTargetByID(r.ID_, a.collection, a.colKey, a.colKeyRx) + } } case ctlRuleRemoveTargetByTag: rules := tx.WAF.Rules.GetRules() for _, r := range rules { if utils.InSlice(a.value, r.Tags_) { - tx.RemoveRuleTargetByID(r.ID(), a.collection, a.colKey) + tx.RemoveRuleTargetByID(r.ID(), a.collection, a.colKey, a.colKeyRx) } } case ctlRuleRemoveTargetByMsg: rules := tx.WAF.Rules.GetRules() for _, r := range rules { if r.Msg != nil && r.Msg.String() == a.value { - tx.RemoveRuleTargetByID(r.ID(), a.collection, a.colKey) + tx.RemoveRuleTargetByID(r.ID(), a.collection, a.colKey, a.colKeyRx) } } case ctlAuditEngine: @@ -260,7 +282,7 @@ func (a *ctlFn) Evaluate(_ plugintypes.RuleMetadata, txS plugintypes.Transaction tx.RemoveRuleByID(id) } else { - ran, err := rangeToInts(tx.WAF.Rules.GetRules(), a.value) + start, end, err := parseRange(a.value) if err != nil { tx.DebugLogger().Error(). Str("ctl", "RuleRemoveByID"). @@ -268,9 +290,7 @@ func (a *ctlFn) Evaluate(_ plugintypes.RuleMetadata, txS plugintypes.Transaction Msg("Invalid range") return } - for _, id := range ran { - tx.RemoveRuleByID(id) - } + tx.RemoveRuleByIDRange(start, end) } case ctlRuleRemoveByMsg: rules := tx.WAF.Rules.GetRules() @@ -375,18 +395,40 @@ func (a *ctlFn) Type() plugintypes.ActionType { return plugintypes.ActionTypeNondisruptive } -func parseCtl(data string) (ctlFunctionType, string, variables.RuleVariable, string, error) { +func parseCtl(data string, memoizer plugintypes.Memoizer) (ctlFunctionType, string, variables.RuleVariable, string, *regexp.Regexp, error) { action, ctlVal, ok := strings.Cut(data, "=") if !ok { - return ctlUnknown, "", 0, "", errors.New("invalid syntax") + return ctlUnknown, "", 0, "", nil, errors.New("invalid syntax") } value, col, ok := strings.Cut(ctlVal, ";") var colkey, colname string if ok { colname, colkey, _ = strings.Cut(col, ":") + colkey = strings.TrimSpace(colkey) } collection, _ := variables.Parse(strings.TrimSpace(colname)) - colkey = strings.ToLower(colkey) + var keyRx *regexp.Regexp + if isRegex, rxPattern := utils.HasRegex(colkey); isRegex { + if len(rxPattern) == 0 { + return ctlUnknown, "", 0, "", nil, errors.New("empty regex pattern in ctl collection key") + } + var err error + if memoizer != nil { + re, compileErr := memoizer.Do(rxPattern, func() (any, error) { return regexp.Compile(rxPattern) }) + if compileErr != nil { + return ctlUnknown, "", 0, "", nil, fmt.Errorf("invalid regex in ctl collection key: %w", compileErr) + } + keyRx = re.(*regexp.Regexp) + } else { + keyRx, err = regexp.Compile(rxPattern) + if err != nil { + return ctlUnknown, "", 0, "", nil, fmt.Errorf("invalid regex in ctl collection key: %w", err) + } + } + colkey = "" + } else { + colkey = strings.ToLower(colkey) + } var act ctlFunctionType switch action { case "auditEngine": @@ -430,49 +472,46 @@ func parseCtl(data string) (ctlFunctionType, string, variables.RuleVariable, str case "debugLogLevel": act = ctlDebugLogLevel default: - return ctlUnknown, "", 0x00, "", fmt.Errorf("unknown ctl action %q", action) + return ctlUnknown, "", 0x00, "", nil, fmt.Errorf("unknown ctl action %q", action) } - return act, value, collection, strings.TrimSpace(colkey), nil + return act, value, collection, colkey, keyRx, nil } -func rangeToInts(rules []corazawaf.Rule, input string) ([]int, error) { - if len(input) == 0 { - return nil, errors.New("empty input") +// parseRange parses a range string of the form "start-end" and returns the start and end +// values as integers. It returns an error if the input is not a valid range. +func parseRange(input string) (start, end int, err error) { + in0, in1, ok := strings.Cut(input, "-") + if !ok { + return 0, 0, errors.New("no range separator found") } - - var ( - ids []int - start, end int - err error - ) - - if in0, in1, ok := strings.Cut(input, "-"); ok { - start, err = strconv.Atoi(in0) - if err != nil { - return nil, err - } - end, err = strconv.Atoi(in1) - if err != nil { - return nil, err - } - - if start > end { - return nil, errors.New("invalid range, start > end") - } - } else { - id, err := strconv.Atoi(input) - if err != nil { - return nil, err - } - start, end = id, id + start, err = strconv.Atoi(in0) + if err != nil { + return 0, 0, err + } + end, err = strconv.Atoi(in1) + if err != nil { + return 0, 0, err } + if start > end { + return 0, 0, errors.New("invalid range, start > end") + } + return start, end, nil +} - for _, r := range rules { - if r.ID_ >= start && r.ID_ <= end { - ids = append(ids, r.ID_) - } +// parseIDOrRange parses either a single integer ID or a range string of the form "start-end". +// For a single ID, start and end are equal. +func parseIDOrRange(input string) (start, end int, err error) { + if len(input) == 0 { + return 0, 0, errors.New("empty input") + } + if _, _, ok := strings.Cut(input, "-"); ok { + return parseRange(input) + } + id, err := strconv.Atoi(input) + if err != nil { + return 0, 0, err } - return ids, nil + return id, id, nil } func ctl() plugintypes.Action { diff --git a/internal/actions/ctl_test.go b/internal/actions/ctl_test.go index 252d3879b..38d27be1a 100644 --- a/internal/actions/ctl_test.go +++ b/internal/actions/ctl_test.go @@ -9,8 +9,8 @@ import ( "testing" "github.com/corazawaf/coraza/v3/debuglog" - "github.com/corazawaf/coraza/v3/internal/corazarules" "github.com/corazawaf/coraza/v3/internal/corazawaf" + "github.com/corazawaf/coraza/v3/internal/memoize" "github.com/corazawaf/coraza/v3/types" "github.com/corazawaf/coraza/v3/types/variables" ) @@ -24,6 +24,24 @@ func TestCtl(t *testing.T) { "ruleRemoveTargetById": { input: "ruleRemoveTargetById=123", }, + "ruleRemoveTargetById range": { + // Rule 1 is in WAF; range 1-5 should match it without error + input: "ruleRemoveTargetById=1-5;ARGS:test", + checkTX: func(t *testing.T, tx *corazawaf.Transaction, logEntry string) { + if wantToNotContain := "Invalid range"; strings.Contains(logEntry, wantToNotContain) { + t.Errorf("unexpected error in log: %q", logEntry) + } + }, + }, + "ruleRemoveTargetById regex key": { + // Rule 1 is in WAF; the regex /^test.*/ should remove matching ARGS targets + input: "ruleRemoveTargetById=1;ARGS:/^test.*/", + checkTX: func(t *testing.T, tx *corazawaf.Transaction, logEntry string) { + if strings.Contains(logEntry, "Invalid") || strings.Contains(logEntry, "invalid") { + t.Errorf("unexpected error in log: %q", logEntry) + } + }, + }, "ruleRemoveTargetByTag": { input: "ruleRemoveTargetByTag=tag1", }, @@ -49,7 +67,7 @@ func TestCtl(t *testing.T) { "auditLogParts": { input: "auditLogParts=ABZ", checkTX: func(t *testing.T, tx *corazawaf.Transaction, logEntry string) { - if want, have := types.AuditLogPartRequestHeaders, tx.AuditLogParts[0]; want != have { + if want, have := types.AuditLogPartRequestHeaders, tx.AuditLogParts[1]; want != have { t.Errorf("Failed to set audit log parts, want %s, have %s", string(want), string(have)) } }, @@ -173,6 +191,16 @@ func TestCtl(t *testing.T) { }, "ruleRemoveById range": { input: "ruleRemoveById=1-3", + checkTX: func(t *testing.T, tx *corazawaf.Transaction, logEntry string) { + if len(tx.GetRuleRemoveByIDRanges()) != 1 { + t.Errorf("expected 1 range entry, got %d", len(tx.GetRuleRemoveByIDRanges())) + return + } + rng := tx.GetRuleRemoveByIDRanges()[0] + if rng[0] != 1 || rng[1] != 3 { + t.Errorf("unexpected range [%d, %d], want [1, 3]", rng[0], rng[1]) + } + }, }, "ruleRemoveById incorrect": { input: "ruleRemoveById=W", @@ -368,7 +396,7 @@ func TestCtl(t *testing.T) { func TestParseCtl(t *testing.T) { t.Run("invalid ctl", func(t *testing.T) { - ctl, _, _, _, err := parseCtl("invalid") + ctl, _, _, _, _, err := parseCtl("invalid", nil) if err == nil { t.Errorf("expected error, got nil") } @@ -379,7 +407,7 @@ func TestParseCtl(t *testing.T) { }) t.Run("malformed ctl", func(t *testing.T) { - ctl, _, _, _, err := parseCtl("unknown=") + ctl, _, _, _, _, err := parseCtl("unknown=", nil) if err == nil { t.Errorf("expected error, got nil") } @@ -389,35 +417,83 @@ func TestParseCtl(t *testing.T) { } }) + t.Run("invalid regex in colKey", func(t *testing.T) { + _, _, _, _, _, err := parseCtl("ruleRemoveTargetById=1;ARGS:/[invalid/", nil) + if err == nil { + t.Errorf("expected error for invalid regex, got nil") + } + }) + + t.Run("empty regex pattern in colKey", func(t *testing.T) { + _, _, _, _, _, err := parseCtl("ruleRemoveTargetById=1;ARGS://", nil) + if err == nil { + t.Errorf("expected error for empty regex pattern, got nil") + } + }) + + t.Run("escaped slash not treated as regex", func(t *testing.T) { + _, _, _, key, rx, err := parseCtl(`ruleRemoveTargetById=1;ARGS:/user\/`, nil) + if err != nil { + t.Fatalf("unexpected error: %s", err.Error()) + } + if rx != nil { + t.Errorf("expected nil regex for escaped-slash key, got: %s", rx.String()) + } + if key != `/user\/` { + t.Errorf("unexpected key, want %q, have %q", `/user\/`, key) + } + }) + + t.Run("memoizer with valid regex", func(t *testing.T) { + m := memoize.NewMemoizer(99) + _, _, _, _, keyRx, err := parseCtl("ruleRemoveTargetById=1;ARGS:/^test.*/", m) + if err != nil { + t.Fatalf("unexpected error: %s", err.Error()) + } + if keyRx == nil { + t.Error("expected non-nil compiled regex, got nil") + } + }) + + t.Run("memoizer with invalid regex", func(t *testing.T) { + m := memoize.NewMemoizer(100) + _, _, _, _, _, err := parseCtl("ruleRemoveTargetById=1;ARGS:/[invalid/", m) + if err == nil { + t.Error("expected error for invalid regex with memoizer, got nil") + } + }) + tCases := []struct { input string expectAction ctlFunctionType expectValue string expectCollection variables.RuleVariable expectKey string + expectKeyRx string }{ - {"auditEngine=On", ctlAuditEngine, "On", variables.Unknown, ""}, - {"auditLogParts=A", ctlAuditLogParts, "A", variables.Unknown, ""}, - {"requestBodyAccess=On", ctlRequestBodyAccess, "On", variables.Unknown, ""}, - {"requestBodyLimit=100", ctlRequestBodyLimit, "100", variables.Unknown, ""}, - {"requestBodyProcessor=JSON", ctlRequestBodyProcessor, "JSON", variables.Unknown, ""}, - {"forceRequestBodyVariable=On", ctlForceRequestBodyVariable, "On", variables.Unknown, ""}, - {"responseBodyAccess=On", ctlResponseBodyAccess, "On", variables.Unknown, ""}, - {"responseBodyLimit=100", ctlResponseBodyLimit, "100", variables.Unknown, ""}, - {"responseBodyProcessor=JSON", ctlResponseBodyProcessor, "JSON", variables.Unknown, ""}, - {"forceResponseBodyVariable=On", ctlForceResponseBodyVariable, "On", variables.Unknown, ""}, - {"ruleEngine=On", ctlRuleEngine, "On", variables.Unknown, ""}, - {"ruleRemoveById=1", ctlRuleRemoveByID, "1", variables.Unknown, ""}, - {"ruleRemoveById=1-9", ctlRuleRemoveByID, "1-9", variables.Unknown, ""}, - {"ruleRemoveByMsg=MY_MSG", ctlRuleRemoveByMsg, "MY_MSG", variables.Unknown, ""}, - {"ruleRemoveByTag=MY_TAG", ctlRuleRemoveByTag, "MY_TAG", variables.Unknown, ""}, - {"ruleRemoveTargetByMsg=MY_MSG;ARGS:user", ctlRuleRemoveTargetByMsg, "MY_MSG", variables.Args, "user"}, - {"ruleRemoveTargetById=2;REQUEST_FILENAME:", ctlRuleRemoveTargetByID, "2", variables.RequestFilename, ""}, + {"auditEngine=On", ctlAuditEngine, "On", variables.Unknown, "", ""}, + {"auditLogParts=A", ctlAuditLogParts, "A", variables.Unknown, "", ""}, + {"requestBodyAccess=On", ctlRequestBodyAccess, "On", variables.Unknown, "", ""}, + {"requestBodyLimit=100", ctlRequestBodyLimit, "100", variables.Unknown, "", ""}, + {"requestBodyProcessor=JSON", ctlRequestBodyProcessor, "JSON", variables.Unknown, "", ""}, + {"forceRequestBodyVariable=On", ctlForceRequestBodyVariable, "On", variables.Unknown, "", ""}, + {"responseBodyAccess=On", ctlResponseBodyAccess, "On", variables.Unknown, "", ""}, + {"responseBodyLimit=100", ctlResponseBodyLimit, "100", variables.Unknown, "", ""}, + {"responseBodyProcessor=JSON", ctlResponseBodyProcessor, "JSON", variables.Unknown, "", ""}, + {"forceResponseBodyVariable=On", ctlForceResponseBodyVariable, "On", variables.Unknown, "", ""}, + {"ruleEngine=On", ctlRuleEngine, "On", variables.Unknown, "", ""}, + {"ruleRemoveById=1", ctlRuleRemoveByID, "1", variables.Unknown, "", ""}, + {"ruleRemoveById=1-9", ctlRuleRemoveByID, "1-9", variables.Unknown, "", ""}, + {"ruleRemoveByMsg=MY_MSG", ctlRuleRemoveByMsg, "MY_MSG", variables.Unknown, "", ""}, + {"ruleRemoveByTag=MY_TAG", ctlRuleRemoveByTag, "MY_TAG", variables.Unknown, "", ""}, + {"ruleRemoveTargetByMsg=MY_MSG;ARGS:user", ctlRuleRemoveTargetByMsg, "MY_MSG", variables.Args, "user", ""}, + {"ruleRemoveTargetById=2;REQUEST_FILENAME:", ctlRuleRemoveTargetByID, "2", variables.RequestFilename, "", ""}, + {"ruleRemoveTargetById=2;ARGS:/^json\\.\\d+\\.description$/", ctlRuleRemoveTargetByID, "2", variables.Args, "", `^json\.\d+\.description$`}, } for _, tCase := range tCases { testName, _, _ := strings.Cut(tCase.input, "=") t.Run(testName, func(t *testing.T) { - action, value, collection, colKey, err := parseCtl(tCase.input) + action, value, collection, colKey, colKeyRx, err := parseCtl(tCase.input, nil) if err != nil { t.Fatalf("unexpected error: %s", err.Error()) } @@ -433,53 +509,96 @@ func TestParseCtl(t *testing.T) { if colKey != tCase.expectKey { t.Errorf("unexpected key, want: %s, have: %s", tCase.expectKey, colKey) } + if tCase.expectKeyRx == "" { + if colKeyRx != nil { + t.Errorf("unexpected non-nil regex, have: %s", colKeyRx.String()) + } + } else { + if colKeyRx == nil { + t.Errorf("expected non-nil regex matching %q, got nil", tCase.expectKeyRx) + } else if colKeyRx.String() != tCase.expectKeyRx { + t.Errorf("unexpected regex, want: %s, have: %s", tCase.expectKeyRx, colKeyRx.String()) + } + } }) } } -func TestCtlParseRange(t *testing.T) { - rules := []corazawaf.Rule{ - { - RuleMetadata: corazarules.RuleMetadata{ - ID_: 5, - }, - }, - { - RuleMetadata: corazarules.RuleMetadata{ - ID_: 15, - }, - }, +func TestCtlParseIDOrRange(t *testing.T) { + tCases := []struct { + input string + expectStart int + expectEnd int + expectErr bool + }{ + {"1-2", 1, 2, false}, + {"4-5", 4, 5, false}, + {"4-15", 4, 15, false}, + {"5", 5, 5, false}, + {"", 0, 0, true}, + {"test", 0, 0, true}, + {"test-2", 0, 0, true}, + {"2-test", 0, 0, true}, + {"-", 0, 0, true}, + {"4-5-15", 0, 0, true}, + } + for _, tCase := range tCases { + t.Run(tCase.input, func(t *testing.T) { + start, end, err := parseIDOrRange(tCase.input) + if tCase.expectErr && err == nil { + t.Error("expected error for input") + } + + if !tCase.expectErr && err != nil { + t.Errorf("unexpected error for input: %s", err.Error()) + } + + if !tCase.expectErr { + if start != tCase.expectStart { + t.Errorf("unexpected start, want %d, have %d", tCase.expectStart, start) + } + if end != tCase.expectEnd { + t.Errorf("unexpected end, want %d, have %d", tCase.expectEnd, end) + } + } + }) } +} +func TestCtlParseRange(t *testing.T) { tCases := []struct { - _range string - expectedNumberOfIds int - expectErr bool + input string + expectStart int + expectEnd int + expectErr bool }{ - {"1-2", 0, false}, - {"4-5", 1, false}, - {"4-15", 2, false}, - {"5", 1, false}, - {"", 0, true}, - {"test", 0, true}, - {"test-2", 0, true}, - {"2-test", 0, true}, - {"-", 0, true}, - {"4-5-15", 0, true}, + {"1-2", 1, 2, false}, + {"4-15", 4, 15, false}, + {"5-5", 5, 5, false}, + {"test-2", 0, 0, true}, + {"2-test", 0, 0, true}, + {"5-4", 0, 0, true}, // start > end + {"-", 0, 0, true}, + {"nodash", 0, 0, true}, // no range separator } for _, tCase := range tCases { - t.Run(tCase._range, func(t *testing.T) { - ints, err := rangeToInts(rules, tCase._range) + t.Run(tCase.input, func(t *testing.T) { + start, end, err := parseRange(tCase.input) if tCase.expectErr && err == nil { - t.Error("expected error for range") + t.Error("expected error for input") } if !tCase.expectErr && err != nil { - t.Errorf("unexpected error for range: %s", err.Error()) + t.Errorf("unexpected error for input: %s", err.Error()) } - if !tCase.expectErr && len(ints) != tCase.expectedNumberOfIds { - t.Error("unexpected number of ids") + if !tCase.expectErr { + if start != tCase.expectStart { + t.Errorf("unexpected start, want %d, have %d", tCase.expectStart, start) + } + if end != tCase.expectEnd { + t.Errorf("unexpected end, want %d, have %d", tCase.expectEnd, end) + } } }) } diff --git a/internal/auditlog/formats.go b/internal/auditlog/formats.go index 316566927..dcaf142cc 100644 --- a/internal/auditlog/formats.go +++ b/internal/auditlog/formats.go @@ -21,6 +21,7 @@ package auditlog import ( "fmt" + "net/http" "strings" "github.com/corazawaf/coraza/v3/experimental/plugins/plugintypes" @@ -45,81 +46,102 @@ func (nativeFormatter) Format(al plugintypes.AuditLog) ([]byte, error) { res.WriteString(boundaryPrefix) res.WriteByte(byte(part)) res.WriteString("--\n") - // [27/Jul/2016:05:46:16 +0200] V5guiH8AAQEAADTeJ2wAAAAK 192.168.3.1 50084 192.168.3.111 80 - _, _ = fmt.Fprintf(&res, "[%s] %s %s %d %s %d", al.Transaction().Timestamp(), al.Transaction().ID(), - al.Transaction().ClientIP(), al.Transaction().ClientPort(), al.Transaction().HostIP(), al.Transaction().HostPort()) + + addSeparator := true + switch part { + case types.AuditLogPartHeader: + // Part A: Audit log header containing only the timestamp and transaction info line + // Note: Part A does not have an empty line separator after it + _, _ = fmt.Fprintf(&res, "[%s] %s %s %d %s %d\n", + al.Transaction().Timestamp(), al.Transaction().ID(), + al.Transaction().ClientIP(), al.Transaction().ClientPort(), + al.Transaction().HostIP(), al.Transaction().HostPort()) + addSeparator = false case types.AuditLogPartRequestHeaders: - // GET /url HTTP/1.1 - // Host: example.com - // User-Agent: Mozilla/5.0 - // Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8 - // Accept-Language: en-US,en;q=0.5 - // Accept-Encoding: gzip, deflate - // Referer: http://example.com/index.html - // Connection: keep-alive - // Content-Type: application/x-www-form-urlencoded - // Content-Length: 6 - _, _ = fmt.Fprintf( - &res, - "\n%s %s %s", - al.Transaction().Request().Method(), - al.Transaction().Request().URI(), - al.Transaction().Request().Protocol(), - ) - for k, vv := range al.Transaction().Request().Headers() { - for _, v := range vv { - res.WriteByte('\n') - res.WriteString(k) - res.WriteString(": ") - res.WriteString(v) + // Part B: Request headers + if al.Transaction().HasRequest() { + _, _ = fmt.Fprintf( + &res, + "%s %s %s", + al.Transaction().Request().Method(), + al.Transaction().Request().URI(), + al.Transaction().Request().Protocol(), + ) + for k, vv := range al.Transaction().Request().Headers() { + for _, v := range vv { + res.WriteByte('\n') + res.WriteString(k) + res.WriteString(": ") + res.WriteString(v) + } } + res.WriteByte('\n') } case types.AuditLogPartRequestBody: - if body := al.Transaction().Request().Body(); body != "" { - res.WriteByte('\n') - res.WriteString(body) + // Part C: Request body + if al.Transaction().HasRequest() { + if body := al.Transaction().Request().Body(); body != "" { + res.WriteString(body) + res.WriteByte('\n') + } } case types.AuditLogPartIntermediaryResponseBody: - if body := al.Transaction().Response().Body(); body != "" { - res.WriteByte('\n') - res.WriteString(al.Transaction().Response().Body()) + // Part E: Intermediary response body + if al.Transaction().HasResponse() { + if body := al.Transaction().Response().Body(); body != "" { + res.WriteString(body) + res.WriteByte('\n') + } } case types.AuditLogPartResponseHeaders: - for k, vv := range al.Transaction().Response().Headers() { - for _, v := range vv { - res.WriteByte('\n') - res.WriteString(k) - res.WriteString(": ") - res.WriteString(v) + // Part F: Response headers + if al.Transaction().HasResponse() { + // Write status line: HTTP/1.1 200 OK + protocol := al.Transaction().Response().Protocol() + if protocol == "" { + protocol = "HTTP/1.1" + } + status := al.Transaction().Response().Status() + statusText := http.StatusText(status) + _, _ = fmt.Fprintf(&res, "%s %d %s\n", protocol, status, statusText) + + // Write headers + for k, vv := range al.Transaction().Response().Headers() { + for _, v := range vv { + res.WriteString(k) + res.WriteString(": ") + res.WriteString(v) + res.WriteByte('\n') + } } } case types.AuditLogPartAuditLogTrailer: - // Stopwatch: 1470025005945403 1715 (- - -) - // Stopwatch2: 1470025005945403 1715; combined=26, p1=0, p2=0, p3=0, p4=0, p5=26, ↩ - // sr=0, sw=0, l=0, gc=0 - // Response-Body-Transformed: Dechunked - // Producer: ModSecurity for Apache/2.9.1 (http://www.modsecurity.org/). - // Server: Apache - // Engine-Mode: "ENABLED" - - // AuditLogTrailer is also expected to contain the error message generated by the rule, if any + // Part H: Audit log trailer for _, alEntry := range al.Messages() { alWithErrMsg, ok := alEntry.(auditLogWithErrMesg) if ok && alWithErrMsg.ErrorMessage() != "" { - res.WriteByte('\n') res.WriteString(alWithErrMsg.ErrorMessage()) + res.WriteByte('\n') } } - - _, _ = fmt.Fprintf(&res, "\nStopwatch: %s\nResponse-Body-Transformed: %s\nProducer: %s\nServer: %s", "", "", "", "") case types.AuditLogPartRulesMatched: + // Part K: Matched rules for _, alEntry := range al.Messages() { - res.WriteByte('\n') res.WriteString(alEntry.Data().Raw()) + res.WriteByte('\n') } + case types.AuditLogPartEndMarker: + // Part Z: Final boundary marker with no content + default: + // For any other parts (D, G, I, J) that aren't explicitly handled, + // they remain empty + } + + // Add separator newline for all parts except A + if addSeparator { + res.WriteByte('\n') } - res.WriteByte('\n') } return []byte(res.String()), nil diff --git a/internal/auditlog/formats_test.go b/internal/auditlog/formats_test.go index 79edba98b..a64b7d728 100644 --- a/internal/auditlog/formats_test.go +++ b/internal/auditlog/formats_test.go @@ -56,9 +56,9 @@ func TestNativeFormatter(t *testing.T) { if !strings.Contains(f.MIME(), "x-coraza-auditlog-native") { t.Errorf("failed to match MIME, expected json and got %s", f.MIME()) } - // Log contains random strings, do a simple sanity check - if !bytes.Contains(data, []byte("[02/Jan/2006:15:04:20 -0700] 123 0 0")) { - t.Errorf("failed to match log, \ngot: %s\n", string(data)) + // Log contains random boundary strings + if len(data) == 0 { + t.Errorf("expected non-empty log output") } scanner := bufio.NewScanner(bytes.NewReader(data)) @@ -69,22 +69,387 @@ func TestNativeFormatter(t *testing.T) { } separator := lines[0] - checkLine(t, lines, 2, "GET /test.php HTTP/1.1") - checkLine(t, lines, 3, "some: request header") + // Part B + checkLine(t, lines, 0, mutateSeparator(separator, 'B')) + checkLine(t, lines, 1, "GET /test.php HTTP/1.1") + checkLine(t, lines, 2, "some: request header") + checkLine(t, lines, 3, "") + // Part C checkLine(t, lines, 4, mutateSeparator(separator, 'C')) - checkLine(t, lines, 6, "some request body") + checkLine(t, lines, 5, "some request body") + checkLine(t, lines, 6, "") + // Part E checkLine(t, lines, 7, mutateSeparator(separator, 'E')) - checkLine(t, lines, 9, "some response body") + checkLine(t, lines, 8, "some response body") + checkLine(t, lines, 9, "") + // Part F checkLine(t, lines, 10, mutateSeparator(separator, 'F')) + checkLine(t, lines, 11, "HTTP/1.1 200 OK") checkLine(t, lines, 12, "some: response header") - checkLine(t, lines, 13, mutateSeparator(separator, 'H')) + checkLine(t, lines, 13, "") + // Part H + checkLine(t, lines, 14, mutateSeparator(separator, 'H')) checkLine(t, lines, 15, "error message") - checkLine(t, lines, 16, "Stopwatch: ") - checkLine(t, lines, 17, "Response-Body-Transformed: ") - checkLine(t, lines, 18, "Producer: ") - checkLine(t, lines, 19, "Server: ") - checkLine(t, lines, 20, mutateSeparator(separator, 'K')) - checkLine(t, lines, 22, `SecAction "id:100"`) + checkLine(t, lines, 16, "") + // Part K + checkLine(t, lines, 17, mutateSeparator(separator, 'K')) + checkLine(t, lines, 18, `SecAction "id:100"`) + checkLine(t, lines, 19, "") + }) + + t.Run("with parts A and Z", func(t *testing.T) { + al := &Log{ + Parts_: []types.AuditLogPart{ + types.AuditLogPartHeader, + types.AuditLogPartRequestHeaders, + types.AuditLogPartEndMarker, + }, + Transaction_: Transaction{ + Timestamp_: "02/Jan/2006:15:04:20 -0700", + UnixTimestamp_: 0, + ID_: "123", + Request_: &TransactionRequest{ + URI_: "/test.php", + Method_: "GET", + Headers_: map[string][]string{ + "some": { + "request header", + }, + }, + Protocol_: "HTTP/1.1", + }, + }, + } + data, err := f.Format(al) + if err != nil { + t.Error(err) + } + + scanner := bufio.NewScanner(bytes.NewReader(data)) + var lines []string + for scanner.Scan() { + lines = append(lines, scanner.Text()) + } + separator := lines[0] + + // Part A (header) - no empty line after it + checkLine(t, lines, 0, mutateSeparator(separator, 'A')) + checkLine(t, lines, 1, "[02/Jan/2006:15:04:20 -0700] 123 0 0") + + // Part B (request headers) - has empty line after it + checkLine(t, lines, 2, mutateSeparator(separator, 'B')) + checkLine(t, lines, 3, "GET /test.php HTTP/1.1") + checkLine(t, lines, 4, "some: request header") + checkLine(t, lines, 5, "") + + // Part Z (end marker) - has empty line after it + checkLine(t, lines, 6, mutateSeparator(separator, 'Z')) + checkLine(t, lines, 7, "") + }) + + t.Run("complete example matching ModSecurity format", func(t *testing.T) { + al := &Log{ + Parts_: []types.AuditLogPart{ + types.AuditLogPartHeader, + types.AuditLogPartRequestHeaders, + types.AuditLogPartIntermediaryResponseHeaders, // Part D + types.AuditLogPartResponseHeaders, // Part F + types.AuditLogPartAuditLogTrailer, // Part H + types.AuditLogPartEndMarker, + }, + Transaction_: Transaction{ + Timestamp_: "20/Feb/2025:13:20:33 +0000", + UnixTimestamp_: 1740576033, + ID_: "174005763366.604533", + ClientIP_: "192.168.65.1", + ClientPort_: 38532, + HostIP_: "172.21.0.3", + HostPort_: 8080, + Request_: &TransactionRequest{ + URI_: "/status/200", + Method_: "GET", + Protocol_: "HTTP/1.1", + HTTPVersion_: "1.1", + Headers_: map[string][]string{ + "Accept": {"*/*"}, + "Connection": {"close"}, + "Host": {"localhost"}, + }, + }, + Response_: &TransactionResponse{ + Status_: 200, + Protocol_: "HTTP/1.1", + Headers_: map[string][]string{ + "Server": {"nginx"}, + "Connection": {"close"}, + }, + }, + }, + Messages_: []plugintypes.AuditLogMessage{ + &Message{ + ErrorMessage_: "ModSecurity: Warning. Test message", + }, + }, + } + + data, err := f.Format(al) + if err != nil { + t.Error(err) + } + + output := string(data) + // Verify structure matches ModSecurity format + if !bytes.Contains(data, []byte("174005763366.604533")) { + t.Errorf("Missing transaction ID in output:\n%s", output) + } + if !bytes.Contains(data, []byte("GET /status/200 HTTP/1.1")) { + t.Errorf("Missing request line in output:\n%s", output) + } + if !bytes.Contains(data, []byte("ModSecurity: Warning. Test message")) { + t.Errorf("Missing error message in output:\n%s", output) + } + + scanner := bufio.NewScanner(bytes.NewReader(data)) + var lines []string + for scanner.Scan() { + lines = append(lines, scanner.Text()) + } + + // Verify part A exists + if !bytes.Contains(data, []byte("-A--")) { + t.Error("Missing part A boundary") + } + // Verify part Z exists + if !bytes.Contains(data, []byte("-Z--")) { + t.Error("Missing part Z boundary") + } + // Verify part A has transaction info + checkLine(t, lines, 1, "[20/Feb/2025:13:20:33 +0000] 174005763366.604533 192.168.65.1 38532 172.21.0.3 8080") + }) + + t.Run("apache example 1 - cookies and multiple messages", func(t *testing.T) { + al := &Log{ + Parts_: []types.AuditLogPart{ + types.AuditLogPartHeader, + types.AuditLogPartRequestHeaders, + types.AuditLogPartResponseHeaders, + types.AuditLogPartIntermediaryResponseBody, + types.AuditLogPartAuditLogTrailer, + types.AuditLogPartEndMarker, + }, + Transaction_: Transaction{ + Timestamp_: "20/Feb/2025:15:15:26.453565 +0000", + UnixTimestamp_: 1740064526, + ID_: "Z7dHDrSPGgnIk-ru4hvJcQAAAIA", + ClientIP_: "192.168.65.1", + ClientPort_: 42378, + HostIP_: "172.22.0.3", + HostPort_: 8080, + Request_: &TransactionRequest{ + URI_: "/", + Method_: "GET", + Protocol_: "HTTP/1.1", + Headers_: map[string][]string{ + "Accept": {"*/*"}, + "Connection": {"close"}, + "Cookie": {"$Version=1; session=\"deadbeef; PHPSESSID=secret; dummy=qaz\""}, + "Host": {"localhost"}, + "Origin": {"https://www.example.com"}, + "Referer": {"https://www.example.com/"}, + "User-Agent": {"OWASP CRS test agent"}, + }, + }, + Response_: &TransactionResponse{ + Status_: 200, + Protocol_: "HTTP/1.1", + Headers_: map[string][]string{ + "Content-Length": {"0"}, + "Connection": {"close"}, + }, + }, + }, + Messages_: []plugintypes.AuditLogMessage{ + &Message{ + ErrorMessage_: `Message: Warning. String match "1" at REQUEST_COOKIES:$Version. [file "/etc/modsecurity.d/owasp-crs/rules/REQUEST-921-PROTOCOL-ATTACK.conf"] [line "332"] [id "921250"]`, + }, + &Message{ + ErrorMessage_: `Message: Warning. Operator GE matched 5 at TX:blocking_inbound_anomaly_score. [file "/etc/modsecurity.d/owasp-crs/rules/REQUEST-949-BLOCKING-EVALUATION.conf"] [line "233"] [id "949110"]`, + }, + &Message{ + ErrorMessage_: `Apache-Error: [file "apache2_util.c"] [line 275] [level 3] [client 192.168.65.1] ModSecurity: Warning.`, + }, + }, + } + + data, err := f.Format(al) + if err != nil { + t.Error(err) + } + + output := string(data) + // Verify key elements exist + if !bytes.Contains(data, []byte("Z7dHDrSPGgnIk-ru4hvJcQAAAIA")) { + t.Errorf("Missing transaction ID in output") + } + if !bytes.Contains(data, []byte("GET / HTTP/1.1")) { + t.Errorf("Missing request line in output") + } + if !bytes.Contains(data, []byte("Cookie: $Version=1")) { + t.Errorf("Missing cookie header in output") + } + if !bytes.Contains(data, []byte("HTTP/1.1 200")) { + t.Errorf("Missing response status in output:\n%s", output) + } + // Verify multiple messages are present + if !bytes.Contains(data, []byte("REQUEST-921-PROTOCOL-ATTACK.conf")) { + t.Errorf("Missing first message in output") + } + if !bytes.Contains(data, []byte("REQUEST-949-BLOCKING-EVALUATION.conf")) { + t.Errorf("Missing second message in output") + } + if !bytes.Contains(data, []byte("Apache-Error:")) { + t.Errorf("Missing Apache-Error message in output") + } + + // Verify parts A, B, F, E, H, Z exist + if !bytes.Contains(data, []byte("-A--")) { + t.Error("Missing part A") + } + if !bytes.Contains(data, []byte("-B--")) { + t.Error("Missing part B") + } + if !bytes.Contains(data, []byte("-F--")) { + t.Error("Missing part F") + } + if !bytes.Contains(data, []byte("-E--")) { + t.Error("Missing part E") + } + if !bytes.Contains(data, []byte("-H--")) { + t.Error("Missing part H") + } + if !bytes.Contains(data, []byte("-Z--")) { + t.Error("Missing part Z") + } + }) + + t.Run("apache example 2 - SQL injection with multiple rules", func(t *testing.T) { + al := &Log{ + Parts_: []types.AuditLogPart{ + types.AuditLogPartHeader, + types.AuditLogPartRequestHeaders, + types.AuditLogPartResponseHeaders, + types.AuditLogPartIntermediaryResponseBody, + types.AuditLogPartAuditLogTrailer, + types.AuditLogPartEndMarker, + }, + Transaction_: Transaction{ + Timestamp_: "23/Feb/2025:22:40:32.479855 +0000", + UnixTimestamp_: 1740350432, + ID_: "Z7uj4AkDMIUwf_JHM4k9hAAAAAY", + ClientIP_: "192.168.65.1", + ClientPort_: 53953, + HostIP_: "172.21.0.3", + HostPort_: 8080, + Request_: &TransactionRequest{ + URI_: "/get?var=sdfsd%27or%201%20%3e%201", + Method_: "GET", + Protocol_: "HTTP/1.0", + Headers_: map[string][]string{ + "Accept": {"text/xml,application/xml,application/xhtml+xml,text/html;q=0.9,text/plain;q=0.8,image/png,*/*;q=0.5"}, + "Connection": {"close"}, + "Host": {"localhost"}, + "User-Agent": {"OWASP CRS test agent"}, + }, + }, + Response_: &TransactionResponse{ + Status_: 200, + Protocol_: "HTTP/1.1", + Headers_: map[string][]string{ + "Content-Length": {"0"}, + "Connection": {"close"}, + }, + }, + }, + Messages_: []plugintypes.AuditLogMessage{ + &Message{ + ErrorMessage_: `Message: Warning. Found 5 byte(s) in ARGS:var outside range: 38,44-46,48-58,61,65-90,95,97-122. [file "/etc/modsecurity.d/owasp-crs/rules/REQUEST-920-PROTOCOL-ENFORCEMENT.conf"] [line "1739"] [id "920273"]`, + }, + &Message{ + ErrorMessage_: `Message: Warning. detected SQLi using libinjection with fingerprint 's&1' [file "/etc/modsecurity.d/owasp-crs/rules/REQUEST-942-APPLICATION-ATTACK-SQLI.conf"] [line "66"] [id "942100"]`, + }, + &Message{ + ErrorMessage_: `Message: Warning. Pattern match "(?i)(?:/\\*)+[\"'` + "`" + `]+[\\s\\x0b]?(?:--|[#\\{]|/\\*)?|[\"'` + "`" + `](?:[\\s\\x0b]*(?:(?:x?or|and|div|like|between)" at ARGS:var. [file "/etc/modsecurity.d/owasp-crs/rules/REQUEST-942-APPLICATION-ATTACK-SQLI.conf"] [line "822"] [id "942180"]`, + }, + &Message{ + ErrorMessage_: `Message: Warning. Operator GE matched 5 at TX:blocking_inbound_anomaly_score. [file "/etc/modsecurity.d/owasp-crs/rules/REQUEST-949-BLOCKING-EVALUATION.conf"] [line "233"] [id "949110"] [msg "Inbound Anomaly Score Exceeded (Total Score: 33)"]`, + }, + }, + } + + data, err := f.Format(al) + if err != nil { + t.Error(err) + } + + output := string(data) + // Verify SQL injection detection messages + if !bytes.Contains(data, []byte("Z7uj4AkDMIUwf_JHM4k9hAAAAAY")) { + t.Errorf("Missing transaction ID") + } + if !bytes.Contains(data, []byte("/get?var=sdfsd%27or%201%20%3e%201")) { + t.Errorf("Missing URI with SQL injection attempt") + } + if !bytes.Contains(data, []byte("libinjection")) { + t.Errorf("Missing libinjection message") + } + if !bytes.Contains(data, []byte("920273")) { + t.Errorf("Missing rule ID 920273") + } + if !bytes.Contains(data, []byte("942100")) { + t.Errorf("Missing rule ID 942100") + } + if !bytes.Contains(data, []byte("942180")) { + t.Errorf("Missing rule ID 942180") + } + if !bytes.Contains(data, []byte("Inbound Anomaly Score Exceeded (Total Score: 33)")) { + t.Errorf("Missing anomaly score message") + } + + scanner := bufio.NewScanner(bytes.NewReader(data)) + var lines []string + for scanner.Scan() { + lines = append(lines, scanner.Text()) + } + + // Verify the timestamp with microseconds + if !bytes.Contains(data, []byte("[23/Feb/2025:22:40:32.479855 +0000]")) { + t.Errorf("Missing timestamp with microseconds in output:\n%s", output) + } + + // Verify part H contains all messages (they should be concatenated) + partHFound := false + for i, line := range lines { + if strings.Contains(line, "-H--") { + partHFound = true + // Check that messages follow part H marker + if i+1 < len(lines) { + // At least one message should be present + messagesFound := 0 + for j := i + 1; j < len(lines) && !strings.Contains(lines[j], "--"); j++ { + if strings.Contains(lines[j], "Message:") { + messagesFound++ + } + } + if messagesFound == 0 { + t.Errorf("No messages found after part H marker") + } + } + break + } + } + if !partHFound { + t.Error("Part H marker not found") + } }) } diff --git a/internal/auditlog/https_writer_test.go b/internal/auditlog/https_writer_test.go index 65d883a8d..5fc4e193f 100644 --- a/internal/auditlog/https_writer_test.go +++ b/internal/auditlog/https_writer_test.go @@ -57,7 +57,7 @@ func TestHTTPAuditLog(t *testing.T) { t.Fatal("Body is empty") } if !bytes.Contains(body, []byte("test123")) { - t.Fatal("Body does not match") + t.Fatalf("Body does not match, got:\n%s", string(body)) } })) defer server.Close() diff --git a/internal/bodyprocessors/json.go b/internal/bodyprocessors/json.go index 101851f7e..5176d9369 100644 --- a/internal/bodyprocessors/json.go +++ b/internal/bodyprocessors/json.go @@ -4,6 +4,7 @@ package bodyprocessors import ( + "errors" "io" "strconv" "strings" @@ -17,17 +18,18 @@ type jsonBodyProcessor struct{} var _ plugintypes.BodyProcessor = &jsonBodyProcessor{} -func (js *jsonBodyProcessor) ProcessRequest(reader io.Reader, v plugintypes.TransactionVariables, _ plugintypes.BodyProcessorOptions) error { - // Read the entire body to store it and process it +func (js *jsonBodyProcessor) ProcessRequest(reader io.Reader, v plugintypes.TransactionVariables, bpo plugintypes.BodyProcessorOptions) error { + // Read the entire body into memory for two purposes: + // 1. Store raw JSON in TX variables for operators like @validateSchema + // 2. Parse and flatten for ARGS_POST collection s := strings.Builder{} if _, err := io.Copy(&s, reader); err != nil { return err } ss := s.String() - - // Process as normal + // Process with recursion limit col := v.ArgsPost() - data, err := readJSON(ss) + data, err := readJSON(ss, bpo.RequestBodyRecursionLimit) if err != nil { return err } @@ -45,6 +47,8 @@ func (js *jsonBodyProcessor) ProcessRequest(reader io.Reader, v plugintypes.Tran return nil } +const ignoreJSONRecursionLimit = -1 + func (js *jsonBodyProcessor) ProcessResponse(reader io.Reader, v plugintypes.TransactionVariables, _ plugintypes.BodyProcessorOptions) error { // Read the entire body to store it and process it s := strings.Builder{} @@ -52,10 +56,9 @@ func (js *jsonBodyProcessor) ProcessResponse(reader io.Reader, v plugintypes.Tra return err } ss := s.String() - - // Process as normal + // Process with no recursion limit as we don't have a directive for response body col := v.ResponseArgs() - data, err := readJSON(ss) + data, err := readJSON(ss, ignoreJSONRecursionLimit) if err != nil { return err } @@ -73,22 +76,33 @@ func (js *jsonBodyProcessor) ProcessResponse(reader io.Reader, v plugintypes.Tra return nil } -func readJSON(s string) (map[string]string, error) { - json := gjson.Parse(s) +func readJSON(s string, maxRecursion int) (map[string]string, error) { res := make(map[string]string) key := []byte("json") - readItems(json, key, res) - return res, nil + + if !gjson.Valid(s) { + return res, errors.New("invalid JSON") + } + json := gjson.Parse(s) + err := readItems(json, key, maxRecursion, res) + return res, err } // Transform JSON to a map[string]string +// This function is recursive and will call itself for nested objects. +// The limit in recursion is defined by maxItems. // Example input: {"data": {"name": "John", "age": 30}, "items": [1,2,3]} // Example output: map[string]string{"json.data.name": "John", "json.data.age": "30", "json.items.0": "1", "json.items.1": "2", "json.items.2": "3"} // Example input: [{"data": {"name": "John", "age": 30}, "items": [1,2,3]}] // Example output: map[string]string{"json.0.data.name": "John", "json.0.data.age": "30", "json.0.items.0": "1", "json.0.items.1": "2", "json.0.items.2": "3"} -// TODO add some anti DOS protection -func readItems(json gjson.Result, objKey []byte, res map[string]string) { +func readItems(json gjson.Result, objKey []byte, maxRecursion int, res map[string]string) error { arrayLen := 0 + var iterationError error + if maxRecursion == 0 { + // We reached the limit of nesting we want to handle. This protects against + // DoS attacks using deeply nested JSON structures (e.g., {"a":{"a":{"a":...}}}). + return errors.New("max recursion reached while reading json object") + } json.ForEach(func(key, value gjson.Result) bool { // Avoid string concatenation to maintain a single buffer for key aggregation. prevParentLength := len(objKey) @@ -103,7 +117,11 @@ func readItems(json gjson.Result, objKey []byte, res map[string]string) { var val string switch value.Type { case gjson.JSON: - readItems(value, objKey, res) + // call recursively with one less item to avoid doing infinite recursion + iterationError = readItems(value, objKey, maxRecursion-1, res) + if iterationError != nil { + return false + } objKey = objKey[:prevParentLength] return true case gjson.String: @@ -123,6 +141,7 @@ func readItems(json gjson.Result, objKey []byte, res map[string]string) { if arrayLen > 0 { res[string(objKey)] = strconv.Itoa(arrayLen) } + return iterationError } func init() { diff --git a/internal/bodyprocessors/json_test.go b/internal/bodyprocessors/json_test.go index 38e190e03..c4042342f 100644 --- a/internal/bodyprocessors/json_test.go +++ b/internal/bodyprocessors/json_test.go @@ -4,13 +4,23 @@ package bodyprocessors import ( + "errors" + "strings" "testing" + + "github.com/tidwall/gjson" +) + +const ( + deeplyNestedJSONObject = 15000 + maxRecursion = 10000 ) var jsonTests = []struct { name string json string want map[string]string + err error }{ { name: "map", @@ -55,6 +65,7 @@ var jsonTests = []struct { "json.f.0.0": "1", "json.f.0.0.0.z": "abc", }, + err: nil, }, { name: "array", @@ -115,6 +126,35 @@ var jsonTests = []struct { "json.1.f.0.0": "1", "json.1.f.0.0.0.z": "abc", }, + err: nil, + }, + { + name: "unbalanced_brackets", + json: `{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"a": 1 }}}}}}}}}}}}}}}}}}}}}}`, + want: map[string]string{}, + err: errors.New("invalid JSON"), + }, + { + name: "broken2", + json: `{"test": 123, "test2": 456, "test3": [22, 44, 55], "test4": 3}`, + want: map[string]string{ + "json.test3.0": "22", + "json.test3.1": "44", + "json.test3.2": "55", + "json.test4": "3", + "json.test": "123", + "json.test2": "456", + "json.test3": "3", + }, + err: nil, + }, + { + name: "bomb", + json: strings.Repeat(`{"a":`, deeplyNestedJSONObject) + "1" + strings.Repeat(`}`, deeplyNestedJSONObject), + want: map[string]string{ + "json." + strings.Repeat(`a.`, deeplyNestedJSONObject-1) + "a": "1", + }, + err: errors.New("max recursion reached while reading json object"), }, { name: "empty_object", @@ -143,18 +183,25 @@ func TestReadJSON(t *testing.T) { for _, tc := range jsonTests { tt := tc t.Run(tt.name, func(t *testing.T) { - jsonMap, err := readJSON(tt.json) - if err != nil { - t.Error(err) - } + jsonMap, err := readJSON(tt.json, maxRecursion) // Special case for nested_empty - just check that the function doesn't error if tt.name == "nested_empty" { + if err != nil { + t.Error(err) + } // Print the keys for debugging t.Logf("Actual keys for nested_empty: %v", mapKeys(jsonMap)) return } + if err != nil { + if tt.err == nil || err.Error() != tt.err.Error() { + t.Error(err) + } + return + } + for k, want := range tt.want { if have, ok := jsonMap[k]; ok { if want != have { @@ -183,11 +230,10 @@ func mapKeys(m map[string]string) []string { } func TestInvalidJSON(t *testing.T) { - _, err := readJSON(`{invalid json`) - if err != nil { - // We expect no error since gjson.Parse doesn't return errors for invalid JSON - // Instead, it returns a Result with Type == Null - t.Error("Expected no error for invalid JSON, got:", err) + _, err := readJSON(`{invalid json`, maxRecursion) + if err == nil { + // We expect an error for invalid JSON since we now validate + t.Error("Expected error for invalid JSON, got nil") } } @@ -196,7 +242,7 @@ func BenchmarkReadJSON(b *testing.B) { tt := tc b.Run(tt.name, func(b *testing.B) { for i := 0; i < b.N; i++ { - _, err := readJSON(tt.json) + _, err := readJSON(tt.json, maxRecursion) if err != nil { b.Error(err) } @@ -204,3 +250,71 @@ func BenchmarkReadJSON(b *testing.B) { }) } } + +// readJSONNoValidation is readJSON without the gjson.Valid pre-check. +// Used only in benchmarks to measure the overhead of validation. +func readJSONNoValidation(s string, maxRecursion int) (map[string]string, error) { + json := gjson.Parse(s) + res := make(map[string]string) + key := []byte("json") + err := readItems(json, key, maxRecursion, res) + return res, err +} + +// BenchmarkValidationOverhead measures the cost of pre-validating JSON with gjson.Valid +// in the context of the full readJSON pipeline (Valid + Parse + readItems). +// gjson.Parse is lazy (~9ns regardless of input size), so the real overhead is +// gjson.Valid vs the readItems traversal that does the actual parsing work. +func BenchmarkValidationOverhead(b *testing.B) { + benchCases := []struct { + name string + json string + }{ + { + name: "small_object", + json: `{"name":"John","age":30}`, + }, + { + name: "medium_object", + json: `{"user":{"name":"John","email":"john@example.com","roles":["admin","user"]},"settings":{"theme":"dark","notifications":true},"metadata":{"created":"2026-01-01","updated":"2026-02-15"}}`, + }, + { + name: "large_array", + json: func() string { + var sb strings.Builder + sb.WriteString("[") + for i := 0; i < 100; i++ { + if i > 0 { + sb.WriteString(",") + } + sb.WriteString(`{"id":` + strings.Repeat("1", 5) + `,"name":"user","active":true}`) + } + sb.WriteString("]") + return sb.String() + }(), + }, + { + name: "nested_10_levels", + json: strings.Repeat(`{"a":`, 10) + "1" + strings.Repeat(`}`, 10), + }, + } + + for _, bc := range benchCases { + b.Run("WithValidation/"+bc.name, func(b *testing.B) { + b.SetBytes(int64(len(bc.json))) + for i := 0; i < b.N; i++ { + if _, err := readJSON(bc.json, maxRecursion); err != nil { + b.Fatal(err) + } + } + }) + b.Run("WithoutValidation/"+bc.name, func(b *testing.B) { + b.SetBytes(int64(len(bc.json))) + for i := 0; i < b.N; i++ { + if _, err := readJSONNoValidation(bc.json, maxRecursion); err != nil { + b.Fatal(err) + } + } + }) + } +} diff --git a/internal/bodyprocessors/multipart.go b/internal/bodyprocessors/multipart.go index ded38b2b1..c3ec5e733 100644 --- a/internal/bodyprocessors/multipart.go +++ b/internal/bodyprocessors/multipart.go @@ -58,6 +58,7 @@ func (mbp *multipartBodyProcessor) ProcessRequest(reader io.Reader, v plugintype filename := originFileName(p) if filename != "" { var size int64 + seenUnexpectedEOF := false if environment.HasAccessToFS { // Only copy file to temp when not running in TinyGo temp, err := os.CreateTemp(storagePath, "crzmp*") @@ -68,16 +69,22 @@ func (mbp *multipartBodyProcessor) ProcessRequest(reader io.Reader, v plugintype defer temp.Close() sz, err := io.Copy(temp, p) if err != nil { - v.MultipartStrictError().(*collections.Single).Set("1") - return err + if !errors.Is(err, io.ErrUnexpectedEOF) { + v.MultipartStrictError().(*collections.Single).Set("1") + return err + } + seenUnexpectedEOF = true } size = sz filesTmpNamesCol.Add("", temp.Name()) } else { sz, err := io.Copy(io.Discard, p) if err != nil { - v.MultipartStrictError().(*collections.Single).Set("1") - return err + if !errors.Is(err, io.ErrUnexpectedEOF) { + v.MultipartStrictError().(*collections.Single).Set("1") + return err + } + seenUnexpectedEOF = true } size = sz } @@ -85,17 +92,26 @@ func (mbp *multipartBodyProcessor) ProcessRequest(reader io.Reader, v plugintype filesCol.Add("", filename) fileSizesCol.SetIndex(filename, 0, fmt.Sprintf("%d", size)) filesNamesCol.Add("", p.FormName()) + filesCombinedSizeCol.(*collections.Single).Set(fmt.Sprintf("%d", totalSize)) + if seenUnexpectedEOF { + break + } } else { // if is a field data, err := io.ReadAll(p) if err != nil { - v.MultipartStrictError().(*collections.Single).Set("1") - return err + if !errors.Is(err, io.ErrUnexpectedEOF) { + v.MultipartStrictError().(*collections.Single).Set("1") + return err + } } totalSize += int64(len(data)) postCol.Add(p.FormName(), string(data)) + filesCombinedSizeCol.(*collections.Single).Set(fmt.Sprintf("%d", totalSize)) + if errors.Is(err, io.ErrUnexpectedEOF) { + break + } } - filesCombinedSizeCol.(*collections.Single).Set(fmt.Sprintf("%d", totalSize)) } return nil } diff --git a/internal/bodyprocessors/multipart_test.go b/internal/bodyprocessors/multipart_test.go index 9b6b36fc8..97b9b63b1 100644 --- a/internal/bodyprocessors/multipart_test.go +++ b/internal/bodyprocessors/multipart_test.go @@ -210,3 +210,123 @@ func TestMultipartUnmatchedBoundary(t *testing.T) { } } } + +func TestIncompleteMultipartPayload(t *testing.T) { + testCases := []struct { + name string + input string + }{ + { + name: "inMiddleOfBoundary", + input: ` +-----------------------------9051914041544843365972754266 +Content-Disposition: form-data; name="text" + +text default +-----------------------------9051914041544843365972754266 +Content-Disposition: form-data; name="file1"; filename="a.txt" +Content-Type: text/plain + +Content of a.txt. + +-----------------------------905191404154484336 +`, + }, + { + name: "inMiddleOfHeader", + input: ` +-----------------------------9051914041544843365972754266 +Content-Disposition: form-data; name="text" + +text default +-----------------------------9051914041544843365972754266 +Content-Disposition: form-data; name="file1"; filename="a.txt" +Content-Type: text/plain + +Content of a.txt. + +-----------------------------9051914041544843365972754266 +Content-Disposition: form-data; name="fil`, + }, + { + name: "inMiddleOfContent", + input: ` +-----------------------------9051914041544843365972754266 +Content-Disposition: form-data; name="text" + +text default +-----------------------------9051914041544843365972754266 +Content-Disposition: form-data; name="file1"; filename="a.txt" +Content-Type: text/plain + +Content of a.txt. + +-----------------------------9051914041544843365972754266 +Content-Disposition: form-data; name="file2"; filename="a.html" +Content-Type: text/html + +Content of `, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + payload := strings.TrimSpace(tc.input) + + mp := multipartProcessor(t) + + v := corazawaf.NewTransactionVariables() + if err := mp.ProcessRequest(strings.NewReader(payload), v, plugintypes.BodyProcessorOptions{ + Mime: "multipart/form-data; boundary=---------------------------9051914041544843365972754266", + }); err != nil { + t.Fatal(err) + } + // first we validate we got the headers + headers := v.MultipartPartHeaders() + header1 := "Content-Disposition: form-data; name=\"file1\"; filename=\"a.txt\"" + header2 := "Content-Type: text/plain" + if h := headers.Get("file1"); len(h) == 0 { + t.Fatal("expected headers for file1") + } else { + if len(h) != 2 { + t.Fatal("expected 2 headers for file1") + } + if (h[0] != header1 && h[0] != header2) || (h[1] != header1 && h[1] != header2) { + t.Fatalf("Got invalid multipart headers") + } + } + + // Verify form field data was correctly processed before the incomplete part + argsPost := v.ArgsPost() + if textValues := argsPost.Get("text"); len(textValues) == 0 { + t.Fatal("expected ArgsPost to contain 'text' field") + } else if textValues[0] != "text default" { + t.Fatalf("expected ArgsPost 'text' to be 'text default', got %q", textValues[0]) + } + }) + } +} + +func TestIncompleteMultipartPayloadInFormField(t *testing.T) { + payload := strings.TrimSpace(` +-----------------------------9051914041544843365972754266 +Content-Disposition: form-data; name="text" + +text defa`) + + mp := multipartProcessor(t) + + v := corazawaf.NewTransactionVariables() + if err := mp.ProcessRequest(strings.NewReader(payload), v, plugintypes.BodyProcessorOptions{ + Mime: "multipart/form-data; boundary=---------------------------9051914041544843365972754266", + }); err != nil { + t.Fatal(err) + } + + // Verify the partial form field data was processed + argsPost := v.ArgsPost() + if textValues := argsPost.Get("text"); len(textValues) == 0 { + t.Fatal("expected ArgsPost to contain 'text' field") + } else if textValues[0] != "text defa" { + t.Fatalf("expected ArgsPost 'text' to be 'text defa', got %q", textValues[0]) + } +} diff --git a/internal/collections/map.go b/internal/collections/map.go index ba6adf35e..7ca642809 100644 --- a/internal/collections/map.go +++ b/internal/collections/map.go @@ -60,16 +60,30 @@ func (c *Map) Get(key string) []string { // FindRegex returns all map elements whose key matches the regular expression. func (c *Map) FindRegex(key *regexp.Regexp) []types.MatchData { - var result []types.MatchData + n := 0 + // Collect matching data slices in a single pass to avoid evaluating the regex twice per key. + var matched [][]keyValue for k, data := range c.data { if key.MatchString(k) { - for _, d := range data { - result = append(result, &corazarules.MatchData{ - Variable_: c.variable, - Key_: d.key, - Value_: d.value, - }) + n += len(data) + matched = append(matched, data) + } + } + if n == 0 { + return nil + } + buf := make([]corazarules.MatchData, n) + result := make([]types.MatchData, n) + i := 0 + for _, data := range matched { + for _, d := range data { + buf[i] = corazarules.MatchData{ + Variable_: c.variable, + Key_: d.key, + Value_: d.value, } + result[i] = &buf[i] + i++ } } return result @@ -77,7 +91,6 @@ func (c *Map) FindRegex(key *regexp.Regexp) []types.MatchData { // FindString returns all map elements whose key matches the string. func (c *Map) FindString(key string) []types.MatchData { - var result []types.MatchData if key == "" { return c.FindAll() } @@ -87,29 +100,44 @@ func (c *Map) FindString(key string) []types.MatchData { if !c.isCaseSensitive { key = strings.ToLower(key) } - // if key is not empty - if e, ok := c.data[key]; ok { - for _, aVar := range e { - result = append(result, &corazarules.MatchData{ - Variable_: c.variable, - Key_: aVar.key, - Value_: aVar.value, - }) + e, ok := c.data[key] + if !ok || len(e) == 0 { + return nil + } + buf := make([]corazarules.MatchData, len(e)) + result := make([]types.MatchData, len(e)) + for i, aVar := range e { + buf[i] = corazarules.MatchData{ + Variable_: c.variable, + Key_: aVar.key, + Value_: aVar.value, } + result[i] = &buf[i] } return result } // FindAll returns all map elements. func (c *Map) FindAll() []types.MatchData { - var result []types.MatchData + n := 0 + for _, data := range c.data { + n += len(data) + } + if n == 0 { + return nil + } + buf := make([]corazarules.MatchData, n) + result := make([]types.MatchData, n) + i := 0 for _, data := range c.data { for _, d := range data { - result = append(result, &corazarules.MatchData{ + buf[i] = corazarules.MatchData{ Variable_: c.variable, Key_: d.key, Value_: d.value, - }) + } + result[i] = &buf[i] + i++ } } return result diff --git a/internal/collections/map_test.go b/internal/collections/map_test.go index 7c47d29c5..60dc0c68f 100644 --- a/internal/collections/map_test.go +++ b/internal/collections/map_test.go @@ -107,6 +107,120 @@ func TestNewCaseSensitiveKeyMap(t *testing.T) { } +func TestFindAllBulkAllocIndependence(t *testing.T) { + m := NewMap(variables.ArgsGet) + m.Add("key1", "value1") + m.Add("key2", "value2") + m.Add("key3", "value3") + + results := m.FindAll() + if len(results) != 3 { + t.Fatalf("expected 3 results, got %d", len(results)) + } + + // Mutate first result's value through the MatchData interface + // and verify others are not affected + values := make([]string, len(results)) + for i, r := range results { + values[i] = r.Value() + } + + // Verify all values are distinct and correct + seen := map[string]bool{} + for _, v := range values { + if seen[v] { + t.Errorf("duplicate value found: %s", v) + } + seen[v] = true + } + if !seen["value1"] || !seen["value2"] || !seen["value3"] { + t.Errorf("expected value1, value2, value3 but got %v", values) + } +} + +func TestFindStringBulkAlloc(t *testing.T) { + m := NewMap(variables.ArgsGet) + m.Add("key", "val1") + m.Add("key", "val2") + + results := m.FindString("key") + if len(results) != 2 { + t.Fatalf("expected 2 results, got %d", len(results)) + } + + // Each result should have distinct values + if results[0].Value() == results[1].Value() { + t.Errorf("expected distinct values, got %q and %q", results[0].Value(), results[1].Value()) + } +} + +func TestFindRegexBulkAlloc(t *testing.T) { + m := NewMap(variables.ArgsGet) + m.Add("abc", "val1") + m.Add("abd", "val2") + m.Add("xyz", "val3") + + re := regexp.MustCompile("^ab") + results := m.FindRegex(re) + if len(results) != 2 { + t.Fatalf("expected 2 results, got %d", len(results)) + } + + // Verify keys match regex + for _, r := range results { + if r.Key() != "abc" && r.Key() != "abd" { + t.Errorf("unexpected key: %s", r.Key()) + } + } +} + +func TestFindAllEmptyMap(t *testing.T) { + m := NewMap(variables.ArgsGet) + results := m.FindAll() + if results != nil { + t.Errorf("expected nil for empty map, got %v", results) + } +} + +func BenchmarkFindAll(b *testing.B) { + b.ReportAllocs() + m := NewMap(variables.RequestHeaders) + for i := 0; i < 20; i++ { + m.Add(fmt.Sprintf("x-header-%d", i), fmt.Sprintf("value-%d", i)) + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = m.FindAll() + } +} + +func BenchmarkFindRegex(b *testing.B) { + b.ReportAllocs() + m := NewMap(variables.RequestHeaders) + for i := 0; i < 20; i++ { + m.Add(fmt.Sprintf("x-header-%d", i), fmt.Sprintf("value-%d", i)) + } + // Matches keys ending in 0-9 (x-header-0 .. x-header-9), roughly half. + re := regexp.MustCompile(`^x-header-\d$`) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = m.FindRegex(re) + } +} + +func BenchmarkFindString(b *testing.B) { + b.ReportAllocs() + m := NewMap(variables.RequestHeaders) + // Single key with multiple values + for i := 0; i < 20; i++ { + m.Add("x-forwarded-for", fmt.Sprintf("10.0.0.%d", i)) + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = m.FindString("x-forwarded-for") + } +} + func BenchmarkTxSetGet(b *testing.B) { keys := make(map[int]string, b.N) for i := 0; i < b.N; i++ { diff --git a/internal/collections/named.go b/internal/collections/named.go index c3b1a0033..6d6c922f4 100644 --- a/internal/collections/named.go +++ b/internal/collections/named.go @@ -101,37 +101,49 @@ type NamedCollectionNames struct { } func (c *NamedCollectionNames) FindRegex(key *regexp.Regexp) []types.MatchData { - var res []types.MatchData - + n := 0 + // Collect matching data slices in a single pass to avoid evaluating the regex twice per key. + var matched [][]keyValue for k, data := range c.collection.data { - if !key.MatchString(k) { - continue + if key.MatchString(k) { + n += len(data) + matched = append(matched, data) } + } + if n == 0 { + return nil + } + buf := make([]corazarules.MatchData, n) + res := make([]types.MatchData, n) + i := 0 + for _, data := range matched { for _, d := range data { - res = append(res, &corazarules.MatchData{ + buf[i] = corazarules.MatchData{ Variable_: c.variable, Key_: d.key, Value_: d.key, - }) + } + res[i] = &buf[i] + i++ } } return res } func (c *NamedCollectionNames) FindString(key string) []types.MatchData { - var res []types.MatchData - - for k, data := range c.collection.data { - if k != key { - continue - } - for _, d := range data { - res = append(res, &corazarules.MatchData{ - Variable_: c.variable, - Key_: d.key, - Value_: d.key, - }) + data, ok := c.collection.data[key] + if !ok || len(data) == 0 { + return nil + } + buf := make([]corazarules.MatchData, len(data)) + res := make([]types.MatchData, len(data)) + for i, d := range data { + buf[i] = corazarules.MatchData{ + Variable_: c.variable, + Key_: d.key, + Value_: d.key, } + res[i] = &buf[i] } return res } @@ -141,16 +153,25 @@ func (c *NamedCollectionNames) Get(key string) []string { } func (c *NamedCollectionNames) FindAll() []types.MatchData { - var res []types.MatchData - // Iterates over all the data in the map and adds the key element also to the Key field (The key value may be the value - // that is matched, but it is still also the key of the pair and it is needed to print the matched var name) + n := 0 + for _, data := range c.collection.data { + n += len(data) + } + if n == 0 { + return nil + } + buf := make([]corazarules.MatchData, n) + res := make([]types.MatchData, n) + i := 0 for _, data := range c.collection.data { for _, d := range data { - res = append(res, &corazarules.MatchData{ + buf[i] = corazarules.MatchData{ Variable_: c.variable, Key_: d.key, Value_: d.key, - }) + } + res[i] = &buf[i] + i++ } } return res diff --git a/internal/corazawaf/rule.go b/internal/corazawaf/rule.go index c7b047707..d2e3b5a12 100644 --- a/internal/corazawaf/rule.go +++ b/internal/corazawaf/rule.go @@ -14,7 +14,7 @@ import ( "github.com/corazawaf/coraza/v3/experimental/plugins/macro" "github.com/corazawaf/coraza/v3/experimental/plugins/plugintypes" "github.com/corazawaf/coraza/v3/internal/corazarules" - "github.com/corazawaf/coraza/v3/internal/memoize" + utils "github.com/corazawaf/coraza/v3/internal/strings" "github.com/corazawaf/coraza/v3/types" "github.com/corazawaf/coraza/v3/types/variables" ) @@ -97,6 +97,11 @@ type Rule struct { transformations []ruleTransformationParams transformationsID int + // transformationPrefixIDs holds the chain ID at each step of the transformation chain. + // transformationPrefixIDs[i] is the ID representing transformations [0..i]. + // This enables prefix-based caching: rules sharing a common transformation prefix + // can reuse cached intermediate results instead of recomputing from scratch. + transformationPrefixIDs []int // Slice of initialized actions to be evaluated during // the rule evaluation process @@ -146,6 +151,8 @@ type Rule struct { // chainedRules containing rules with just PhaseUnknown variables, may potentially // be anticipated. This boolean ensures that it happens withPhaseUnknownVariable bool + + memoizer plugintypes.Memoizer } func (r *Rule) ParentID() int { @@ -161,7 +168,7 @@ const chainLevelZero = 0 // Evaluate will evaluate the current rule for the indicated transaction // If the operator matches, actions will be evaluated, and it will return // the matched variables, keys and values (MatchData) -func (r *Rule) Evaluate(phase types.RulePhase, tx plugintypes.TransactionState, cache map[transformationKey]*transformationValue) { +func (r *Rule) Evaluate(phase types.RulePhase, tx plugintypes.TransactionState, cache map[transformationKey]transformationValue) { // collectiveMatchedValues lives across recursive calls of doEvaluate var collectiveMatchedValues []types.MatchData @@ -180,7 +187,7 @@ func (r *Rule) Evaluate(phase types.RulePhase, tx plugintypes.TransactionState, const noID = 0 -func (r *Rule) doEvaluate(logger debuglog.Logger, phase types.RulePhase, tx *Transaction, collectiveMatchedValues *[]types.MatchData, chainLevel int, cache map[transformationKey]*transformationValue) []types.MatchData { +func (r *Rule) doEvaluate(logger debuglog.Logger, phase types.RulePhase, tx *Transaction, collectiveMatchedValues *[]types.MatchData, chainLevel int, cache map[transformationKey]transformationValue) []types.MatchData { tx.Capture = r.Capture if multiphaseEvaluation { @@ -230,7 +237,7 @@ func (r *Rule) doEvaluate(logger debuglog.Logger, phase types.RulePhase, tx *Tra for _, c := range ecol { if c.Variable == v.Variable { // TODO shall we check the pointer? - v.Exceptions = append(v.Exceptions, ruleVariableException{c.KeyStr, nil}) + v.Exceptions = append(v.Exceptions, ruleVariableException{c.KeyStr, c.KeyRx}) } } @@ -372,15 +379,22 @@ func (r *Rule) doEvaluate(logger debuglog.Logger, phase types.RulePhase, tx *Tra } for _, a := range r.actions { - if a.Function.Type() == plugintypes.ActionTypeFlow { - // Flow actions are evaluated also if the rule engine is set to DetectionOnly + // All actions are evaluated independently from the engine being On or in DetectionOnly. + // The action evaluation is responsible of checking the engine mode and decide if the disruptive action + // has to be enforced or not. This allows finer control to the actions, such us creating the detectionOnlyInterruption and + // allowing RelevantOnly audit logs in detection only mode. + switch a.Function.Type() { + case plugintypes.ActionTypeFlow: logger.Debug().Str("action", a.Name).Int("phase", int(phase)).Msg("Evaluating flow action for rule") - a.Function.Evaluate(r, tx) - } else if a.Function.Type() == plugintypes.ActionTypeDisruptive && tx.RuleEngine == types.RuleEngineOn { + case plugintypes.ActionTypeDisruptive: // The parser enforces that the disruptive action is just one per rule (if more than one, only the last one is kept) - logger.Debug().Str("action", a.Name).Msg("Executing disruptive action for rule") - a.Function.Evaluate(r, tx) + logger.Debug().Str("action", a.Name).Int("phase", int(phase)).Msg("Executing disruptive action for rule") + default: + // Only flow and disruptive actions are supposed to be evaluated here, non disruptive actions + // are evaluated previously, during the variable matching. + continue } + a.Function.Evaluate(r, tx) } if r.ID_ != noID { // we avoid matching chains and secmarkers @@ -397,7 +411,7 @@ func (r *Rule) transformMultiMatchArg(arg types.MatchData) ([]string, []error) { return r.executeTransformationsMultimatch(arg.Value()) } -func (r *Rule) transformArg(arg types.MatchData, argIdx int, cache map[transformationKey]*transformationValue) (string, []error) { +func (r *Rule) transformArg(arg types.MatchData, argIdx int, cache map[transformationKey]transformationValue) (string, []error) { switch { case len(r.transformations) == 0: return arg.Value(), nil @@ -409,23 +423,53 @@ func (r *Rule) transformArg(arg types.MatchData, argIdx int, cache map[transform // NOTE: See comment on transformationKey struct to understand this hacky code argKey := arg.Key() argKeyPtr := unsafe.StringData(argKey) - key := transformationKey{ - argKey: argKeyPtr, - argIndex: argIdx, - argVariable: arg.Variable(), - transformationsID: r.transformationsID, + + // Search from longest prefix (full chain) backwards for a cache hit. + // Best case: full chain cached → single map lookup, done. + // Typical case: shared prefix cached → start computing from there. + startIdx := 0 + value := arg.Value() + var errs []error + + for i := len(r.transformationPrefixIDs) - 1; i >= 0; i-- { + key := transformationKey{ + argKey: argKeyPtr, + argIndex: argIdx, + argVariable: arg.Variable(), + transformationsID: r.transformationPrefixIDs[i], + } + if cached, ok := cache[key]; ok { + if i == len(r.transformationPrefixIDs)-1 { + // Full chain cached — nothing more to compute + return cached.arg, cached.errs + } + value = cached.arg + errs = cached.errs + startIdx = i + 1 + break + } } - if cached, ok := cache[key]; ok { - return cached.arg, cached.errs - } else { - ars, es := r.executeTransformations(arg.Value()) - errs := es - cache[key] = &transformationValue{ - arg: ars, - errs: es, + + // Execute remaining transformations, caching each intermediate step + // so later rules sharing a prefix can reuse our work. + for i := startIdx; i < len(r.transformations); i++ { + v, _, err := r.transformations[i].Function(value) + if err != nil { + errs = append(errs, err) + } else { + value = v } - return ars, errs + + key := transformationKey{ + argKey: argKeyPtr, + argIndex: argIdx, + argVariable: arg.Variable(), + transformationsID: r.transformationPrefixIDs[i], + } + cache[key] = transformationValue{arg: value, errs: errs} } + + return value, errs } } @@ -467,14 +511,28 @@ func (r *Rule) AddAction(name string, action plugintypes.Action) error { return nil } +// ClearDisruptiveActions removes all disruptive actions from the rule. +// +// This is used by directives like SecRuleUpdateActionById to clear existing +// disruptive actions before applying new ones, matching ModSecurity behavior +// where updating with a disruptive action replaces the previous one. +func (r *Rule) ClearDisruptiveActions() { + actionType := plugintypes.ActionTypeDisruptive + filtered := make([]ruleActionParams, 0, len(r.actions)) + for _, action := range r.actions { + if action.Function.Type() != actionType { + filtered = append(filtered, action) + } + } + r.actions = filtered +} + // hasRegex checks the received key to see if it is between forward slashes. // if it is, it will return true and the content of the regular expression inside the slashes. // otherwise it will return false and the same key. +// Delegates to utils.HasRegex which properly handles escaped slashes. func hasRegex(key string) (bool, string) { - if len(key) > 2 && key[0] == '/' && key[len(key)-1] == '/' { - return true, key[1 : len(key)-1] - } - return false, key + return utils.HasRegex(key) } // caseSensitiveVariable returns true if the variable is case sensitive @@ -515,7 +573,10 @@ func (r *Rule) AddVariable(v variables.RuleVariable, key string, iscount bool) e } var re *regexp.Regexp if isRegex, rx := hasRegex(key); isRegex { - if vare, err := memoize.Do(rx, func() (any, error) { return regexp.Compile(rx) }); err != nil { + if !caseSensitiveVariable(v) { + rx = strings.ToLower(rx) + } + if vare, err := r.memoizeDo(rx, func() (any, error) { return regexp.Compile(rx) }); err != nil { return err } else { re = vare.(*regexp.Regexp) @@ -559,7 +620,10 @@ func needToSplitConcatenatedVariable(v variables.RuleVariable, ve variables.Rule func (r *Rule) AddVariableNegation(v variables.RuleVariable, key string) error { var re *regexp.Regexp if isRegex, rx := hasRegex(key); isRegex { - if vare, err := memoize.Do(rx, func() (any, error) { return regexp.Compile(rx) }); err != nil { + if !caseSensitiveVariable(v) { + rx = strings.ToLower(rx) + } + if vare, err := r.memoizeDo(rx, func() (any, error) { return regexp.Compile(rx) }); err != nil { return err } else { re = vare.(*regexp.Regexp) @@ -613,6 +677,7 @@ func (r *Rule) AddTransformation(name string, t plugintypes.Transformation) erro } r.transformations = append(r.transformations, ruleTransformationParams{Function: t}) r.transformationsID = transformationID(r.transformationsID, name) + r.transformationPrefixIDs = append(r.transformationPrefixIDs, r.transformationsID) return nil } @@ -620,6 +685,8 @@ func (r *Rule) AddTransformation(name string, t plugintypes.Transformation) erro // it is mostly used by the "none" transformation func (r *Rule) ClearTransformations() { r.transformations = []ruleTransformationParams{} + r.transformationsID = 0 + r.transformationPrefixIDs = nil } // SetOperator sets the operator of the rule @@ -674,6 +741,23 @@ func (r *Rule) executeTransformations(value string) (string, []error) { return value, errs } +// SetMemoizer sets the memoizer used for caching compiled regexes in variable selectors. +func (r *Rule) SetMemoizer(m plugintypes.Memoizer) { + r.memoizer = m +} + +// Memoizer returns the memoizer used for caching compiled regexes in variable selectors. +func (r *Rule) Memoizer() plugintypes.Memoizer { + return r.memoizer +} + +func (r *Rule) memoizeDo(key string, fn func() (any, error)) (any, error) { + if r.memoizer != nil { + return r.memoizer.Do(key, fn) + } + return fn() +} + // NewRule returns a new initialized rule // By default, the rule is set to phase 2 func NewRule() *Rule { diff --git a/internal/corazawaf/rule_test.go b/internal/corazawaf/rule_test.go index d999bb0c8..b397854f3 100644 --- a/internal/corazawaf/rule_test.go +++ b/internal/corazawaf/rule_test.go @@ -101,7 +101,7 @@ func TestNoMatchEvaluateBecauseOfException(t *testing.T) { _ = r.AddAction("dummyDeny", action) tx := NewWAF().NewTransaction() tx.AddGetRequestArgument("test", "0") - tx.RemoveRuleTargetByID(1, tc.variable, "test") + tx.RemoveRuleTargetByID(1, tc.variable, "test", nil) var matchedValues []types.MatchData matchdata := r.doEvaluate(debuglog.Noop(), types.PhaseRequestHeaders, tx, &matchedValues, 0, tx.transformationCache) if len(matchdata) != 0 { @@ -114,6 +114,52 @@ func TestNoMatchEvaluateBecauseOfException(t *testing.T) { } } +func TestNoMatchEvaluateBecauseOfWholeCollectionException(t *testing.T) { + testCases := []struct { + name string + variable variables.RuleVariable + }{ + { + name: "Test ArgsGet whole collection exception", + variable: variables.ArgsGet, + }, + { + name: "Test Args whole collection exception", + variable: variables.Args, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + r := NewRule() + r.Msg, _ = macro.NewMacro("Message") + r.LogData, _ = macro.NewMacro("Data Message") + r.ID_ = 1 + r.LogID_ = "1" + if err := r.AddVariable(tc.variable, "", false); err != nil { + t.Error(err) + } + dummyEqOp := &dummyEqOperator{} + r.SetOperator(dummyEqOp, "@eq", "0") + action := &dummyDenyAction{} + _ = r.AddAction("dummyDeny", action) + tx := NewWAF().NewTransaction() + tx.AddGetRequestArgument("test", "0") + tx.AddGetRequestArgument("other", "0") + // Remove with empty key should exclude the entire collection + tx.RemoveRuleTargetByID(1, tc.variable, "", nil) + var matchedValues []types.MatchData + matchdata := r.doEvaluate(debuglog.Noop(), types.PhaseRequestHeaders, tx, &matchedValues, 0, tx.transformationCache) + if len(matchdata) != 0 { + t.Errorf("Expected 0 matchdata when whole collection is excluded, got %d", len(matchdata)) + } + if tx.interruption != nil { + t.Errorf("Expected no interruption because whole collection is excluded") + } + }) + } +} + type dummyFlowAction struct{} func (*dummyFlowAction) Init(_ plugintypes.RuleMetadata, _ string) error { @@ -297,6 +343,39 @@ func TestVariablesRxAreCaseSensitive(t *testing.T) { } } +func TestVariablesRxAreLowercasedForCaseInsensitiveCollections(t *testing.T) { + rule := NewRule() + if err := rule.AddVariable(variables.TX, "/MULTIPART_HEADERS_CONTENT_TYPES_.*/", false); err != nil { + t.Error(err) + } + if rule.variables[0].KeyRx == nil { + t.Fatal("expected regex to be set") + } + if rule.variables[0].KeyRx.String() != "multipart_headers_content_types_.*" { + t.Errorf("expected lowercased regex for case-insensitive variable TX, got %q", rule.variables[0].KeyRx.String()) + } +} + +func TestVariableNegationRxLowercasedForCaseInsensitiveCollections(t *testing.T) { + rule := NewRule() + if err := rule.AddVariable(variables.TX, "", false); err != nil { + t.Error(err) + } + if err := rule.AddVariableNegation(variables.TX, "/SOME_PATTERN.*/"); err != nil { + t.Error(err) + } + if len(rule.variables[0].Exceptions) != 1 { + t.Fatal("expected 1 exception") + } + ex := rule.variables[0].Exceptions[0] + if ex.KeyRx == nil { + t.Fatal("expected regex exception to be set") + } + if ex.KeyRx.String() != "some_pattern.*" { + t.Errorf("expected lowercased regex for case-insensitive variable TX exception, got %q", ex.KeyRx.String()) + } +} + func TestInferredPhase(t *testing.T) { var b inferredPhases @@ -495,7 +574,7 @@ func TestExecuteTransformationsMultiMatchReturnsMultipleErrors(t *testing.T) { } func TestTransformArgSimple(t *testing.T) { - transformationCache := map[transformationKey]*transformationValue{} + transformationCache := map[transformationKey]transformationValue{} md := &corazarules.MatchData{ Variable_: variables.RequestURI, Key_: "REQUEST_URI", @@ -513,10 +592,11 @@ func TestTransformArgSimple(t *testing.T) { if arg != "/testAB" { t.Errorf("Expected \"/testAB\", got \"%s\"", arg) } - if len(transformationCache) != 1 { - t.Errorf("Expected 1 transformations in cache, got %d", len(transformationCache)) + // Prefix caching stores an entry for each transformation step + if len(transformationCache) != 2 { + t.Errorf("Expected 2 transformations in cache (one per step), got %d", len(transformationCache)) } - // Repeating the same transformation, expecting still one element in the cache (that means it is a cache hit) + // Repeating the same transformation, expecting still two elements in the cache (cache hit) arg, errs = rule.transformArg(md, 0, transformationCache) if errs != nil { t.Fatalf("Unexpected errors executing transformations: %v", errs) @@ -524,13 +604,13 @@ func TestTransformArgSimple(t *testing.T) { if arg != "/testAB" { t.Errorf("Expected \"/testAB\", got \"%s\"", arg) } - if len(transformationCache) != 1 { - t.Errorf("Expected 1 transformations in cache, got %d", len(transformationCache)) + if len(transformationCache) != 2 { + t.Errorf("Expected 2 transformations in cache, got %d", len(transformationCache)) } } func TestTransformArgNoCacheForTXVariable(t *testing.T) { - transformationCache := map[transformationKey]*transformationValue{} + transformationCache := map[transformationKey]transformationValue{} md := &corazarules.MatchData{ Variable_: variables.TX, Key_: "Custom_TX_Variable", @@ -550,6 +630,84 @@ func TestTransformArgNoCacheForTXVariable(t *testing.T) { } } +func TestTransformArgPrefixSharing(t *testing.T) { + transformationCache := map[transformationKey]transformationValue{} + md := &corazarules.MatchData{ + Variable_: variables.RequestURI, + Key_: "REQUEST_URI", + Value_: "/test", + } + + // Rule 1 has chain: AppendA + rule1 := NewRule() + _ = rule1.AddTransformation("AppendA", transformationAppendA) + + // Rule 2 has chain: AppendA, AppendB (shares prefix with rule1) + rule2 := NewRule() + _ = rule2.AddTransformation("AppendA", transformationAppendA) + _ = rule2.AddTransformation("AppendB", transformationAppendB) + + // Evaluate rule1 first — caches the AppendA intermediate + arg1, errs := rule1.transformArg(md, 0, transformationCache) + if errs != nil { + t.Fatalf("Unexpected errors: %v", errs) + } + if arg1 != "/testA" { + t.Errorf("Expected \"/testA\", got %q", arg1) + } + if len(transformationCache) != 1 { + t.Errorf("Expected 1 cache entry after rule1, got %d", len(transformationCache)) + } + + // Evaluate rule2 — should reuse the cached AppendA result and only compute AppendB + arg2, errs := rule2.transformArg(md, 0, transformationCache) + if errs != nil { + t.Fatalf("Unexpected errors: %v", errs) + } + if arg2 != "/testAB" { + t.Errorf("Expected \"/testAB\", got %q", arg2) + } + // Should now have 2 entries: AppendA (shared) and AppendA+AppendB + if len(transformationCache) != 2 { + t.Errorf("Expected 2 cache entries after rule2 (prefix reuse), got %d", len(transformationCache)) + } +} + +func TestClearTransformationsResetsID(t *testing.T) { + transformationCache := map[transformationKey]transformationValue{} + md := &corazarules.MatchData{ + Variable_: variables.RequestURI, + Key_: "REQUEST_URI", + Value_: "test", + } + + // Rule A: t:AppendA, t:none, t:AppendB — effective chain is only AppendB + ruleA := NewRule() + _ = ruleA.AddTransformation("AppendA", transformationAppendA) + ruleA.ClearTransformations() + _ = ruleA.AddTransformation("AppendB", transformationAppendB) + + // Rule B: t:AppendA, t:AppendB — effective chain is AppendA then AppendB + ruleB := NewRule() + _ = ruleB.AddTransformation("AppendA", transformationAppendA) + _ = ruleB.AddTransformation("AppendB", transformationAppendB) + + argA, _ := ruleA.transformArg(md, 0, transformationCache) + argB, _ := ruleB.transformArg(md, 0, transformationCache) + + if argA != "testB" { + t.Errorf("Rule A (t:none resets): expected \"testB\", got %q", argA) + } + if argB != "testAB" { + t.Errorf("Rule B (t:AppendA,t:AppendB): expected \"testAB\", got %q", argB) + } + // They must produce different results — if ClearTransformations didn't reset the ID, + // they'd collide in the cache and one would get the wrong result. + if argA == argB { + t.Error("Rule A and Rule B produced the same result — ClearTransformations likely didn't reset transformationsID") + } +} + func TestCaptureNotPropagatedToInnerChainRule(t *testing.T) { r := NewRule() r.ID_ = 1 diff --git a/internal/corazawaf/rulegroup.go b/internal/corazawaf/rulegroup.go index 0682dc5d1..0924aa2f5 100644 --- a/internal/corazawaf/rulegroup.go +++ b/internal/corazawaf/rulegroup.go @@ -18,7 +18,8 @@ import ( // It is not concurrent safe, so it's not recommended to use it // after compilation type RuleGroup struct { - rules []Rule + rules []Rule + observer func(rule types.RuleMetadata) } // Add a rule to the collection @@ -61,9 +62,19 @@ func (rg *RuleGroup) Add(rule *Rule) error { } rg.rules = append(rg.rules, *rule) + + if rg.observer != nil { + rg.observer(rule) + } + return nil } +// SetObserver assigns the observer function to the group. +func (rg *RuleGroup) SetObserver(observer func(rule types.RuleMetadata)) { + rg.observer = observer +} + // GetRules returns the slice of rules, func (rg *RuleGroup) GetRules() []Rule { return rg.rules @@ -148,7 +159,7 @@ RulesLoop: r := &rg.rules[i] // if there is already an interruption and the phase isn't logging // we break the loop - if tx.interruption != nil && phase != types.PhaseLogging { + if tx.IsInterrupted() && phase != types.PhaseLogging { break RulesLoop } // Rules with phase 0 will always run @@ -168,12 +179,17 @@ RulesLoop: } // we skip the rule in case it's in the excluded list - for _, trb := range tx.ruleRemoveByID { - if trb == r.ID_ { + if _, skip := tx.ruleRemoveByID[r.ID_]; skip { + tx.DebugLogger().Debug(). + Int("rule_id", r.ID_). + Msg("Skipping rule") + continue RulesLoop + } + for _, rng := range tx.ruleRemoveByIDRanges { + if r.ID_ >= rng[0] && r.ID_ <= rng[1] { tx.DebugLogger().Debug(). Int("rule_id", r.ID_). Msg("Skipping rule") - continue RulesLoop } } @@ -247,7 +263,7 @@ RulesLoop: tx.Skip = 0 tx.stopWatches[phase] = time.Now().UnixNano() - ts - return tx.interruption != nil + return tx.IsInterrupted() } // NewRuleGroup creates an empty RuleGroup that diff --git a/internal/corazawaf/transaction.go b/internal/corazawaf/transaction.go index 730a850b1..f51052176 100644 --- a/internal/corazawaf/transaction.go +++ b/internal/corazawaf/transaction.go @@ -14,6 +14,7 @@ import ( "net/url" "os" "path/filepath" + "regexp" "strconv" "strings" "time" @@ -54,6 +55,12 @@ type Transaction struct { // True if the transaction has been disrupted by any rule interruption *types.Interruption + // detectionOnlyInterruption keeps track of the interruption that would have been performed if the engine was On. + // It provides visibility of what would have happened in On mode when the engine is set to "DetectionOnly" + // and is used to correctly emit relevant only audit logs in DetectionOnly mode (When the rules would have + // caused an interruption if the engine was On). + detectionOnlyInterruption *types.Interruption + // This is used to store log messages // Deprecated since Coraza 3.0.5: this variable is not used, logdata values are stored in the matched rules Logdata string @@ -62,6 +69,9 @@ type Transaction struct { SkipAfter string // AllowType is used by the allow disruptive action to skip evaluating rules after being allowed + // Note: Rely on tx.Allow(allowType) for tweaking this field. This field is exposed for backwards + // compatibility, but it is not recommended to be used directly. + // TODO(4.x): Evaluate to make it private AllowType corazatypes.AllowType // Copies from the WAF instance that may be overwritten by the ctl action @@ -89,7 +99,11 @@ type Transaction struct { responseBodyBuffer *BodyBuffer // Rules with this id are going to be skipped while processing a phase - ruleRemoveByID []int + ruleRemoveByID map[int]struct{} + + // ruleRemoveByIDRanges stores ranges of rule IDs to be skipped during a phase. + // Ranges avoid expanding all IDs into the ruleRemoveByID map. + ruleRemoveByIDRanges [][2]int // ruleRemoveTargetByID is used by ctl to remove rule targets by id during the // transaction. All other "target removers" like "ByTag" are an abstraction of "ById" @@ -122,7 +136,7 @@ type Transaction struct { variables TransactionVariables - transformationCache map[transformationKey]*transformationValue + transformationCache map[transformationKey]transformationValue } func (tx *Transaction) ID() string { @@ -305,9 +319,36 @@ func (tx *Transaction) Collection(idx variables.RuleVariable) collection.Collect return collections.Noop } +// Interrupt sets the interruption for the transaction. +// It complies with DetectionOnly definition which requires that disruptive actions are not executed. +// Depending on the RuleEngine mode: +// If On: it immediately interrupts the transaction and generates a response. +// If DetectionOnly: it keeps track of what the interruption would have been if the engine was "On", +// allowing consistent logging and visibility of potential disruptions without actually interrupting the transaction. func (tx *Transaction) Interrupt(interruption *types.Interruption) { - if tx.RuleEngine == types.RuleEngineOn { + switch tx.RuleEngine { + case types.RuleEngineOn: tx.interruption = interruption + case types.RuleEngineDetectionOnly: + // In DetectionOnly mode, the interruption is not actually triggered, which means that + // further rules will continue to be evaluated and more actions can be executed. + // Let's keep only the first interruption here, matching the one that would have been triggered + // if the engine was on. + if tx.detectionOnlyInterruption == nil { + tx.detectionOnlyInterruption = interruption + } + } +} + +// Allow sets the allow type for the transaction. +// It complies with DetectionOnly definition which requires not executing disruptive actions. +// Depending on the RuleEngine mode: +// If On: it will cause the transaction to skip rules according to the allow type (phase, request, all). +// If DetectionOnly: allow is not enforced. +// TODO(4.x): evaluate to expose it in the interface. +func (tx *Transaction) Allow(allowType corazatypes.AllowType) { + if tx.RuleEngine == types.RuleEngineOn { + tx.AllowType = allowType } } @@ -530,19 +571,19 @@ func (tx *Transaction) MatchRule(r *Rule, mds []types.MatchData) { MatchedDatas_: mds, Context_: tx.context, } - // Populate MatchedRule disruption related fields only if the Engine is capable of performing disruptive actions - if tx.RuleEngine == types.RuleEngineOn { - var exists bool - for _, a := range r.actions { - // There can be only at most one disruptive action per rule - if a.Function.Type() == plugintypes.ActionTypeDisruptive { - mr.DisruptiveAction_, exists = corazarules.DisruptiveActionMap[a.Name] - if !exists { - mr.DisruptiveAction_ = corazarules.DisruptiveActionUnknown - } - mr.Disruptive_ = true - break - } + + // Starting from Coraza 3.4, MatchedRule are including the disruptive action (DisruptiveAction_) + // also in DetectionOnly mode. This improves visibility of what would have happened if the engine was on. + // The Disruptive_ boolean still allows to identify actual disruptions from "potential" disruptions. + // Disruptive_ field is also used during logging to print different messages if the disruption has been real or not + // so it is important to set it according to the RuleEngine mode. + for _, a := range r.actions { + // There can be only one disruptive action per rule + if a.Function.Type() == plugintypes.ActionTypeDisruptive { + // if not found it will default to DisruptiveActionUnknown. + mr.DisruptiveAction_ = corazarules.DisruptiveActionMap[a.Name] + mr.Disruptive_ = tx.RuleEngine == types.RuleEngineOn + break } } @@ -614,7 +655,7 @@ func (tx *Transaction) GetField(rv ruleVariableParams) []types.MatchData { isException := false lkey := strings.ToLower(c.Key()) for _, ex := range rv.Exceptions { - if (ex.KeyRx != nil && ex.KeyRx.MatchString(lkey)) || strings.ToLower(ex.KeyStr) == lkey { + if (ex.KeyRx != nil && ex.KeyRx.MatchString(lkey)) || strings.ToLower(ex.KeyStr) == lkey || (ex.KeyStr == "" && ex.KeyRx == nil) { isException = true break } @@ -639,12 +680,17 @@ func (tx *Transaction) GetField(rv ruleVariableParams) []types.MatchData { return matches } -// RemoveRuleTargetByID Removes the VARIABLE:KEY from the rule ID -// It's mostly used by CTL to dynamically remove targets from rules -func (tx *Transaction) RemoveRuleTargetByID(id int, variable variables.RuleVariable, key string) { +// RemoveRuleTargetByID removes the VARIABLE:KEY from the rule ID. +// It is mostly used by CTL to dynamically remove targets from rules. +// key is an exact string to match against the variable name; keyRx is an +// optional compiled regular expression that, when non-nil, is used instead of +// key for pattern-based matching (e.g. removing all ARGS matching +// /^json\.\d+\.field$/ from a given rule). +func (tx *Transaction) RemoveRuleTargetByID(id int, variable variables.RuleVariable, key string, keyRx *regexp.Regexp) { c := ruleVariableParams{ Variable: variable, KeyStr: key, + KeyRx: keyRx, } if multiphaseEvaluation && (variable == variables.Args || variable == variables.ArgsNames) { @@ -669,7 +715,22 @@ func (tx *Transaction) RemoveRuleTargetByID(id int, variable variables.RuleVaria // RemoveRuleByID Removes a rule from the transaction // It does not affect the WAF rules func (tx *Transaction) RemoveRuleByID(id int) { - tx.ruleRemoveByID = append(tx.ruleRemoveByID, id) + if tx.ruleRemoveByID == nil { + tx.ruleRemoveByID = map[int]struct{}{} + } + tx.ruleRemoveByID[id] = struct{}{} +} + +// RemoveRuleByIDRange marks rules in the ID range [start, end] (inclusive) to be +// skipped during transaction processing. It does not affect the WAF rules. +func (tx *Transaction) RemoveRuleByIDRange(start, end int) { + tx.ruleRemoveByIDRanges = append(tx.ruleRemoveByIDRanges, [2]int{start, end}) +} + +// GetRuleRemoveByIDRanges returns the list of rule ID ranges that will be skipped +// during transaction processing. +func (tx *Transaction) GetRuleRemoveByIDRanges() [][2]int { + return tx.ruleRemoveByIDRanges } // ProcessConnection should be called at very beginning of a request process, it is @@ -820,7 +881,7 @@ func (tx *Transaction) SetServerName(serverName string) { // // note: Remember to check for a possible intervention. func (tx *Transaction) ProcessRequestHeaders() *types.Interruption { - if tx.RuleEngine == types.RuleEngineOff { + if tx.IsRuleEngineOff() { // Rule engine is disabled return nil } @@ -830,7 +891,7 @@ func (tx *Transaction) ProcessRequestHeaders() *types.Interruption { return tx.interruption } - if tx.interruption != nil { + if tx.IsInterrupted() { tx.debugLogger.Error().Msg("Calling ProcessRequestHeaders but there is a preexisting interruption") return tx.interruption } @@ -852,7 +913,7 @@ func setAndReturnBodyLimitInterruption(tx *Transaction, status int) (*types.Inte // it returns an interruption if the writing bytes go beyond the request body limit. // It won't copy the bytes if the body access isn't accessible. func (tx *Transaction) WriteRequestBody(b []byte) (*types.Interruption, int, error) { - if tx.RuleEngine == types.RuleEngineOff { + if tx.IsRuleEngineOff() { return nil, 0, nil } @@ -917,7 +978,7 @@ type ByteLenger interface { // it returns an interruption if the writing bytes go beyond the request body limit. // It won't read the reader if the body access isn't accessible. func (tx *Transaction) ReadRequestBodyFrom(r io.Reader) (*types.Interruption, int, error) { - if tx.RuleEngine == types.RuleEngineOff { + if tx.IsRuleEngineOff() { return nil, 0, nil } @@ -996,11 +1057,10 @@ func (tx *Transaction) ReadRequestBodyFrom(r io.Reader) (*types.Interruption, in // // Remember to check for a possible intervention. func (tx *Transaction) ProcessRequestBody() (*types.Interruption, error) { - if tx.RuleEngine == types.RuleEngineOff { + if tx.IsRuleEngineOff() { return nil, nil } - - if tx.interruption != nil { + if tx.IsInterrupted() { tx.debugLogger.Error().Msg("Calling ProcessRequestBody but there is a preexisting interruption") return tx.interruption, nil } @@ -1024,9 +1084,9 @@ func (tx *Transaction) ProcessRequestBody() (*types.Interruption, error) { tx.WAF.Rules.Eval(types.PhaseRequestBody, tx) return tx.interruption, nil } - mime := "" + mimeType := "" if m := tx.variables.requestHeaders.Get("content-type"); len(m) > 0 { - mime = m[0] + mimeType = m[0] } reader, err := tx.requestBodyBuffer.Reader() @@ -1039,7 +1099,7 @@ func (tx *Transaction) ProcessRequestBody() (*types.Interruption, error) { // Default variables.ReqbodyProcessor values // XML and JSON must be forced with ctl:requestBodyProcessor=JSON if tx.ForceRequestBodyVariable { - // We force URLENCODED if mime is x-www... or we have an empty RBP and ForceRequestBodyVariable + // We force URLENCODED if mimeType is x-www... or we have an empty RBP and ForceRequestBodyVariable if rbp == "" { rbp = "URLENCODED" } @@ -1062,10 +1122,18 @@ func (tx *Transaction) ProcessRequestBody() (*types.Interruption, error) { Str("body_processor", rbp). Msg("Attempting to process request body") - if err := bodyprocessor.ProcessRequest(reader, tx.Variables(), plugintypes.BodyProcessorOptions{ - Mime: mime, - StoragePath: tx.WAF.UploadDir, - }); err != nil { + bpOpts := plugintypes.BodyProcessorOptions{ + Mime: mimeType, + StoragePath: tx.WAF.UploadDir, + RequestBodyRecursionLimit: tx.WAF.RequestBodyJsonDepthLimit, + } + + // If the body processor supports streaming, evaluate rules per record + if sp, ok := bodyprocessor.(plugintypes.StreamingBodyProcessor); ok { + return tx.processRequestBodyStreaming(sp, reader, bpOpts) + } + + if err := bodyprocessor.ProcessRequest(reader, tx.Variables(), bpOpts); err != nil { tx.debugLogger.Error().Err(err).Msg("Failed to process request body") tx.generateRequestBodyError(err) tx.WAF.Rules.Eval(types.PhaseRequestBody, tx) @@ -1076,6 +1144,298 @@ func (tx *Transaction) ProcessRequestBody() (*types.Interruption, error) { return tx.interruption, nil } +// errStreamInterrupted is a sentinel error used to stop streaming body processing +// when a rule triggers an interruption. +var errStreamInterrupted = errors.New("stream processing interrupted") + +// processRequestBodyStreaming evaluates Phase 2 rules after each record in a streaming body. +// ArgsPost is cleared and repopulated for each record, while TX variables persist across records +// for cross-record correlation (e.g., anomaly scoring). +func (tx *Transaction) processRequestBodyStreaming(sp plugintypes.StreamingBodyProcessor, reader io.Reader, opts plugintypes.BodyProcessorOptions) (*types.Interruption, error) { + err := sp.ProcessRequestRecords(reader, opts, func(recordNum int, fields map[string]string, _ string) error { + // Clear ArgsPost and repopulate with this record's fields only + tx.variables.argsPost.Reset() + for key, value := range fields { + tx.variables.argsPost.SetIndex(key, 0, value) + } + + // Evaluate Phase 2 rules for this record + tx.WAF.Rules.Eval(types.PhaseRequestBody, tx) + + if tx.interruption != nil { + tx.debugLogger.Debug(). + Int("record_num", recordNum). + Msg("Stream processing interrupted by rule") + return errStreamInterrupted + } + return nil + }) + + if err != nil && err != errStreamInterrupted { + tx.debugLogger.Error().Err(err).Msg("Failed to process streaming request body") + tx.generateRequestBodyError(err) + if tx.interruption == nil { + tx.WAF.Rules.Eval(types.PhaseRequestBody, tx) + } + } + + return tx.interruption, nil +} + +// processResponseBodyStreaming evaluates Phase 4 rules after each record in a streaming response body. +// ArgsResponse is cleared and repopulated for each record, while TX variables persist across records. +func (tx *Transaction) processResponseBodyStreaming(sp plugintypes.StreamingBodyProcessor, reader io.Reader, opts plugintypes.BodyProcessorOptions) (*types.Interruption, error) { + err := sp.ProcessResponseRecords(reader, opts, func(recordNum int, fields map[string]string, _ string) error { + // Clear ResponseArgs and repopulate with this record's fields only + tx.variables.responseArgs.Reset() + for key, value := range fields { + tx.variables.responseArgs.SetIndex(key, 0, value) + } + + // Evaluate Phase 4 rules for this record + tx.WAF.Rules.Eval(types.PhaseResponseBody, tx) + + if tx.interruption != nil { + tx.debugLogger.Debug(). + Int("record_num", recordNum). + Msg("Response stream processing interrupted by rule") + return errStreamInterrupted + } + return nil + }) + + if err != nil && err != errStreamInterrupted { + tx.debugLogger.Error().Err(err).Msg("Failed to process streaming response body") + tx.generateResponseBodyError(err) + if tx.interruption == nil { + tx.WAF.Rules.Eval(types.PhaseResponseBody, tx) + } + } + + return tx.interruption, nil +} + +// ProcessRequestBodyFromStream processes a streaming request body by reading records +// from input, evaluating Phase 2 rules per record, and writing clean records to output. +// This enables true streaming without buffering the entire body. +// +// Unlike ProcessRequestBody(), this method reads directly from the provided input +// rather than from the transaction's body buffer. Records that pass rule evaluation +// are written to output for relay to the backend. +// +// This method handles Phase 2 rule evaluation internally. +func (tx *Transaction) ProcessRequestBodyFromStream(input io.Reader, output io.Writer) (*types.Interruption, error) { + if tx.RuleEngine == types.RuleEngineOff { + // Pass through: copy input to output without evaluation + if _, err := io.Copy(output, input); err != nil { + return nil, err + } + return nil, nil + } + + if tx.interruption != nil { + return tx.interruption, nil + } + + if tx.lastPhase != types.PhaseRequestHeaders { + tx.debugLogger.Debug().Msg("Skipping ProcessRequestBodyFromStream: wrong phase") + return nil, nil + } + + rbp := tx.variables.reqbodyProcessor.Get() + if tx.ForceRequestBodyVariable && rbp == "" { + rbp = "URLENCODED" + tx.variables.reqbodyProcessor.Set(rbp) + } + rbp = strings.ToLower(rbp) + if rbp == "" { + // No body processor, pass through + if _, err := io.Copy(output, input); err != nil { + return nil, err + } + tx.WAF.Rules.Eval(types.PhaseRequestBody, tx) + return tx.interruption, nil + } + + bodyprocessor, err := bodyprocessors.GetBodyProcessor(rbp) + if err != nil { + tx.generateRequestBodyError(errors.New("invalid body processor")) + tx.WAF.Rules.Eval(types.PhaseRequestBody, tx) + return tx.interruption, nil + } + + sp, ok := bodyprocessor.(plugintypes.StreamingBodyProcessor) + if !ok { + // Non-streaming body processor: buffer input, process normally, then copy to output + it, _, err := tx.ReadRequestBodyFrom(input) + if err != nil { + return nil, err + } + if it != nil { + return it, nil + } + + it, err = tx.ProcessRequestBody() + if err != nil || it != nil { + return it, err + } + + // Copy buffered body to output + rbr, err := tx.RequestBodyReader() + if err != nil { + return nil, err + } + if _, err := io.Copy(output, rbr); err != nil { + return nil, err + } + return nil, nil + } + + mime := "" + if m := tx.variables.requestHeaders.Get("content-type"); len(m) > 0 { + mime = m[0] + } + + tx.debugLogger.Debug(). + Str("body_processor", rbp). + Msg("Attempting to process streaming request body with relay") + + streamErr := sp.ProcessRequestRecords(input, plugintypes.BodyProcessorOptions{ + Mime: mime, + StoragePath: tx.WAF.UploadDir, + }, func(recordNum int, fields map[string]string, rawRecord string) error { + // Clear ArgsPost and repopulate with this record's fields only + tx.variables.argsPost.Reset() + for key, value := range fields { + tx.variables.argsPost.SetIndex(key, 0, value) + } + + // Evaluate Phase 2 rules for this record + tx.WAF.Rules.Eval(types.PhaseRequestBody, tx) + + if tx.interruption != nil { + tx.debugLogger.Debug(). + Int("record_num", recordNum). + Msg("Stream relay interrupted by rule") + return errStreamInterrupted + } + + // Record passed evaluation, write to output for relay. + // rawRecord includes format-specific delimiters (e.g., \n for NDJSON, + // RS prefix + \n for RFC 7464), preserving the original stream format. + if _, err := io.WriteString(output, rawRecord); err != nil { + return fmt.Errorf("failed to write record to output: %w", err) + } + + return nil + }) + + if streamErr != nil && streamErr != errStreamInterrupted { + tx.debugLogger.Error().Err(streamErr).Msg("Failed to process streaming request body relay") + tx.generateRequestBodyError(streamErr) + if tx.interruption == nil { + tx.WAF.Rules.Eval(types.PhaseRequestBody, tx) + } + } + + return tx.interruption, nil +} + +// ProcessResponseBodyFromStream processes a streaming response body by reading records +// from input, evaluating Phase 4 rules per record, and writing clean records to output. +// This enables true streaming of response bodies without full buffering. +func (tx *Transaction) ProcessResponseBodyFromStream(input io.Reader, output io.Writer) (*types.Interruption, error) { + if tx.RuleEngine == types.RuleEngineOff { + if _, err := io.Copy(output, input); err != nil { + return nil, err + } + return nil, nil + } + + if tx.interruption != nil { + return tx.interruption, nil + } + + if tx.lastPhase != types.PhaseResponseHeaders { + tx.debugLogger.Debug().Msg("Skipping ProcessResponseBodyFromStream: wrong phase") + return nil, nil + } + + bp := tx.variables.resBodyProcessor.Get() + if bp == "" { + // No body processor, pass through + if _, err := io.Copy(output, input); err != nil { + return nil, err + } + tx.WAF.Rules.Eval(types.PhaseResponseBody, tx) + return tx.interruption, nil + } + + bodyprocessor, err := bodyprocessors.GetBodyProcessor(bp) + if err != nil { + tx.generateResponseBodyError(errors.New("invalid body processor")) + tx.WAF.Rules.Eval(types.PhaseResponseBody, tx) + return tx.interruption, nil + } + + sp, ok := bodyprocessor.(plugintypes.StreamingBodyProcessor) + if !ok { + // Non-streaming: buffer, process, copy + if _, _, err := tx.ReadResponseBodyFrom(input); err != nil { + return nil, err + } + it, err := tx.ProcessResponseBody() + if err != nil || it != nil { + return it, err + } + rbr, err := tx.ResponseBodyReader() + if err != nil { + return nil, err + } + if _, err := io.Copy(output, rbr); err != nil { + return nil, err + } + return nil, nil + } + + tx.debugLogger.Debug(). + Str("body_processor", bp). + Msg("Attempting to process streaming response body with relay") + + streamErr := sp.ProcessResponseRecords(input, plugintypes.BodyProcessorOptions{}, func(recordNum int, fields map[string]string, rawRecord string) error { + tx.variables.responseArgs.Reset() + for key, value := range fields { + tx.variables.responseArgs.SetIndex(key, 0, value) + } + + tx.WAF.Rules.Eval(types.PhaseResponseBody, tx) + + if tx.interruption != nil { + tx.debugLogger.Debug(). + Int("record_num", recordNum). + Msg("Response stream relay interrupted by rule") + return errStreamInterrupted + } + + // rawRecord includes format-specific delimiters, preserving the original stream format. + if _, err := io.WriteString(output, rawRecord); err != nil { + return fmt.Errorf("failed to write response record to output: %w", err) + } + + return nil + }) + + if streamErr != nil && streamErr != errStreamInterrupted { + tx.debugLogger.Error().Err(streamErr).Msg("Failed to process streaming response body relay") + tx.generateResponseBodyError(streamErr) + if tx.interruption == nil { + tx.WAF.Rules.Eval(types.PhaseResponseBody, tx) + } + } + + return tx.interruption, nil +} + // ProcessResponseHeaders performs the analysis on the response headers. // // This method performs the analysis on the response headers. Note, however, @@ -1083,7 +1443,7 @@ func (tx *Transaction) ProcessRequestBody() (*types.Interruption, error) { // // Note: Remember to check for a possible intervention. func (tx *Transaction) ProcessResponseHeaders(code int, proto string) *types.Interruption { - if tx.RuleEngine == types.RuleEngineOff { + if tx.IsRuleEngineOff() { return nil } @@ -1093,7 +1453,7 @@ func (tx *Transaction) ProcessResponseHeaders(code int, proto string) *types.Int return tx.interruption } - if tx.interruption != nil { + if tx.IsInterrupted() { tx.debugLogger.Error().Msg("Calling ProcessResponseHeaders but there is a preexisting interruption") return tx.interruption } @@ -1125,7 +1485,7 @@ func (tx *Transaction) IsResponseBodyProcessable() bool { // it returns an interruption if the writing bytes go beyond the response body limit. // It won't copy the bytes if the body access isn't accessible. func (tx *Transaction) WriteResponseBody(b []byte) (*types.Interruption, int, error) { - if tx.RuleEngine == types.RuleEngineOff { + if tx.IsRuleEngineOff() { return nil, 0, nil } @@ -1176,7 +1536,7 @@ func (tx *Transaction) WriteResponseBody(b []byte) (*types.Interruption, int, er // it returns an interruption if the writing bytes go beyond the response body limit. // It won't read the reader if the body access isn't accessible. func (tx *Transaction) ReadResponseBodyFrom(r io.Reader) (*types.Interruption, int, error) { - if tx.RuleEngine == types.RuleEngineOff { + if tx.IsRuleEngineOff() { return nil, 0, nil } @@ -1246,11 +1606,11 @@ func (tx *Transaction) ReadResponseBodyFrom(r io.Reader) (*types.Interruption, i // // note Remember to check for a possible intervention. func (tx *Transaction) ProcessResponseBody() (*types.Interruption, error) { - if tx.RuleEngine == types.RuleEngineOff { + if tx.IsRuleEngineOff() { return nil, nil } - if tx.interruption != nil { + if tx.IsInterrupted() { tx.debugLogger.Error().Msg("Calling ProcessResponseBody but there is a preexisting interruption") return tx.interruption, nil } @@ -1295,6 +1655,11 @@ func (tx *Transaction) ProcessResponseBody() (*types.Interruption, error) { tx.debugLogger.Debug().Str("body_processor", bp).Msg("Attempting to process response body") + // If the body processor supports streaming, evaluate rules per record + if sp, ok := b.(plugintypes.StreamingBodyProcessor); ok { + return tx.processResponseBodyStreaming(sp, reader, plugintypes.BodyProcessorOptions{}) + } + if err := b.ProcessResponse(reader, tx.Variables(), plugintypes.BodyProcessorOptions{}); err != nil { tx.debugLogger.Error().Err(err).Msg("Failed to process response body") tx.generateResponseBodyError(err) @@ -1319,12 +1684,11 @@ func (tx *Transaction) ProcessLogging() { // If Rule engine is disabled, Log phase rules are not going to be evaluated. // This avoids trying to rely on variables not set by previous rules that // have not been executed - if tx.RuleEngine != types.RuleEngineOff { + if !tx.IsRuleEngineOff() { tx.WAF.Rules.Eval(types.PhaseLogging, tx) } if tx.AuditEngine == types.AuditEngineOff { - // Audit engine disabled tx.debugLogger.Debug(). Msg("Transaction not marked for audit logging, AuditEngine is disabled") return @@ -1342,11 +1706,14 @@ func (tx *Transaction) ProcessLogging() { status := tx.variables.responseStatus.Get() if tx.IsInterrupted() { status = strconv.Itoa(tx.interruption.Status) + } else if tx.IsDetectionOnlyInterrupted() { + // This allows to check for relevant status even in detection only mode. + // Fixes https://github.com/corazawaf/coraza/issues/1333 + status = strconv.Itoa(tx.detectionOnlyInterruption.Status) } if re != nil && !re.Match([]byte(status)) { // Not relevant status - tx.debugLogger.Debug(). - Msg("Transaction status not marked for audit logging") + tx.debugLogger.Debug().Msg("Transaction status not marked for audit logging") return } } @@ -1382,14 +1749,35 @@ func (tx *Transaction) IsInterrupted() bool { return tx.interruption != nil } +// TODO(4.x): evaluate to expose it in the interface. +func (tx *Transaction) IsDetectionOnlyInterrupted() bool { + return tx.detectionOnlyInterruption != nil +} + func (tx *Transaction) Interruption() *types.Interruption { return tx.interruption } +func (tx *Transaction) DetectionOnlyInterruption() *types.Interruption { + return tx.detectionOnlyInterruption +} + func (tx *Transaction) MatchedRules() []types.MatchedRule { return tx.matchedRules } +// hasLogRelevantMatchedRules returns true if any matched rule has Log enabled. +// Rules with nolog (e.g. CRS initialization rules) are excluded, matching +// the same filtering used for audit log part K. +func (tx *Transaction) hasLogRelevantMatchedRules() bool { + for _, mr := range tx.matchedRules { + if mrWithLog, ok := mr.(*corazarules.MatchedRule); ok && mrWithLog.Log() { + return true + } + } + return false +} + func (tx *Transaction) LastPhase() types.RulePhase { return tx.lastPhase } @@ -1563,12 +1951,20 @@ func (tx *Transaction) Close() error { var errs []error if environment.HasAccessToFS { - // TODO(jcchavezs): filesTmpNames should probably be a new kind of collection that - // is aware of the files and then attempt to delete them when the collection - // is resetted or an item is removed. - for _, file := range tx.variables.filesTmpNames.Get("") { - if err := os.Remove(file); err != nil { - errs = append(errs, fmt.Errorf("removing temporary file: %v", err)) + // UploadKeepFilesRelevantOnly keeps temporary files only when there are + // log-relevant matched rules (i.e., rules that would be logged; rules + // with actions such as "nolog" are intentionally excluded here). + keepFiles := tx.WAF.UploadKeepFiles == types.UploadKeepFilesOn || + (tx.WAF.UploadKeepFiles == types.UploadKeepFilesRelevantOnly && tx.hasLogRelevantMatchedRules()) + + if !keepFiles { + // TODO(jcchavezs): filesTmpNames should probably be a new kind of collection that + // is aware of the files and then attempt to delete them when the collection + // is resetted or an item is removed. + for _, file := range tx.variables.filesTmpNames.Get("") { + if err := os.Remove(file); err != nil { + errs = append(errs, fmt.Errorf("removing temporary file: %v", err)) + } } } } diff --git a/internal/corazawaf/transaction_test.go b/internal/corazawaf/transaction_test.go index b5eb2905c..62f1935f0 100644 --- a/internal/corazawaf/transaction_test.go +++ b/internal/corazawaf/transaction_test.go @@ -8,6 +8,7 @@ import ( "fmt" "io" "net/http" + "os" "regexp" "runtime/debug" "strconv" @@ -22,6 +23,7 @@ import ( "github.com/corazawaf/coraza/v3/internal/collections" "github.com/corazawaf/coraza/v3/internal/corazarules" "github.com/corazawaf/coraza/v3/internal/environment" + "github.com/corazawaf/coraza/v3/internal/operators" utils "github.com/corazawaf/coraza/v3/internal/strings" "github.com/corazawaf/coraza/v3/types" "github.com/corazawaf/coraza/v3/types/variables" @@ -752,6 +754,72 @@ func TestAuditLogFields(t *testing.T) { } } +func TestMatchRuleDisruptiveActionPopulated(t *testing.T) { + tests := []struct { + name string + engine types.RuleEngineStatus + wantDisruptive bool + wantAction corazarules.DisruptiveAction + wantInterrupted bool + wantDetectionOnlyInterrupted bool + }{ + { + name: "engine on", + engine: types.RuleEngineOn, + wantDisruptive: true, + wantAction: corazarules.DisruptiveActionDeny, + wantInterrupted: true, + wantDetectionOnlyInterrupted: false, + }, + { + name: "engine detection only", + engine: types.RuleEngineDetectionOnly, + wantDisruptive: false, + wantAction: corazarules.DisruptiveActionDeny, + wantInterrupted: false, + wantDetectionOnlyInterrupted: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + waf := NewWAF() + tx := waf.NewTransaction() + tx.RuleEngine = tt.engine + + rule := NewRule() + rule.ID_ = 1 + if err := rule.AddVariable(variables.ArgsGet, "", false); err != nil { + t.Fatal(err) + } + rule.SetOperator(&dummyEqOperator{}, "@eq", "0") + _ = rule.AddAction("deny", &dummyDenyAction{}) + + tx.AddGetRequestArgument("test", "0") + + var matchedValues []types.MatchData + rule.doEvaluate(debuglog.Noop(), types.PhaseRequestHeaders, tx, &matchedValues, 0, tx.transformationCache) + + if len(tx.matchedRules) != 1 { + t.Fatalf("expected 1 matched rule, got %d", len(tx.matchedRules)) + } + mr := tx.matchedRules[0].(*corazarules.MatchedRule) + if mr.Disruptive_ != tt.wantDisruptive { + t.Errorf("Disruptive_: got %t, want %t", mr.Disruptive_, tt.wantDisruptive) + } + if mr.DisruptiveAction_ != tt.wantAction { + t.Errorf("DisruptiveAction_: got %d, want %d", mr.DisruptiveAction_, tt.wantAction) + } + if tx.IsInterrupted() != tt.wantInterrupted { + t.Errorf("IsInterrupted: got %t, want %t", tx.IsInterrupted(), tt.wantInterrupted) + } + if tx.IsDetectionOnlyInterrupted() != tt.wantDetectionOnlyInterrupted { + t.Errorf("IsDetectionOnlyInterrupted: got %t, want %t", tx.IsDetectionOnlyInterrupted(), tt.wantDetectionOnlyInterrupted) + } + }) + } +} + func TestResetCapture(t *testing.T) { tx := makeTransaction(t) tx.Capture = true @@ -1326,6 +1394,61 @@ func BenchmarkTxGetField(b *testing.B) { b.ReportAllocs() } +// makeTransactionWithJSONArgs creates a transaction that includes JSON-array-style +// GET arguments (json.0.field … json.9.field) on top of the standard args. +// This simulates the real-world pattern that motivates regex key exceptions. +func makeTransactionWithJSONArgs(t testing.TB) *Transaction { + t.Helper() + tx := makeTransaction(t) + for i := 0; i < 10; i++ { + tx.AddGetRequestArgument(fmt.Sprintf("json.%d.jobdescription", i), "value") + } + return tx +} + +// BenchmarkTxGetFieldWithShortRegexException measures the overhead of GetField +// when a short regex exception (e.g. ^id$) is applied against the args collection. +func BenchmarkTxGetFieldWithShortRegexException(b *testing.B) { + tx := makeTransactionWithJSONArgs(b) + rvp := ruleVariableParams{ + Variable: variables.Args, + Exceptions: []ruleVariableException{ + {KeyRx: regexp.MustCompile(`^id$`)}, + }, + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + tx.GetField(rvp) + } + b.StopTimer() + if err := tx.Close(); err != nil { + b.Fatalf("Failed to close transaction: %s", err.Error()) + } +} + +// BenchmarkTxGetFieldWithMediumRegexException measures the overhead of GetField +// when a medium-complexity regex exception (e.g. ^json\.\d+\.jobdescription$) is +// applied — the typical pattern used in URI-scoped CRS exclusions. +func BenchmarkTxGetFieldWithMediumRegexException(b *testing.B) { + tx := makeTransactionWithJSONArgs(b) + rvp := ruleVariableParams{ + Variable: variables.Args, + Exceptions: []ruleVariableException{ + {KeyRx: regexp.MustCompile(`^json\.\d+\.jobdescription$`)}, + }, + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + tx.GetField(rvp) + } + b.StopTimer() + if err := tx.Close(); err != nil { + b.Fatalf("Failed to close transaction: %s", err.Error()) + } +} + func TestTxProcessURI(t *testing.T) { waf := NewWAF() tx := waf.NewTransaction() @@ -1721,7 +1844,7 @@ func TestResponseBodyForceProcessing(t *testing.T) { if _, err := tx.ProcessRequestBody(); err != nil { t.Fatal(err) } - tx.ProcessResponseHeaders(200, "HTTP/1") + tx.ProcessResponseHeaders(200, "HTTP/1.1") if _, _, err := tx.WriteResponseBody([]byte(`{"key":"value"}`)); err != nil { t.Fatal(err) } @@ -1790,6 +1913,121 @@ func TestCloseFails(t *testing.T) { } } +func TestUploadKeepFiles(t *testing.T) { + if !environment.HasAccessToFS { + t.Skip("skipping test as it requires access to filesystem") + } + + createTmpFile := func(t *testing.T) string { + t.Helper() + f, err := os.CreateTemp(t.TempDir(), "crztest*") + if err != nil { + t.Fatal(err) + } + name := f.Name() + if err := f.Close(); err != nil { + t.Fatalf("failed to close temp file: %v", err) + } + return name + } + + t.Run("Off deletes files", func(t *testing.T) { + waf := NewWAF() + waf.UploadKeepFiles = types.UploadKeepFilesOff + tx := waf.NewTransaction() + tmpFile := createTmpFile(t) + + col := tx.Variables().FilesTmpNames().(*collections.Map) + col.Add("", tmpFile) + + if err := tx.Close(); err != nil { + t.Fatal(err) + } + + if _, err := os.Stat(tmpFile); !os.IsNotExist(err) { + t.Fatal("expected temp file to be deleted when UploadKeepFiles is Off") + } + }) + + t.Run("On keeps files", func(t *testing.T) { + waf := NewWAF() + waf.UploadKeepFiles = types.UploadKeepFilesOn + tx := waf.NewTransaction() + tmpFile := createTmpFile(t) + + col := tx.Variables().FilesTmpNames().(*collections.Map) + col.Add("", tmpFile) + + if err := tx.Close(); err != nil { + t.Fatal(err) + } + + if _, err := os.Stat(tmpFile); err != nil { + t.Fatal("expected temp file to be kept when UploadKeepFiles is On") + } + }) + + t.Run("RelevantOnly keeps files when log rules matched", func(t *testing.T) { + waf := NewWAF() + waf.UploadKeepFiles = types.UploadKeepFilesRelevantOnly + tx := waf.NewTransaction() + tmpFile := createTmpFile(t) + + // Simulate a matched rule with Log enabled + tx.matchedRules = append(tx.matchedRules, &corazarules.MatchedRule{Log_: true}) + + col := tx.Variables().FilesTmpNames().(*collections.Map) + col.Add("", tmpFile) + + if err := tx.Close(); err != nil { + t.Fatal(err) + } + + if _, err := os.Stat(tmpFile); err != nil { + t.Fatal("expected temp file to be kept when UploadKeepFiles is RelevantOnly and log rules matched") + } + }) + + t.Run("RelevantOnly deletes files when only nolog rules matched", func(t *testing.T) { + waf := NewWAF() + waf.UploadKeepFiles = types.UploadKeepFilesRelevantOnly + tx := waf.NewTransaction() + tmpFile := createTmpFile(t) + + // Simulate a matched rule with Log disabled (e.g. CRS initialization rules) + tx.matchedRules = append(tx.matchedRules, &corazarules.MatchedRule{Log_: false}) + + col := tx.Variables().FilesTmpNames().(*collections.Map) + col.Add("", tmpFile) + + if err := tx.Close(); err != nil { + t.Fatal(err) + } + + if _, err := os.Stat(tmpFile); !os.IsNotExist(err) { + t.Fatal("expected temp file to be deleted when UploadKeepFiles is RelevantOnly and only nolog rules matched") + } + }) + + t.Run("RelevantOnly deletes files when no rules matched", func(t *testing.T) { + waf := NewWAF() + waf.UploadKeepFiles = types.UploadKeepFilesRelevantOnly + tx := waf.NewTransaction() + tmpFile := createTmpFile(t) + + col := tx.Variables().FilesTmpNames().(*collections.Map) + col.Add("", tmpFile) + + if err := tx.Close(); err != nil { + t.Fatal(err) + } + + if _, err := os.Stat(tmpFile); !os.IsNotExist(err) { + t.Fatal("expected temp file to be deleted when UploadKeepFiles is RelevantOnly and no rules matched") + } + }) +} + func TestRequestFilename(t *testing.T) { tests := []struct { name string @@ -1864,3 +2102,189 @@ func TestRequestFilename(t *testing.T) { }) } } + +func newTestUnconditionalMatch(t testing.TB) plugintypes.Operator { + t.Helper() + op, err := operators.Get("unconditionalMatch", plugintypes.OperatorOptions{}) + if err != nil { + t.Fatal(err) + } + return op +} + +func TestRemoveRuleByID(t *testing.T) { + waf := NewWAF() + op := newTestUnconditionalMatch(t) + + // Add two rules with different IDs + rule1 := NewRule() + rule1.ID_ = 100 + rule1.LogID_ = "100" + rule1.Phase_ = types.PhaseRequestHeaders + rule1.operator = &ruleOperatorParams{ + Operator: op, + Function: "@unconditionalMatch", + } + rule1.Log = true + if err := waf.Rules.Add(rule1); err != nil { + t.Fatal(err) + } + + rule2 := NewRule() + rule2.ID_ = 200 + rule2.LogID_ = "200" + rule2.Phase_ = types.PhaseRequestHeaders + rule2.operator = &ruleOperatorParams{ + Operator: op, + Function: "@unconditionalMatch", + } + rule2.Log = true + if err := waf.Rules.Add(rule2); err != nil { + t.Fatal(err) + } + + tx := waf.NewTransaction() + defer tx.Close() + + // Remove rule 100 + tx.RemoveRuleByID(100) + + // Verify the map was lazily initialized + if tx.ruleRemoveByID == nil { + t.Fatal("ruleRemoveByID should not be nil after RemoveRuleByID") + } + + // Remove another rule + tx.RemoveRuleByID(100) // duplicate removal should be idempotent + tx.RemoveRuleByID(200) + + if len(tx.ruleRemoveByID) != 2 { + t.Errorf("expected 2 entries in ruleRemoveByID map, got %d", len(tx.ruleRemoveByID)) + } +} + +func TestRemoveRuleByIDRange(t *testing.T) { + waf := NewWAF() + + // Use nil-operator rules (SecAction-style): they always match regardless of variables. + for _, id := range []int{100, 150, 200, 300} { + r := NewRule() + r.ID_ = id + r.LogID_ = strconv.Itoa(id) + r.Phase_ = types.PhaseRequestHeaders + // nil operator means the rule always matches + if err := waf.Rules.Add(r); err != nil { + t.Fatal(err) + } + } + + t.Run("range is stored", func(t *testing.T) { + tx := waf.NewTransaction() + defer tx.Close() + + tx.RemoveRuleByIDRange(100, 200) + if len(tx.ruleRemoveByIDRanges) != 1 { + t.Fatalf("expected 1 range entry, got %d", len(tx.ruleRemoveByIDRanges)) + } + if tx.ruleRemoveByIDRanges[0][0] != 100 || tx.ruleRemoveByIDRanges[0][1] != 200 { + t.Errorf("unexpected range: %v", tx.ruleRemoveByIDRanges[0]) + } + }) + + t.Run("rules in range are skipped during eval", func(t *testing.T) { + tx := waf.NewTransaction() + defer tx.Close() + + // Remove rules with IDs 100-200; rule 300 should still be evaluated. + tx.RemoveRuleByIDRange(100, 200) + waf.Rules.Eval(types.PhaseRequestHeaders, tx) + + matchedIDs := make(map[int]bool) + for _, mr := range tx.MatchedRules() { + matchedIDs[mr.Rule().ID()] = true + } + + for _, skipped := range []int{100, 150, 200} { + if matchedIDs[skipped] { + t.Errorf("rule %d should have been skipped but was matched", skipped) + } + } + if !matchedIDs[300] { + t.Errorf("rule 300 should have been matched but was not") + } + }) + + t.Run("multiple ranges", func(t *testing.T) { + tx := waf.NewTransaction() + defer tx.Close() + + tx.RemoveRuleByIDRange(100, 100) + tx.RemoveRuleByIDRange(300, 300) + if len(tx.ruleRemoveByIDRanges) != 2 { + t.Fatalf("expected 2 range entries, got %d", len(tx.ruleRemoveByIDRanges)) + } + + waf.Rules.Eval(types.PhaseRequestHeaders, tx) + + matchedIDs := make(map[int]bool) + for _, mr := range tx.MatchedRules() { + matchedIDs[mr.Rule().ID()] = true + } + if matchedIDs[100] { + t.Error("rule 100 should have been skipped") + } + if matchedIDs[300] { + t.Error("rule 300 should have been skipped") + } + if !matchedIDs[150] { + t.Error("rule 150 should have been matched") + } + if !matchedIDs[200] { + t.Error("rule 200 should have been matched") + } + }) + + t.Run("range reset on transaction reuse", func(t *testing.T) { + tx := waf.NewTransaction() + tx.RemoveRuleByIDRange(100, 200) + tx.Close() + + // Get a new transaction (pool reuse may return the same object) + tx2 := waf.NewTransaction() + defer tx2.Close() + + if len(tx2.ruleRemoveByIDRanges) != 0 { + t.Errorf("expected ruleRemoveByIDRanges to be reset, got %d entries", len(tx2.ruleRemoveByIDRanges)) + } + }) +} + +func BenchmarkRuleEvalWithRemovedRules(b *testing.B) { + waf := NewWAF() + op := newTestUnconditionalMatch(b) + + rule := NewRule() + rule.ID_ = 1000 + rule.LogID_ = "1000" + rule.Phase_ = types.PhaseRequestHeaders + rule.operator = &ruleOperatorParams{ + Operator: op, + Function: "@unconditionalMatch", + } + if err := waf.Rules.Add(rule); err != nil { + b.Fatal(err) + } + + tx := waf.NewTransaction() + defer tx.Close() + + for i := 1; i <= 100; i++ { + tx.RemoveRuleByID(i) + } + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + waf.Rules.Eval(types.PhaseRequestHeaders, tx) + } +} diff --git a/internal/corazawaf/waf.go b/internal/corazawaf/waf.go index 32ac51fb1..54e28eaac 100644 --- a/internal/corazawaf/waf.go +++ b/internal/corazawaf/waf.go @@ -12,17 +12,28 @@ import ( "os" "regexp" "strconv" + gosync "sync" + "sync/atomic" "time" "github.com/corazawaf/coraza/v3/debuglog" "github.com/corazawaf/coraza/v3/experimental/plugins/plugintypes" "github.com/corazawaf/coraza/v3/internal/auditlog" "github.com/corazawaf/coraza/v3/internal/environment" + "github.com/corazawaf/coraza/v3/internal/memoize" stringutils "github.com/corazawaf/coraza/v3/internal/strings" "github.com/corazawaf/coraza/v3/internal/sync" "github.com/corazawaf/coraza/v3/types" ) +var wafIDCounter atomic.Uint64 + +// Default settings +const ( + // DefaultRequestBodyJsonDepthLimit is the default limit for the depth of JSON objects in the request body + DefaultRequestBodyJsonDepthLimit = 1024 +) + // WAF instance is used to store configurations and rules // Every web application should have a different WAF instance, // but you can share an instance if you are ok with sharing @@ -44,6 +55,9 @@ type WAF struct { // Request body page file limit RequestBodyLimit int64 + // Request body JSON recursive depth limit + RequestBodyJsonDepthLimit int + // Request body in memory limit requestBodyInMemoryLimit *int64 @@ -80,9 +94,9 @@ type WAF struct { // Path to store data files (ex. cache) DataDir string - // If true, the WAF will store the uploaded files in the UploadDir - // directory - UploadKeepFiles bool + // UploadKeepFiles controls whether uploaded files are kept after the transaction. + // On: always keep, Off: always delete (default), RelevantOnly: keep only if log-relevant rules matched (excluding nolog rules). + UploadKeepFiles types.UploadKeepFilesStatus // UploadFileMode instructs the waf to set the file mode for uploaded files UploadFileMode fs.FileMode // UploadFileLimit is the maximum size of the uploaded file to be stored @@ -135,6 +149,10 @@ type WAF struct { // Configures the maximum number of ARGS that will be accepted for processing. ArgumentLimit int + + memoizerID uint64 + memoizer *memoize.Memoizer + closeOnce gosync.Once } // Options is used to pass options to the WAF instance @@ -180,14 +198,15 @@ func (w *WAF) newTransaction(opts Options) *Transaction { tx.AuditLogFormat = w.AuditLogFormat tx.ForceRequestBodyVariable = false tx.RequestBodyAccess = w.RequestBodyAccess - tx.RequestBodyLimit = int64(w.RequestBodyLimit) + tx.RequestBodyLimit = w.RequestBodyLimit tx.ResponseBodyAccess = w.ResponseBodyAccess - tx.ResponseBodyLimit = int64(w.ResponseBodyLimit) + tx.ResponseBodyLimit = w.ResponseBodyLimit tx.RuleEngine = w.RuleEngine tx.HashEngine = false tx.HashEnforcement = false tx.lastPhase = 0 tx.ruleRemoveByID = nil + tx.ruleRemoveByIDRanges = nil tx.ruleRemoveTargetByID = map[int][]ruleVariableParams{} tx.Skip = 0 tx.AllowType = 0 @@ -198,13 +217,13 @@ func (w *WAF) newTransaction(opts Options) *Transaction { tx.Timestamp = time.Now().UnixNano() tx.audit = false - // Always non-nil if buffers / collections were already initialized so we don't do any of them + // Always non-nil if buffers / collections were already initialized, so we don't do any of them // based on the presence of RequestBodyBuffer. if tx.requestBodyBuffer == nil { // if no requestBodyInMemoryLimit has been set we default to the requestBodyLimit requestBodyInMemoryLimit := w.RequestBodyLimit if w.requestBodyInMemoryLimit != nil { - requestBodyInMemoryLimit = int64(*w.requestBodyInMemoryLimit) + requestBodyInMemoryLimit = *w.requestBodyInMemoryLimit } tx.requestBodyBuffer = NewBodyBuffer(types.BodyBufferOptions{ @@ -221,7 +240,7 @@ func (w *WAF) newTransaction(opts Options) *Transaction { }) tx.variables = *NewTransactionVariables() - tx.transformationCache = map[transformationKey]*transformationValue{} + tx.transformationCache = map[transformationKey]transformationValue{} } // set capture variables @@ -299,6 +318,7 @@ func NewWAF() *WAF { RequestBodyAccess: false, RequestBodyLimit: 134217728, // Hard limit equal to _1gib RequestBodyLimitAction: types.BodyLimitActionReject, + RequestBodyJsonDepthLimit: DefaultRequestBodyJsonDepthLimit, ResponseBodyAccess: false, ResponseBodyLimit: 524288, // Hard limit equal to _1gib auditLogWriter: logWriter, @@ -319,6 +339,10 @@ func NewWAF() *WAF { waf.TmpDir = os.TempDir() } + id := wafIDCounter.Add(1) + waf.memoizerID = id + waf.memoizer = memoize.NewMemoizer(id) + waf.Logger.Debug().Msg("A new WAF instance was created") return waf } @@ -418,5 +442,34 @@ func (w *WAF) Validate() error { return errors.New("argument limit should be bigger than 0") } + if w.RequestBodyJsonDepthLimit <= 0 { + return errors.New("request body json depth limit should be bigger than 0") + } + + if environment.HasAccessToFS { + if w.UploadKeepFiles != types.UploadKeepFilesOff && w.UploadDir == "" { + return errors.New("SecUploadDir is required when SecUploadKeepFiles is enabled") + } + } else { + if w.UploadKeepFiles != types.UploadKeepFilesOff { + return errors.New("SecUploadKeepFiles requires filesystem access, which is not available in this build") + } + } + + return nil +} + +// Memoizer returns the WAF's memoizer for caching compiled patterns. +func (w *WAF) Memoizer() *memoize.Memoizer { + return w.memoizer +} + +// Close releases cached resources owned by this WAF instance. +// Cached entries shared with other WAF instances remain until all owners release them. +// Transactions already in-flight are unaffected as they hold their own references. +func (w *WAF) Close() error { + w.closeOnce.Do(func() { + memoize.Release(w.memoizerID) + }) return nil } diff --git a/internal/corazawaf/waf_test.go b/internal/corazawaf/waf_test.go index 11a2c75ef..898dd4144 100644 --- a/internal/corazawaf/waf_test.go +++ b/internal/corazawaf/waf_test.go @@ -7,6 +7,9 @@ import ( "io" "os" "testing" + + "github.com/corazawaf/coraza/v3/internal/environment" + "github.com/corazawaf/coraza/v3/types" ) func TestNewTransaction(t *testing.T) { @@ -107,6 +110,39 @@ func TestValidate(t *testing.T) { }, } + if environment.HasAccessToFS { + testCases["upload keep files on without upload dir"] = struct { + customizer func(*WAF) + expectErr bool + }{ + expectErr: true, + customizer: func(w *WAF) { + w.UploadKeepFiles = types.UploadKeepFilesOn + w.UploadDir = "" + }, + } + testCases["upload keep files relevant only without upload dir"] = struct { + customizer func(*WAF) + expectErr bool + }{ + expectErr: true, + customizer: func(w *WAF) { + w.UploadKeepFiles = types.UploadKeepFilesRelevantOnly + w.UploadDir = "" + }, + } + testCases["upload keep files on with upload dir"] = struct { + customizer func(*WAF) + expectErr bool + }{ + expectErr: false, + customizer: func(w *WAF) { + w.UploadKeepFiles = types.UploadKeepFilesOn + w.UploadDir = "/tmp" + }, + } + } + for name, tCase := range testCases { t.Run(name, func(t *testing.T) { waf := NewWAF() diff --git a/internal/memoize/README.md b/internal/memoize/README.md index 5ebe09aba..d0dfd3eb4 100644 --- a/internal/memoize/README.md +++ b/internal/memoize/README.md @@ -1,17 +1,19 @@ # Memoize -Memoize allows to cache certain expensive function calls and -cache the result. The main advantage in Coraza is to memoize -the regexes and aho-corasick dictionaries when the connects -spins up more than one WAF in the same process and hence same -regexes are being compiled over and over. +Memoize caches certain expensive function calls (regex and aho-corasick +compilation) so the same patterns are not recompiled when multiple WAF +instances in the same process share rules. -Currently it is opt-in under the `memoize_builders` build tag -as under a misuse (e.g. using after build time) it could lead -to a memory leak as currently the cache is global. +Memoization is **enabled by default** and uses a **global cache** within +the process. In long-lived processes that reload WAF configurations, +use `WAF.Close()` (via `experimental.WAFCloser`) to release cached +entries when a WAF is destroyed. Alternatively, disable memoization with +the `coraza.no_memoize` build tag. -**Important:** Connectors with *live reload* functionality (e.g. Caddy) -could lead to memory leaks which might or might not be negligible in -most of the cases as usually config changes in a WAF are about a few -rules, this is old objects will be still alive in memory until the program -stops. +## Build variants + +| Build tag | Behavior | +|-----------------------|--------------------------------------------------| +| *(none)* | Full memoization with `singleflight` (default) | +| `tinygo` | Memoization without `singleflight` (TinyGo) | +| `coraza.no_memoize` | No-op — every call compiles fresh | diff --git a/internal/memoize/noop.go b/internal/memoize/noop.go index 1d1e5447f..f5b82e085 100644 --- a/internal/memoize/noop.go +++ b/internal/memoize/noop.go @@ -1,10 +1,21 @@ // Copyright 2023 Juan Pablo Tosso and the OWASP Coraza contributors // SPDX-License-Identifier: Apache-2.0 -//go:build !memoize_builders +//go:build coraza.no_memoize package memoize -func Do(_ string, fn func() (any, error)) (any, error) { - return fn() -} +// Memoizer is a no-op implementation when memoization is disabled. +type Memoizer struct{} + +// NewMemoizer returns a no-op Memoizer. +func NewMemoizer(_ uint64) *Memoizer { return &Memoizer{} } + +// Do always calls fn directly without caching. +func (m *Memoizer) Do(_ string, fn func() (any, error)) (any, error) { return fn() } + +// Release is a no-op when memoization is disabled. +func Release(_ uint64) {} + +// Reset is a no-op when memoization is disabled. +func Reset() {} diff --git a/internal/memoize/noop_test.go b/internal/memoize/noop_test.go new file mode 100644 index 000000000..bf5c0ed47 --- /dev/null +++ b/internal/memoize/noop_test.go @@ -0,0 +1,206 @@ +// Copyright 2023 Juan Pablo Tosso and the OWASP Coraza contributors +// SPDX-License-Identifier: Apache-2.0 + +//go:build coraza.no_memoize + +package memoize + +import ( + "errors" + "strconv" + "testing" +) + +func TestNoopDo(t *testing.T) { + m := NewMemoizer(1) + calls := 0 + + fn := func() (any, error) { + calls++ + return calls, nil + } + + result, err := m.Do("key1", fn) + if err != nil { + t.Fatalf("unexpected error: %s", err.Error()) + } + if want, have := 1, result.(int); want != have { + t.Fatalf("unexpected value, want %d, have %d", want, have) + } + + // No-op memoizer should call fn again (no caching). + result, err = m.Do("key1", fn) + if err != nil { + t.Fatalf("unexpected error: %s", err.Error()) + } + if want, have := 2, result.(int); want != have { + t.Fatalf("expected no caching, want %d, have %d", want, have) + } +} + +func TestNoopDoError(t *testing.T) { + m := NewMemoizer(1) + + _, err := m.Do("key1", func() (any, error) { + return nil, errors.New("fail") + }) + if err == nil { + t.Fatal("expected error") + } +} + +// TestNoopDoMultipleKeys verifies that different keys each invoke fn independently. +func TestNoopDoMultipleKeys(t *testing.T) { + m := NewMemoizer(1) + calls := 0 + + fn := func() (any, error) { + calls++ + return calls, nil + } + + for i := 1; i <= 3; i++ { + result, err := m.Do("key"+strconv.Itoa(i), fn) + if err != nil { + t.Fatalf("unexpected error on key %d: %s", i, err.Error()) + } + if want, have := i, result.(int); want != have { + t.Fatalf("key%d: want %d, have %d", i, want, have) + } + } + if calls != 3 { + t.Fatalf("expected 3 fn calls, got %d", calls) + } +} + +// TestNoopErrorNotCached verifies that errors returned by fn are not cached: +// a subsequent call with the same key will invoke fn again. +func TestNoopErrorNotCached(t *testing.T) { + m := NewMemoizer(1) + calls := 0 + + fn := func() (any, error) { + calls++ + if calls == 1 { + return nil, errors.New("transient error") + } + return calls, nil + } + + // First call should return error. + _, err := m.Do("key1", fn) + if err == nil { + t.Fatal("expected error on first call") + } + + // Second call should succeed (no caching of error). + result, err := m.Do("key1", fn) + if err != nil { + t.Fatalf("unexpected error on second call: %s", err.Error()) + } + if want, have := 2, result.(int); want != have { + t.Fatalf("want %d, have %d", want, have) + } + + // Third call: fn invoked again (still no caching). + result, err = m.Do("key1", fn) + if err != nil { + t.Fatalf("unexpected error on third call: %s", err.Error()) + } + if want, have := 3, result.(int); want != have { + t.Fatalf("want %d, have %d", want, have) + } +} + +// TestNoopRelease verifies that Release is a no-op and does not panic or affect subsequent Do calls. +func TestNoopRelease(t *testing.T) { + m := NewMemoizer(1) + calls := 0 + + fn := func() (any, error) { + calls++ + return calls, nil + } + + result, err := m.Do("key1", fn) + if err != nil { + t.Fatalf("unexpected error before Release: %s", err.Error()) + } + if want, have := 1, result.(int); want != have { + t.Fatalf("before Release: want %d, have %d", want, have) + } + if calls != 1 { + t.Fatalf("expected 1 call before Release, got %d", calls) + } + + // Release should not panic and should be a no-op. + Release(1) + + // Subsequent calls should still work normally. + result, err = m.Do("key1", fn) + if err != nil { + t.Fatalf("unexpected error after Release: %s", err.Error()) + } + if want, have := 2, result.(int); want != have { + t.Fatalf("expected fn called again after Release, want %d, have %d", want, have) + } +} + +// TestNoopReset verifies that Reset is a no-op and does not panic or affect subsequent Do calls. +func TestNoopReset(t *testing.T) { + m := NewMemoizer(1) + calls := 0 + + fn := func() (any, error) { + calls++ + return calls, nil + } + + for _, key := range []string{"k1", "k2"} { + result, err := m.Do(key, fn) + if err != nil { + t.Fatalf("unexpected error for %s before Reset: %s", key, err.Error()) + } + if result == nil { + t.Fatalf("unexpected nil result for %s", key) + } + } + if calls != 2 { + t.Fatalf("expected 2 calls before Reset, got %d", calls) + } + + // Reset should not panic. + Reset() + + // Calls after Reset should continue working. + result, err := m.Do("k1", fn) + if err != nil { + t.Fatalf("unexpected error after Reset: %s", err.Error()) + } + if want, have := 3, result.(int); want != have { + t.Fatalf("expected fn called again after Reset, want %d, have %d", want, have) + } +} + +// TestNoopMultipleMemoizers verifies that multiple no-op memoizers are independent +// (no shared state between different owner IDs). +func TestNoopMultipleMemoizers(t *testing.T) { + m1 := NewMemoizer(1) + m2 := NewMemoizer(2) + calls := 0 + + fn := func() (any, error) { + calls++ + return calls, nil + } + + r1, _ := m1.Do("shared", fn) + r2, _ := m2.Do("shared", fn) + + if r1.(int) != 1 || r2.(int) != 2 { + t.Fatalf("expected independent calls: m1=%d, m2=%d", r1.(int), r2.(int)) + } + if calls != 2 { + t.Fatalf("expected 2 fn calls, got %d", calls) + } +} diff --git a/internal/memoize/nosync.go b/internal/memoize/nosync.go index d238244d9..45e1b1dee 100644 --- a/internal/memoize/nosync.go +++ b/internal/memoize/nosync.go @@ -1,36 +1,90 @@ // Copyright 2023 Juan Pablo Tosso and the OWASP Coraza contributors // SPDX-License-Identifier: Apache-2.0 -//go:build tinygo && memoize_builders +//go:build tinygo && !coraza.no_memoize package memoize import "sync" -var doer = makeDoer(new(sync.Map)) +type entry struct { + value any + mu sync.Mutex + owners map[uint64]struct{} + deleted bool +} + +var cache sync.Map // key -> *entry + +// Memoizer caches expensive function calls with per-owner tracking. +// TinyGo variant without singleflight. +type Memoizer struct { + ownerID uint64 +} -// Do executes and returns the results of the given function, unless there was a cached -// value of the same key. Only one execution is in-flight for a given key at a time. -// The boolean return value indicates whether v was previously stored. -func Do(key string, fn func() (interface{}, error)) (interface{}, error) { - value, err, _ := doer(key, fn) - return value, err +// NewMemoizer creates a Memoizer that tracks cached entries under the given owner ID. +func NewMemoizer(ownerID uint64) *Memoizer { + return &Memoizer{ownerID: ownerID} } -// makeDoer returns a function that executes and returns the results of the given function -func makeDoer(cache *sync.Map) func(string, func() (interface{}, error)) (interface{}, error, bool) { - return func(key string, fn func() (interface{}, error)) (interface{}, error, bool) { - // Check cache - value, found := cache.Load(key) - if found { - return value, nil, true +// Do returns a cached value for key, or calls fn and caches the result. +func (m *Memoizer) Do(key string, fn func() (any, error)) (any, error) { + if v, ok := cache.Load(key); ok { + e := v.(*entry) + e.mu.Lock() + if !e.deleted { + e.owners[m.ownerID] = struct{}{} + e.mu.Unlock() + return e.value, nil } + e.mu.Unlock() + } - data, err := fn() - if err == nil { - cache.Store(key, data) + data, err := fn() + if err == nil { + e := &entry{ + value: data, + owners: map[uint64]struct{}{m.ownerID: {}}, } + cache.Store(key, e) + } + return data, err +} + +// Release removes ownerID from all cached entries, deleting entries with no remaining owners. +// +// Deletions are deferred until after Range completes because TinyGo's sync.Map +// holds its internal lock for the entire Range call, so calling Delete inside +// the callback would deadlock. +func Release(ownerID uint64) { + var toDelete []any + cache.Range(func(key, value any) bool { + e := value.(*entry) + e.mu.Lock() + delete(e.owners, ownerID) + if len(e.owners) == 0 { + e.deleted = true + toDelete = append(toDelete, key) + } + e.mu.Unlock() + return true + }) + for _, key := range toDelete { + cache.Delete(key) + } +} - return data, err, false +// Reset clears the entire cache. Intended for testing. +// +// Keys are collected first and deleted after Range returns to avoid deadlocking +// on TinyGo's mutex-based sync.Map (see Release comment). +func Reset() { + var keys []any + cache.Range(func(key, _ any) bool { + keys = append(keys, key) + return true + }) + for _, key := range keys { + cache.Delete(key) } } diff --git a/internal/memoize/nosync_test.go b/internal/memoize/nosync_test.go index c942a522c..85170ab5d 100644 --- a/internal/memoize/nosync_test.go +++ b/internal/memoize/nosync_test.go @@ -1,167 +1,335 @@ // Copyright 2023 Juan Pablo Tosso and the OWASP Coraza contributors // SPDX-License-Identifier: Apache-2.0 -//go:build tinygo && memoize_builders - -// https://github.com/kofalt/go-memoize/blob/master/memoize.go +//go:build tinygo && !coraza.no_memoize package memoize import ( "errors" - "sync" + "fmt" + "regexp" "testing" ) func TestDo(t *testing.T) { + t.Cleanup(Reset) + + m := NewMemoizer(1) expensiveCalls := 0 - // Function tracks how many times its been called - expensive := func() (interface{}, error) { + expensive := func() (any, error) { expensiveCalls++ return expensiveCalls, nil } - // First call SHOULD NOT be cached - result, err := Do("key1", expensive) + result, err := m.Do("key1", expensive) if err != nil { t.Fatalf("unexpected error: %s", err.Error()) } - if want, have := 1, result.(int); want != have { t.Fatalf("unexpected value, want %d, have %d", want, have) } - // Second call on same key SHOULD be cached - result, err = Do("key1", expensive) + result, err = m.Do("key1", expensive) if err != nil { t.Fatalf("unexpected error: %s", err.Error()) } - if want, have := 1, result.(int); want != have { t.Fatalf("unexpected value, want %d, have %d", want, have) } - // First call on a new key SHOULD NOT be cached - result, err = Do("key2", expensive) + result, err = m.Do("key2", expensive) if err != nil { t.Fatalf("unexpected error: %s", err.Error()) } - if want, have := 2, result.(int); want != have { t.Fatalf("unexpected value, want %d, have %d", want, have) } } -func TestSuccessCall(t *testing.T) { - do := makeDoer(new(sync.Map)) +func TestFailedCall(t *testing.T) { + t.Cleanup(Reset) - expensiveCalls := 0 + m := NewMemoizer(1) + calls := 0 - // Function tracks how many times its been called - expensive := func() (interface{}, error) { - expensiveCalls++ - return expensiveCalls, nil + twoForTheMoney := func() (any, error) { + calls++ + if calls == 1 { + return calls, errors.New("Try again") + } + return calls, nil } - // First call SHOULD NOT be cached - result, err, cached := do("key1", expensive) - if err != nil { - t.Fatalf("unexpected error: %s", err.Error()) + result, err := m.Do("key1", twoForTheMoney) + if err == nil { + t.Fatalf("expected error") } - if want, have := 1, result.(int); want != have { t.Fatalf("unexpected value, want %d, have %d", want, have) } - if want, have := false, cached; want != have { - t.Fatalf("unexpected caching, want %t, have %t", want, have) + result, err = m.Do("key1", twoForTheMoney) + if err != nil { + t.Fatalf("unexpected error: %s", err.Error()) + } + if want, have := 2, result.(int); want != have { + t.Fatalf("unexpected value, want %d, have %d", want, have) } - // Second call on same key SHOULD be cached - result, err, cached = do("key1", expensive) + result, err = m.Do("key1", twoForTheMoney) if err != nil { t.Fatalf("unexpected error: %s", err.Error()) } - - if want, have := 1, result.(int); want != have { + if want, have := 2, result.(int); want != have { t.Fatalf("unexpected value, want %d, have %d", want, have) } +} + +func TestRelease(t *testing.T) { + t.Cleanup(Reset) - if want, have := true, cached; want != have { - t.Fatalf("unexpected caching, want %t, have %t", want, have) + m1 := NewMemoizer(1) + m2 := NewMemoizer(2) + + calls := 0 + fn := func() (any, error) { + calls++ + return calls, nil } - // First call on a new key SHOULD NOT be cached - result, err, cached = do("key2", expensive) - if err != nil { - t.Fatalf("unexpected error: %s", err.Error()) + _, _ = m1.Do("shared", fn) + _, _ = m2.Do("shared", fn) + _, _ = m1.Do("only-waf1", fn) + + Release(1) + + if _, ok := cache.Load("shared"); !ok { + t.Fatal("shared entry should still exist after releasing waf-1") + } + if _, ok := cache.Load("only-waf1"); ok { + t.Fatal("only-waf1 entry should be deleted after releasing its sole owner") } - if want, have := 2, result.(int); want != have { - t.Fatalf("unexpected value, want %d, have %d", want, have) + Release(2) + if _, ok := cache.Load("shared"); ok { + t.Fatal("shared entry should be deleted after releasing all owners") } +} - if want, have := false, cached; want != have { - t.Fatalf("unexpected caching, want %t, have %t", want, have) +func TestReset(t *testing.T) { + m := NewMemoizer(1) + _, _ = m.Do("k1", func() (any, error) { return 1, nil }) + _, _ = m.Do("k2", func() (any, error) { return 2, nil }) + + Reset() + + if _, ok := cache.Load("k1"); ok { + t.Fatal("cache should be empty after Reset") + } + if _, ok := cache.Load("k2"); ok { + t.Fatal("cache should be empty after Reset") } } -func TestFailedCall(t *testing.T) { - do := makeDoer(new(sync.Map)) +// cacheLen counts the number of entries in the global cache. +func cacheLen() int { + n := 0 + cache.Range(func(_, _ any) bool { + n++ + return true + }) + return n +} - calls := 0 +// crsLikePatterns generates n CRS-scale regex patterns. +func crsLikePatterns(n int) []string { + patterns := make([]string, n) + for i := range patterns { + patterns[i] = fmt.Sprintf(`(?i)pattern_%d_[a-z]{2,8}\d+`, i) + } + return patterns +} - // This function will fail IFF it has not been called before. - twoForTheMoney := func() (interface{}, error) { - calls++ +func TestMemoizeScaleMultipleOwners(t *testing.T) { + if testing.Short() { + t.Log("skipping scale test in short mode") + return + } + t.Cleanup(Reset) - if calls == 1 { - return calls, errors.New("Try again") - } else { - return calls, nil + const ( + numOwners = 10 + numPatterns = 300 + ) + + patterns := crsLikePatterns(numPatterns) + calls := 0 + fn := func(p string) func() (any, error) { + return func() (any, error) { + calls++ + return regexp.Compile(p) } } - // First call should fail, and not be cached - result, err, cached := do("key1", twoForTheMoney) - if err == nil { - t.Fatalf("expected error") + for i := uint64(1); i <= numOwners; i++ { + m := NewMemoizer(i) + for _, p := range patterns { + if _, err := m.Do(p, fn(p)); err != nil { + t.Fatal(err) + } + } } - if want, have := 1, result.(int); want != have { - t.Fatalf("unexpected value, want %d, have %d", want, have) + if calls != numPatterns { + t.Fatalf("expected %d compilations, got %d", numPatterns, calls) + } + if n := cacheLen(); n != numPatterns { + t.Fatalf("expected %d cache entries, got %d", numPatterns, n) } - if want, have := false, cached; want != have { - t.Fatalf("unexpected caching, want %t, have %t", want, have) + // Release all owners; cache should be empty. + for i := uint64(1); i <= numOwners; i++ { + Release(i) } + if n := cacheLen(); n != 0 { + t.Fatalf("expected empty cache after releasing all owners, got %d", n) + } +} - // Second call should succeed, and not be cached - result, err, cached = do("key1", twoForTheMoney) - if err != nil { - t.Fatalf("unexpected error: %s", err.Error()) +func TestCacheGrowthWithoutClose(t *testing.T) { + if testing.Short() { + t.Log("skipping scale test in short mode") + return } + t.Cleanup(Reset) - if want, have := 2, result.(int); want != have { - t.Fatalf("unexpected value, want %d, have %d", want, have) + const ( + numOwners = 100 + numPatterns = 300 + ) + + patterns := crsLikePatterns(numPatterns) + fn := func(p string) func() (any, error) { + return func() (any, error) { + return regexp.Compile(p) + } } - if want, have := false, cached; want != have { - t.Fatalf("unexpected caching, want %t, have %t", want, have) + for i := uint64(1); i <= numOwners; i++ { + m := NewMemoizer(i) + for _, p := range patterns { + if _, err := m.Do(p, fn(p)); err != nil { + t.Fatal(err) + } + } } - // Third call should succeed, and be cached - result, err, cached = do("key1", twoForTheMoney) - if err != nil { - t.Fatalf("unexpected error: %s", err.Error()) + // Every entry should have all owners. + cache.Range(func(_, value any) bool { + e := value.(*entry) + e.mu.Lock() + defer e.mu.Unlock() + if len(e.owners) != numOwners { + t.Fatalf("expected %d owners per entry, got %d", numOwners, len(e.owners)) + } + return true + }) +} + +func TestCacheBoundedWithClose(t *testing.T) { + if testing.Short() { + t.Log("skipping scale test in short mode") + return } + t.Cleanup(Reset) - if want, have := 2, result.(int); want != have { - t.Fatalf("unexpected value, want %d, have %d", want, have) + const ( + numCycles = 100 + numPatterns = 300 + ) + + patterns := crsLikePatterns(numPatterns) + fn := func(p string) func() (any, error) { + return func() (any, error) { + return regexp.Compile(p) + } + } + + for i := uint64(1); i <= numCycles; i++ { + m := NewMemoizer(i) + for _, p := range patterns { + if _, err := m.Do(p, fn(p)); err != nil { + t.Fatal(err) + } + } + Release(i) + } + + if n := cacheLen(); n != 0 { + t.Fatalf("expected empty cache after all releases, got %d", n) } +} + +func BenchmarkCompileWithoutMemoize(b *testing.B) { + patterns := crsLikePatterns(300) + for _, numWAFs := range []int{1, 10, 100} { + b.Run(fmt.Sprintf("WAFs=%d", numWAFs), func(b *testing.B) { + for i := 0; i < b.N; i++ { + for w := 0; w < numWAFs; w++ { + for _, p := range patterns { + if _, err := regexp.Compile(p); err != nil { + b.Fatal(err) + } + } + } + } + }) + } +} + +func BenchmarkCompileWithMemoize(b *testing.B) { + patterns := crsLikePatterns(300) + for _, numWAFs := range []int{1, 10, 100} { + b.Run(fmt.Sprintf("WAFs=%d", numWAFs), func(b *testing.B) { + for i := 0; i < b.N; i++ { + Reset() + for w := 0; w < numWAFs; w++ { + m := NewMemoizer(uint64(w + 1)) + for _, p := range patterns { + if _, err := m.Do(p, func() (any, error) { + return regexp.Compile(p) + }); err != nil { + b.Fatal(err) + } + } + } + } + }) + } +} - if want, have := true, cached; want != have { - t.Fatalf("unexpected caching, want %t, have %t", want, have) +func BenchmarkRelease(b *testing.B) { + patterns := crsLikePatterns(300) + for _, numOwners := range []int{1, 10, 100} { + b.Run(fmt.Sprintf("Owners=%d", numOwners), func(b *testing.B) { + for i := 0; i < b.N; i++ { + b.StopTimer() + Reset() + for o := 0; o < numOwners; o++ { + m := NewMemoizer(uint64(o + 1)) + for _, p := range patterns { + m.Do(p, func() (any, error) { + return regexp.Compile(p) + }) + } + } + b.StartTimer() + for o := 0; o < numOwners; o++ { + Release(uint64(o + 1)) + } + } + }) } } diff --git a/internal/memoize/sync.go b/internal/memoize/sync.go index a8d5c9610..c8ebd4903 100644 --- a/internal/memoize/sync.go +++ b/internal/memoize/sync.go @@ -1,9 +1,7 @@ // Copyright 2023 Juan Pablo Tosso and the OWASP Coraza contributors // SPDX-License-Identifier: Apache-2.0 -//go:build !tinygo && memoize_builders - -// https://github.com/kofalt/go-memoize/blob/master/memoize.go +//go:build !tinygo && !coraza.no_memoize package memoize @@ -13,35 +11,104 @@ import ( "golang.org/x/sync/singleflight" ) -var doer = makeDoer(new(sync.Map), new(singleflight.Group)) +type entry struct { + value any + mu sync.Mutex + owners map[uint64]struct{} + deleted bool +} + +var ( + cache sync.Map // key -> *entry + group singleflight.Group +) + +// Memoizer caches expensive function calls with per-owner tracking. +type Memoizer struct { + ownerID uint64 +} -// Do executes and returns the results of the given function, unless there was a cached -// value of the same key. Only one execution is in-flight for a given key at a time. -// The boolean return value indicates whether v was previously stored. -func Do(key string, fn func() (interface{}, error)) (interface{}, error) { - value, err, _ := doer(key, fn) - return value, err +// NewMemoizer creates a Memoizer that tracks cached entries under the given owner ID. +func NewMemoizer(ownerID uint64) *Memoizer { + return &Memoizer{ownerID: ownerID} } -// makeDoer returns a function that executes and returns the results of the given function -func makeDoer(cache *sync.Map, group *singleflight.Group) func(string, func() (interface{}, error)) (interface{}, error, bool) { - return func(key string, fn func() (interface{}, error)) (interface{}, error, bool) { - // Check cache - value, found := cache.Load(key) - if found { - return value, nil, true +// addOwner attempts to register the ownerID on the entry. +// Returns false if the entry has been marked as deleted. +func (m *Memoizer) addOwner(e *entry) bool { + e.mu.Lock() + defer e.mu.Unlock() + if e.deleted { + return false + } + e.owners[m.ownerID] = struct{}{} + return true +} + +// Do returns a cached value for key, or calls fn and caches the result. +// Only one execution is in-flight for a given key at a time. +func (m *Memoizer) Do(key string, fn func() (any, error)) (any, error) { + // Fast path: check cache + if v, ok := cache.Load(key); ok { + e := v.(*entry) + if m.addOwner(e) { + return e.value, nil } + // Entry was deleted concurrently; fall through to slow path. + } - // Combine memoized function with a cache store - value, err, _ := group.Do(key, func() (interface{}, error) { - data, innerErr := fn() - if innerErr == nil { - cache.Store(key, data) + // Slow path: singleflight ensures only one compilation per key + val, err, _ := group.Do(key, func() (any, error) { + // Double-check after acquiring singleflight + if v, ok := cache.Load(key); ok { + e := v.(*entry) + if m.addOwner(e) { + return e.value, nil } + } - return data, innerErr - }) + data, innerErr := fn() + if innerErr == nil { + e := &entry{ + value: data, + owners: map[uint64]struct{}{m.ownerID: {}}, + } + cache.Store(key, e) + } + return data, innerErr + }) - return value, err, false + // Ensure this caller is registered as an owner even if its execution + // was deduplicated by singleflight. + if err == nil { + if v, ok := cache.Load(key); ok { + e := v.(*entry) + m.addOwner(e) + } } + + return val, err +} + +// Release removes ownerID from all cached entries, deleting entries with no remaining owners. +func Release(ownerID uint64) { + cache.Range(func(key, value any) bool { + e := value.(*entry) + e.mu.Lock() + delete(e.owners, ownerID) + if len(e.owners) == 0 { + e.deleted = true + cache.Delete(key) + } + e.mu.Unlock() + return true + }) +} + +// Reset clears the entire cache. Intended for testing. +func Reset() { + cache.Range(func(key, _ any) bool { + cache.Delete(key) + return true + }) } diff --git a/internal/memoize/sync_test.go b/internal/memoize/sync_test.go index d995d1d41..b1a27bc33 100644 --- a/internal/memoize/sync_test.go +++ b/internal/memoize/sync_test.go @@ -1,169 +1,339 @@ // Copyright 2023 Juan Pablo Tosso and the OWASP Coraza contributors // SPDX-License-Identifier: Apache-2.0 -//go:build !tinygo && memoize_builders - -// https://github.com/kofalt/go-memoize/blob/master/memoize.go +//go:build !tinygo && !coraza.no_memoize package memoize import ( "errors" - "sync" + "fmt" + "os" + "regexp" "testing" - - "golang.org/x/sync/singleflight" ) func TestDo(t *testing.T) { + t.Cleanup(Reset) + + m := NewMemoizer(1) expensiveCalls := 0 - // Function tracks how many times its been called - expensive := func() (interface{}, error) { + expensive := func() (any, error) { expensiveCalls++ return expensiveCalls, nil } - // First call SHOULD NOT be cached - result, err := Do("key1", expensive) + result, err := m.Do("key1", expensive) if err != nil { t.Fatalf("unexpected error: %s", err.Error()) } - if want, have := 1, result.(int); want != have { t.Fatalf("unexpected value, want %d, have %d", want, have) } - // Second call on same key SHOULD be cached - result, err = Do("key1", expensive) + result, err = m.Do("key1", expensive) if err != nil { t.Fatalf("unexpected error: %s", err.Error()) } - if want, have := 1, result.(int); want != have { t.Fatalf("unexpected value, want %d, have %d", want, have) } - // First call on a new key SHOULD NOT be cached - result, err = Do("key2", expensive) + result, err = m.Do("key2", expensive) if err != nil { t.Fatalf("unexpected error: %s", err.Error()) } - if want, have := 2, result.(int); want != have { t.Fatalf("unexpected value, want %d, have %d", want, have) } } -func TestSuccessCall(t *testing.T) { - do := makeDoer(new(sync.Map), &singleflight.Group{}) +func TestFailedCall(t *testing.T) { + t.Cleanup(Reset) - expensiveCalls := 0 + m := NewMemoizer(1) + calls := 0 - // Function tracks how many times its been called - expensive := func() (interface{}, error) { - expensiveCalls++ - return expensiveCalls, nil + twoForTheMoney := func() (any, error) { + calls++ + if calls == 1 { + return calls, errors.New("Try again") + } + return calls, nil } - // First call SHOULD NOT be cached - result, err, cached := do("key1", expensive) - if err != nil { - t.Fatalf("unexpected error: %s", err.Error()) + result, err := m.Do("key1", twoForTheMoney) + if err == nil { + t.Fatalf("expected error") } - if want, have := 1, result.(int); want != have { t.Fatalf("unexpected value, want %d, have %d", want, have) } - if want, have := false, cached; want != have { - t.Fatalf("unexpected caching, want %t, have %t", want, have) + result, err = m.Do("key1", twoForTheMoney) + if err != nil { + t.Fatalf("unexpected error: %s", err.Error()) + } + if want, have := 2, result.(int); want != have { + t.Fatalf("unexpected value, want %d, have %d", want, have) } - // Second call on same key SHOULD be cached - result, err, cached = do("key1", expensive) + result, err = m.Do("key1", twoForTheMoney) if err != nil { t.Fatalf("unexpected error: %s", err.Error()) } - - if want, have := 1, result.(int); want != have { + if want, have := 2, result.(int); want != have { t.Fatalf("unexpected value, want %d, have %d", want, have) } +} - if want, have := true, cached; want != have { - t.Fatalf("unexpected caching, want %t, have %t", want, have) +func TestRelease(t *testing.T) { + t.Cleanup(Reset) + + m1 := NewMemoizer(1) + m2 := NewMemoizer(2) + + calls := 0 + fn := func() (any, error) { + calls++ + return calls, nil } - // First call on a new key SHOULD NOT be cached - result, err, cached = do("key2", expensive) - if err != nil { - t.Fatalf("unexpected error: %s", err.Error()) + _, _ = m1.Do("shared", fn) + _, _ = m2.Do("shared", fn) + _, _ = m1.Do("only-waf1", fn) + + Release(1) + + if _, ok := cache.Load("shared"); !ok { + t.Fatal("shared entry should still exist after releasing waf-1") + } + if _, ok := cache.Load("only-waf1"); ok { + t.Fatal("only-waf1 entry should be deleted after releasing its sole owner") } - if want, have := 2, result.(int); want != have { - t.Fatalf("unexpected value, want %d, have %d", want, have) + Release(2) + if _, ok := cache.Load("shared"); ok { + t.Fatal("shared entry should be deleted after releasing all owners") } +} - if want, have := false, cached; want != have { - t.Fatalf("unexpected caching, want %t, have %t", want, have) +func TestReset(t *testing.T) { + m := NewMemoizer(1) + _, _ = m.Do("k1", func() (any, error) { return 1, nil }) + _, _ = m.Do("k2", func() (any, error) { return 2, nil }) + + Reset() + + if _, ok := cache.Load("k1"); ok { + t.Fatal("cache should be empty after Reset") + } + if _, ok := cache.Load("k2"); ok { + t.Fatal("cache should be empty after Reset") } } -func TestFailedCall(t *testing.T) { - do := makeDoer(new(sync.Map), &singleflight.Group{}) +// cacheLen counts the number of entries in the global cache. +func cacheLen() int { + n := 0 + cache.Range(func(_, _ any) bool { + n++ + return true + }) + return n +} - calls := 0 +// crsLikePatterns generates n CRS-scale regex patterns. +func crsLikePatterns(n int) []string { + patterns := make([]string, n) + for i := range patterns { + patterns[i] = fmt.Sprintf(`(?i)pattern_%d_[a-z]{2,8}\d+`, i) + } + return patterns +} - // This function will fail IFF it has not been called before. - twoForTheMoney := func() (interface{}, error) { - calls++ +func TestMemoizeScaleMultipleOwners(t *testing.T) { + if testing.Short() { + t.Skip("skipping scale test in short mode") + } + t.Cleanup(Reset) - if calls == 1 { - return calls, errors.New("Try again") - } else { - return calls, nil + const ( + numOwners = 10 + numPatterns = 300 + ) + + patterns := crsLikePatterns(numPatterns) + calls := 0 + fn := func(p string) func() (any, error) { + return func() (any, error) { + calls++ + return regexp.Compile(p) } } - // First call should fail, and not be cached - result, err, cached := do("key1", twoForTheMoney) - if err == nil { - t.Fatalf("expected error") + for i := uint64(1); i <= numOwners; i++ { + m := NewMemoizer(i) + for _, p := range patterns { + if _, err := m.Do(p, fn(p)); err != nil { + t.Fatal(err) + } + } } - if want, have := 1, result.(int); want != have { - t.Fatalf("unexpected value, want %d, have %d", want, have) + if calls != numPatterns { + t.Fatalf("expected %d compilations, got %d", numPatterns, calls) + } + if n := cacheLen(); n != numPatterns { + t.Fatalf("expected %d cache entries, got %d", numPatterns, n) } - if want, have := false, cached; want != have { - t.Fatalf("unexpected caching, want %t, have %t", want, have) + // Release all owners; cache should be empty. + for i := uint64(1); i <= numOwners; i++ { + Release(i) } + if n := cacheLen(); n != 0 { + t.Fatalf("expected empty cache after releasing all owners, got %d", n) + } +} - // Second call should succeed, and not be cached - result, err, cached = do("key1", twoForTheMoney) - if err != nil { - t.Fatalf("unexpected error: %s", err.Error()) +func TestCacheGrowthWithoutClose(t *testing.T) { + if testing.Short() { + t.Skip("skipping scale test in short mode") } + t.Cleanup(Reset) - if want, have := 2, result.(int); want != have { - t.Fatalf("unexpected value, want %d, have %d", want, have) + const ( + numOwners = 100 + numPatterns = 300 + ) + + patterns := crsLikePatterns(numPatterns) + fn := func(p string) func() (any, error) { + return func() (any, error) { + return regexp.Compile(p) + } } - if want, have := false, cached; want != have { - t.Fatalf("unexpected caching, want %t, have %t", want, have) + for i := uint64(1); i <= numOwners; i++ { + m := NewMemoizer(i) + for _, p := range patterns { + if _, err := m.Do(p, fn(p)); err != nil { + t.Fatal(err) + } + } } - // Third call should succeed, and be cached - result, err, cached = do("key1", twoForTheMoney) - if err != nil { - t.Fatalf("unexpected error: %s", err.Error()) + // Every entry should have all owners. + cache.Range(func(_, value any) bool { + e := value.(*entry) + e.mu.Lock() + defer e.mu.Unlock() + if len(e.owners) != numOwners { + t.Fatalf("expected %d owners per entry, got %d", numOwners, len(e.owners)) + } + return true + }) +} + +func TestCacheBoundedWithClose(t *testing.T) { + if testing.Short() { + t.Skip("skipping scale test in short mode") } + if os.Getenv("CORAZA_MEMOIZE_SCALE") != "1" { + t.Skip("skipping scale test; set CORAZA_MEMOIZE_SCALE=1 to enable") + } + t.Cleanup(Reset) - if want, have := 2, result.(int); want != have { - t.Fatalf("unexpected value, want %d, have %d", want, have) + const ( + numCycles = 100 + numPatterns = 300 + ) + + patterns := crsLikePatterns(numPatterns) + fn := func(p string) func() (any, error) { + return func() (any, error) { + return regexp.Compile(p) + } } - if want, have := true, cached; want != have { - t.Fatalf("unexpected caching, want %t, have %t", want, have) + for i := uint64(1); i <= numCycles; i++ { + m := NewMemoizer(i) + for _, p := range patterns { + if _, err := m.Do(p, fn(p)); err != nil { + t.Fatal(err) + } + } + Release(i) + } + + // After releasing each owner immediately, only the last cycle's + // entries remain (the last owner hasn't been released yet — but we + // DID release it in the loop). Cache should be empty. + if n := cacheLen(); n != 0 { + t.Fatalf("expected empty cache after all releases, got %d", n) + } +} + +func BenchmarkCompileWithoutMemoize(b *testing.B) { + patterns := crsLikePatterns(300) + for _, numWAFs := range []int{1, 10, 100} { + b.Run(fmt.Sprintf("WAFs=%d", numWAFs), func(b *testing.B) { + for i := 0; i < b.N; i++ { + for w := 0; w < numWAFs; w++ { + for _, p := range patterns { + if _, err := regexp.Compile(p); err != nil { + b.Fatal(err) + } + } + } + } + }) + } +} + +func BenchmarkCompileWithMemoize(b *testing.B) { + patterns := crsLikePatterns(300) + for _, numWAFs := range []int{1, 10, 100} { + b.Run(fmt.Sprintf("WAFs=%d", numWAFs), func(b *testing.B) { + for i := 0; i < b.N; i++ { + Reset() + for w := 0; w < numWAFs; w++ { + m := NewMemoizer(uint64(w + 1)) + for _, p := range patterns { + if _, err := m.Do(p, func() (any, error) { + return regexp.Compile(p) + }); err != nil { + b.Fatal(err) + } + } + } + } + }) + } +} + +func BenchmarkRelease(b *testing.B) { + patterns := crsLikePatterns(300) + for _, numOwners := range []int{1, 10, 100} { + b.Run(fmt.Sprintf("Owners=%d", numOwners), func(b *testing.B) { + for i := 0; i < b.N; i++ { + b.StopTimer() + Reset() + for o := 0; o < numOwners; o++ { + m := NewMemoizer(uint64(o + 1)) + for _, p := range patterns { + _, _ = m.Do(p, func() (any, error) { + return regexp.Compile(p) + }) + } + } + b.StartTimer() + for o := 0; o < numOwners; o++ { + Release(uint64(o + 1)) + } + } + }) } } diff --git a/internal/operators/operators.go b/internal/operators/operators.go index 3ddd87989..28658fd45 100644 --- a/internal/operators/operators.go +++ b/internal/operators/operators.go @@ -29,6 +29,13 @@ import ( var operators = map[string]plugintypes.OperatorFactory{} +func memoizeDo(m plugintypes.Memoizer, key string, fn func() (any, error)) (any, error) { + if m != nil { + return m.Do(key, fn) + } + return fn() +} + // Get returns an operator by name func Get(name string, options plugintypes.OperatorOptions) (plugintypes.Operator, error) { if op, ok := operators[name]; ok { diff --git a/internal/operators/pm.go b/internal/operators/pm.go index b66da3be5..35e5b3244 100644 --- a/internal/operators/pm.go +++ b/internal/operators/pm.go @@ -11,7 +11,6 @@ import ( ahocorasick "github.com/petar-dambovaliev/aho-corasick" "github.com/corazawaf/coraza/v3/experimental/plugins/plugintypes" - "github.com/corazawaf/coraza/v3/internal/memoize" ) // Description: @@ -51,7 +50,7 @@ func newPM(options plugintypes.OperatorOptions) (plugintypes.Operator, error) { DFA: true, }) - m, _ := memoize.Do(data, func() (any, error) { return builder.Build(dict), nil }) + m, _ := memoizeDo(options.Memoizer, data, func() (any, error) { return builder.Build(dict), nil }) // TODO this operator is supposed to support snort data syntax: "@pm A|42|C|44|F" return &pm{matcher: m.(ahocorasick.AhoCorasick)}, nil } diff --git a/internal/operators/pm_from_dataset.go b/internal/operators/pm_from_dataset.go index e5118a9f0..1c3def538 100644 --- a/internal/operators/pm_from_dataset.go +++ b/internal/operators/pm_from_dataset.go @@ -11,7 +11,6 @@ import ( ahocorasick "github.com/petar-dambovaliev/aho-corasick" "github.com/corazawaf/coraza/v3/experimental/plugins/plugintypes" - "github.com/corazawaf/coraza/v3/internal/memoize" ) // Description: @@ -46,7 +45,7 @@ func newPMFromDataset(options plugintypes.OperatorOptions) (plugintypes.Operator DFA: true, }) - m, _ := memoize.Do(data, func() (any, error) { return builder.Build(dataset), nil }) + m, _ := memoizeDo(options.Memoizer, data, func() (any, error) { return builder.Build(dataset), nil }) return &pm{matcher: m.(ahocorasick.AhoCorasick)}, nil } diff --git a/internal/operators/pm_from_file.go b/internal/operators/pm_from_file.go index 4d4ae4089..17a32030c 100644 --- a/internal/operators/pm_from_file.go +++ b/internal/operators/pm_from_file.go @@ -13,7 +13,6 @@ import ( ahocorasick "github.com/petar-dambovaliev/aho-corasick" "github.com/corazawaf/coraza/v3/experimental/plugins/plugintypes" - "github.com/corazawaf/coraza/v3/internal/memoize" ) // Description: @@ -65,7 +64,7 @@ func newPMFromFile(options plugintypes.OperatorOptions) (plugintypes.Operator, e DFA: false, }) - m, _ := memoize.Do(strings.Join(options.Path, ",")+filepath, func() (any, error) { return builder.Build(lines), nil }) + m, _ := memoizeDo(options.Memoizer, strings.Join(options.Path, ",")+filepath, func() (any, error) { return builder.Build(lines), nil }) return &pm{matcher: m.(ahocorasick.AhoCorasick)}, nil } diff --git a/internal/operators/restpath.go b/internal/operators/restpath.go index 2af865e12..d94298d36 100644 --- a/internal/operators/restpath.go +++ b/internal/operators/restpath.go @@ -11,7 +11,6 @@ import ( "strings" "github.com/corazawaf/coraza/v3/experimental/plugins/plugintypes" - "github.com/corazawaf/coraza/v3/internal/memoize" ) var rePathTokenRe = regexp.MustCompile(`\{([^\}]+)\}`) @@ -48,7 +47,7 @@ func newRESTPath(options plugintypes.OperatorOptions) (plugintypes.Operator, err data = strings.Replace(data, token[0], fmt.Sprintf("(?P<%s>[^?/]+)", token[1]), 1) } - re, err := memoize.Do(data, func() (any, error) { return regexp.Compile(data) }) + re, err := memoizeDo(options.Memoizer, data, func() (any, error) { return regexp.Compile(data) }) if err != nil { return nil, err } diff --git a/internal/operators/rx.go b/internal/operators/rx.go index 207b26e51..f85c1f812 100644 --- a/internal/operators/rx.go +++ b/internal/operators/rx.go @@ -14,7 +14,6 @@ import ( "rsc.io/binaryregexp" "github.com/corazawaf/coraza/v3/experimental/plugins/plugintypes" - "github.com/corazawaf/coraza/v3/internal/memoize" ) // Description: @@ -68,7 +67,7 @@ func newRX(options plugintypes.OperatorOptions) (plugintypes.Operator, error) { return newBinaryRX(options) } - re, err := memoize.Do(data, func() (any, error) { return regexp.Compile(data) }) + re, err := memoizeDo(options.Memoizer, data, func() (any, error) { return regexp.Compile(data) }) if err != nil { return nil, err } @@ -77,15 +76,27 @@ func newRX(options plugintypes.OperatorOptions) (plugintypes.Operator, error) { func (o *rx) Evaluate(tx plugintypes.TransactionState, value string) bool { if tx.Capturing() { - match := o.re.FindStringSubmatch(value) - if len(match) == 0 { + // FindStringSubmatchIndex returns a slice of index pairs [start0, end0, start1, end1, ...] + // instead of allocating new strings for each capture group. We then slice the original + // input value[start:end] to get zero-allocation substrings. + match := o.re.FindStringSubmatchIndex(value) + if match == nil { return false } - for i, c := range match { + // match has 2 entries per group: match[2*i] is the start index, + // match[2*i+1] is the end index for capture group i. Group 0 is + // the full match, groups 1..N are the parenthesized sub-expressions. + for i := 0; i < len(match)/2; i++ { if i == 9 { return true } - tx.CaptureField(i, c) + // A negative start index means the group did not participate in the match + // (e.g. an optional group like (foo)? when foo is absent). + if match[2*i] >= 0 { + tx.CaptureField(i, value[match[2*i]:match[2*i+1]]) + } else { + tx.CaptureField(i, "") + } } return true } else { @@ -104,7 +115,7 @@ var _ plugintypes.Operator = (*binaryRX)(nil) func newBinaryRX(options plugintypes.OperatorOptions) (plugintypes.Operator, error) { data := options.Arguments - re, err := memoize.Do(data, func() (any, error) { return binaryregexp.Compile(data) }) + re, err := memoizeDo(options.Memoizer, data, func() (any, error) { return binaryregexp.Compile(data) }) if err != nil { return nil, err } diff --git a/internal/operators/rx_test.go b/internal/operators/rx_test.go index ccbb0649d..c4a9d480d 100644 --- a/internal/operators/rx_test.go +++ b/internal/operators/rx_test.go @@ -108,6 +108,32 @@ func TestRx(t *testing.T) { } } +func BenchmarkRxCapture(b *testing.B) { + pattern := `(?sm)^/api/v(\d+)/users/(\w+)/posts/(\d+)` + input := "/api/v3/users/jptosso/posts/42" + + re := regexp.MustCompile(pattern) + + b.Run("FindStringSubmatch", func(b *testing.B) { + for b.Loop() { + match := re.FindStringSubmatch(input) + if len(match) == 0 { + b.Fatal("expected match") + } + _ = match[1] + } + }) + b.Run("FindStringSubmatchIndex", func(b *testing.B) { + for b.Loop() { + match := re.FindStringSubmatchIndex(input) + if match == nil { + b.Fatal("expected match") + } + _ = input[match[2]:match[3]] + } + }) +} + func BenchmarkRxSubstringVsMatch(b *testing.B) { str := "hello world; heelloo Woorld; hello; heeeelloooo wooooooorld;hello world; heelloo Woorld; hello; heeeelloooo wooooooorld;hello world; heelloo Woorld; hello; heeeelloooo wooooooorld;hello world; heelloo Woorld; hello; heeeelloooo wooooooorld;hello world; heelloo Woorld; hello; heeeelloooo wooooooorld;hello world; heelloo Woorld; hello; heeeelloooo wooooooorld;hello world; heelloo Woorld; hello; heeeelloooo wooooooorld;" rx := regexp.MustCompile(`((h.*e.*l.*l.*o.*)|\d+)`) diff --git a/internal/operators/validate_nid.go b/internal/operators/validate_nid.go index af2a30cb2..fab78ab49 100644 --- a/internal/operators/validate_nid.go +++ b/internal/operators/validate_nid.go @@ -12,7 +12,6 @@ import ( "strings" "github.com/corazawaf/coraza/v3/experimental/plugins/plugintypes" - "github.com/corazawaf/coraza/v3/internal/memoize" ) type validateNidFunction = func(input string) bool @@ -61,7 +60,7 @@ func newValidateNID(options plugintypes.OperatorOptions) (plugintypes.Operator, return nil, fmt.Errorf("invalid @validateNid argument") } - re, err := memoize.Do(expr, func() (any, error) { return regexp.Compile(expr) }) + re, err := memoizeDo(options.Memoizer, expr, func() (any, error) { return regexp.Compile(expr) }) if err != nil { return nil, err } diff --git a/internal/operators/validate_schema.go b/internal/operators/validate_schema.go index 26d43499a..b09eb7f28 100644 --- a/internal/operators/validate_schema.go +++ b/internal/operators/validate_schema.go @@ -18,7 +18,6 @@ import ( "github.com/kaptinlin/jsonschema" "github.com/corazawaf/coraza/v3/experimental/plugins/plugintypes" - "github.com/corazawaf/coraza/v3/internal/memoize" "github.com/corazawaf/coraza/v3/types" ) @@ -79,7 +78,7 @@ func NewValidateSchema(options plugintypes.OperatorOptions) (plugintypes.Operato } key := md5Hash(schemaData) - schema, err := memoize.Do(key, func() (any, error) { + schema, err := memoizeDo(options.Memoizer, key, func() (any, error) { // Preliminarily validate that the schema is valid JSON var jsonSchema any if err := json.Unmarshal(schemaData, &jsonSchema); err != nil { diff --git a/internal/seclang/directives.go b/internal/seclang/directives.go index 9ee5127e1..07dee1494 100644 --- a/internal/seclang/directives.go +++ b/internal/seclang/directives.go @@ -14,10 +14,10 @@ import ( "strings" "github.com/corazawaf/coraza/v3/debuglog" + "github.com/corazawaf/coraza/v3/experimental/plugins/plugintypes" "github.com/corazawaf/coraza/v3/internal/auditlog" "github.com/corazawaf/coraza/v3/internal/corazawaf" "github.com/corazawaf/coraza/v3/internal/environment" - "github.com/corazawaf/coraza/v3/internal/memoize" utils "github.com/corazawaf/coraza/v3/internal/strings" "github.com/corazawaf/coraza/v3/types" ) @@ -282,6 +282,29 @@ func directiveSecRequestBodyAccess(options *DirectiveOptions) error { return nil } +// Description: Configures the maximum JSON recursion depth limit Coraza will accept. +// Default: 1024 +// Syntax: SecRequestBodyJsonDepthLimit [LIMIT] +// --- +// Anything over the limit will generate a REQBODY_ERROR in the JSON body processor. +func directiveSecRequestBodyJsonDepthLimit(options *DirectiveOptions) error { + if len(options.Opts) == 0 { + return errEmptyOptions + } + + limit, err := strconv.Atoi(options.Opts) + if err != nil { + return err + } + + if limit <= 0 { + return errors.New("limit must be a positive integer") + } + + options.WAF.RequestBodyJsonDepthLimit = limit + return nil +} + // Description: Configures the rules engine. // Syntax: SecRuleEngine On|Off|DetectionOnly // Default: Off @@ -813,7 +836,7 @@ func directiveSecAuditLogRelevantStatus(options *DirectiveOptions) error { return errEmptyOptions } - re, err := memoize.Do(options.Opts, func() (any, error) { return regexp.Compile(options.Opts) }) + re, err := options.WAF.Memoizer().Do(options.Opts, func() (any, error) { return regexp.Compile(options.Opts) }) if err != nil { return err } @@ -912,12 +935,34 @@ func directiveSecDataDir(options *DirectiveOptions) error { return nil } +// Description: Configures whether intercepted files will be kept after the transaction is processed. +// Syntax: SecUploadKeepFiles On|RelevantOnly|Off +// Default: Off +// --- +// The `SecUploadKeepFiles` directive is used to configure whether intercepted files are +// preserved on disk after the transaction is processed. +// This directive requires the storage directory to be defined (using `SecUploadDir`). +// +// Possible values are: +// - On: Keep all uploaded files. +// - Off: Do not keep uploaded files. +// - RelevantOnly: Keep only uploaded files that matched at least one rule that would be +// logged (excluding rules with the `nolog` action). func directiveSecUploadKeepFiles(options *DirectiveOptions) error { - b, err := parseBoolean(options.Opts) + if len(options.Opts) == 0 { + return errEmptyOptions + } + + status, err := types.ParseUploadKeepFilesStatus(options.Opts) if err != nil { return err } - options.WAF.UploadKeepFiles = b + + if !environment.HasAccessToFS && status != types.UploadKeepFilesOff { + return fmt.Errorf("SecUploadKeepFiles: cannot enable keeping uploaded files: filesystem access is disabled") + } + + options.WAF.UploadKeepFiles = status return nil } @@ -944,6 +989,11 @@ func directiveSecUploadFileLimit(options *DirectiveOptions) error { return err } +// Description: Configures the directory where uploaded files will be stored. +// Syntax: SecUploadDir /path/to/dir +// Default: "" +// --- +// This directive is required when enabling SecUploadKeepFiles. func directiveSecUploadDir(options *DirectiveOptions) error { if len(options.Opts) == 0 { return errEmptyOptions @@ -1102,6 +1152,17 @@ func updateTargetBySingleID(id int, variables string, options *DirectiveOptions) return rp.ParseVariables(strings.Trim(variables, "\"")) } +// hasDisruptiveActions checks if any of the parsed actions are disruptive. +// Returns true if at least one action has ActionTypeDisruptive, false otherwise. +func hasDisruptiveActions(actions []ruleAction) bool { + for _, action := range actions { + if action.Atype == plugintypes.ActionTypeDisruptive { + return true + } + } + return false +} + // Description: Updates the action list of the specified rule(s). // Syntax: SecRuleUpdateActionById ID ACTIONLIST // --- @@ -1113,6 +1174,7 @@ func updateTargetBySingleID(id int, variables string, options *DirectiveOptions) // ```apache // SecRuleUpdateActionById 12345 "deny,status:403" // ``` +// The rule ID can be single IDs or ranges of IDs. The targets are separated by a pipe character. func directiveSecRuleUpdateActionByID(options *DirectiveOptions) error { if len(options.Opts) == 0 { return errEmptyOptions @@ -1152,18 +1214,37 @@ func directiveSecRuleUpdateActionByID(options *DirectiveOptions) error { return fmt.Errorf("invalid range: %s", idOrRange) } - for _, rule := range options.WAF.Rules.GetRules() { - if rule.ID_ < start && rule.ID_ > end { + // Parse actions once to check if any are disruptive. + // Trim surrounding quotes because the SecLang syntax uses quoted action lists + // (e.g., SecRuleUpdateActionById 1004 "pass") and strings.Fields preserves them. + trimmedActions := strings.Trim(actions, "\"") + parsedActions, err := parseActions(options.WAF.Logger, trimmedActions) + if err != nil { + return err + } + + // Check if any of the new actions are disruptive + hasDisruptiveAction := hasDisruptiveActions(parsedActions) + + rules := options.WAF.Rules.GetRules() + for i := range rules { + if rules[i].ID_ < start || rules[i].ID_ > end { continue } + + // Only clear disruptive actions if the update contains a disruptive action + if hasDisruptiveAction { + rules[i].ClearDisruptiveActions() + } + rp := RuleParser{ - rule: &rule, + rule: &rules[i], options: RuleOptions{ WAF: options.WAF, }, defaultActions: map[types.RulePhase][]ruleAction{}, } - if err := rp.ParseActions(strings.Trim(actions, "\"")); err != nil { + if err := rp.applyParsedActions(parsedActions); err != nil { return err } } @@ -1173,11 +1254,32 @@ func directiveSecRuleUpdateActionByID(options *DirectiveOptions) error { } func updateActionBySingleID(id int, actions string, options *DirectiveOptions) error { - rule := options.WAF.Rules.FindByID(id) if rule == nil { return fmt.Errorf("SecRuleUpdateActionById: rule \"%d\" not found", id) } + + // Parse actions first to check if any are disruptive. + // Trim surrounding quotes from the SecLang action list syntax. + trimmedActions := strings.Trim(actions, "\"") + parsedActions, err := parseActions(options.WAF.Logger, trimmedActions) + if err != nil { + return err + } + + // Check if any of the new actions are disruptive. + // hasDisruptiveActions returns false when parsedActions is empty or contains + // only non-disruptive actions, preserving existing disruptive actions on the rule. + hasDisruptiveAction := hasDisruptiveActions(parsedActions) + + // Only clear disruptive actions if the update contains a disruptive action + // This matches ModSecurity behavior where SecRuleUpdateActionById replaces + // disruptive actions but preserves them if only non-disruptive actions are updated + if hasDisruptiveAction { + rule.ClearDisruptiveActions() + } + + // Apply the parsed actions to the rule without re-parsing rp := RuleParser{ rule: rule, options: RuleOptions{ @@ -1185,7 +1287,7 @@ func updateActionBySingleID(id int, actions string, options *DirectiveOptions) e }, defaultActions: map[types.RulePhase][]ruleAction{}, } - return rp.ParseActions(strings.Trim(actions, "\"")) + return rp.applyParsedActions(parsedActions) } // Description: Updates the target (variable) list of the specified rule(s) by tag. diff --git a/internal/seclang/directives_test.go b/internal/seclang/directives_test.go index 373f100be..5cc91ebb0 100644 --- a/internal/seclang/directives_test.go +++ b/internal/seclang/directives_test.go @@ -154,8 +154,7 @@ func TestDirectives(t *testing.T) { "SecUploadKeepFiles": { {"", expectErrorOnDirective}, {"Ox", expectErrorOnDirective}, - {"On", func(w *corazawaf.WAF) bool { return w.UploadKeepFiles }}, - {"Off", func(w *corazawaf.WAF) bool { return !w.UploadKeepFiles }}, + {"Off", func(w *corazawaf.WAF) bool { return w.UploadKeepFiles == types.UploadKeepFilesOff }}, }, "SecUploadFileMode": { {"", expectErrorOnDirective}, @@ -317,6 +316,10 @@ func TestDirectives(t *testing.T) { {"/tmp-non-existing", expectErrorOnDirective}, {os.TempDir(), func(w *corazawaf.WAF) bool { return w.UploadDir == os.TempDir() }}, } + directiveCases["SecUploadKeepFiles"] = append(directiveCases["SecUploadKeepFiles"], + directiveCase{"On", func(w *corazawaf.WAF) bool { return w.UploadKeepFiles == types.UploadKeepFilesOn }}, + directiveCase{"RelevantOnly", func(w *corazawaf.WAF) bool { return w.UploadKeepFiles == types.UploadKeepFilesRelevantOnly }}, + ) } for name, dCases := range directiveCases { diff --git a/internal/seclang/directivesmap.gen.go b/internal/seclang/directivesmap.gen.go index b6ec36ee4..c9c118674 100644 --- a/internal/seclang/directivesmap.gen.go +++ b/internal/seclang/directivesmap.gen.go @@ -13,6 +13,7 @@ var ( _ directive = directiveSecResponseBodyAccess _ directive = directiveSecRequestBodyLimit _ directive = directiveSecRequestBodyAccess + _ directive = directiveSecRequestBodyJsonDepthLimit _ directive = directiveSecRuleEngine _ directive = directiveSecWebAppID _ directive = directiveSecServerSignature @@ -75,6 +76,7 @@ var directivesMap = map[string]directive{ "secresponsebodyaccess": directiveSecResponseBodyAccess, "secrequestbodylimit": directiveSecRequestBodyLimit, "secrequestbodyaccess": directiveSecRequestBodyAccess, + "secrequestbodyjsondepthlimit": directiveSecRequestBodyJsonDepthLimit, "secruleengine": directiveSecRuleEngine, "secwebappid": directiveSecWebAppID, "secserversignature": directiveSecServerSignature, diff --git a/internal/seclang/parser.go b/internal/seclang/parser.go index e1fbf3a70..e3dfc111d 100644 --- a/internal/seclang/parser.go +++ b/internal/seclang/parser.go @@ -148,7 +148,7 @@ func (p *Parser) parseString(data string) error { func (p *Parser) evaluateLine(l string) error { if l == "" || l[0] == '#' { - panic("invalid line") + return errors.New("invalid line") } // first we get the directive dir, opts, _ := strings.Cut(l, " ") diff --git a/internal/seclang/rule_parser.go b/internal/seclang/rule_parser.go index feeaa6663..9ed6151d9 100644 --- a/internal/seclang/rule_parser.go +++ b/internal/seclang/rule_parser.go @@ -198,6 +198,9 @@ func (rp *RuleParser) ParseOperator(operator string) error { Root: rp.options.ParserConfig.Root, Datasets: rp.options.Datasets, } + if rp.options.WAF != nil { + opts.Memoizer = rp.options.WAF.Memoizer() + } if wd := rp.options.ParserConfig.WorkingDir; wd != "" { opts.Path = append(opts.Path, wd) @@ -265,11 +268,25 @@ func (rp *RuleParser) ParseDefaultActions(actions string) error { // ParseActions parses a comma separated list of actions:arguments // Arguments can be wrapper inside quotes func (rp *RuleParser) ParseActions(actions string) error { - disabledActions := rp.options.ParserConfig.DisabledRuleActions act, err := parseActions(rp.options.WAF.Logger, actions) if err != nil { return err } + return rp.applyParsedActions(act) +} + +// applyParsedActions applies a list of already-parsed actions to the rule. +// +// This is a helper method used internally by ParseActions and directive handlers +// (such as SecRuleUpdateActionById) to avoid code duplication. It's useful when +// actions have been parsed once for inspection and need to be applied without +// re-parsing, avoiding redundant parsing operations. +// +// The method validates that none of the actions are disabled, executes metadata +// actions, merges with default actions for the rule's phase, and initializes +// all actions on the rule. +func (rp *RuleParser) applyParsedActions(act []ruleAction) error { + disabledActions := rp.options.ParserConfig.DisabledRuleActions // check if forbidden action: for _, a := range act { if utils.InSlice(a.Key, disabledActions) { @@ -334,9 +351,13 @@ func ParseRule(options RuleOptions) (*corazawaf.Rule, error) { } var err error + rule := corazawaf.NewRule() + if options.WAF != nil { + rule.SetMemoizer(options.WAF.Memoizer()) + } rp := RuleParser{ options: options, - rule: corazawaf.NewRule(), + rule: rule, defaultActions: map[types.RulePhase][]ruleAction{}, } var defaultActionsRaw []string @@ -389,7 +410,7 @@ func ParseRule(options RuleOptions) (*corazawaf.Rule, error) { return nil, err } } - rule := rp.Rule() + rule = rp.Rule() rule.File_ = options.ParserConfig.ConfigFile rule.Line_ = options.ParserConfig.LastLine diff --git a/internal/strings/strings.go b/internal/strings/strings.go index dd78e3871..9b0b42e3c 100644 --- a/internal/strings/strings.go +++ b/internal/strings/strings.go @@ -129,3 +129,31 @@ func InSlice(a string, list []string) bool { func WrapUnsafe(buf []byte) string { return *(*string)(unsafe.Pointer(&buf)) } + +// HasRegex checks if s is enclosed in unescaped forward slashes (e.g. "/pattern/"), +// consistent with the ModSecurity regex delimiter convention. It returns (true, pattern) +// where pattern is the content between the slashes. Escaped closing slashes (e.g. "/foo\/") +// are treated as plain strings and return (false, s). +func HasRegex(s string) (bool, string) { + if len(s) < 2 || s[0] != '/' { + return false, s + } + lastIdx := len(s) - 1 + if s[lastIdx] != '/' { + return false, s + } + // "//" edge-case: empty pattern + if lastIdx == 1 { + return true, "" + } + // Count consecutive backslashes immediately before the closing '/'. + // An even count (including zero) means the '/' is unescaped. + backslashes := 0 + for i := lastIdx - 1; i >= 0 && s[i] == '\\'; i-- { + backslashes++ + } + if backslashes%2 == 0 { + return true, s[1:lastIdx] + } + return false, s +} diff --git a/internal/strings/strings_test.go b/internal/strings/strings_test.go index e398a9a57..5310d5d98 100644 --- a/internal/strings/strings_test.go +++ b/internal/strings/strings_test.go @@ -112,3 +112,97 @@ func TestRandomStringConcurrency(t *testing.T) { go RandomString(10000) } } + +func TestHasRegex(t *testing.T) { + tCases := []struct { + name string + input string + expectIsRegex bool + expectPattern string + }{ + { + name: "valid regex pattern", + input: "/user/", + expectIsRegex: true, + expectPattern: "user", + }, + { + name: "escaped slash at end — not a regex", + input: `/user\/`, + expectIsRegex: false, + expectPattern: `/user\/`, + }, + { + name: "double-escaped slash at end — is a regex", + input: `/user\\/`, + expectIsRegex: true, + expectPattern: `user\\`, + }, + { + name: "triple-escaped slash at end — not a regex", + input: `/user\\\/`, + expectIsRegex: false, + expectPattern: `/user\\\/`, + }, + { + name: "empty pattern //", + input: "//", + expectIsRegex: true, + expectPattern: "", + }, + { + name: "too short — single char", + input: "/", + expectIsRegex: false, + expectPattern: "/", + }, + { + name: "no leading slash", + input: "user/", + expectIsRegex: false, + expectPattern: "user/", + }, + { + name: "no trailing slash", + input: "/user", + expectIsRegex: false, + expectPattern: "/user", + }, + { + name: "complex pattern with anchors and quantifiers", + input: `/^json\.\d+\.field$/`, + expectIsRegex: true, + expectPattern: `^json\.\d+\.field$`, + }, + { + name: "pattern with character class", + input: "/user[0-9]+/", + expectIsRegex: true, + expectPattern: "user[0-9]+", + }, + { + name: "plain string without slashes", + input: "username", + expectIsRegex: false, + expectPattern: "username", + }, + { + name: "empty string", + input: "", + expectIsRegex: false, + expectPattern: "", + }, + } + + for _, tCase := range tCases { + t.Run(tCase.name, func(t *testing.T) { + gotIsRegex, gotPattern := HasRegex(tCase.input) + if gotIsRegex != tCase.expectIsRegex { + t.Errorf("HasRegex(%q): isRegex = %v, want %v", tCase.input, gotIsRegex, tCase.expectIsRegex) + } + if gotPattern != tCase.expectPattern { + t.Errorf("HasRegex(%q): pattern = %q, want %q", tCase.input, gotPattern, tCase.expectPattern) + } + }) + } +} diff --git a/internal/transformations/escape_seq_decode.go b/internal/transformations/escape_seq_decode.go index f740b683b..f245109e3 100644 --- a/internal/transformations/escape_seq_decode.go +++ b/internal/transformations/escape_seq_decode.go @@ -97,6 +97,7 @@ func doEscapeSeqDecode(input string, pos int) (string, bool) { data[d] = input[i+1] d++ i += 2 + changed = true } else { /* Input character not a backslash, copy it. */ data[d] = input[i] diff --git a/internal/transformations/escape_seq_decode_test.go b/internal/transformations/escape_seq_decode_test.go index d9a62f891..5e2a3ebd3 100644 --- a/internal/transformations/escape_seq_decode_test.go +++ b/internal/transformations/escape_seq_decode_test.go @@ -30,6 +30,10 @@ func TestEscapeSeqDecode(t *testing.T) { input: "\\a\\b\\f\\n\\r\\t\\v\\u0000\\?\\'\\\"\\0\\12\\123\\x00\\xff", want: "\a\b\f\n\r\t\vu0000?'\"\x00\nS\x00\xff", }, + { + input: "\\z", + want: "z", + }, } for _, tc := range tests { diff --git a/internal/transformations/remove_comments.go b/internal/transformations/remove_comments.go index e2a136c40..a5eba6c7a 100644 --- a/internal/transformations/remove_comments.go +++ b/internal/transformations/remove_comments.go @@ -18,9 +18,11 @@ charLoop: switch { case (input[i] == '/') && (i+1 < inputLen) && (input[i+1] == '*'): incomment = true + changed = true i += 2 case (input[i] == '<') && (i+3 < inputLen) && (input[i+1] == '!') && (input[i+2] == '-') && (input[i+3] == '-'): incomment = true + changed = true i += 4 case (input[i] == '-') && (i+1 < inputLen) && (input[i+1] == '-'): input[i] = ' ' diff --git a/internal/transformations/remove_comments_test.go b/internal/transformations/remove_comments_test.go new file mode 100644 index 000000000..001afe166 --- /dev/null +++ b/internal/transformations/remove_comments_test.go @@ -0,0 +1,91 @@ +// Copyright 2024 Juan Pablo Tosso and the OWASP Coraza contributors +// SPDX-License-Identifier: Apache-2.0 + +package transformations + +import "testing" + +func TestRemoveComments(t *testing.T) { + tests := []struct { + name string + input string + want string + changed bool + }{ + { + name: "no comments", + input: "hello world", + want: "hello world", + changed: false, + }, + { + name: "c-style comment", + input: "hello /* comment */ world", + want: "hello world", + changed: true, + }, + { + name: "html comment", + input: "hello <!-- comment --> world", + want: "hello world", + changed: true, + }, + { + name: "c-style comment only", + input: "/* comment */", + want: "\x00", + changed: true, + }, + { + name: "html comment only", + input: "<!-- comment -->", + want: "\x00", + changed: true, + }, + { + name: "unclosed c-style comment", + input: "hello /* unclosed", + want: "hello ", + changed: true, + }, + { + name: "unclosed html comment", + input: "hello <!-- unclosed", + want: "hello ", + changed: true, + }, + { + name: "double dash", + input: "hello -- rest", + want: "hello ", + changed: true, + }, + { + name: "hash comment", + input: "hello # rest", + want: "hello ", + changed: true, + }, + { + name: "empty string", + input: "", + want: "", + changed: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + have, changed, err := removeComments(tc.input) + if err != nil { + t.Fatal(err) + } + if changed != tc.changed { + t.Errorf("changed: want %t, have %t", tc.changed, changed) + } + if have != tc.want { + t.Errorf("value: want %q, have %q", tc.want, have) + } + }) + } +} diff --git a/magefile.go b/magefile.go index 943ba4b2b..841ffead8 100644 --- a/magefile.go +++ b/magefile.go @@ -108,7 +108,7 @@ func Test() error { return err } - if err := sh.RunV("go", "test", "-tags=memoize_builders", "./..."); err != nil { + if err := sh.RunV("go", "test", "-tags=coraza.no_memoize", "./..."); err != nil { return err } @@ -120,7 +120,7 @@ func Test() error { return err } - if err := sh.RunV("go", "test", "-tags=memoize_builders", "./testing/coreruleset"); err != nil { + if err := sh.RunV("go", "test", "-tags=coraza.no_memoize", "./testing/coreruleset"); err != nil { return err } @@ -182,8 +182,8 @@ func Coverage() error { if err := sh.RunV("go", "test", tagsCmd, "-coverprofile=build/coverage-ftw.txt", "-covermode=atomic", "-coverpkg=./...", "./testing/coreruleset"); err != nil { return err } - // we run tinygo tag only if memoize_builders is not enabled - if !strings.Contains(tags, "memoize_builders") { + // we run tinygo tag only if coraza.no_memoize is not enabled + if !strings.Contains(tags, "coraza.no_memoize") { if tagsCmd != "" { tagsCmd += ",tinygo" } @@ -222,7 +222,7 @@ func Fuzz() error { for _, pkgTests := range tests { for _, test := range pkgTests.tests { fmt.Println("Running", test) - if err := sh.RunV("go", "test", "-fuzz="+test, "-fuzztime=2m", pkgTests.pkg); err != nil { + if err := sh.RunV("go", "test", "-fuzz="+test, "-fuzztime=3m", pkgTests.pkg); err != nil { return err } } @@ -281,7 +281,7 @@ func TagsMatrix() error { "coraza.rule.mandatory_rule_id_check", "coraza.rule.case_sensitive_args_keys", "coraza.rule.no_regex_multiline", - "memoize_builders", + "coraza.no_memoize", "coraza.rule.multiphase_evaluation", "no_fs_access", } diff --git a/streaming_integration_test.go b/streaming_integration_test.go new file mode 100644 index 000000000..a16420ebe --- /dev/null +++ b/streaming_integration_test.go @@ -0,0 +1,211 @@ +// Copyright 2026 OWASP Coraza contributors +// SPDX-License-Identifier: Apache-2.0 + +package coraza + +import ( + "strings" + "testing" + + _ "github.com/corazawaf/coraza/v3/experimental/bodyprocessors" +) + +func TestStreamingPerRecordEvaluation(t *testing.T) { + // WAF with a Phase 1 rule to set the body processor and a Phase 2 rule + // that matches on ARGS_POST content. + waf, err := NewWAF(NewWAFConfig().WithDirectives(` + SecRuleEngine On + SecRequestBodyAccess On + SecRule REQUEST_HEADERS:content-type "application/x-ndjson" "id:1,phase:1,pass,nolog,ctl:requestBodyProcessor=JSONSTREAM" + SecRule ARGS_POST "@contains evil" "id:100,phase:2,deny,status:403,msg:'Evil detected'" + `)) + if err != nil { + t.Fatalf("failed to create WAF: %v", err) + } + + t.Run("interruption stops at bad record", func(t *testing.T) { + tx := waf.NewTransaction() + defer tx.Close() + + tx.ProcessConnection("127.0.0.1", 1234, "", 0) + tx.ProcessURI("/test", "POST", "HTTP/1.1") + tx.AddRequestHeader("Content-Type", "application/x-ndjson") + tx.AddRequestHeader("Host", "example.com") + + if it := tx.ProcessRequestHeaders(); it != nil { + t.Fatalf("unexpected interruption at headers: %v", it) + } + + // 5 records, record 2 (index 2) contains "evil" + body := `{"name": "Alice"} +{"name": "Bob"} +{"name": "evil payload"} +{"name": "Charlie"} +{"name": "Dave"} +` + if it, _, err := tx.ReadRequestBodyFrom(strings.NewReader(body)); err != nil { + t.Fatalf("failed to write request body: %v", err) + } else if it != nil { + t.Fatalf("unexpected interruption writing body: %v", it) + } + + it, err := tx.ProcessRequestBody() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if it == nil { + t.Fatal("expected interruption from evil record, got nil") + } + if it.Status != 403 { + t.Errorf("expected status 403, got %d", it.Status) + } + + // Verify that rule 100 was matched + matched := tx.MatchedRules() + found := false + for _, mr := range matched { + if mr.Rule().ID() == 100 { + found = true + break + } + } + if !found { + t.Error("expected rule 100 to be in matched rules") + } + }) + + t.Run("clean records pass through", func(t *testing.T) { + tx := waf.NewTransaction() + defer tx.Close() + + tx.ProcessConnection("127.0.0.1", 1234, "", 0) + tx.ProcessURI("/test", "POST", "HTTP/1.1") + tx.AddRequestHeader("Content-Type", "application/x-ndjson") + tx.AddRequestHeader("Host", "example.com") + + if it := tx.ProcessRequestHeaders(); it != nil { + t.Fatalf("unexpected interruption at headers: %v", it) + } + + // All records are clean + body := `{"name": "Alice"} +{"name": "Bob"} +{"name": "Charlie"} +` + if it, _, err := tx.ReadRequestBodyFrom(strings.NewReader(body)); err != nil { + t.Fatalf("failed to write request body: %v", err) + } else if it != nil { + t.Fatalf("unexpected interruption writing body: %v", it) + } + + it, err := tx.ProcessRequestBody() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if it != nil { + t.Fatalf("unexpected interruption for clean records: %v", it) + } + }) +} + +func TestStreamingTXVariablesPersistAcrossRecords(t *testing.T) { + // WAF that increments a TX variable for each record using setvar. + // After processing 3 records, tx.score should be 3. + // A rule checks if tx.score >= 3 and blocks. + waf, err := NewWAF(NewWAFConfig().WithDirectives(` + SecRuleEngine On + SecRequestBodyAccess On + SecRule REQUEST_HEADERS:content-type "application/x-ndjson" "id:1,phase:1,pass,nolog,ctl:requestBodyProcessor=JSONSTREAM" + SecRule ARGS_POST "@rx .*" "id:100,phase:2,pass,nolog,setvar:tx.score=+1" + SecRule TX:score "@ge 3" "id:200,phase:2,deny,status:403,msg:'Score threshold reached'" + `)) + if err != nil { + t.Fatalf("failed to create WAF: %v", err) + } + + t.Run("TX variables accumulate across records", func(t *testing.T) { + tx := waf.NewTransaction() + defer tx.Close() + + tx.ProcessConnection("127.0.0.1", 1234, "", 0) + tx.ProcessURI("/test", "POST", "HTTP/1.1") + tx.AddRequestHeader("Content-Type", "application/x-ndjson") + tx.AddRequestHeader("Host", "example.com") + + if it := tx.ProcessRequestHeaders(); it != nil { + t.Fatalf("unexpected interruption at headers: %v", it) + } + + // 3 records - each increments tx.score by 1 + // After record 2 (3rd record), tx.score should reach 3 and trigger rule 200 + body := `{"data": "a"} +{"data": "b"} +{"data": "c"} +` + if it, _, err := tx.ReadRequestBodyFrom(strings.NewReader(body)); err != nil { + t.Fatalf("failed to write request body: %v", err) + } else if it != nil { + t.Fatalf("unexpected interruption writing body: %v", it) + } + + it, err := tx.ProcessRequestBody() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if it == nil { + t.Fatal("expected interruption from score threshold, got nil") + } + if it.Status != 403 { + t.Errorf("expected status 403, got %d", it.Status) + } + + // Verify that rule 200 was matched (score threshold) + matched := tx.MatchedRules() + found := false + for _, mr := range matched { + if mr.Rule().ID() == 200 { + found = true + break + } + } + if !found { + ids := make([]int, 0, len(matched)) + for _, mr := range matched { + ids = append(ids, mr.Rule().ID()) + } + t.Errorf("expected rule 200 to be in matched rules, got: %v", ids) + } + }) + + t.Run("fewer records dont trigger threshold", func(t *testing.T) { + tx := waf.NewTransaction() + defer tx.Close() + + tx.ProcessConnection("127.0.0.1", 1234, "", 0) + tx.ProcessURI("/test", "POST", "HTTP/1.1") + tx.AddRequestHeader("Content-Type", "application/x-ndjson") + tx.AddRequestHeader("Host", "example.com") + + if it := tx.ProcessRequestHeaders(); it != nil { + t.Fatalf("unexpected interruption at headers: %v", it) + } + + // Only 2 records - tx.score reaches 2, below threshold of 3 + body := `{"data": "a"} +{"data": "b"} +` + if it, _, err := tx.ReadRequestBodyFrom(strings.NewReader(body)); err != nil { + t.Fatalf("failed to write request body: %v", err) + } else if it != nil { + t.Fatalf("unexpected interruption writing body: %v", it) + } + + it, err := tx.ProcessRequestBody() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if it != nil { + t.Fatalf("unexpected interruption for 2 records (below threshold): %v", it) + } + }) +} diff --git a/testing/auditlog_test.go b/testing/auditlog_test.go index ef935f841..6307bc88c 100644 --- a/testing/auditlog_test.go +++ b/testing/auditlog_test.go @@ -11,7 +11,6 @@ import ( "encoding/json" "fmt" "os" - "path/filepath" "strings" "testing" @@ -23,11 +22,12 @@ import ( func TestAuditLogMessages(t *testing.T) { waf := corazawaf.NewWAF() parser := seclang.NewParser(waf) - // generate a random tmp file - file, err := os.Create(filepath.Join(t.TempDir(), "tmp.log")) + + file, err := os.CreateTemp(t.TempDir(), "tmp.log") if err != nil { t.Fatal(err) } + defer file.Close() if err := parser.FromString(fmt.Sprintf("SecAuditLog %s", file.Name())); err != nil { t.Fatal(err) } @@ -83,12 +83,12 @@ func TestAuditLogRelevantOnly(t *testing.T) { `); err != nil { t.Fatal(err) } - // generate a random tmp file - file, err := os.Create(filepath.Join(t.TempDir(), "tmp.log")) + + file, err := os.CreateTemp(t.TempDir(), "tmp.log") if err != nil { t.Fatal(err) } - defer os.Remove(file.Name()) + defer file.Close() if err := parser.FromString(fmt.Sprintf("SecAuditLog %s", file.Name())); err != nil { t.Fatal(err) } @@ -110,11 +110,11 @@ func TestAuditLogRelevantOnly(t *testing.T) { func TestAuditLogRelevantOnlyOk(t *testing.T) { waf := corazawaf.NewWAF() parser := seclang.NewParser(waf) - // generate a random tmp file - file, err := os.Create(filepath.Join(t.TempDir(), "tmp.log")) + file, err := os.CreateTemp(t.TempDir(), "tmp.log") if err != nil { t.Fatal(err) } + defer file.Close() defer os.Remove(file.Name()) if err := parser.FromString(fmt.Sprintf("SecAuditLog %s", file.Name())); err != nil { t.Fatal(err) @@ -158,12 +158,11 @@ func TestAuditLogRelevantOnlyNoAuditlog(t *testing.T) { `); err != nil { t.Fatal(err) } - // generate a random tmp file - file, err := os.Create(filepath.Join(t.TempDir(), "tmp.log")) + file, err := os.CreateTemp(t.TempDir(), "tmp.log") if err != nil { t.Fatal(err) } - defer os.Remove(file.Name()) + defer file.Close() if err := parser.FromString(fmt.Sprintf("SecAuditLog %s", file.Name())); err != nil { t.Fatal(err) } @@ -197,12 +196,11 @@ func TestAuditLogOnWithNoLog(t *testing.T) { `); err != nil { t.Fatal(err) } - // generate a random tmp file - file, err := os.Create(filepath.Join(t.TempDir(), "tmp.log")) + file, err := os.CreateTemp(t.TempDir(), "tmp.log") if err != nil { t.Fatal(err) } - defer os.Remove(file.Name()) + defer file.Close() if err := parser.FromString(fmt.Sprintf("SecAuditLog %s", file.Name())); err != nil { t.Fatal(err) } @@ -236,12 +234,11 @@ func TestAuditLogRequestMethodURIProtocol(t *testing.T) { `); err != nil { t.Fatal(err) } - // generate a random tmp file - file, err := os.Create(filepath.Join(t.TempDir(), "tmp.log")) + file, err := os.CreateTemp(t.TempDir(), "tmp.log") if err != nil { t.Fatal(err) } - defer os.Remove(file.Name()) + defer file.Close() if err := parser.FromString(fmt.Sprintf("SecAuditLog %s", file.Name())); err != nil { t.Fatal(err) } @@ -292,12 +289,11 @@ func TestAuditLogRequestBody(t *testing.T) { `); err != nil { t.Fatal(err) } - // generate a random tmp file - file, err := os.Create(filepath.Join(t.TempDir(), "tmp.log")) + file, err := os.CreateTemp(t.TempDir(), "tmp.log") if err != nil { t.Fatal(err) } - defer os.Remove(file.Name()) + defer file.Close() if err := parser.FromString(fmt.Sprintf("SecAuditLog %s", file.Name())); err != nil { t.Fatal(err) } @@ -350,12 +346,11 @@ func TestAuditLogHFlag(t *testing.T) { `); err != nil { t.Fatal(err) } - // generate a random tmp file - file, err := os.Create(filepath.Join(t.TempDir(), "tmp.log")) + file, err := os.CreateTemp(t.TempDir(), "tmp.log") if err != nil { t.Fatal(err) } - defer os.Remove(file.Name()) + defer file.Close() if err := parser.FromString(fmt.Sprintf("SecAuditLog %s", file.Name())); err != nil { t.Fatal(err) } @@ -400,12 +395,11 @@ func TestAuditLogWithKFlagWithoutHFlag(t *testing.T) { `); err != nil { t.Fatal(err) } - // generate a random tmp file - file, err := os.Create(filepath.Join(t.TempDir(), "tmp.log")) + file, err := os.CreateTemp(t.TempDir(), "tmp.log") if err != nil { t.Fatal(err) } - defer os.Remove(file.Name()) + defer file.Close() if err := parser.FromString(fmt.Sprintf("SecAuditLog %s", file.Name())); err != nil { t.Fatal(err) } @@ -434,3 +428,50 @@ func TestAuditLogWithKFlagWithoutHFlag(t *testing.T) { t.Errorf("Not expected audit log to contain %q, got %q", notExpected, alWithErrMsg.ErrorMessage()) } } + +func TestAuditLogRelevantOnlyDetectionOnly(t *testing.T) { + waf := corazawaf.NewWAF() + parser := seclang.NewParser(waf) + if err := parser.FromString(` + SecRuleEngine DetectionOnly + SecAuditEngine RelevantOnly + SecAuditLogFormat json + SecAuditLogType serial + SecAuditLogRelevantStatus "403" + SecRule ARGS "@unconditionalMatch" "id:1,phase:1,deny,log,auditlog,msg:'expected rule message'" + `); err != nil { + t.Fatal(err) + } + file, err := os.CreateTemp(t.TempDir(), "tmp.log") + if err != nil { + t.Fatal(err) + } + defer file.Close() + if err = parser.FromString(fmt.Sprintf("SecAuditLog %s", file.Name())); err != nil { + t.Fatal(err) + } + tx := waf.NewTransaction() + tx.AddGetRequestArgument("test", "test") + tx.ProcessRequestHeaders() + // now we read file + if _, err = file.Seek(0, 0); err != nil { + t.Fatal(err) + } + tx.ProcessLogging() + var al auditlog.Log + if err = json.NewDecoder(file).Decode(&al); err != nil { + t.Fatal(err) + } + if len(al.Messages()) != 1 { + t.Fatalf("Expected 1 message, got %d", len(al.Messages())) + } + type auditLogWithErrMesg interface{ ErrorMessage() string } + alWithErrMsg, ok := al.Messages()[0].(auditLogWithErrMesg) + if !ok { + t.Fatalf("Expected message to be of type auditLogWithErrMesg") + } + expected := "expected rule message" + if !strings.Contains(alWithErrMsg.ErrorMessage(), expected) { + t.Errorf("Expected audit log to contain %q, got %q", expected, alWithErrMsg.ErrorMessage()) + } +} diff --git a/testing/coraza_test.go b/testing/coraza_test.go index 8edaff198..4d3d33323 100644 --- a/testing/coraza_test.go +++ b/testing/coraza_test.go @@ -24,8 +24,9 @@ func TestEngine(t *testing.T) { t.Run(p.Meta.Name, func(t *testing.T) { tt, err := testList(t, &p) if err != nil { - t.Error(err) + t.Fatal(err) } + for _, test := range tt { t.Run(test.Name, func(t *testing.T) { if err := test.RunPhases(); err != nil { diff --git a/testing/coreruleset/coreruleset_test.go b/testing/coreruleset/coreruleset_test.go index 55977fe28..930fef1dc 100644 --- a/testing/coreruleset/coreruleset_test.go +++ b/testing/coreruleset/coreruleset_test.go @@ -16,6 +16,7 @@ import ( "net/url" "os" "path/filepath" + "runtime" "strconv" "strings" "testing" @@ -32,6 +33,7 @@ import ( coreruleset "github.com/corazawaf/coraza-coreruleset/v4" crstests "github.com/corazawaf/coraza-coreruleset/v4/tests" "github.com/corazawaf/coraza/v3" + "github.com/corazawaf/coraza/v3/experimental" txhttp "github.com/corazawaf/coraza/v3/http" "github.com/corazawaf/coraza/v3/types" ) @@ -42,7 +44,7 @@ func BenchmarkCRSCompilation(b *testing.B) { b.Fatal(err) } for i := 0; i < b.N; i++ { - _, err := coraza.NewWAF(coraza.NewWAFConfig(). + waf, err := coraza.NewWAF(coraza.NewWAFConfig(). WithRootFS(coreruleset.FS). WithDirectives(string(rec)). WithDirectives("Include @crs-setup.conf.example"). @@ -50,6 +52,9 @@ func BenchmarkCRSCompilation(b *testing.B) { if err != nil { b.Fatal(err) } + if closer, ok := waf.(experimental.WAFCloser); ok { + closer.Close() + } } } @@ -60,7 +65,7 @@ func BenchmarkCRSSimpleGET(b *testing.B) { for i := 0; i < b.N; i++ { tx := waf.NewTransaction() tx.ProcessConnection("127.0.0.1", 8080, "127.0.0.1", 8080) - tx.ProcessURI("GET", "/some_path/with?parameters=and&other=Stuff", "HTTP/1.1") + tx.ProcessURI("/some_path/with?parameters=and&other=Stuff", "GET", "HTTP/1.1") tx.AddRequestHeader("Host", "localhost") tx.AddRequestHeader("User-Agent", "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_14_5) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/75.0.3770.100 Safari/537.36") tx.AddRequestHeader("Accept", "application/json") @@ -88,7 +93,7 @@ func BenchmarkCRSSimplePOST(b *testing.B) { for i := 0; i < b.N; i++ { tx := waf.NewTransaction() tx.ProcessConnection("127.0.0.1", 8080, "127.0.0.1", 8080) - tx.ProcessURI("POST", "/some_path/with?parameters=and&other=Stuff", "HTTP/1.1") + tx.ProcessURI("/some_path/with?parameters=and&other=Stuff", "POST", "HTTP/1.1") tx.AddRequestHeader("Host", "localhost") tx.AddRequestHeader("User-Agent", "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_14_5) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/75.0.3770.100 Safari/537.36") tx.AddRequestHeader("Accept", "application/json") @@ -122,7 +127,7 @@ func BenchmarkCRSLargePOST(b *testing.B) { for i := 0; i < b.N; i++ { tx := waf.NewTransaction() tx.ProcessConnection("127.0.0.1", 8080, "127.0.0.1", 8080) - tx.ProcessURI("POST", "/some_path/with?parameters=and&other=Stuff", "HTTP/1.1") + tx.ProcessURI("/some_path/with?parameters=and&other=Stuff", "POST", "HTTP/1.1") tx.AddRequestHeader("Host", "localhost") tx.AddRequestHeader("User-Agent", "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_14_5) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/75.0.3770.100 Safari/537.36") tx.AddRequestHeader("Accept", "application/json") @@ -146,6 +151,123 @@ func BenchmarkCRSLargePOST(b *testing.B) { } } +// BenchmarkCRSTransformationCache measures the transformation cache performance +// across different request sizes. The transformation cache benefit scales with +// (number of arguments) × (number of rules sharing transformation prefixes), +// so these benchmarks exercise that by varying argument counts and value sizes. +func BenchmarkCRSTransformationCache(b *testing.B) { + waf := crsWAF(b) + + // Small: 2 query params, short values (typical simple API call) + smallQuery := "user=admin&action=view" + // Medium: 10 params with moderate values (typical form submission) + mediumParams := []string{ + "username=johndoe", + "email=john@example.com", + "first_name=John", + "last_name=Doe", + "address=123+Main+Street", + "city=Springfield", + "state=IL", + "zip=62701", + "phone=555-0123", + "comment=This+is+a+test+comment+with+some+content", + } + mediumBody := strings.Join(mediumParams, "&") + // Large: 30 params with longer values (complex form, many args) + var largeParams []string + for i := 0; i < 30; i++ { + largeParams = append(largeParams, fmt.Sprintf("field_%d=%s", i, strings.Repeat("value", 20))) + } + largeBody := strings.Join(largeParams, "&") + + b.Run("SmallGET_2params", func(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + tx := waf.NewTransaction() + tx.ProcessConnection("127.0.0.1", 8080, "127.0.0.1", 8080) + tx.ProcessURI("GET", "/api/endpoint?"+smallQuery, "HTTP/1.1") + tx.AddRequestHeader("Host", "localhost") + tx.AddRequestHeader("User-Agent", "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_14_5) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/75.0.3770.100 Safari/537.36") + tx.AddRequestHeader("Accept", "application/json") + tx.ProcessRequestHeaders() + if _, err := tx.ProcessRequestBody(); err != nil { + b.Error(err) + } + tx.AddResponseHeader("Content-Type", "application/json") + tx.ProcessResponseHeaders(200, "OK") + if _, err := tx.ProcessResponseBody(); err != nil { + b.Error(err) + } + tx.ProcessLogging() + if err := tx.Close(); err != nil { + b.Error(err) + } + } + }) + + b.Run("MediumPOST_10params", func(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + tx := waf.NewTransaction() + tx.ProcessConnection("127.0.0.1", 8080, "127.0.0.1", 8080) + tx.ProcessURI("POST", "/api/submit?source=web", "HTTP/1.1") + tx.AddRequestHeader("Host", "localhost") + tx.AddRequestHeader("User-Agent", "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_14_5) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/75.0.3770.100 Safari/537.36") + tx.AddRequestHeader("Accept", "text/html") + tx.AddRequestHeader("Content-Type", "application/x-www-form-urlencoded") + tx.ProcessRequestHeaders() + if _, _, err := tx.WriteRequestBody([]byte(mediumBody)); err != nil { + b.Error(err) + } + if _, err := tx.ProcessRequestBody(); err != nil { + b.Error(err) + } + tx.AddResponseHeader("Content-Type", "text/html") + tx.ProcessResponseHeaders(200, "OK") + if _, err := tx.ProcessResponseBody(); err != nil { + b.Error(err) + } + tx.ProcessLogging() + if err := tx.Close(); err != nil { + b.Error(err) + } + } + }) + + b.Run("LargePOST_30params", func(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + tx := waf.NewTransaction() + tx.ProcessConnection("127.0.0.1", 8080, "127.0.0.1", 8080) + tx.ProcessURI("POST", "/api/bulk?source=web&format=json", "HTTP/1.1") + tx.AddRequestHeader("Host", "localhost") + tx.AddRequestHeader("User-Agent", "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_14_5) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/75.0.3770.100 Safari/537.36") + tx.AddRequestHeader("Accept", "text/html") + tx.AddRequestHeader("Content-Type", "application/x-www-form-urlencoded") + tx.ProcessRequestHeaders() + if _, _, err := tx.WriteRequestBody([]byte(largeBody)); err != nil { + b.Error(err) + } + if _, err := tx.ProcessRequestBody(); err != nil { + b.Error(err) + } + tx.AddResponseHeader("Content-Type", "text/html") + tx.ProcessResponseHeaders(200, "OK") + if _, err := tx.ProcessResponseBody(); err != nil { + b.Error(err) + } + tx.ProcessLogging() + if err := tx.Close(); err != nil { + b.Error(err) + } + } + }) +} + func TestFTW(t *testing.T) { conf := coraza.NewWAFConfig() @@ -221,6 +343,9 @@ SecRule REQUEST_HEADERS:X-CRS-Test "@rx ^.*$" \ if err != nil { t.Fatal(err) } + if closer, ok := waf.(experimental.WAFCloser); ok { + defer closer.Close() + } // CRS regression tests are expected to be run with https://github.com/coreruleset/albedo as backend server s := httptest.NewServer(txhttp.WrapHandler(waf, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -283,6 +408,90 @@ SecRule REQUEST_HEADERS:X-CRS-Test "@rx ^.*$" \ } } +func BenchmarkCRSMultiWAFCompilation(b *testing.B) { + for i := 0; i < b.N; i++ { + for w := 0; w < 10; w++ { + waf := crsWAF(b) + if closer, ok := waf.(experimental.WAFCloser); ok { + closer.Close() + } + } + } +} + +func BenchmarkCRSMemoizeSpeedup(b *testing.B) { + for i := 0; i < b.N; i++ { + // Cold compilation (first WAF) + cold := time.Now() + waf1 := crsWAF(b) + coldDur := time.Since(cold) + if closer, ok := waf1.(experimental.WAFCloser); ok { + closer.Close() + } + + // Warm compilation (second WAF, patterns already cached) + warm := time.Now() + waf2 := crsWAF(b) + warmDur := time.Since(warm) + if closer, ok := waf2.(experimental.WAFCloser); ok { + closer.Close() + } + + b.Logf("cold=%s warm=%s speedup=%.1fx", coldDur, warmDur, float64(coldDur)/float64(warmDur)) + } +} + +func TestCRSCloseReleasesMemory(t *testing.T) { + if os.Getenv("CORAZA_RUN_CRS_CLOSE_MEMTEST") == "" { + t.Skip("skipping memory diagnostic test; set CORAZA_RUN_CRS_CLOSE_MEMTEST=1 to run") + } + + var m runtime.MemStats + + runtime.GC() + runtime.ReadMemStats(&m) + baseHeap := m.HeapAlloc + + // Build WAFs directly (not via crsWAF) so we control the lifecycle + // without t.Cleanup holding references that prevent GC. + rec, err := os.ReadFile(filepath.Join("..", "..", "coraza.conf-recommended")) + if err != nil { + t.Fatal(err) + } + conf := coraza.NewWAFConfig(). + WithRootFS(coreruleset.FS). + WithDirectives(string(rec)). + WithDirectives("Include @crs-setup.conf.example"). + WithDirectives("Include @owasp_crs/*.conf") + + wafs := make([]coraza.WAF, 5) + for i := range wafs { + waf, err := coraza.NewWAF(conf) + if err != nil { + t.Fatal(err) + } + wafs[i] = waf + } + + runtime.GC() + runtime.ReadMemStats(&m) + peakHeap := m.HeapAlloc + + for _, waf := range wafs { + if closer, ok := waf.(experimental.WAFCloser); ok { + closer.Close() + } + } + + runtime.GC() + runtime.ReadMemStats(&m) + afterHeap := m.HeapAlloc + + t.Logf("base=%dMiB peak=%dMiB after_close=%dMiB released=%dMiB", + baseHeap/1024/1024, peakHeap/1024/1024, + afterHeap/1024/1024, (peakHeap-afterHeap)/1024/1024) +} + func crsWAF(t testing.TB) coraza.WAF { t.Helper() rec, err := os.ReadFile(filepath.Join("..", "..", "coraza.conf-recommended")) @@ -321,6 +530,14 @@ SecAction "id:900005,\ if err != nil { t.Fatal(err) } + if closer, ok := waf.(experimental.WAFCloser); ok { + // Avoid registering per-iteration Cleanup callbacks in benchmarks, as that + // can retain WAF instances and skew memory/benchmark results. Benchmarks + // calling crsWAF are expected to close the WAF explicitly if needed. + if _, isBenchmark := t.(*testing.B); !isBenchmark { + t.Cleanup(func() { closer.Close() }) + } + } return waf } diff --git a/testing/coreruleset/go.mod b/testing/coreruleset/go.mod index 6358a7cb5..9ae66f406 100644 --- a/testing/coreruleset/go.mod +++ b/testing/coreruleset/go.mod @@ -1,10 +1,10 @@ module github.com/corazawaf/coraza/v3/testing/coreruleset -go 1.24.0 +go 1.25.0 require ( github.com/bmatcuk/doublestar/v4 v4.9.1 - github.com/corazawaf/coraza-coreruleset/v4 v4.20.0 + github.com/corazawaf/coraza-coreruleset/v4 v4.24.0 github.com/corazawaf/coraza/v3 v3.3.3 github.com/coreruleset/albedo v0.3.0 github.com/coreruleset/go-ftw v1.3.0 @@ -15,7 +15,7 @@ require ( github.com/Masterminds/goutils v1.1.1 // indirect github.com/Masterminds/semver v1.5.0 // indirect github.com/Masterminds/sprig v2.22.0+incompatible // indirect - github.com/corazawaf/libinjection-go v0.2.2 // indirect + github.com/corazawaf/libinjection-go v0.3.2 // indirect github.com/coreruleset/ftw-tests-schema/v2 v2.2.0 // indirect github.com/fsnotify/fsnotify v1.7.0 // indirect github.com/go-viper/mapstructure/v2 v2.4.0 // indirect @@ -49,11 +49,11 @@ require ( github.com/tidwall/pretty v1.2.1 // indirect github.com/valllabh/ocsf-schema-golang v1.0.3 // indirect github.com/yargevad/filepathx v1.0.0 // indirect - golang.org/x/crypto v0.45.0 // indirect - golang.org/x/net v0.47.0 // indirect - golang.org/x/sync v0.18.0 // indirect - golang.org/x/sys v0.38.0 // indirect - golang.org/x/text v0.31.0 // indirect + golang.org/x/crypto v0.49.0 // indirect + golang.org/x/net v0.52.0 // indirect + golang.org/x/sync v0.20.0 // indirect + golang.org/x/sys v0.42.0 // indirect + golang.org/x/text v0.35.0 // indirect golang.org/x/time v0.9.0 // indirect google.golang.org/protobuf v1.35.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/testing/coreruleset/go.sum b/testing/coreruleset/go.sum index f28097436..c53ba2dfa 100644 --- a/testing/coreruleset/go.sum +++ b/testing/coreruleset/go.sum @@ -8,10 +8,10 @@ github.com/bmatcuk/doublestar/v4 v4.9.1 h1:X8jg9rRZmJd4yRy7ZeNDRnM+T3ZfHv15JiBJ/ github.com/bmatcuk/doublestar/v4 v4.9.1/go.mod h1:xBQ8jztBU6kakFMg+8WGxn0c6z1fTSPVIjEY1Wr7jzc= github.com/corazawaf/coraza-coreruleset v0.0.0-20240226094324-415b1017abdc h1:OlJhrgI3I+FLUCTI3JJW8MoqyM78WbqJjecqMnqG+wc= github.com/corazawaf/coraza-coreruleset v0.0.0-20240226094324-415b1017abdc/go.mod h1:7rsocqNDkTCira5T0M7buoKR2ehh7YZiPkzxRuAgvVU= -github.com/corazawaf/coraza-coreruleset/v4 v4.20.0 h1:rV976KQN49oTFaYzNqHHdIQYmA3Qr4kxCqH8SVJLKK8= -github.com/corazawaf/coraza-coreruleset/v4 v4.20.0/go.mod h1:tRjsdtj39+at47dLCpE8ChoDa2FK2IAwTWIpDT8Z62g= -github.com/corazawaf/libinjection-go v0.2.2 h1:Chzodvb6+NXh6wew5/yhD0Ggioif9ACrQGR4qjTCs1g= -github.com/corazawaf/libinjection-go v0.2.2/go.mod h1:OP4TM7xdJ2skyXqNX1AN1wN5nNZEmJNuWbNPOItn7aw= +github.com/corazawaf/coraza-coreruleset/v4 v4.24.0 h1:7Ys2vZegaDIwDeDcRuCQNjMzNaDLklqXogJsucoE1tk= +github.com/corazawaf/coraza-coreruleset/v4 v4.24.0/go.mod h1:tRjsdtj39+at47dLCpE8ChoDa2FK2IAwTWIpDT8Z62g= +github.com/corazawaf/libinjection-go v0.3.2 h1:9rrKt0lpg4WvUXt+lwS06GywfqRXXsa/7JcOw5cQLwI= +github.com/corazawaf/libinjection-go v0.3.2/go.mod h1:Ik/+w3UmTWH9yn366RgS9D95K3y7Atb5m/H/gXzzPCk= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/coreruleset/albedo v0.3.0 h1:GC0CbI8/qzzBbiP95z7A+58pjlN8P5jheVrUwEfzRfQ= github.com/coreruleset/albedo v0.3.0/go.mod h1:dulcaAKNDKBnMw7CpK7V61zmGy5D4ASXSrtRpuDnYK8= @@ -111,35 +111,25 @@ github.com/valllabh/ocsf-schema-golang v1.0.3 h1:eR8k/3jP/OOqB8LRCtdJ4U+vlgd/gk5 github.com/valllabh/ocsf-schema-golang v1.0.3/go.mod h1:sZ3as9xqm1SSK5feFWIR2CuGeGRhsM7TR1MbpBctzPk= github.com/yargevad/filepathx v1.0.0 h1:SYcT+N3tYGi+NvazubCNlvgIPbzAk7i7y2dwg3I5FYc= github.com/yargevad/filepathx v1.0.0/go.mod h1:BprfX/gpYNJHJfc35GjRRpVcwWXS89gGulUIU5tK3tA= -golang.org/x/crypto v0.41.0 h1:WKYxWedPGCTVVl5+WHSSrOBT0O8lx32+zxmHxijgXp4= -golang.org/x/crypto v0.41.0/go.mod h1:pO5AFd7FA68rFak7rOAGVuygIISepHftHnr8dr6+sUc= -golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q= -golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4= -golang.org/x/mod v0.26.0 h1:EGMPT//Ezu+ylkCijjPc+f4Aih7sZvaAr+O3EHBxvZg= -golang.org/x/mod v0.26.0/go.mod h1:/j6NAhSk8iQ723BGAUyoAcn7SlD7s15Dp9Nd/SfeaFQ= -golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE= -golang.org/x/net v0.43.0/go.mod h1:vhO1fvI4dGsIjh73sWfUVjj3N7CA9WkKJNQm2svM6Jg= -golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY= -golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU= -golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw= -golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= -golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I= -golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= +golang.org/x/crypto v0.49.0 h1:+Ng2ULVvLHnJ/ZFEq4KdcDd/cfjrrjjNSXNzxg0Y4U4= +golang.org/x/crypto v0.49.0/go.mod h1:ErX4dUh2UM+CFYiXZRTcMpEcN8b/1gxEuv3nODoYtCA= +golang.org/x/mod v0.33.0 h1:tHFzIWbBifEmbwtGz65eaWyGiGZatSrT9prnU8DbVL8= +golang.org/x/mod v0.33.0/go.mod h1:swjeQEj+6r7fODbD2cqrnje9PnziFuw4bmLbBZFrQ5w= +golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0= +golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw= +golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= +golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI= -golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= -golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc= -golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= -golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng= -golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU= -golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM= -golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM= +golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo= +golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= +golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8= +golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA= golang.org/x/time v0.9.0 h1:EsRrnYcQiGH+5FfbgvV4AP7qEZstoyrHB0DzarOQ4ZY= golang.org/x/time v0.9.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= -golang.org/x/tools v0.35.0 h1:mBffYraMEf7aa0sB+NuKnuCy8qI/9Bughn8dC2Gu5r0= -golang.org/x/tools v0.35.0/go.mod h1:NKdj5HkL/73byiZSJjqJgKn3ep7KjFkBOkR/Hps3VPw= +golang.org/x/tools v0.42.0 h1:uNgphsn75Tdz5Ji2q36v/nsFSfR/9BRFvqhGBaJGd5k= +golang.org/x/tools v0.42.0/go.mod h1:Ma6lCIwGZvHK6XtgbswSoWroEkhugApmsXyrUmBhfr0= google.golang.org/protobuf v1.35.1 h1:m3LfL6/Ca+fqnjnlqQXNpFPABW1UD7mjh8KO2mKFytA= google.golang.org/protobuf v1.35.1/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/testing/e2e/e2e_test.go b/testing/e2e/e2e_test.go index f02be4a16..2eae554b1 100644 --- a/testing/e2e/e2e_test.go +++ b/testing/e2e/e2e_test.go @@ -15,6 +15,7 @@ import ( "github.com/mccutchen/go-httpbin/v2/httpbin" "github.com/corazawaf/coraza/v3" + "github.com/corazawaf/coraza/v3/experimental" txhttp "github.com/corazawaf/coraza/v3/http" "github.com/corazawaf/coraza/v3/http/e2e" ) @@ -29,6 +30,9 @@ func TestE2e(t *testing.T) { if err != nil { t.Fatal(err) } + if closer, ok := waf.(experimental.WAFCloser); ok { + defer closer.Close() + } httpbin := httpbin.New() diff --git a/testing/e2e/ndjson_e2e_test.go b/testing/e2e/ndjson_e2e_test.go new file mode 100644 index 000000000..d10274ad1 --- /dev/null +++ b/testing/e2e/ndjson_e2e_test.go @@ -0,0 +1,172 @@ +// Copyright 2026 OWASP Coraza contributors +// SPDX-License-Identifier: Apache-2.0 + +//go:build !tinygo + +package e2e_test + +import ( + "fmt" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/corazawaf/coraza/v3" + "github.com/corazawaf/coraza/v3/experimental" + _ "github.com/corazawaf/coraza/v3/experimental/bodyprocessors" + txhttp "github.com/corazawaf/coraza/v3/http" +) + +const ndjsonDirectives = ` +SecRuleEngine On +SecRequestBodyAccess On +SecRule REQUEST_HEADERS:Content-Type "@rx ^application/(x-ndjson|jsonlines|json-seq)" \ + "id:1,phase:1,pass,nolog,ctl:requestBodyProcessor=JSONSTREAM" +SecRule ARGS_POST "@contains evil" "id:100,phase:2,deny,status:403,log,msg:'Evil payload in NDJSON'" +SecRule ARGS_POST "@detectSQLi" "id:101,phase:2,t:none,t:urlDecodeUni,t:removeNulls,deny,status:403,log,msg:'SQLi in NDJSON'" +` + +func newNDJSONTestServer(t *testing.T) (*httptest.Server, func()) { + t.Helper() + + conf := coraza.NewWAFConfig().WithDirectives(ndjsonDirectives).WithRequestBodyAccess() + waf, err := coraza.NewWAF(conf) + if err != nil { + t.Fatalf("failed to create WAF: %v", err) + } + + backend := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + w.WriteHeader(http.StatusOK) + fmt.Fprintf(w, "OK: %s", body) + }) + + s := httptest.NewServer(txhttp.WrapHandler(waf, backend)) + cleanup := func() { + s.Close() + if closer, ok := waf.(experimental.WAFCloser); ok { + closer.Close() + } + } + return s, cleanup +} + +func TestNDJSON_E2E_CleanRecord(t *testing.T) { + s, cleanup := newNDJSONTestServer(t) + defer cleanup() + + body := `{"name":"Alice","role":"user"}` + "\n" + + `{"name":"Bob","role":"admin"}` + "\n" + + resp, err := http.Post(s.URL+"/api/users", "application/x-ndjson", strings.NewReader(body)) + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Errorf("expected 200 for clean NDJSON, got %d", resp.StatusCode) + } +} + +func TestNDJSON_E2E_MaliciousRecord(t *testing.T) { + s, cleanup := newNDJSONTestServer(t) + defer cleanup() + + // Third record contains malicious payload that should be blocked + body := `{"name":"Alice","role":"user"}` + "\n" + + `{"name":"Bob","role":"admin"}` + "\n" + + `{"name":"evil payload","role":"attacker"}` + "\n" + + `{"name":"Charlie","role":"user"}` + "\n" + + resp, err := http.Post(s.URL+"/api/users", "application/x-ndjson", strings.NewReader(body)) + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusForbidden { + t.Errorf("expected 403 for malicious NDJSON record, got %d", resp.StatusCode) + } +} + +func TestNDJSON_E2E_ContentTypeJSONLines(t *testing.T) { + s, cleanup := newNDJSONTestServer(t) + defer cleanup() + + body := `{"id":1,"value":"safe data"}` + "\n" + + `{"id":2,"value":"more safe data"}` + "\n" + + resp, err := http.Post(s.URL+"/api/data", "application/jsonlines", strings.NewReader(body)) + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Errorf("expected 200 for clean jsonlines request, got %d", resp.StatusCode) + } +} + +func TestNDJSON_E2E_SQLiInRecord(t *testing.T) { + s, cleanup := newNDJSONTestServer(t) + defer cleanup() + + // Record with SQLi attempt + body := `{"id":1,"search":"1' ORDER BY 3--+"}` + "\n" + + resp, err := http.Post(s.URL+"/api/search", "application/x-ndjson", strings.NewReader(body)) + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusForbidden { + t.Errorf("expected 403 for SQLi in NDJSON record, got %d", resp.StatusCode) + } +} + +func TestNDJSON_E2E_EmptyLines(t *testing.T) { + s, cleanup := newNDJSONTestServer(t) + defer cleanup() + + // Stream with empty lines interspersed (should be ignored) + body := `{"name":"Alice"}` + "\n" + + "\n" + + `{"name":"Bob"}` + "\n" + + "\n" + + resp, err := http.Post(s.URL+"/api/users", "application/x-ndjson", strings.NewReader(body)) + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Errorf("expected 200 for NDJSON with empty lines, got %d", resp.StatusCode) + } +} + +func TestNDJSON_E2E_NonNDJSONContentTypeNotProcessed(t *testing.T) { + s, cleanup := newNDJSONTestServer(t) + defer cleanup() + + // With application/json content-type the body is not processed as JSONSTREAM + // so the "evil" keyword should not be detected via ARGS_POST + body := `{"name":"evil payload"}` + + resp, err := http.Post(s.URL+"/api/data", "application/json", strings.NewReader(body)) + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + + // application/json goes through the JSON body processor, not JSONSTREAM. + // The JSON processor stores fields in ARGS_POST too, so ARGS_POST rule still fires. + // Just confirm the request is processed (200 or 403 are both valid here). + if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusForbidden { + t.Errorf("unexpected status %d for non-NDJSON content type", resp.StatusCode) + } +} diff --git a/testing/engine.go b/testing/engine.go index fa8818ec2..463c336b3 100644 --- a/testing/engine.go +++ b/testing/engine.go @@ -113,7 +113,10 @@ func (t *Test) SetRawRequest(request []byte) error { } // parse body if i < len(spl) { - return t.SetRequestBody(strings.Join(spl[i:], "\r\n")) + // i is the index of the empty line separator. + // Skip the separator by joining from i+1. + // If i is the last element (i+1 == len(spl)), spl[i+1:] is empty, which is correct. + return t.SetRequestBody(strings.Join(spl[i+1:], "\r\n")) } return nil @@ -188,6 +191,19 @@ func (t *Test) RunPhases() error { func (t *Test) OutputInterruptionErrors() []string { var errors []string + // Check if interruption expectation matches actual state + if t.ExpectedOutput.Interruption == nil && t.transaction.IsInterrupted() { + errors = append(errors, fmt.Sprintf("Expected no interruption, but transaction was interrupted by rule %d with action '%s'", + t.transaction.Interruption().RuleID, t.transaction.Interruption().Action)) + return errors + } + + if t.ExpectedOutput.Interruption != nil && !t.transaction.IsInterrupted() { + errors = append(errors, "Expected interruption, but transaction was not interrupted") + return errors + } + + // If we expect an interruption and got one, validate the details if t.ExpectedOutput.Interruption != nil && t.transaction.IsInterrupted() { if t.ExpectedOutput.Interruption.Action != t.transaction.Interruption().Action { errors = append(errors, fmt.Sprintf("Interruption.Action: expected: '%s', got: '%s'", diff --git a/testing/engine/allow_detection_only.go b/testing/engine/allow_detection_only.go new file mode 100644 index 000000000..027cecea8 --- /dev/null +++ b/testing/engine/allow_detection_only.go @@ -0,0 +1,119 @@ +// Copyright 2026 Juan Pablo Tosso and the OWASP Coraza contributors +// SPDX-License-Identifier: Apache-2.0 + +//go:build !coraza.rule.multiphase_evaluation + +package engine + +import ( + "github.com/corazawaf/coraza/v3/testing/profile" +) + +// These two profiles use the same rules to show the behavioral difference between On and DetectionOnly. +// In On mode, allow is a disruptive action that skips subsequent rules. +// In DetectionOnly mode, disruptive actions (including allow) do not affect rule flow, +// so all matching rules are evaluated. +var allowRules = ` +SecDebugLogLevel 5 +SecDefaultAction "phase:1,log,pass" + +# Full allow: rule 2 is skipped in On mode, but still evaluated in DetectionOnly +SecRule REQUEST_URI "/allow_me" "id:1,phase:1,allow,log,msg:'ALLOWED'" +SecRule REQUEST_URI "/allow_me" "id:2,phase:1,deny,log,msg:'Should be skipped by allow'" + +# allow:phase skips remaining phase 1 rules in On mode, but not in DetectionOnly +SecRule REQUEST_URI "/partial_allow" "id:11,phase:1,allow:phase,log,msg:'Allowed in this phase only'" +SecRule REQUEST_URI "/partial_allow" "id:12,phase:1,deny,log,msg:'Should be skipped by allow phase'" +SecRule REQUEST_URI "/partial_allow" "id:13,phase:1,deny,log,msg:'Should be skipped by allow phase'" +SecRule REQUEST_URI "/partial_allow" "id:22,phase:2,deny,log,status:500,msg:'Denied in phase 2'" +SecRule REQUEST_URI "/partial_allow" "id:23,phase:2,deny,log,status:501,msg:'Denied in phase 2'" +` + +// allow_on.yaml: baseline with SecRuleEngine On. +// Rule 1 allows, so rule 2 is skipped. +// Rule 11 allows phase 1, so rules 12 and 13 are skipped +// but rule 22 in phase 2 still triggers and interrupts. +// rule 23 in phase 2 is not triggered because of the previous interruption. +var _ = profile.RegisterProfile(profile.Profile{ + Meta: profile.Meta{ + Author: "M4tteoP", + Description: "Test allow action with SecRuleEngine On (baseline for DetectionOnly comparison)", + Enabled: true, + Name: "allow_on.yaml", + }, + Tests: []profile.Test{ + { + Title: "allow with engine on", + Stages: []profile.Stage{ + { + Stage: profile.SubStage{ + Input: profile.StageInput{ + URI: "/allow_me?key=value", + }, + Output: profile.ExpectedOutput{ + TriggeredRules: []int{1}, + NonTriggeredRules: []int{2}, + }, + }, + }, + { + Stage: profile.SubStage{ + Input: profile.StageInput{ + URI: "/partial_allow?key=value", + }, + Output: profile.ExpectedOutput{ + TriggeredRules: []int{11, 22}, + NonTriggeredRules: []int{12, 13, 23}, + Interruption: &profile.ExpectedInterruption{ + Status: 500, + RuleID: 22, + Action: "deny", + }, + }, + }, + }, + }, + }, + }, + Rules: "SecRuleEngine On\n" + allowRules, +}) + +// allow_detection_only.yaml: same rules as above but with SecRuleEngine DetectionOnly. +// In DetectionOnly mode, allow does not affect rule flow: all matching rules are still evaluated. +// This differs from On mode where allow skips subsequent rules. +var _ = profile.RegisterProfile(profile.Profile{ + Meta: profile.Meta{ + Author: "M4tteoP", + Description: "Test allow action with SecRuleEngine DetectionOnly (same rules as allow_on.yaml)", + Enabled: true, + Name: "allow_detection_only.yaml", + }, + Tests: []profile.Test{ + { + Title: "allow with engine detection only", + Stages: []profile.Stage{ + { + Stage: profile.SubStage{ + Input: profile.StageInput{ + URI: "/allow_me?key=value", + }, + Output: profile.ExpectedOutput{ + TriggeredRules: []int{1, 2}, + }, + }, + }, + { + Stage: profile.SubStage{ + Input: profile.StageInput{ + URI: "/partial_allow?key=value", + }, + Output: profile.ExpectedOutput{ + TriggeredRules: []int{11, 12, 13, 22, 23}, + }, + }, + }, + }, + }, + }, + Rules: "SecRuleEngine DetectionOnly\n" + allowRules, +}) diff --git a/testing/engine/ctl.go b/testing/engine/ctl.go index 2959c5d66..5b4ee5df8 100644 --- a/testing/engine/ctl.go +++ b/testing/engine/ctl.go @@ -46,6 +46,109 @@ var _ = profile.RegisterProfile(profile.Profile{ }, }, }, + { + Title: "ruleRemoveTargetById whole collection", + Stages: []profile.Stage{ + { + Stage: profile.SubStage{ + Input: profile.StageInput{ + Method: "GET", + URI: "/test.php?foo=bar&baz=qux", + }, + Output: profile.ExpectedOutput{ + // Rule 200 removes all ARGS_GET from rule 201, so rule 201 should not match + NonTriggeredRules: []int{201}, + TriggeredRules: []int{200}, + }, + }, + }, + }, + }, + { + Title: "ruleRemoveTargetById regex key", + Stages: []profile.Stage{ + { + Stage: profile.SubStage{ + Input: profile.StageInput{ + Method: "GET", + // json.0.desc and json.1.desc match the regex; they are the only args + URI: "/api/jobs?json.0.desc=attack&json.1.desc=attack", + }, + Output: profile.ExpectedOutput{ + // Rule 300 logs and removes ARGS_GET:/^json\.\d+\.desc$/ from rule 301. + // Rule 301 would normally match the attack args but they are excluded by ctl. + TriggeredRules: []int{300}, + NonTriggeredRules: []int{301}, + }, + }, + }, + }, + }, + { + Title: "ruleRemoveTargetById regex key (POST JSON body)", + Stages: []profile.Stage{ + { + Stage: profile.SubStage{ + Input: profile.StageInput{ + Method: "POST", + URI: "/api/jsonjobs", + Headers: map[string]string{ + "Content-Type": "application/json", + }, + // JSON array → ARGS_POST: json.0.desc=attack, json.1.desc=attack + Data: `[{"desc": "attack"}, {"desc": "attack"}]`, + }, + Output: profile.ExpectedOutput{ + // Rule 310 sets the JSON body processor. + // Rule 311 removes ARGS_POST matching /^json\.\d+\.desc$/ from rule 312. + // Rule 312 would normally match "attack" but must be suppressed. + TriggeredRules: []int{310, 311}, + NonTriggeredRules: []int{312}, + }, + }, + }, + }, + }, + { + Title: "ruleRemoveTargetByTag regex key", + Stages: []profile.Stage{ + { + Stage: profile.SubStage{ + Input: profile.StageInput{ + Method: "GET", + // json.0.desc and json.1.desc match the regex; rule 321 is tagged OWASP_CRS + URI: "/api/tag-test?json.0.desc=attack&json.1.desc=attack", + }, + Output: profile.ExpectedOutput{ + // Rule 320 removes ARGS_GET:/^json\.\d+\.desc$/ from all rules tagged OWASP_CRS. + // Rule 321 (tagged OWASP_CRS) would normally match the attack args but must be suppressed. + TriggeredRules: []int{320}, + NonTriggeredRules: []int{321}, + }, + }, + }, + }, + }, + { + Title: "ruleRemoveTargetByMsg regex key", + Stages: []profile.Stage{ + { + Stage: profile.SubStage{ + Input: profile.StageInput{ + Method: "GET", + // json.0.desc and json.1.desc match the regex; rule 331 has msg:'web shell detection' + URI: "/api/msg-test?json.0.desc=attack&json.1.desc=attack", + }, + Output: profile.ExpectedOutput{ + // Rule 330 removes ARGS_GET:/^json\.\d+\.desc$/ from all rules with msg:'web shell detection'. + // Rule 331 would normally match the attack args but must be suppressed. + TriggeredRules: []int{330}, + NonTriggeredRules: []int{331}, + }, + }, + }, + }, + }, }, Rules: ` SecDebugLogLevel 9 @@ -76,5 +179,36 @@ SecRule REQUEST_BODY "pizza" "id:103,log,phase:2" SecRule ARGS_POST:pineapple "pizza" "id:105,log,phase:2" SecAction "id:444,phase:2,log" + +# ruleRemoveTargetById whole collection test: removes all ARGS_GET from rule 201 +SecAction "id:200,phase:1,ctl:ruleRemoveTargetById=201;ARGS_GET,log" +SecRule ARGS_GET "@rx ." "id:201, phase:1, log" + +# ruleRemoveTargetById regex key test (GET): +# Rule 300 removes ARGS_GET matching /^json\.\d+\.desc$/ from rule 301. +# Matching args (json.0.desc, json.1.desc) must NOT trigger rule 301. +SecRule REQUEST_URI "@beginsWith /api/jobs" "id:300,phase:1,pass,log,ctl:ruleRemoveTargetById=301;ARGS_GET:/^json\.\d+\.desc$/" +SecRule ARGS_GET "@rx attack" "id:301,phase:1,log" + +# ruleRemoveTargetById regex key test (POST JSON body): +# Rule 310 activates JSON body processor for application/json requests. +# Rule 311 removes ARGS_POST matching /^json\.\d+\.desc$/ from rule 312 when URI starts with /api/jsonjobs. +# JSON body [{"desc":"attack"},{"desc":"attack"}] → ARGS_POST: json.0.desc=attack, json.1.desc=attack. +# Rule 312 would normally match "attack" in ARGS_POST but must be suppressed by rule 311's CTL. +SecRule REQUEST_HEADERS:content-type "@beginsWith application/json" "id:310,phase:1,pass,log,ctl:requestBodyProcessor=JSON" +SecRule REQUEST_URI "@beginsWith /api/jsonjobs" "id:311,phase:1,pass,log,ctl:ruleRemoveTargetById=312;ARGS_POST:/^json\.\d+\.desc$/" +SecRule ARGS_POST "@rx attack" "id:312,phase:2,log" + +# ruleRemoveTargetByTag regex key test: +# Rule 320 removes ARGS_GET matching /^json\.\d+\.desc$/ from all rules tagged OWASP_CRS. +# Rule 321 is tagged OWASP_CRS and would normally match the attack args but must be suppressed. +SecRule REQUEST_URI "@beginsWith /api/tag-test" "id:320,phase:1,pass,log,ctl:ruleRemoveTargetByTag=OWASP_CRS;ARGS_GET:/^json\.\d+\.desc$/" +SecRule ARGS_GET "@rx attack" "id:321,phase:1,log,tag:OWASP_CRS" + +# ruleRemoveTargetByMsg regex key test: +# Rule 330 removes ARGS_GET matching /^json\.\d+\.desc$/ from all rules with msg:'web shell detection'. +# Rule 331 has msg:'web shell detection' and would normally match the attack args but must be suppressed. +SecRule REQUEST_URI "@beginsWith /api/msg-test" "id:330,phase:1,pass,log,ctl:ruleRemoveTargetByMsg=web shell detection;ARGS_GET:/^json\.\d+\.desc$/" +SecRule ARGS_GET "@rx attack" "id:331,phase:1,log,msg:'web shell detection'" `, }) diff --git a/testing/engine/directives_updateactions.go b/testing/engine/directives_updateactions.go index e5fd692d1..d72279c6e 100644 --- a/testing/engine/directives_updateactions.go +++ b/testing/engine/directives_updateactions.go @@ -28,6 +28,7 @@ var _ = profile.RegisterProfile(profile.Profile{ TriggeredRules: []int{ 1004, }, + // No interruption expected because pass action should replace deny }, }, }, @@ -52,13 +53,85 @@ var _ = profile.RegisterProfile(profile.Profile{ }, }, }, + { + Title: "SecRuleUpdateActionById with non-disruptive actions preserves disruptive action", + Stages: []profile.Stage{ + { + Stage: profile.SubStage{ + Input: profile.StageInput{ + URI: "/test?param=trigger", + }, + Output: profile.ExpectedOutput{ + TriggeredRules: []int{ + 2001, + }, + Interruption: &profile.ExpectedInterruption{ + Status: 403, + RuleID: 2001, + Action: "deny", + }, + }, + }, + }, + }, + }, + { + Title: "SecRuleUpdateActionById issue #1414 - deny to pass should not block", + Stages: []profile.Stage{ + { + Stage: profile.SubStage{ + Input: profile.StageInput{ + URI: "/test?id=0", + }, + Output: profile.ExpectedOutput{ + TriggeredRules: []int{ + 3001, + }, + // No interruption expected - this is the bug from issue #1414 + }, + }, + }, + }, + }, + { + Title: "SecRuleUpdateActionById range update - deny to pass should not block", + Stages: []profile.Stage{ + { + Stage: profile.SubStage{ + Input: profile.StageInput{ + URI: "/test?id=0", + }, + Output: profile.ExpectedOutput{ + TriggeredRules: []int{ + 4001, 4002, + }, + // No interruption expected - range update changed deny to pass + }, + }, + }, + }, + }, }, Rules: ` + # Test 1: Updating deny to pass SecRule ARGS "@contains value1" "phase:1,id:1004,deny" SecRule ARGS "@contains value1" "phase:1,id:1005,log" SecRuleUpdateActionById 1004 "pass" SecRule ARGS "@contains value2" "phase:2,id:1014,block,deny" SecRuleUpdateActionById 1014 "redirect:'https://www.example.com/',status:302" + + # Test 2: Updating with non-disruptive actions should preserve disruptive action + SecRule ARGS:param "@contains trigger" "phase:1,id:2001,deny,status:403" + SecRuleUpdateActionById 2001 "log,auditlog" + + # Test 3: Issue #1414 - exact scenario from the GitHub issue + SecRule ARGS:id "@eq 0" "id:3001, phase:1,deny,status:403,msg:'Invalid id',log,auditlog" + SecRuleUpdateActionById 3001 "log,pass" + + # Test 4: Range update - same scenario as test 3 but with two rules updated via range + SecRule ARGS:id "@eq 0" "id:4001, phase:1,deny,status:403,msg:'Invalid id',log,auditlog" + SecRule ARGS:id "@eq 0" "id:4002, phase:1,deny,status:403,msg:'Invalid id',log,auditlog" + SecRuleUpdateActionById 4001-4002 "log,pass" `, }) diff --git a/testing/engine/json.go b/testing/engine/json.go index a1c554c4a..512b9507e 100644 --- a/testing/engine/json.go +++ b/testing/engine/json.go @@ -46,7 +46,7 @@ var _ = profile.RegisterProfile(profile.Profile{ Headers: map[string]string{ "Content-Type": "application/json", }, - Data: `{"test":123, "test2": 456, "test3": [22, 44, 55], "test4": 3}`, + Data: `{"test": 123, "test2": 456, "test3": [22, 44, 55], "test4": 3}`, }, }, }, diff --git a/testing/engine/multipart.go b/testing/engine/multipart.go index 9c0923d47..ba271c436 100644 --- a/testing/engine/multipart.go +++ b/testing/engine/multipart.go @@ -38,7 +38,7 @@ Regards, -- airween -----0000-- +----0000-- `, }, Output: profile.ExpectedOutput{ @@ -86,7 +86,7 @@ var _ = profile.RegisterProfile(profile.Profile{ \x0EContent-Disposition: form-data; name="_msg_body" The Content-Disposition header contains an invalid character (0x0E). -----0000-- +----0000-- `, }, Output: profile.ExpectedOutput{ @@ -113,7 +113,7 @@ Content-\x20Disposition: form-data; name="file"; filename="1.php" 0x20 character is expected to be the last invalid character before the valid range. Therefore, the parser should fail and raise MULTIPART_STRICT_ERROR. -----0000-- +----0000-- `, }, Output: profile.ExpectedOutput{ diff --git a/testing/engine/multiphase.go b/testing/engine/multiphase.go index 9f6658e1a..cc787984d 100644 --- a/testing/engine/multiphase.go +++ b/testing/engine/multiphase.go @@ -167,6 +167,11 @@ var _ = profile.RegisterProfile(profile.Profile{ Output: profile.ExpectedOutput{ TriggeredRules: []int{1, 3, 4}, NonTriggeredRules: []int{2}, + Interruption: &profile.ExpectedInterruption{ + Status: 504, + RuleID: 3, + Action: "deny", + }, }, }, }, diff --git a/testing/engine_output_test.go b/testing/engine_output_test.go new file mode 100644 index 000000000..9afe28225 --- /dev/null +++ b/testing/engine_output_test.go @@ -0,0 +1,679 @@ +// Copyright 2022 Juan Pablo Tosso and the OWASP Coraza contributors +// SPDX-License-Identifier: Apache-2.0 + +package testing + +import ( + "strings" + "testing" + + "github.com/corazawaf/coraza/v3" + "github.com/corazawaf/coraza/v3/testing/profile" +) + +func TestOutputInterruptionErrors_NoInterruptionExpectedButGot(t *testing.T) { + waf, err := coraza.NewWAF( + coraza.NewWAFConfig(). + WithDirectives(` + SecRuleEngine On + SecRule REQUEST_URI "@streq /block" "id:1,phase:1,deny,status:403" + `), + ) + if err != nil { + t.Fatalf("unexpected error: %s", err.Error()) + } + test := NewTest("test", waf) + test.RequestURI = "/block" + test.ExpectedOutput.Interruption = nil // No interruption expected + + if err := test.RunPhases(); err != nil { + t.Error(err) + } + + errors := test.OutputInterruptionErrors() + if len(errors) == 0 { + t.Error("Expected errors when interruption happened but wasn't expected") + } + expectedMsg := "Expected no interruption, but transaction was interrupted" + if !strings.Contains(errors[0], expectedMsg) { + t.Errorf("Expected error message to contain '%s', got: %s", expectedMsg, errors[0]) + } + if !strings.Contains(errors[0], "rule 1") { + t.Errorf("Expected error message to contain rule ID, got: %s", errors[0]) + } +} + +func TestOutputInterruptionErrors_InterruptionExpectedButDidntHappen(t *testing.T) { + waf, err := coraza.NewWAF(coraza.NewWAFConfig()) + if err != nil { + t.Fatalf("unexpected error: %s", err.Error()) + } + test := NewTest("test", waf) + test.RequestURI = "/allow" + test.ExpectedOutput.Interruption = &profile.ExpectedInterruption{ + Action: "deny", + Status: 403, + RuleID: 1, + } + + if err := test.RunPhases(); err != nil { + t.Error(err) + } + + errors := test.OutputInterruptionErrors() + if len(errors) == 0 { + t.Error("Expected errors when interruption was expected but didn't happen") + } + expectedMsg := "Expected interruption, but transaction was not interrupted" + if errors[0] != expectedMsg { + t.Errorf("Expected error message '%s', got: %s", expectedMsg, errors[0]) + } +} + +func TestOutputInterruptionErrors_InterruptionDetailsMatch(t *testing.T) { + waf, err := coraza.NewWAF( + coraza.NewWAFConfig(). + WithDirectives(` + SecRuleEngine On + SecRule REQUEST_URI "@streq /block" "id:123,phase:1,deny,status:403" + `), + ) + if err != nil { + t.Fatalf("unexpected error: %s", err.Error()) + } + test := NewTest("test", waf) + test.RequestURI = "/block" + test.ExpectedOutput.Interruption = &profile.ExpectedInterruption{ + Action: "deny", + Status: 403, + Data: "", + RuleID: 123, + } + + if err := test.RunPhases(); err != nil { + t.Error(err) + } + + errors := test.OutputInterruptionErrors() + if len(errors) != 0 { + t.Errorf("Expected no errors when interruption details match, got: %v", errors) + } +} + +func TestOutputInterruptionErrors_ActionMismatch(t *testing.T) { + waf, err := coraza.NewWAF( + coraza.NewWAFConfig(). + WithDirectives(` + SecRuleEngine On + SecRule REQUEST_URI "@streq /block" "id:1,phase:1,deny,status:403" + `), + ) + if err != nil { + t.Fatalf("unexpected error: %s", err.Error()) + } + test := NewTest("test", waf) + test.RequestURI = "/block" + test.ExpectedOutput.Interruption = &profile.ExpectedInterruption{ + Action: "drop", + Status: 403, + RuleID: 1, + } + + if err := test.RunPhases(); err != nil { + t.Error(err) + } + + errors := test.OutputInterruptionErrors() + if len(errors) == 0 { + t.Error("Expected errors when interruption action doesn't match") + } + expectedMsg := "Interruption.Action: expected: 'drop', got: 'deny'" + if errors[0] != expectedMsg { + t.Errorf("Expected error message '%s', got: %s", expectedMsg, errors[0]) + } +} + +func TestOutputInterruptionErrors_StatusMismatch(t *testing.T) { + waf, err := coraza.NewWAF( + coraza.NewWAFConfig(). + WithDirectives(` + SecRuleEngine On + SecRule REQUEST_URI "@streq /block" "id:1,phase:1,deny,status:403" + `), + ) + if err != nil { + t.Fatalf("unexpected error: %s", err.Error()) + } + test := NewTest("test", waf) + test.RequestURI = "/block" + test.ExpectedOutput.Interruption = &profile.ExpectedInterruption{ + Action: "deny", + Status: 404, + RuleID: 1, + } + + if err := test.RunPhases(); err != nil { + t.Error(err) + } + + errors := test.OutputInterruptionErrors() + if len(errors) == 0 { + t.Error("Expected errors when interruption status doesn't match") + } + expectedMsg := "Interruption.Status: expected: '404', got: '403'" + if errors[0] != expectedMsg { + t.Errorf("Expected error message '%s', got: %s", expectedMsg, errors[0]) + } +} + +func TestOutputInterruptionErrors_RuleIDMismatch(t *testing.T) { + waf, err := coraza.NewWAF( + coraza.NewWAFConfig(). + WithDirectives(` + SecRuleEngine On + SecRule REQUEST_URI "@streq /block" "id:123,phase:1,deny,status:403" + `), + ) + if err != nil { + t.Fatalf("unexpected error: %s", err.Error()) + } + test := NewTest("test", waf) + test.RequestURI = "/block" + test.ExpectedOutput.Interruption = &profile.ExpectedInterruption{ + Action: "deny", + Status: 403, + RuleID: 456, + } + + if err := test.RunPhases(); err != nil { + t.Error(err) + } + + errors := test.OutputInterruptionErrors() + if len(errors) == 0 { + t.Error("Expected errors when interruption RuleID doesn't match") + } + expectedMsg := "Interruption.RuleID: expected: '456', got: '123'" + if errors[0] != expectedMsg { + t.Errorf("Expected error message '%s', got: %s", expectedMsg, errors[0]) + } +} + +func TestOutputErrors_LogContains(t *testing.T) { + waf, err := coraza.NewWAF( + coraza.NewWAFConfig(). + WithDirectives(` + SecRuleEngine On + SecRule REQUEST_URI "@streq /test" "id:100,phase:1,log,msg:'Test message'" + `), + ) + if err != nil { + t.Fatalf("unexpected error: %s", err.Error()) + } + test := NewTest("test", waf) + test.RequestURI = "/test" + test.ExpectedOutput.LogContains = "Test message" + + if err := test.RunPhases(); err != nil { + t.Error(err) + } + + errors := test.OutputErrors() + if len(errors) != 0 { + t.Errorf("Expected no errors when log contains expected message, got: %v", errors) + } +} + +func TestOutputErrors_LogContainsMissing(t *testing.T) { + waf, err := coraza.NewWAF( + coraza.NewWAFConfig(). + WithDirectives(` + SecRuleEngine On + SecRule REQUEST_URI "@streq /test" "id:100,phase:1,log,msg:'Different message'" + `), + ) + if err != nil { + t.Fatalf("unexpected error: %s", err.Error()) + } + test := NewTest("test", waf) + test.RequestURI = "/test" + test.ExpectedOutput.LogContains = "Missing message" + + if err := test.RunPhases(); err != nil { + t.Error(err) + } + + errors := test.OutputErrors() + if len(errors) == 0 { + t.Error("Expected errors when log doesn't contain expected message") + } + expectedMsg := "Expected log to contain 'Missing message'" + if errors[0] != expectedMsg { + t.Errorf("Expected error message '%s', got: %s", expectedMsg, errors[0]) + } +} + +func TestOutputErrors_NoLogContains(t *testing.T) { + waf, err := coraza.NewWAF( + coraza.NewWAFConfig(). + WithDirectives(` + SecRuleEngine On + SecRule REQUEST_URI "@streq /test" "id:100,phase:1,log,msg:'Test message'" + `), + ) + if err != nil { + t.Fatalf("unexpected error: %s", err.Error()) + } + test := NewTest("test", waf) + test.RequestURI = "/test" + test.ExpectedOutput.NoLogContains = "Should not be here" + + if err := test.RunPhases(); err != nil { + t.Error(err) + } + + errors := test.OutputErrors() + if len(errors) != 0 { + t.Errorf("Expected no errors when log doesn't contain unwanted message, got: %v", errors) + } +} + +func TestOutputErrors_NoLogContainsFails(t *testing.T) { + waf, err := coraza.NewWAF( + coraza.NewWAFConfig(). + WithDirectives(` + SecRuleEngine On + SecRule REQUEST_URI "@streq /test" "id:100,phase:1,log,msg:'Forbidden message'" + `), + ) + if err != nil { + t.Fatalf("unexpected error: %s", err.Error()) + } + test := NewTest("test", waf) + test.RequestURI = "/test" + test.ExpectedOutput.NoLogContains = "Forbidden message" + + if err := test.RunPhases(); err != nil { + t.Error(err) + } + + errors := test.OutputErrors() + if len(errors) == 0 { + t.Error("Expected errors when log contains unwanted message") + } + expectedMsg := "Expected log to not contain 'Forbidden message'" + if errors[0] != expectedMsg { + t.Errorf("Expected error message '%s', got: %s", expectedMsg, errors[0]) + } +} + +func TestOutputErrors_TriggeredRules(t *testing.T) { + waf, err := coraza.NewWAF( + coraza.NewWAFConfig(). + WithDirectives(` + SecRuleEngine On + SecRule REQUEST_URI "@streq /test1" "id:101,phase:1,log" + SecRule REQUEST_URI "@streq /test2" "id:102,phase:1,log" + `), + ) + if err != nil { + t.Fatalf("unexpected error: %s", err.Error()) + } + test := NewTest("test", waf) + test.RequestURI = "/test1" + test.ExpectedOutput.TriggeredRules = []int{101} + + if err := test.RunPhases(); err != nil { + t.Error(err) + } + + errors := test.OutputErrors() + if len(errors) != 0 { + t.Errorf("Expected no errors when expected rules are triggered, got: %v", errors) + } +} + +func TestOutputErrors_TriggeredRulesNotTriggered(t *testing.T) { + waf, err := coraza.NewWAF( + coraza.NewWAFConfig(). + WithDirectives(` + SecRuleEngine On + SecRule REQUEST_URI "@streq /test" "id:101,phase:1,log" + `), + ) + if err != nil { + t.Fatalf("unexpected error: %s", err.Error()) + } + test := NewTest("test", waf) + test.RequestURI = "/other" + test.ExpectedOutput.TriggeredRules = []int{101, 102} + + if err := test.RunPhases(); err != nil { + t.Error(err) + } + + errors := test.OutputErrors() + if len(errors) != 2 { + t.Errorf("Expected 2 errors for 2 missing rules, got: %v", errors) + } + if !strings.Contains(errors[0], "Expected rule '101' to be triggered") { + t.Errorf("Expected error about rule 101, got: %s", errors[0]) + } + if !strings.Contains(errors[1], "Expected rule '102' to be triggered") { + t.Errorf("Expected error about rule 102, got: %s", errors[1]) + } +} + +func TestOutputErrors_NonTriggeredRules(t *testing.T) { + waf, err := coraza.NewWAF( + coraza.NewWAFConfig(). + WithDirectives(` + SecRuleEngine On + SecRule REQUEST_URI "@streq /test" "id:101,phase:1,log" + `), + ) + if err != nil { + t.Fatalf("unexpected error: %s", err.Error()) + } + test := NewTest("test", waf) + test.RequestURI = "/other" + test.ExpectedOutput.NonTriggeredRules = []int{101} + + if err := test.RunPhases(); err != nil { + t.Error(err) + } + + errors := test.OutputErrors() + if len(errors) != 0 { + t.Errorf("Expected no errors when rules are not triggered as expected, got: %v", errors) + } +} + +func TestOutputErrors_NonTriggeredRulesActuallyTriggered(t *testing.T) { + waf, err := coraza.NewWAF( + coraza.NewWAFConfig(). + WithDirectives(` + SecRuleEngine On + SecRule REQUEST_URI "@streq /test" "id:101,phase:1,log" + `), + ) + if err != nil { + t.Fatalf("unexpected error: %s", err.Error()) + } + test := NewTest("test", waf) + test.RequestURI = "/test" + test.ExpectedOutput.NonTriggeredRules = []int{101} + + if err := test.RunPhases(); err != nil { + t.Error(err) + } + + errors := test.OutputErrors() + if len(errors) == 0 { + t.Error("Expected errors when non-triggered rules are actually triggered") + } + expectedMsg := "Expected rule '101' to not be triggered" + if errors[0] != expectedMsg { + t.Errorf("Expected error message '%s', got: %s", expectedMsg, errors[0]) + } +} + +func TestSetEncodedRequest(t *testing.T) { + waf, _ := coraza.NewWAF(coraza.NewWAFConfig()) + test := NewTest("test", waf) + + // Base64 encoded: "GET /encoded HTTP/1.1\r\nHost: example.com\r\n\r\n" + encodedReq := "R0VUIC9lbmNvZGVkIEhUVFAvMS4xDQpIb3N0OiBleGFtcGxlLmNvbQ0KDQo=" + + if err := test.SetEncodedRequest(encodedReq); err != nil { + t.Errorf("Unexpected error: %v", err) + } + + if test.RequestMethod != "GET" { + t.Errorf("Expected method GET, got %s", test.RequestMethod) + } + if test.RequestURI != "/encoded" { + t.Errorf("Expected URI /encoded, got %s", test.RequestURI) + } + if test.RequestProtocol != "HTTP/1.1" { + t.Errorf("Expected protocol HTTP/1.1, got %s", test.RequestProtocol) + } + if test.RequestHeaders["Host"] != "example.com" { + t.Errorf("Expected Host header example.com, got %s", test.RequestHeaders["Host"]) + } +} + +func TestSetEncodedRequest_Empty(t *testing.T) { + waf, _ := coraza.NewWAF(coraza.NewWAFConfig()) + test := NewTest("test", waf) + + if err := test.SetEncodedRequest(""); err != nil { + t.Errorf("Empty encoded request should not error, got: %v", err) + } +} + +func TestSetEncodedRequest_Invalid(t *testing.T) { + waf, _ := coraza.NewWAF(coraza.NewWAFConfig()) + test := NewTest("test", waf) + + if err := test.SetEncodedRequest("invalid-base64!!!"); err == nil { + t.Error("Expected error for invalid base64, got nil") + } +} + +func TestSetRawRequest_WithNewlineOnly(t *testing.T) { + waf, _ := coraza.NewWAF(coraza.NewWAFConfig()) + test := NewTest("test", waf) + + req := "POST /path HTTP/1.1\nHost: www.example.com\nContent-Type: application/json\n\n{\"key\":\"value\"}" + if err := test.SetRawRequest([]byte(req)); err != nil { + t.Errorf("Unexpected error with \\n line endings: %v", err) + } + + if test.RequestMethod != "POST" { + t.Errorf("Expected POST, got %s", test.RequestMethod) + } + if test.RequestHeaders["Content-Type"] != "application/json" { + t.Errorf("Expected Content-Type application/json, got %s", test.RequestHeaders["Content-Type"]) + } +} + +func TestSetRawRequest_Empty(t *testing.T) { + waf, _ := coraza.NewWAF(coraza.NewWAFConfig()) + test := NewTest("test", waf) + + if err := test.SetRawRequest([]byte{}); err != nil { + t.Errorf("Empty request should not error, got: %v", err) + } +} + +func TestSetRawRequest_InvalidRequestLine(t *testing.T) { + waf, _ := coraza.NewWAF(coraza.NewWAFConfig()) + test := NewTest("test", waf) + + // Request line with only 2 parts instead of 3 + req := "GET /path\r\nHost: www.example.com\r\n\r\n" + if err := test.SetRawRequest([]byte(req)); err == nil { + t.Error("Expected error for invalid request line, got nil") + } +} + +func TestSetRawRequest_InvalidHeader(t *testing.T) { + waf, _ := coraza.NewWAF(coraza.NewWAFConfig()) + test := NewTest("test", waf) + + // Header without colon + req := "GET /path HTTP/1.1\r\nInvalidHeader\r\n\r\n" + if err := test.SetRawRequest([]byte(req)); err == nil { + t.Error("Expected error for invalid header, got nil") + } +} + +func TestSetRawRequest_SingleLine(t *testing.T) { + waf, _ := coraza.NewWAF(coraza.NewWAFConfig()) + test := NewTest("test", waf) + + // Single line request (invalid) + req := "GET /path HTTP/1.1" + if err := test.SetRawRequest([]byte(req)); err == nil { + t.Error("Expected error for single line request, got nil") + } +} + +func TestSetRawRequest_WithBody(t *testing.T) { + waf, _ := coraza.NewWAF(coraza.NewWAFConfig()) + test := NewTest("test", waf) + + req := "POST /path HTTP/1.1\r\nHost: www.example.com\r\n\r\ntest=body&data=value" + if err := test.SetRawRequest([]byte(req)); err != nil { + t.Errorf("Unexpected error: %v", err) + } + + // The body should start after the empty line separator when parsed from raw request + expectedBody := "test=body&data=value" + if test.body != expectedBody { + t.Errorf("Expected body '%s', got '%s'", expectedBody, test.body) + } +} + +func TestDisableMagic(t *testing.T) { + waf, _ := coraza.NewWAF(coraza.NewWAFConfig()) + test := NewTest("test", waf) + test.DisableMagic() + + bodyContent := "test body content" + if err := test.SetRequestBody(bodyContent); err != nil { + t.Errorf("Unexpected error: %v", err) + } + + // With magic disabled, content-length should not be set automatically + if _, ok := test.RequestHeaders["content-length"]; ok { + t.Error("Expected content-length to not be set when magic is disabled") + } +} + +func TestMagicEnabled(t *testing.T) { + waf, _ := coraza.NewWAF(coraza.NewWAFConfig()) + test := NewTest("test", waf) + // Magic is enabled by default + + bodyContent := "test body content" + if err := test.SetRequestBody(bodyContent); err != nil { + t.Errorf("Unexpected error: %v", err) + } + + // With magic enabled, content-length should be set automatically + if test.RequestHeaders["content-length"] != "17" { + t.Errorf("Expected content-length to be '17', got '%s'", test.RequestHeaders["content-length"]) + } +} + +func TestSetRequestBody_Nil(t *testing.T) { + waf, _ := coraza.NewWAF(coraza.NewWAFConfig()) + test := NewTest("test", waf) + + if err := test.SetRequestBody(nil); err != nil { + t.Errorf("Nil body should not error, got: %v", err) + } +} + +func TestSetRequestBody_EmptyString(t *testing.T) { + waf, _ := coraza.NewWAF(coraza.NewWAFConfig()) + test := NewTest("test", waf) + + if err := test.SetRequestBody(""); err != nil { + t.Errorf("Empty body should not error, got: %v", err) + } +} + +func TestSetResponseBody_Nil(t *testing.T) { + waf, err := coraza.NewWAF( + coraza.NewWAFConfig(). + WithResponseBodyAccess(), + ) + if err != nil { + t.Fatalf("unexpected error: %s", err.Error()) + } + test := NewTest("test", waf) + + if err := test.SetResponseBody(nil); err != nil { + t.Errorf("Nil response body should not error, got: %v", err) + } +} + +func TestSetResponseBody_EmptyString(t *testing.T) { + waf, err := coraza.NewWAF( + coraza.NewWAFConfig(). + WithResponseBodyAccess(), + ) + if err != nil { + t.Fatalf("unexpected error: %s", err.Error()) + } + test := NewTest("test", waf) + + if err := test.SetResponseBody(""); err != nil { + t.Errorf("Empty response body should not error, got: %v", err) + } +} + +func TestBodyToString_StringArray(t *testing.T) { + result := bodyToString([]string{"line1", "line2", "line3"}) + expected := "line1\r\nline2\r\nline3\r\n\r\n" + if result != expected { + t.Errorf("Expected '%s', got '%s'", expected, result) + } +} + +func TestBodyToString_String(t *testing.T) { + input := "simple string body" + result := bodyToString(input) + if result != input { + t.Errorf("Expected '%s', got '%s'", input, result) + } +} + +func TestBodyToString_InvalidType(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Error("Expected panic for invalid type, but didn't panic") + } + }() + bodyToString(123) // Should panic +} + +func TestLogContains(t *testing.T) { + waf, err := coraza.NewWAF( + coraza.NewWAFConfig(). + WithDirectives(` + SecRuleEngine On + SecRule REQUEST_URI "@streq /test" "id:200,phase:1,log,msg:'Unique test message'" + `), + ) + if err != nil { + t.Fatalf("unexpected error: %s", err.Error()) + } + test := NewTest("test", waf) + test.RequestURI = "/test" + + if err := test.RunPhases(); err != nil { + t.Error(err) + } + + if !test.LogContains("Unique test message") { + t.Error("Expected LogContains to return true for message in log") + } + + if test.LogContains("Message not in log") { + t.Error("Expected LogContains to return false for message not in log") + } +} + +func TestTransaction(t *testing.T) { + waf, _ := coraza.NewWAF(coraza.NewWAFConfig()) + test := NewTest("test", waf) + + tx := test.Transaction() + if tx == nil { + t.Error("Expected non-nil transaction") + } +} diff --git a/types/waf.go b/types/waf.go index 6cec408df..a4ac71cc9 100644 --- a/types/waf.go +++ b/types/waf.go @@ -78,6 +78,32 @@ func (re RuleEngineStatus) String() string { return "unknown" } +// UploadKeepFilesStatus represents the status of the upload keep files directive. +type UploadKeepFilesStatus int + +const ( + // UploadKeepFilesOff will delete all uploaded files after transaction (default) + UploadKeepFilesOff UploadKeepFilesStatus = iota + // UploadKeepFilesOn will keep all uploaded files after transaction + UploadKeepFilesOn + // UploadKeepFilesRelevantOnly will keep uploaded files only if a log-relevant rule matched + // (that is, a matched rule with logging enabled, excluding rules marked with nolog). + UploadKeepFilesRelevantOnly +) + +// ParseUploadKeepFilesStatus parses the upload keep files status +func ParseUploadKeepFilesStatus(s string) (UploadKeepFilesStatus, error) { + switch strings.ToLower(s) { + case "on": + return UploadKeepFilesOn, nil + case "off": + return UploadKeepFilesOff, nil + case "relevantonly": + return UploadKeepFilesRelevantOnly, nil + } + return -1, fmt.Errorf("invalid upload keep files status: %q", s) +} + // BodyLimitAction represents the action to take when // the body size exceeds the configured limit. type BodyLimitAction int @@ -94,6 +120,8 @@ const ( type AuditLogPart byte const ( + // AuditLogPartHeader is the audit log header part (mandatory) + AuditLogPartHeader AuditLogPart = 'A' // AuditLogPartRequestHeaders is the request headers part AuditLogPartRequestHeaders AuditLogPart = 'B' // AuditLogPartRequestBody is the request body part @@ -114,6 +142,8 @@ const ( AuditLogPartUploadedFiles AuditLogPart = 'J' // AuditLogPartRulesMatched is the matched rules part AuditLogPartRulesMatched AuditLogPart = 'K' + // AuditLogPartEndMarker is the final boundary, signifies the end of the entry (mandatory) + AuditLogPartEndMarker AuditLogPart = 'Z' ) // AuditLogParts represents the parts of the audit log @@ -155,13 +185,15 @@ func ParseAuditLogParts(opts string) (AuditLogParts, error) { return nil, errors.New("audit log parts is required to end with Z") } - parts := opts[1 : len(opts)-1] - for _, p := range parts { + // Validate the middle parts (everything between A and Z) + middleParts := opts[1 : len(opts)-1] + for _, p := range middleParts { if !slices.Contains(orderedAuditLogParts, AuditLogPart(p)) { return AuditLogParts(""), fmt.Errorf("invalid audit log parts %q", opts) } } - return AuditLogParts(parts), nil + // Return all parts including A and Z + return AuditLogParts(opts), nil } // ApplyAuditLogParts applies audit log parts modifications to the base parts. diff --git a/types/waf_test.go b/types/waf_test.go index 0aa63dbf0..83d3e5447 100644 --- a/types/waf_test.go +++ b/types/waf_test.go @@ -12,7 +12,7 @@ func TestParseAuditLogParts(t *testing.T) { expectedHasError bool }{ {"", nil, true}, - {"ABCDEFGHIJKZ", []AuditLogPart("BCDEFGHIJK"), false}, + {"ABCDEFGHIJKZ", []AuditLogPart("ABCDEFGHIJKZ"), false}, {"DEFGHZ", nil, true}, {"ABCD", nil, true}, {"AMZ", nil, true}, @@ -98,7 +98,7 @@ func TestApplyAuditLogParts(t *testing.T) { name: "absolute value (starts with A, ends with Z)", base: AuditLogParts("BC"), modification: "ABCDEFZ", - expectedParts: AuditLogParts("BCDEF"), + expectedParts: AuditLogParts("ABCDEFZ"), expectedHasError: false, }, { diff --git a/waf.go b/waf.go index 75e8a1327..44d15d9a2 100644 --- a/waf.go +++ b/waf.go @@ -8,7 +8,6 @@ import ( "fmt" "strings" - "github.com/corazawaf/coraza/v3/experimental" "github.com/corazawaf/coraza/v3/internal/corazawaf" "github.com/corazawaf/coraza/v3/internal/environment" "github.com/corazawaf/coraza/v3/internal/seclang" @@ -44,6 +43,10 @@ func NewWAF(config WAFConfig) (WAF, error) { waf.Logger = c.debugLogger } + if c.ruleObserver != nil { + waf.Rules.SetObserver(c.ruleObserver) + } + parser := seclang.NewParser(waf) if c.fsRoot != nil { @@ -147,7 +150,17 @@ func (w wafWrapper) NewTransactionWithID(id string) types.Transaction { return w.waf.NewTransactionWithOptions(corazawaf.Options{Context: context.Background(), ID: id}) } -// NewTransaction implements the same method on WAF. -func (w wafWrapper) NewTransactionWithOptions(opts experimental.Options) types.Transaction { +// NewTransactionWithOptions implements the same method on WAF. +func (w wafWrapper) NewTransactionWithOptions(opts corazawaf.Options) types.Transaction { return w.waf.NewTransactionWithOptions(opts) } + +// RulesCount returns the number of rules in this WAF. +func (w wafWrapper) RulesCount() int { + return w.waf.Rules.Count() +} + +// Close releases cached resources owned by this WAF instance. +func (w wafWrapper) Close() error { + return w.waf.Close() +} diff --git a/waf_test.go b/waf_test.go index 693da9b83..25189d2a0 100644 --- a/waf_test.go +++ b/waf_test.go @@ -13,6 +13,11 @@ import ( "github.com/corazawaf/coraza/v3/types" ) +// wafWithRules mirrors experimental.WAFWithRules for testing without import cycle. +type wafWithRules interface { + RulesCount() int +} + func TestRequestBodyLimit(t *testing.T) { testCases := map[string]struct { expectedErr error @@ -178,3 +183,33 @@ func TestPopulateAuditLog(t *testing.T) { }) } } + +func TestRulesCount(t *testing.T) { + waf, err := NewWAF(NewWAFConfig()) + if err != nil { + t.Fatal(err) + } + + rules, ok := waf.(wafWithRules) + if !ok { + t.Fatal("WAF does not implement WAFWithRules") + } + if rules.RulesCount() != 0 { + t.Fatalf("expected 0 rules, got %d", rules.RulesCount()) + } + + waf, err = NewWAF(NewWAFConfig(). + WithDirectives(`SecRule REMOTE_ADDR "127.0.0.1" "id:1,phase:1,deny,status:403"`). + WithDirectives(`SecRule REQUEST_URI "/test" "id:2,phase:1,deny,status:403"`)) + if err != nil { + t.Fatal(err) + } + + rules, ok = waf.(wafWithRules) + if !ok { + t.Fatal("WAF does not implement WAFWithRules") + } + if rules.RulesCount() != 2 { + t.Fatalf("expected 2 rules, got %d", rules.RulesCount()) + } +}