diff --git a/.github/workflows/close-issues.yml b/.github/workflows/close-issues.yml
index 6d00bc25b..a6a6d4363 100644
--- a/.github/workflows/close-issues.yml
+++ b/.github/workflows/close-issues.yml
@@ -3,6 +3,8 @@ on:
schedule:
- cron: "30 1 * * *"
+permissions: {}
+
jobs:
close-issues:
runs-on: ubuntu-latest
diff --git a/.github/workflows/codeql-analysis.yml b/.github/workflows/codeql-analysis.yml
index e98e67f9d..535107a26 100644
--- a/.github/workflows/codeql-analysis.yml
+++ b/.github/workflows/codeql-analysis.yml
@@ -3,10 +3,20 @@ on:
pull_request:
schedule:
- cron: '0 6 * * 6'
+
+permissions: {}
+
+concurrency:
+ group: codeql-${{ github.ref }}
+ cancel-in-progress: ${{ github.event_name == 'pull_request' }}
+
jobs:
analyze:
name: Analyze
runs-on: ubuntu-latest
+ permissions:
+ security-events: write
+ contents: read
steps:
- name: Checkout repository
diff --git a/.github/workflows/fuzz.yml b/.github/workflows/fuzz.yml
index a601a92a6..faca0e3d8 100644
--- a/.github/workflows/fuzz.yml
+++ b/.github/workflows/fuzz.yml
@@ -6,15 +6,20 @@ on:
- cron: "05 14 * * *"
workflow_dispatch:
+permissions: {}
+
jobs:
fuzz:
name: Fuzz tests
runs-on: ubuntu-latest
+ permissions:
+ contents: read
+ issues: write
steps:
- uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6
- uses: actions/setup-go@44694675825211faa026b3c33043df3e48a5fa00 # v6
with:
- go-version: ">=1.24.0"
+ go-version: ">=1.25.0"
- run: go run mage.go fuzz
- run: |
gh issue create --title "$GITHUB_WORKFLOW #$GITHUB_RUN_NUMBER failed" \
diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml
index 6e373b8a7..eb6a69972 100644
--- a/.github/workflows/lint.yml
+++ b/.github/workflows/lint.yml
@@ -14,9 +14,17 @@ on:
- "**/*.md"
- "LICENSE"
+permissions: {}
+
+concurrency:
+ group: lint-${{ github.ref }}
+ cancel-in-progress: ${{ github.event_name == 'pull_request' }}
+
jobs:
lint:
runs-on: ubuntu-latest
+ permissions:
+ contents: read
steps:
- uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6
- name: Install Go
diff --git a/.github/workflows/regression.yml b/.github/workflows/regression.yml
index 0c8974611..d8b4e1acb 100644
--- a/.github/workflows/regression.yml
+++ b/.github/workflows/regression.yml
@@ -8,37 +8,51 @@ on:
- "**/*.md"
- "LICENSE"
pull_request:
+ branches:
+ - main
paths-ignore:
- "**/*.md"
- "LICENSE"
-jobs:
- # Generate matrix of tags for all permutations of the tests
- generate-matrix:
- runs-on: ubuntu-latest
- outputs:
- tags: ${{ steps.generate.outputs.tags }}
- steps:
- - name: Checkout code
- uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6
+permissions: {}
+
+concurrency:
+ group: regression-${{ github.ref }}
+ cancel-in-progress: ${{ github.event_name == 'pull_request' }}
- - name: Generate tag combinations
- id: generate
- run: |
- go run mage.go tagsmatrix > tags.json
- echo "tags=$(cat tags.json)" >> "$GITHUB_OUTPUT"
- shell: bash
+jobs:
test:
- needs: generate-matrix
strategy:
fail-fast: false
matrix:
- go-version: [1.24.x, 1.25.x]
+ go-version: [1.25.x]
os: [ubuntu-latest]
- build-flag: ${{ fromJson(needs.generate-matrix.outputs.tags) }}
+ build-flag:
+ - ""
+ - "coraza.rule.mandatory_rule_id_check"
+ - "coraza.rule.case_sensitive_args_keys"
+ - "coraza.rule.no_regex_multiline"
+ - "coraza.no_memoize"
+ - "coraza.rule.multiphase_evaluation"
+ - "no_fs_access"
+ - "coraza.rule.multiphase_evaluation,coraza.rule.mandatory_rule_id_check"
+ - "coraza.rule.multiphase_evaluation,coraza.rule.case_sensitive_args_keys"
+ - "coraza.rule.multiphase_evaluation,coraza.rule.no_regex_multiline"
+ - "no_fs_access,coraza.no_memoize"
+ - "coraza.rule.mandatory_rule_id_check,coraza.rule.case_sensitive_args_keys,coraza.rule.no_regex_multiline"
+ - "coraza.rule.multiphase_evaluation,coraza.rule.mandatory_rule_id_check,coraza.rule.case_sensitive_args_keys,coraza.rule.no_regex_multiline,coraza.no_memoize,no_fs_access"
+ include:
+ - go-version: 1.26.x
+ os: ubuntu-latest
+ build-flag: ""
+ - go-version: 1.26.x
+ os: ubuntu-latest
+ build-flag: "coraza.rule.multiphase_evaluation,coraza.rule.mandatory_rule_id_check,coraza.rule.case_sensitive_args_keys,coraza.rule.no_regex_multiline,coraza.no_memoize,no_fs_access"
runs-on: ${{ matrix.os }}
+ permissions:
+ contents: read
env:
- GOLANG_BASE_VERSION: "1.24.x"
+ GOLANG_BASE_VERSION: "1.25.x"
steps:
- name: Checkout code
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6
@@ -48,9 +62,9 @@ jobs:
go-version: ${{ matrix.go-version }}
cache: true
- name: Tests and coverage
- run: |
- export BUILD_TAGS=${{ matrix.build-flag }}
- go run mage.go coverage
+ env:
+ BUILD_TAGS: ${{ matrix.build-flag }}
+ run: go run mage.go coverage
- name: "Codecov: General"
uses: codecov/codecov-action@5a1091511ad55cbe89839c7260b706298ca349f7 # v5
if: ${{ matrix.go-version == env.GOLANG_BASE_VERSION }}
@@ -84,12 +98,13 @@ jobs:
runs-on: ubuntu-latest
permissions:
checks: read
+ contents: read
steps:
- name: GitHub Checks
uses: poseidon/wait-for-status-checks@899c768d191b56eef585c18f8558da19e1f3e707 # v0.6.0
with:
token: ${{ secrets.GITHUB_TOKEN }}
- delay: 120s # give some time to matrix jobs
+ delay: 30s
interval: 10s # default value
timeout: 3600s # default value
ignore: "codecov/patch,codecov/project"
diff --git a/.github/workflows/tinygo.yml b/.github/workflows/tinygo.yml
index 00d2de703..9e15119f7 100644
--- a/.github/workflows/tinygo.yml
+++ b/.github/workflows/tinygo.yml
@@ -14,15 +14,23 @@ on:
- "**/*.md"
- "LICENSE"
+permissions: {}
+
+concurrency:
+ group: tinygo-${{ github.ref }}
+ cancel-in-progress: ${{ github.event_name == 'pull_request' }}
+
jobs:
test:
strategy:
matrix:
- go-version: [1.24.x]
+ go-version: [1.25.x]
# tinygo-version is meant to stay aligned with the one used in corazawaf/coraza-proxy-wasm
tinygo-version: [0.40.1]
os: [ubuntu-latest]
runs-on: ${{ matrix.os }}
+ permissions:
+ contents: read
steps:
- name: Checkout code
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6
@@ -53,5 +61,5 @@ jobs:
- name: Tests
run: tinygo test -v -short ./internal/...
- - name: Tests memoize
- run: tinygo test -v -short -tags=memoize_builders ./internal/...
+ - name: Tests no_memoize
+ run: tinygo test -v -short -tags=coraza.no_memoize ./internal/...
diff --git a/README.md b/README.md
index f284bd945..eea6ad613 100644
--- a/README.md
+++ b/README.md
@@ -100,9 +100,11 @@ have compatibility guarantees across minor versions - use with care.
the operator with `plugins.RegisterOperator` to reduce binary size / startup overhead.
* `coraza.rule.multiphase_evaluation` - enables evaluation of rule variables in the phases that they are ready, not
only the phase the rule is defined for.
-* `memoize_builders` - enables memoization of builders for regex and aho-corasick
-dictionaries to reduce memory consumption in deployments that launch several coraza
-instances. For more context check [this issue](https://github.com/corazawaf/coraza-caddy/issues/76)
+* `coraza.no_memoize` - disables the default memoization of regex and aho-corasick builders.
+Memoization is enabled by default and uses a global cache to reuse compiled patterns across WAF
+instances, reducing memory consumption and startup overhead. In long-lived processes that perform
+live reloads, use `WAF.Close()` (via `experimental.WAFCloser`) to release cached entries when a
+WAF is destroyed, or use this tag to opt out of memoization entirely.
* `no_fs_access` - indicates that the target environment has no access to FS in order to not leverage OS' filesystem related functionality e.g. file body buffers.
* `coraza.rule.case_sensitive_args_keys` - enables case-sensitive matching for ARGS keys, aligning Coraza behavior with RFC 3986 specification. It will be enabled by default in the next major version.
* `coraza.rule.no_regex_multiline` - disables enabling by default regexes multiline modifiers in `@rx` operator. It aligns with CRS expected behavior, reduces false positives and might improve performances. No multiline regexes by default will be enabled in the next major version. For more context check [this PR](https://github.com/corazawaf/coraza/pull/876)
diff --git a/config.go b/config.go
index ffb3d5894..6795c9591 100644
--- a/config.go
+++ b/config.go
@@ -94,6 +94,7 @@ type wafRule struct {
// int is a signed integer type that is at least 32 bits in size (platform-dependent size).
// We still basically assume 64-bit usage where int are big sizes.
type wafConfig struct {
+ ruleObserver func(rule types.RuleMetadata)
rules []wafRule
auditLog *auditLogConfig
requestBodyAccess bool
@@ -119,6 +120,12 @@ func (c *wafConfig) WithRules(rules ...*corazawaf.Rule) WAFConfig {
return ret
}
+func (c *wafConfig) WithRuleObserver(observer func(rule types.RuleMetadata)) WAFConfig {
+ ret := c.clone()
+ ret.ruleObserver = observer
+ return ret
+}
+
func (c *wafConfig) WithDirectivesFromFile(path string) WAFConfig {
ret := c.clone()
ret.rules = append(ret.rules, wafRule{file: path})
diff --git a/coraza.conf-recommended b/coraza.conf-recommended
index 8ef869eae..b0d3654bf 100644
--- a/coraza.conf-recommended
+++ b/coraza.conf-recommended
@@ -34,6 +34,23 @@ SecRule REQUEST_HEADERS:Content-Type "^application/json" \
SecRule REQUEST_HEADERS:Content-Type "^application/[a-z0-9.-]+[+]json" \
"id:'200006',phase:1,t:none,t:lowercase,pass,nolog,ctl:requestBodyProcessor=JSON"
+# Configures the maximum JSON recursion depth limit Coraza will accept.
+SecRequestBodyJsonDepthLimit 1024
+
+# Enable JSON stream request body parser for NDJSON (Newline Delimited JSON) format.
+# This processor handles streaming JSON where each line contains a complete JSON object.
+# Commonly used for bulk data imports, log streaming, and batch API endpoints.
+# Each JSON object is indexed by line number: json.0.field, json.1.field, etc.
+#
+#SecRule REQUEST_HEADERS:Content-Type "^application/x-ndjson" \
+# "id:'200007',phase:1,t:none,t:lowercase,pass,nolog,ctl:requestBodyProcessor=JSONSTREAM"
+#
+#SecRule REQUEST_HEADERS:Content-Type "^application/jsonlines" \
+# "id:'200008',phase:1,t:none,t:lowercase,pass,nolog,ctl:requestBodyProcessor=JSONSTREAM"
+#
+#SecRule REQUEST_HEADERS:Content-Type "^application/json-seq" \
+# "id:'200010',phase:1,t:none,t:lowercase,pass,nolog,ctl:requestBodyProcessor=JSONSTREAM"
+
# Maximum request body size we will accept for buffering. If you support
# file uploads, this value must has to be as large as the largest file
# you are willing to accept.
@@ -89,6 +106,9 @@ SecResponseBodyAccess On
# configuration below to catch documents but avoid static files
# (e.g., images and archives).
#
+# Note: Add 'application/json' and 'application/x-ndjson' if you want to
+# inspect JSON and NDJSON response bodies
+#
SecResponseBodyMimeType text/plain text/html text/xml
# Buffer response bodies of up to 512 KB in length.
@@ -118,11 +138,11 @@ SecDataDir /tmp/
#
#SecUploadDir /opt/coraza/var/upload/
-# If On, the WAF will store the uploaded files in the SecUploadDir
-# directory.
-# Note: SecUploadKeepFiles is currently NOT supported by Coraza
+# Controls whether intercepted uploaded files will be kept after
+# transaction is processed. Possible values: On, Off, RelevantOnly.
+# RelevantOnly will keep files only when a matching rule is logged (rules with 'nolog' do not qualify).
#
-#SecUploadKeepFiles Off
+#SecUploadKeepFiles RelevantOnly
# Uploaded files are by default created with permissions that do not allow
# any other user to access them. You may need to relax that if you want to
diff --git a/examples/http-server/go.mod b/examples/http-server/go.mod
index f2c2e1835..d1e0de7a3 100644
--- a/examples/http-server/go.mod
+++ b/examples/http-server/go.mod
@@ -1,11 +1,11 @@
module github.com/corazawaf/coraza/v3/examples/http-server
-go 1.24.0
+go 1.25.0
require github.com/corazawaf/coraza/v3 v3.3.3
require (
- github.com/corazawaf/libinjection-go v0.2.2 // indirect
+ github.com/corazawaf/libinjection-go v0.3.2 // indirect
github.com/google/go-cmp v0.6.0 // indirect
github.com/magefile/mage v1.15.1-0.20250615140142-78acbaf2e3ae // indirect
github.com/petar-dambovaliev/aho-corasick v0.0.0-20250424160509-463d218d4745 // indirect
@@ -13,9 +13,9 @@ require (
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.1 // indirect
github.com/valllabh/ocsf-schema-golang v1.0.3 // indirect
- golang.org/x/net v0.43.0 // indirect
- golang.org/x/sync v0.16.0 // indirect
- golang.org/x/tools v0.35.0 // indirect
+ golang.org/x/net v0.52.0 // indirect
+ golang.org/x/sync v0.20.0 // indirect
+ golang.org/x/tools v0.42.0 // indirect
google.golang.org/protobuf v1.35.1 // indirect
rsc.io/binaryregexp v0.2.0 // indirect
)
diff --git a/examples/http-server/go.sum b/examples/http-server/go.sum
index df496a76e..3cba2f706 100644
--- a/examples/http-server/go.sum
+++ b/examples/http-server/go.sum
@@ -2,8 +2,8 @@ github.com/corazawaf/coraza-coreruleset v0.0.0-20240226094324-415b1017abdc h1:Ol
github.com/corazawaf/coraza-coreruleset v0.0.0-20240226094324-415b1017abdc/go.mod h1:7rsocqNDkTCira5T0M7buoKR2ehh7YZiPkzxRuAgvVU=
github.com/corazawaf/coraza/v3 v3.3.3 h1:kqjStHAgWqwP5dh7n0vhTOF0a3t+VikNS/EaMiG0Fhk=
github.com/corazawaf/coraza/v3 v3.3.3/go.mod h1:xSaXWOhFMSbrV8qOOfBKAyw3aOqfwaSaOy5BgSF8XlA=
-github.com/corazawaf/libinjection-go v0.2.2 h1:Chzodvb6+NXh6wew5/yhD0Ggioif9ACrQGR4qjTCs1g=
-github.com/corazawaf/libinjection-go v0.2.2/go.mod h1:OP4TM7xdJ2skyXqNX1AN1wN5nNZEmJNuWbNPOItn7aw=
+github.com/corazawaf/libinjection-go v0.3.2 h1:9rrKt0lpg4WvUXt+lwS06GywfqRXXsa/7JcOw5cQLwI=
+github.com/corazawaf/libinjection-go v0.3.2/go.mod h1:Ik/+w3UmTWH9yn366RgS9D95K3y7Atb5m/H/gXzzPCk=
github.com/foxcpp/go-mockdns v1.1.0 h1:jI0rD8M0wuYAxL7r/ynTrCQQq0BVqfB99Vgk7DlmewI=
github.com/foxcpp/go-mockdns v1.1.0/go.mod h1:IhLeSFGed3mJIAXPH2aiRQB+kqz7oqu8ld2qVbOu7Wk=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
@@ -25,16 +25,16 @@ github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/valllabh/ocsf-schema-golang v1.0.3 h1:eR8k/3jP/OOqB8LRCtdJ4U+vlgd/gk5y3KMXoodrsrw=
github.com/valllabh/ocsf-schema-golang v1.0.3/go.mod h1:sZ3as9xqm1SSK5feFWIR2CuGeGRhsM7TR1MbpBctzPk=
-golang.org/x/mod v0.26.0 h1:EGMPT//Ezu+ylkCijjPc+f4Aih7sZvaAr+O3EHBxvZg=
-golang.org/x/mod v0.26.0/go.mod h1:/j6NAhSk8iQ723BGAUyoAcn7SlD7s15Dp9Nd/SfeaFQ=
-golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE=
-golang.org/x/net v0.43.0/go.mod h1:vhO1fvI4dGsIjh73sWfUVjj3N7CA9WkKJNQm2svM6Jg=
-golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw=
-golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
-golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI=
-golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
-golang.org/x/tools v0.35.0 h1:mBffYraMEf7aa0sB+NuKnuCy8qI/9Bughn8dC2Gu5r0=
-golang.org/x/tools v0.35.0/go.mod h1:NKdj5HkL/73byiZSJjqJgKn3ep7KjFkBOkR/Hps3VPw=
+golang.org/x/mod v0.33.0 h1:tHFzIWbBifEmbwtGz65eaWyGiGZatSrT9prnU8DbVL8=
+golang.org/x/mod v0.33.0/go.mod h1:swjeQEj+6r7fODbD2cqrnje9PnziFuw4bmLbBZFrQ5w=
+golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0=
+golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw=
+golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
+golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
+golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo=
+golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
+golang.org/x/tools v0.42.0 h1:uNgphsn75Tdz5Ji2q36v/nsFSfR/9BRFvqhGBaJGd5k=
+golang.org/x/tools v0.42.0/go.mod h1:Ma6lCIwGZvHK6XtgbswSoWroEkhugApmsXyrUmBhfr0=
google.golang.org/protobuf v1.35.1 h1:m3LfL6/Ca+fqnjnlqQXNpFPABW1UD7mjh8KO2mKFytA=
google.golang.org/protobuf v1.35.1/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE=
rsc.io/binaryregexp v0.2.0 h1:HfqmD5MEmC0zvwBuF187nq9mdnXjXsSivRiXN7SmRkE=
diff --git a/experimental/bodyprocessors/jsonstream.go b/experimental/bodyprocessors/jsonstream.go
new file mode 100644
index 000000000..b1eaad125
--- /dev/null
+++ b/experimental/bodyprocessors/jsonstream.go
@@ -0,0 +1,390 @@
+// Copyright 2026 OWASP Coraza contributors
+// SPDX-License-Identifier: Apache-2.0
+
+package bodyprocessors
+
+import (
+ "bufio"
+ "errors"
+ "fmt"
+ "io"
+ "strconv"
+ "strings"
+
+ "github.com/tidwall/gjson"
+
+ "github.com/corazawaf/coraza/v3/experimental/plugins"
+ "github.com/corazawaf/coraza/v3/experimental/plugins/plugintypes"
+)
+
+const (
+ // DefaultStreamRecursionLimit is the default recursion limit for streaming JSON processing
+ // This protects against deeply nested JSON objects in each line
+ DefaultStreamRecursionLimit = 1024
+
+ // recordSeparator is the ASCII RS character (0x1E) used in RFC 7464 JSON Sequences
+ recordSeparator = '\x1e'
+)
+
+// indexedCollection is an interface for collections that support indexed key-value storage.
+type indexedCollection interface {
+ SetIndex(string, int, string)
+}
+
+// jsonStreamBodyProcessor handles streaming JSON formats.
+// Each record/line in the input is expected to be a complete, valid JSON object.
+// Empty lines are ignored. Each JSON object is flattened and indexed by record number.
+//
+// Supported formats:
+// - NDJSON (application/x-ndjson): Each line is a complete JSON object
+// - JSON Lines (application/jsonlines): Alias for NDJSON
+// - JSON Sequence (application/json-seq): RFC 7464 format with RS (0x1E) record separator
+//
+// The processor auto-detects the format based on the presence of RS characters.
+type jsonStreamBodyProcessor struct{}
+
+var _ plugintypes.StreamingBodyProcessor = &jsonStreamBodyProcessor{}
+
+func (js *jsonStreamBodyProcessor) ProcessRequest(reader io.Reader, v plugintypes.TransactionVariables, _ plugintypes.BodyProcessorOptions) error {
+ col := v.ArgsPost()
+
+ // Store the raw body for TX variables.
+ // Note: This creates a memory copy of the entire body, similar to the regular JSON processor.
+ // This is necessary for operators like @validateSchema that need access to the raw content.
+ // Memory usage: 2x the body size (once in buffer, once in parsed variables)
+ var rawBody strings.Builder
+
+ // Create a TeeReader to read the body and store it simultaneously
+ tee := io.TeeReader(reader, &rawBody)
+
+ // Use default recursion limit for now
+ // TODO: Use RequestBodyRecursionLimit from BodyProcessorOptions when available
+ lineNum, err := processJSONStream(tee, col, DefaultStreamRecursionLimit)
+ if err != nil {
+ return err
+ }
+
+ // Store the raw JSON stream in the TX variable for potential validation
+ if txVar := v.TX(); txVar != nil {
+ txVar.Set("jsonstream_request_body", []string{rawBody.String()})
+ txVar.Set("jsonstream_request_line_count", []string{strconv.Itoa(lineNum)})
+ }
+
+ return nil
+}
+
+func (js *jsonStreamBodyProcessor) ProcessResponse(reader io.Reader, v plugintypes.TransactionVariables, _ plugintypes.BodyProcessorOptions) error {
+ col := v.ResponseArgs()
+
+ // Store the raw body for TX variables.
+ // Note: This creates a memory copy of the entire body, similar to the regular JSON processor.
+ // Memory usage: 2x the body size (once in buffer, once in parsed variables)
+ var rawBody strings.Builder
+
+ // Create a TeeReader to read the body and store it simultaneously
+ tee := io.TeeReader(reader, &rawBody)
+
+ // Use default recursion limit for response bodies too
+ // TODO: Consider using a different limit for responses when configurable
+ lineNum, err := processJSONStream(tee, col, DefaultStreamRecursionLimit)
+ if err != nil {
+ return err
+ }
+
+ // Store the raw JSON stream in the TX variable for potential validation
+ if txVar := v.TX(); txVar != nil && v.ResponseBody() != nil {
+ txVar.Set("jsonstream_response_body", []string{rawBody.String()})
+ txVar.Set("jsonstream_response_line_count", []string{strconv.Itoa(lineNum)})
+ }
+
+ return nil
+}
+
+// processJSONStream processes a stream of JSON objects incrementally.
+// Supports both NDJSON (newline-delimited) and RFC 7464 JSON Sequence (RS-delimited) formats.
+// The format is auto-detected by peeking at the first chunk of data.
+// Returns the number of records processed and any error encountered.
+func processJSONStream(reader io.Reader, col indexedCollection, maxRecursion int) (int, error) {
+ return processJSONStreamWithCallback(reader, maxRecursion, func(_ int, fields map[string]string, _ string) error {
+ for key, value := range fields {
+ col.SetIndex(key, 0, value)
+ }
+ return nil
+ })
+}
+
+// processJSONStreamWithCallback processes a stream of JSON objects, calling fn for each record.
+// The callback receives the record number, a map of pre-formatted fields (with record-prefixed keys
+// like "json.0.name"), and the raw record text. If fn returns a non-nil error, processing stops.
+// Returns the number of records processed and any error encountered.
+func processJSONStreamWithCallback(reader io.Reader, maxRecursion int,
+ fn func(recordNum int, fields map[string]string, rawRecord string) error) (int, error) {
+ bufReader := bufio.NewReader(reader)
+
+ // Peek at the first chunk to detect format without consuming the entire stream
+ // Use 4KB as a reasonable peek size - enough to detect RS in typical streams
+ peekSize := 4096
+ peekBytes, err := bufReader.Peek(peekSize)
+ if err != nil && err != io.EOF && err != bufio.ErrBufferFull {
+ return 0, fmt.Errorf("error peeking stream: %w", err)
+ }
+
+ // Check if we have any data at all
+ if len(peekBytes) == 0 {
+ return 0, errors.New("no valid JSON objects found in stream")
+ }
+
+ // Auto-detect format: if peek contains RS characters, use JSON Sequence parsing
+ // Otherwise, use NDJSON (newline) parsing
+ if containsRS(peekBytes) {
+ return processJSONSequenceStreamWithCallback(bufReader, maxRecursion, fn)
+ }
+ return processNDJSONStreamWithCallback(bufReader, maxRecursion, fn)
+}
+
+// containsRS checks if a byte slice contains the RS character
+func containsRS(data []byte) bool {
+ for _, b := range data {
+ if b == recordSeparator {
+ return true
+ }
+ }
+ return false
+}
+
+// newRecordScanner creates a bufio.Scanner with the standard buffer sizes used for JSON record scanning.
+func newRecordScanner(reader io.Reader, split bufio.SplitFunc) *bufio.Scanner {
+ scanner := bufio.NewScanner(reader)
+ if split != nil {
+ scanner.Split(split)
+ }
+
+ // Increase scanner buffer to handle large JSON objects (default is 64KB)
+ // Set max to 1MB to match typical JSON object sizes while preventing memory exhaustion
+ const maxScanTokenSize = 1024 * 1024 // 1MB
+ buf := make([]byte, 64*1024)
+ scanner.Buffer(buf, maxScanTokenSize)
+
+ return scanner
+}
+
+// scanRecords iterates over scanner records, parsing each as JSON and calling fn.
+// The formatRecord function wraps each parsed record with the format-specific delimiters
+// (e.g., trailing \n for NDJSON, or RS prefix + trailing \n for RFC 7464),
+// so the rawRecord passed to fn can be relayed as-is to preserve the original format.
+func scanRecords(scanner *bufio.Scanner, maxRecursion int, formatRecord func(string) string,
+ fn func(recordNum int, fields map[string]string, rawRecord string) error) (int, error) {
+ recordNum := 0
+
+ for scanner.Scan() {
+ record := strings.TrimSpace(scanner.Text())
+ if record == "" {
+ continue
+ }
+
+ fields, err := parseJSONRecord(record, recordNum, maxRecursion)
+ if err != nil {
+ return recordNum, err
+ }
+
+ if err := fn(recordNum, fields, formatRecord(record)); err != nil {
+ return recordNum + 1, err
+ }
+
+ recordNum++
+ }
+
+ if err := scanner.Err(); err != nil {
+ return recordNum, fmt.Errorf("error reading stream: %w", err)
+ }
+
+ if recordNum == 0 {
+ return 0, errors.New("no valid JSON objects found in stream")
+ }
+
+ return recordNum, nil
+}
+
+func formatNDJSON(record string) string {
+ return record + "\n"
+}
+
+func formatJSONSequence(record string) string {
+ return "\x1e" + record + "\n"
+}
+
+// processNDJSONStreamWithCallback processes NDJSON format (newline-delimited JSON objects) from a reader,
+// calling fn for each record.
+func processNDJSONStreamWithCallback(reader io.Reader, maxRecursion int,
+ fn func(recordNum int, fields map[string]string, rawRecord string) error) (int, error) {
+ return scanRecords(newRecordScanner(reader, nil), maxRecursion, formatNDJSON, fn)
+}
+
+// processJSONSequenceStreamWithCallback processes RFC 7464 JSON Sequence format (RS-delimited JSON objects)
+// from a reader, calling fn for each record.
+func processJSONSequenceStreamWithCallback(reader io.Reader, maxRecursion int,
+ fn func(recordNum int, fields map[string]string, rawRecord string) error) (int, error) {
+ return scanRecords(newRecordScanner(reader, splitOnRS), maxRecursion, formatJSONSequence, fn)
+}
+
+// splitOnRS is a custom split function for bufio.Scanner that splits on RS (0x1E) characters.
+// This enables streaming processing of RFC 7464 JSON Sequence format.
+func splitOnRS(data []byte, atEOF bool) (advance int, token []byte, err error) {
+ // Skip leading RS characters
+ start := 0
+ for start < len(data) && data[start] == recordSeparator {
+ start++
+ }
+
+ // If we've consumed all data and we're at EOF, we're done
+ if atEOF && start >= len(data) {
+ return len(data), nil, nil
+ }
+
+ // Find the next RS character after start
+ for i := start; i < len(data); i++ {
+ if data[i] == recordSeparator {
+ // Found RS, return the record between start and i
+ return i + 1, data[start:i], nil
+ }
+ }
+
+ // If we're at EOF, return remaining data as the last record
+ if atEOF && start < len(data) {
+ return len(data), data[start:], nil
+ }
+
+ // Request more data
+ return 0, nil, nil
+}
+
+// parseJSONRecord parses a single JSON record and returns a map of fields with record-prefixed keys.
+// Keys are formatted as "json.{recordNum}.field.subfield".
+func parseJSONRecord(jsonText string, recordNum int, maxRecursion int) (map[string]string, error) {
+ // Validate JSON before parsing
+ if !gjson.Valid(jsonText) {
+ // Use 1-based numbering for user-friendly error messages
+ return nil, fmt.Errorf("invalid JSON at record %d", recordNum+1)
+ }
+
+ // Parse the JSON record
+ data, err := readJSONWithLimit(jsonText, maxRecursion)
+ if err != nil {
+ // Use 1-based numbering for user-friendly error messages
+ return nil, fmt.Errorf("error parsing JSON at record %d: %w", recordNum+1, err)
+ }
+
+ // Build result with record-prefixed keys
+ // Example: json.0.field, json.1.field, etc.
+ fields := make(map[string]string, len(data))
+ for key, value := range data {
+ // Replace the "json" prefix with "json.{recordNum}"
+ // Original key format: "json.field.subfield"
+ // New key format: "json.0.field.subfield"
+ if strings.HasPrefix(key, "json.") {
+ key = fmt.Sprintf("json.%d.%s", recordNum, key[5:]) // Skip "json."
+ } else if key == "json" {
+ key = fmt.Sprintf("json.%d", recordNum)
+ }
+ fields[key] = value
+ }
+
+ return fields, nil
+}
+
+// readJSONWithLimit is a helper that calls readJSON but with protection against deep nesting.
+// This is a separate copy from internal/bodyprocessors/json.go due to package boundaries.
+// TODO: Remove this when readItems supports maxRecursion parameter natively (see #1110)
+func readJSONWithLimit(s string, maxRecursion int) (map[string]string, error) {
+ json := gjson.Parse(s)
+ res := make(map[string]string)
+ key := []byte("json")
+ err := readItemsWithLimit(json, key, maxRecursion, res)
+ return res, err
+}
+
+// readItemsWithLimit is similar to readItems in internal/bodyprocessors/json.go but with recursion limit.
+// This is a separate copy due to package boundaries (experimental vs internal).
+// TODO: Remove this when readItems supports maxRecursion parameter natively (see #1110)
+func readItemsWithLimit(json gjson.Result, objKey []byte, maxRecursion int, res map[string]string) error {
+ arrayLen := 0
+ var iterationError error
+
+ if maxRecursion == 0 {
+ return errors.New("max recursion reached while reading json object")
+ }
+
+ json.ForEach(func(key, value gjson.Result) bool {
+ prevParentLength := len(objKey)
+ objKey = append(objKey, '.')
+ if key.Type == gjson.String {
+ objKey = append(objKey, key.Str...)
+ } else {
+ objKey = strconv.AppendInt(objKey, int64(key.Num), 10)
+ arrayLen++
+ }
+
+ var val string
+ switch value.Type {
+ case gjson.JSON:
+ iterationError = readItemsWithLimit(value, objKey, maxRecursion-1, res)
+ if iterationError != nil {
+ return false
+ }
+ objKey = objKey[:prevParentLength]
+ return true
+ case gjson.String:
+ val = value.Str
+ case gjson.Null:
+ val = ""
+ default:
+ val = value.Raw
+ }
+
+ res[string(objKey)] = val
+ objKey = objKey[:prevParentLength]
+
+ return true
+ })
+ if arrayLen > 0 {
+ res[string(objKey)] = strconv.Itoa(arrayLen)
+ }
+ return iterationError
+}
+
+func (js *jsonStreamBodyProcessor) ProcessRequestRecords(reader io.Reader, _ plugintypes.BodyProcessorOptions,
+ fn func(recordNum int, fields map[string]string, rawRecord string) error) error {
+ recordCount, err := processJSONStreamWithCallback(reader, DefaultStreamRecursionLimit, fn)
+ if err != nil {
+ return err
+ }
+ if recordCount == 0 {
+ return errors.New("no valid JSON objects found in stream")
+ }
+ return nil
+}
+
+func (js *jsonStreamBodyProcessor) ProcessResponseRecords(reader io.Reader, _ plugintypes.BodyProcessorOptions,
+ fn func(recordNum int, fields map[string]string, rawRecord string) error) error {
+ recordCount, err := processJSONStreamWithCallback(reader, DefaultStreamRecursionLimit, fn)
+ if err != nil {
+ return err
+ }
+ if recordCount == 0 {
+ return errors.New("no valid JSON objects found in stream")
+ }
+ return nil
+}
+
+func init() {
+ // Register the processor with multiple names for different content-types
+ plugins.RegisterBodyProcessor("jsonstream", func() plugintypes.BodyProcessor {
+ return &jsonStreamBodyProcessor{}
+ })
+ plugins.RegisterBodyProcessor("ndjson", func() plugintypes.BodyProcessor {
+ return &jsonStreamBodyProcessor{}
+ })
+ plugins.RegisterBodyProcessor("jsonlines", func() plugintypes.BodyProcessor {
+ return &jsonStreamBodyProcessor{}
+ })
+}
diff --git a/experimental/bodyprocessors/jsonstream_test.go b/experimental/bodyprocessors/jsonstream_test.go
new file mode 100644
index 000000000..848515b6f
--- /dev/null
+++ b/experimental/bodyprocessors/jsonstream_test.go
@@ -0,0 +1,1113 @@
+// Copyright 2026 OWASP Coraza contributors
+// SPDX-License-Identifier: Apache-2.0
+
+package bodyprocessors_test
+
+import (
+ "errors"
+ "fmt"
+ "strings"
+ "testing"
+
+ _ "github.com/corazawaf/coraza/v3/experimental/bodyprocessors"
+
+ "github.com/corazawaf/coraza/v3/experimental/plugins"
+ "github.com/corazawaf/coraza/v3/experimental/plugins/plugintypes"
+ "github.com/corazawaf/coraza/v3/internal/corazawaf"
+)
+
+func jsonstreamProcessor(t *testing.T) plugintypes.BodyProcessor {
+ t.Helper()
+ jsp, err := plugins.GetBodyProcessor("jsonstream")
+ if err != nil {
+ t.Fatal(err)
+ }
+ return jsp
+}
+
+func TestJSONStreamSingleLine(t *testing.T) {
+ input := `{"name": "John", "age": 30}
+`
+
+ jsp := jsonstreamProcessor(t)
+ v := corazawaf.NewTransactionVariables()
+
+ err := jsp.ProcessRequest(strings.NewReader(input), v, plugintypes.BodyProcessorOptions{})
+
+ if err != nil {
+ t.Errorf("unexpected error: %v", err)
+ return
+ }
+
+ argsPost := v.ArgsPost()
+
+ // Check expected keys
+ if name := argsPost.Get("json.0.name"); len(name) == 0 || name[0] != "John" {
+ t.Errorf("json.0.name should be 'John', got: %v", name)
+ }
+
+ if age := argsPost.Get("json.0.age"); len(age) == 0 || age[0] != "30" {
+ t.Errorf("json.0.age should be '30', got: %v", age)
+ }
+}
+
+func TestJSONStreamMultipleLines(t *testing.T) {
+ input := `{"name": "John", "age": 30}
+{"name": "Jane", "age": 25}
+{"name": "Bob", "age": 35}
+`
+
+ jsp := jsonstreamProcessor(t)
+ v := corazawaf.NewTransactionVariables()
+
+ err := jsp.ProcessRequest(strings.NewReader(input), v, plugintypes.BodyProcessorOptions{})
+
+ if err != nil {
+ t.Errorf("unexpected error: %v", err)
+ return
+ }
+
+ argsPost := v.ArgsPost()
+
+ // Check all three lines
+ tests := []struct {
+ line int
+ name string
+ age string
+ }{
+ {0, "John", "30"},
+ {1, "Jane", "25"},
+ {2, "Bob", "35"},
+ }
+
+ for _, tt := range tests {
+ nameKey := fmt.Sprintf("json.%d.name", tt.line)
+ ageKey := fmt.Sprintf("json.%d.age", tt.line)
+
+ if name := argsPost.Get(nameKey); len(name) == 0 || name[0] != tt.name {
+ t.Errorf("%s should be '%s', got: %v", nameKey, tt.name, name)
+ }
+
+ if age := argsPost.Get(ageKey); len(age) == 0 || age[0] != tt.age {
+ t.Errorf("%s should be '%s', got: %v", ageKey, tt.age, age)
+ }
+ }
+}
+
+func TestJSONStreamNestedObjects(t *testing.T) {
+ input := `{"user": {"name": "John", "id": 1}, "active": true}
+{"user": {"name": "Jane", "id": 2}, "active": false}
+`
+
+ jsp := jsonstreamProcessor(t)
+ v := corazawaf.NewTransactionVariables()
+
+ err := jsp.ProcessRequest(strings.NewReader(input), v, plugintypes.BodyProcessorOptions{})
+
+ if err != nil {
+ t.Errorf("unexpected error: %v", err)
+ return
+ }
+
+ argsPost := v.ArgsPost()
+
+ // Check nested fields
+ if name := argsPost.Get("json.0.user.name"); len(name) == 0 || name[0] != "John" {
+ t.Errorf("json.0.user.name should be 'John', got: %v", name)
+ }
+
+ if id := argsPost.Get("json.0.user.id"); len(id) == 0 || id[0] != "1" {
+ t.Errorf("json.0.user.id should be '1', got: %v", id)
+ }
+
+ if active := argsPost.Get("json.0.active"); len(active) == 0 || active[0] != "true" {
+ t.Errorf("json.0.active should be 'true', got: %v", active)
+ }
+
+ if name := argsPost.Get("json.1.user.name"); len(name) == 0 || name[0] != "Jane" {
+ t.Errorf("json.1.user.name should be 'Jane', got: %v", name)
+ }
+
+ if active := argsPost.Get("json.1.active"); len(active) == 0 || active[0] != "false" {
+ t.Errorf("json.1.active should be 'false', got: %v", active)
+ }
+}
+
+func TestJSONStreamWithArrays(t *testing.T) {
+ input := `{"name": "John", "tags": ["admin", "user"]}
+`
+
+ jsp := jsonstreamProcessor(t)
+ v := corazawaf.NewTransactionVariables()
+
+ err := jsp.ProcessRequest(strings.NewReader(input), v, plugintypes.BodyProcessorOptions{})
+
+ if err != nil {
+ t.Errorf("unexpected error: %v", err)
+ return
+ }
+
+ argsPost := v.ArgsPost()
+
+ // Check array fields
+ if tags := argsPost.Get("json.0.tags"); len(tags) == 0 || tags[0] != "2" {
+ t.Errorf("json.0.tags should be '2' (array length), got: %v", tags)
+ }
+
+ if tag0 := argsPost.Get("json.0.tags.0"); len(tag0) == 0 || tag0[0] != "admin" {
+ t.Errorf("json.0.tags.0 should be 'admin', got: %v", tag0)
+ }
+
+ if tag1 := argsPost.Get("json.0.tags.1"); len(tag1) == 0 || tag1[0] != "user" {
+ t.Errorf("json.0.tags.1 should be 'user', got: %v", tag1)
+ }
+}
+
+func TestJSONStreamSkipEmptyLines(t *testing.T) {
+ input := `{"name": "John"}
+
+{"name": "Jane"}
+
+`
+
+ jsp := jsonstreamProcessor(t)
+ v := corazawaf.NewTransactionVariables()
+
+ err := jsp.ProcessRequest(strings.NewReader(input), v, plugintypes.BodyProcessorOptions{})
+
+ if err != nil {
+ t.Errorf("unexpected error: %v", err)
+ return
+ }
+
+ argsPost := v.ArgsPost()
+
+ // Empty lines should be skipped, so we should only have line 0 and 1
+ if name := argsPost.Get("json.0.name"); len(name) == 0 || name[0] != "John" {
+ t.Errorf("json.0.name should be 'John', got: %v", name)
+ }
+
+ if name := argsPost.Get("json.1.name"); len(name) == 0 || name[0] != "Jane" {
+ t.Errorf("json.1.name should be 'Jane', got: %v", name)
+ }
+
+ // Line 2 should not exist
+ if name := argsPost.Get("json.2.name"); len(name) != 0 {
+ t.Errorf("json.2.name should not exist, got: %v", name)
+ }
+}
+
+func TestJSONStreamArrayAsRoot(t *testing.T) {
+ input := `[1, 2, 3]
+[4, 5, 6]
+`
+
+ jsp := jsonstreamProcessor(t)
+ v := corazawaf.NewTransactionVariables()
+
+ err := jsp.ProcessRequest(strings.NewReader(input), v, plugintypes.BodyProcessorOptions{})
+
+ if err != nil {
+ t.Errorf("unexpected error: %v", err)
+ return
+ }
+
+ argsPost := v.ArgsPost()
+
+ // Check first array
+ if arr := argsPost.Get("json.0"); len(arr) == 0 || arr[0] != "3" {
+ t.Errorf("json.0 should be '3' (array length), got: %v", arr)
+ }
+
+ if val := argsPost.Get("json.0.0"); len(val) == 0 || val[0] != "1" {
+ t.Errorf("json.0.0 should be '1', got: %v", val)
+ }
+
+ // Check second array
+ if arr := argsPost.Get("json.1"); len(arr) == 0 || arr[0] != "3" {
+ t.Errorf("json.1 should be '3' (array length), got: %v", arr)
+ }
+
+ if val := argsPost.Get("json.1.0"); len(val) == 0 || val[0] != "4" {
+ t.Errorf("json.1.0 should be '4', got: %v", val)
+ }
+}
+
+func TestJSONStreamNullAndBooleans(t *testing.T) {
+ input := `{"null": null, "true": true, "false": false}
+`
+
+ jsp := jsonstreamProcessor(t)
+ v := corazawaf.NewTransactionVariables()
+
+ err := jsp.ProcessRequest(strings.NewReader(input), v, plugintypes.BodyProcessorOptions{})
+
+ if err != nil {
+ t.Errorf("unexpected error: %v", err)
+ return
+ }
+
+ argsPost := v.ArgsPost()
+
+ // Check null value (should be empty string)
+ if null := argsPost.Get("json.0.null"); len(null) == 0 || null[0] != "" {
+ t.Errorf("json.0.null should be empty string, got: %v", null)
+ }
+
+ // Check boolean values
+ if trueVal := argsPost.Get("json.0.true"); len(trueVal) == 0 || trueVal[0] != "true" {
+ t.Errorf("json.0.true should be 'true', got: %v", trueVal)
+ }
+
+ if falseVal := argsPost.Get("json.0.false"); len(falseVal) == 0 || falseVal[0] != "false" {
+ t.Errorf("json.0.false should be 'false', got: %v", falseVal)
+ }
+}
+
+func TestJSONStreamInvalidJSON(t *testing.T) {
+ input := `{"name": "John"}
+{invalid json}
+{"name": "Jane"}
+`
+
+ jsp := jsonstreamProcessor(t)
+ v := corazawaf.NewTransactionVariables()
+
+ err := jsp.ProcessRequest(strings.NewReader(input), v, plugintypes.BodyProcessorOptions{})
+
+ if err == nil {
+ t.Errorf("expected error for invalid JSON, got none")
+ }
+
+ if !strings.Contains(err.Error(), "invalid JSON") {
+ t.Errorf("expected 'invalid JSON' error, got: %v", err)
+ }
+}
+
+func TestJSONStreamEmptyStream(t *testing.T) {
+ input := ""
+
+ jsp := jsonstreamProcessor(t)
+ v := corazawaf.NewTransactionVariables()
+
+ err := jsp.ProcessRequest(strings.NewReader(input), v, plugintypes.BodyProcessorOptions{})
+
+ if err == nil {
+ t.Errorf("expected error for empty stream, got none")
+ }
+
+ if !strings.Contains(err.Error(), "no valid JSON objects") {
+ t.Errorf("expected 'no valid JSON objects' error, got: %v", err)
+ }
+}
+
+func TestJSONStreamOnlyEmptyLines(t *testing.T) {
+ input := "\n\n\n"
+
+ jsp := jsonstreamProcessor(t)
+ v := corazawaf.NewTransactionVariables()
+
+ err := jsp.ProcessRequest(strings.NewReader(input), v, plugintypes.BodyProcessorOptions{})
+
+ if err == nil {
+ t.Errorf("expected error for only empty lines, got none")
+ }
+
+ if !strings.Contains(err.Error(), "no valid JSON objects") {
+ t.Errorf("expected 'no valid JSON objects' error, got: %v", err)
+ }
+}
+
+func TestJSONStreamRecursionLimit(t *testing.T) {
+ // Create a deeply nested JSON object that exceeds the default limit
+ // Default limit is 1024, so we create 1500 levels to trigger the error
+ deeplyNested := strings.Repeat(`{"a":`, 1500) + "1" + strings.Repeat(`}`, 1500)
+ input := deeplyNested + "\n"
+
+ jsp := jsonstreamProcessor(t)
+ v := corazawaf.NewTransactionVariables()
+
+ // This should fail because it exceeds DefaultStreamRecursionLimit (1024)
+ err := jsp.ProcessRequest(strings.NewReader(input), v, plugintypes.BodyProcessorOptions{})
+
+ if err == nil {
+ t.Errorf("expected error due to recursion limit, got none")
+ }
+
+ if !strings.Contains(err.Error(), "max recursion") {
+ t.Errorf("expected recursion error, got: %v", err)
+ }
+}
+
+func TestJSONStreamTXVariables(t *testing.T) {
+ input := `{"name": "John"}
+{"name": "Jane"}
+`
+
+ jsp := jsonstreamProcessor(t)
+ v := corazawaf.NewTransactionVariables()
+
+ err := jsp.ProcessRequest(strings.NewReader(input), v, plugintypes.BodyProcessorOptions{})
+
+ if err != nil {
+ t.Errorf("unexpected error: %v", err)
+ return
+ }
+
+ // Check TX variables
+ txVars := v.TX()
+
+ // Check raw body storage
+ rawBody := txVars.Get("jsonstream_request_body")
+ if len(rawBody) == 0 || rawBody[0] != input {
+ t.Errorf("jsonstream_request_body not stored correctly")
+ }
+
+ // Check line count
+ lineCount := txVars.Get("jsonstream_request_line_count")
+ if len(lineCount) == 0 || lineCount[0] != "2" {
+ t.Errorf("jsonstream_request_line_count should be 2, got: %v", lineCount)
+ }
+}
+
+func TestJSONStreamProcessResponse(t *testing.T) {
+ input := `{"status": "ok"}
+{"status": "error"}
+`
+
+ jsp := jsonstreamProcessor(t)
+ v := corazawaf.NewTransactionVariables()
+
+ err := jsp.ProcessResponse(strings.NewReader(input), v, plugintypes.BodyProcessorOptions{})
+
+ if err != nil {
+ t.Errorf("unexpected error: %v", err)
+ return
+ }
+
+ // Check response args
+ responseArgs := v.ResponseArgs()
+
+ if status0 := responseArgs.Get("json.0.status"); len(status0) == 0 || status0[0] != "ok" {
+ t.Errorf("json.0.status should be 'ok', got: %v", status0)
+ }
+
+ if status1 := responseArgs.Get("json.1.status"); len(status1) == 0 || status1[0] != "error" {
+ t.Errorf("json.1.status should be 'error', got: %v", status1)
+ }
+}
+
+func TestJSONSequenceRFC7464(t *testing.T) {
+ // RFC 7464 format uses ASCII RS (0x1E) as record separator
+ const RS = "\x1e"
+
+ input := RS + `{"name": "John", "age": 30}` + "\n" +
+ RS + `{"name": "Jane", "age": 25}` + "\n" +
+ RS + `{"name": "Bob", "age": 35}` + "\n"
+
+ jsp := jsonstreamProcessor(t)
+ v := corazawaf.NewTransactionVariables()
+
+ err := jsp.ProcessRequest(strings.NewReader(input), v, plugintypes.BodyProcessorOptions{})
+
+ if err != nil {
+ t.Errorf("unexpected error: %v", err)
+ return
+ }
+
+ argsPost := v.ArgsPost()
+
+ // Check all three records
+ tests := []struct {
+ record int
+ name string
+ age string
+ }{
+ {0, "John", "30"},
+ {1, "Jane", "25"},
+ {2, "Bob", "35"},
+ }
+
+ for _, tt := range tests {
+ nameKey := fmt.Sprintf("json.%d.name", tt.record)
+ ageKey := fmt.Sprintf("json.%d.age", tt.record)
+
+ if name := argsPost.Get(nameKey); len(name) == 0 || name[0] != tt.name {
+ t.Errorf("%s should be '%s', got: %v", nameKey, tt.name, name)
+ }
+
+ if age := argsPost.Get(ageKey); len(age) == 0 || age[0] != tt.age {
+ t.Errorf("%s should be '%s', got: %v", ageKey, tt.age, age)
+ }
+ }
+
+ // Check line count
+ txVars := v.TX()
+ lineCount := txVars.Get("jsonstream_request_line_count")
+ if len(lineCount) == 0 || lineCount[0] != "3" {
+ t.Errorf("jsonstream_request_line_count should be 3, got: %v", lineCount)
+ }
+}
+
+func TestJSONSequenceNestedObjects(t *testing.T) {
+ const RS = "\x1e"
+
+ input := RS + `{"user": {"name": "John", "id": 1}, "active": true}` + "\n" +
+ RS + `{"user": {"name": "Jane", "id": 2}, "active": false}` + "\n"
+
+ jsp := jsonstreamProcessor(t)
+ v := corazawaf.NewTransactionVariables()
+
+ err := jsp.ProcessRequest(strings.NewReader(input), v, plugintypes.BodyProcessorOptions{})
+
+ if err != nil {
+ t.Errorf("unexpected error: %v", err)
+ return
+ }
+
+ argsPost := v.ArgsPost()
+
+ // Check nested fields
+ if name := argsPost.Get("json.0.user.name"); len(name) == 0 || name[0] != "John" {
+ t.Errorf("json.0.user.name should be 'John', got: %v", name)
+ }
+
+ if id := argsPost.Get("json.0.user.id"); len(id) == 0 || id[0] != "1" {
+ t.Errorf("json.0.user.id should be '1', got: %v", id)
+ }
+
+ if active := argsPost.Get("json.0.active"); len(active) == 0 || active[0] != "true" {
+ t.Errorf("json.0.active should be 'true', got: %v", active)
+ }
+}
+
+func TestJSONSequenceWithoutTrailingNewlines(t *testing.T) {
+ const RS = "\x1e"
+
+ // RFC 7464 says newlines are optional, test without them
+ input := RS + `{"name": "John"}` + RS + `{"name": "Jane"}` + RS + `{"name": "Bob"}`
+
+ jsp := jsonstreamProcessor(t)
+ v := corazawaf.NewTransactionVariables()
+
+ err := jsp.ProcessRequest(strings.NewReader(input), v, plugintypes.BodyProcessorOptions{})
+
+ if err != nil {
+ t.Errorf("unexpected error: %v", err)
+ return
+ }
+
+ argsPost := v.ArgsPost()
+
+ if name := argsPost.Get("json.0.name"); len(name) == 0 || name[0] != "John" {
+ t.Errorf("json.0.name should be 'John', got: %v", name)
+ }
+
+ if name := argsPost.Get("json.1.name"); len(name) == 0 || name[0] != "Jane" {
+ t.Errorf("json.1.name should be 'Jane', got: %v", name)
+ }
+
+ if name := argsPost.Get("json.2.name"); len(name) == 0 || name[0] != "Bob" {
+ t.Errorf("json.2.name should be 'Bob', got: %v", name)
+ }
+}
+
+func TestJSONSequenceEmptyRecords(t *testing.T) {
+ const RS = "\x1e"
+
+ // Empty records (multiple RS in a row) should be skipped
+ input := RS + RS + `{"name": "John"}` + "\n" + RS + "\n" + RS + `{"name": "Jane"}` + "\n"
+
+ jsp := jsonstreamProcessor(t)
+ v := corazawaf.NewTransactionVariables()
+
+ err := jsp.ProcessRequest(strings.NewReader(input), v, plugintypes.BodyProcessorOptions{})
+
+ if err != nil {
+ t.Errorf("unexpected error: %v", err)
+ return
+ }
+
+ argsPost := v.ArgsPost()
+
+ // Should only have 2 records (empty ones skipped)
+ if name := argsPost.Get("json.0.name"); len(name) == 0 || name[0] != "John" {
+ t.Errorf("json.0.name should be 'John', got: %v", name)
+ }
+
+ if name := argsPost.Get("json.1.name"); len(name) == 0 || name[0] != "Jane" {
+ t.Errorf("json.1.name should be 'Jane', got: %v", name)
+ }
+
+ // Third record should not exist
+ if name := argsPost.Get("json.2.name"); len(name) != 0 {
+ t.Errorf("json.2.name should not exist, got: %v", name)
+ }
+}
+
+func TestJSONSequenceInvalidJSON(t *testing.T) {
+ const RS = "\x1e"
+
+ input := RS + `{"name": "John"}` + "\n" +
+ RS + `{invalid json}` + "\n" +
+ RS + `{"name": "Jane"}` + "\n"
+
+ jsp := jsonstreamProcessor(t)
+ v := corazawaf.NewTransactionVariables()
+
+ err := jsp.ProcessRequest(strings.NewReader(input), v, plugintypes.BodyProcessorOptions{})
+
+ if err == nil {
+ t.Errorf("expected error for invalid JSON, got none")
+ }
+
+ if !strings.Contains(err.Error(), "invalid JSON") {
+ t.Errorf("expected 'invalid JSON' error, got: %v", err)
+ }
+}
+
+func TestFormatAutoDetection(t *testing.T) {
+ const RS = "\x1e"
+
+ tests := []struct {
+ name string
+ input string
+ format string
+ }{
+ {
+ name: "NDJSON without RS",
+ input: `{"name": "John"}` + "\n" + `{"name": "Jane"}` + "\n",
+ format: "NDJSON",
+ },
+ {
+ name: "JSON Sequence with RS",
+ input: RS + `{"name": "John"}` + "\n" + RS + `{"name": "Jane"}` + "\n",
+ format: "JSON Sequence",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ jsp := jsonstreamProcessor(t)
+ v := corazawaf.NewTransactionVariables()
+
+ err := jsp.ProcessRequest(strings.NewReader(tt.input), v, plugintypes.BodyProcessorOptions{})
+
+ if err != nil {
+ t.Errorf("unexpected error for %s: %v", tt.format, err)
+ return
+ }
+
+ argsPost := v.ArgsPost()
+
+ // Both formats should produce the same output
+ if name := argsPost.Get("json.0.name"); len(name) == 0 || name[0] != "John" {
+ t.Errorf("%s: json.0.name should be 'John', got: %v", tt.format, name)
+ }
+
+ if name := argsPost.Get("json.1.name"); len(name) == 0 || name[0] != "Jane" {
+ t.Errorf("%s: json.1.name should be 'Jane', got: %v", tt.format, name)
+ }
+ })
+ }
+}
+
+// --- Benchmarks ---
+
+// buildNDJSONStream generates an NDJSON stream with the given number of records using the record template.
+func buildNDJSONStream(numRecords int, record string) string {
+ var sb strings.Builder
+ sb.Grow(numRecords * (len(record) + 1))
+ for i := 0; i < numRecords; i++ {
+ sb.WriteString(record)
+ sb.WriteByte('\n')
+ }
+ return sb.String()
+}
+
+// buildRFC7464Stream generates an RFC 7464 JSON Sequence stream.
+func buildRFC7464Stream(numRecords int, record string) string {
+ var sb strings.Builder
+ sb.Grow(numRecords * (len(record) + 2))
+ for i := 0; i < numRecords; i++ {
+ sb.WriteByte('\x1e')
+ sb.WriteString(record)
+ sb.WriteByte('\n')
+ }
+ return sb.String()
+}
+
+const (
+ smallRecord = `{"id":1,"name":"Alice"}`
+ mediumRecord = `{"user_id":1234567890,"name":"User Name","email":"user@example.com","role":"admin","active":true,"tags":["tag1","tag2","tag3"]}`
+ nestedRecord = `{"user":{"name":"Alice","address":{"city":"NYC","zip":"10001"}},"scores":[95,87,92],"meta":{"created":"2026-01-01","active":true}}`
+)
+
+func BenchmarkJSONStreamProcessor(b *testing.B) {
+ jsp, err := plugins.GetBodyProcessor("jsonstream")
+ if err != nil {
+ b.Fatal(err)
+ }
+
+ benchmarks := []struct {
+ name string
+ numRecords int
+ record string
+ }{
+ {"small/1", 1, smallRecord},
+ {"small/10", 10, smallRecord},
+ {"small/100", 100, smallRecord},
+ {"small/1000", 1000, smallRecord},
+ {"medium/1", 1, mediumRecord},
+ {"medium/10", 10, mediumRecord},
+ {"medium/100", 100, mediumRecord},
+ {"medium/1000", 1000, mediumRecord},
+ {"nested/1", 1, nestedRecord},
+ {"nested/10", 10, nestedRecord},
+ {"nested/100", 100, nestedRecord},
+ {"nested/1000", 1000, nestedRecord},
+ }
+
+ for _, bm := range benchmarks {
+ input := buildNDJSONStream(bm.numRecords, bm.record)
+ b.Run("ProcessRequest/"+bm.name, func(b *testing.B) {
+ b.SetBytes(int64(len(input)))
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ v := corazawaf.NewTransactionVariables()
+ if err := jsp.ProcessRequest(strings.NewReader(input), v, plugintypes.BodyProcessorOptions{}); err != nil {
+ b.Fatal(err)
+ }
+ }
+ })
+ }
+}
+
+func BenchmarkJSONStreamCallback(b *testing.B) {
+ bp, err := plugins.GetBodyProcessor("jsonstream")
+ if err != nil {
+ b.Fatal(err)
+ }
+ sp, ok := bp.(plugintypes.StreamingBodyProcessor)
+ if !ok {
+ b.Fatal("jsonstream processor does not implement StreamingBodyProcessor")
+ }
+
+ benchmarks := []struct {
+ name string
+ numRecords int
+ record string
+ }{
+ {"small/1", 1, smallRecord},
+ {"small/10", 10, smallRecord},
+ {"small/100", 100, smallRecord},
+ {"small/1000", 1000, smallRecord},
+ {"medium/1", 1, mediumRecord},
+ {"medium/10", 10, mediumRecord},
+ {"medium/100", 100, mediumRecord},
+ {"medium/1000", 1000, mediumRecord},
+ {"nested/1", 1, nestedRecord},
+ {"nested/10", 10, nestedRecord},
+ {"nested/100", 100, nestedRecord},
+ {"nested/1000", 1000, nestedRecord},
+ }
+
+ noop := func(_ int, _ map[string]string, _ string) error { return nil }
+
+ for _, bm := range benchmarks {
+ input := buildNDJSONStream(bm.numRecords, bm.record)
+ b.Run(bm.name, func(b *testing.B) {
+ b.SetBytes(int64(len(input)))
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ if err := sp.ProcessRequestRecords(strings.NewReader(input), plugintypes.BodyProcessorOptions{}, noop); err != nil {
+ b.Fatal(err)
+ }
+ }
+ })
+ }
+}
+
+func BenchmarkJSONStreamRFC7464(b *testing.B) {
+ bp, err := plugins.GetBodyProcessor("jsonstream")
+ if err != nil {
+ b.Fatal(err)
+ }
+ sp, ok := bp.(plugintypes.StreamingBodyProcessor)
+ if !ok {
+ b.Fatal("jsonstream processor does not implement StreamingBodyProcessor")
+ }
+
+ benchmarks := []struct {
+ name string
+ numRecords int
+ record string
+ }{
+ {"small/10", 10, smallRecord},
+ {"small/100", 100, smallRecord},
+ {"medium/100", 100, mediumRecord},
+ {"nested/100", 100, nestedRecord},
+ }
+
+ noop := func(_ int, _ map[string]string, _ string) error { return nil }
+
+ for _, bm := range benchmarks {
+ input := buildRFC7464Stream(bm.numRecords, bm.record)
+ b.Run(bm.name, func(b *testing.B) {
+ b.SetBytes(int64(len(input)))
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ if err := sp.ProcessRequestRecords(strings.NewReader(input), plugintypes.BodyProcessorOptions{}, noop); err != nil {
+ b.Fatal(err)
+ }
+ }
+ })
+ }
+}
+
+func TestProcessorRegistration(t *testing.T) {
+ // Test that all three aliases are registered
+ aliases := []string{"jsonstream", "ndjson", "jsonlines"}
+
+ for _, alias := range aliases {
+ t.Run(alias, func(t *testing.T) {
+ processor, err := plugins.GetBodyProcessor(alias)
+ if err != nil {
+ t.Errorf("Failed to get processor '%s': %v", alias, err)
+ }
+ if processor == nil {
+ t.Errorf("Processor '%s' is nil", alias)
+ }
+ })
+ }
+}
+
+func TestJSONStreamLargeToken(t *testing.T) {
+ // Create a JSON object that exceeds 1MB to trigger scanner buffer error
+ largeValue := strings.Repeat("x", 2*1024*1024) // 2MB string
+ input := fmt.Sprintf(`{"large": "%s"}`, largeValue) + "\n"
+
+ jsp := jsonstreamProcessor(t)
+ v := corazawaf.NewTransactionVariables()
+
+ err := jsp.ProcessRequest(strings.NewReader(input), v, plugintypes.BodyProcessorOptions{})
+
+ // Should get a scanner error about token too long
+ if err == nil {
+ t.Errorf("expected error for token too large, got none")
+ }
+
+ if !strings.Contains(err.Error(), "error reading stream") && !strings.Contains(err.Error(), "token too long") {
+ t.Logf("Got error (this is expected): %v", err)
+ }
+}
+
+func TestJSONSequenceLargeToken(t *testing.T) {
+ const RS = "\x1e"
+ // Create a JSON object that exceeds 1MB to trigger scanner buffer error
+ largeValue := strings.Repeat("x", 2*1024*1024) // 2MB string
+ input := RS + fmt.Sprintf(`{"large": "%s"}`, largeValue) + "\n"
+
+ jsp := jsonstreamProcessor(t)
+ v := corazawaf.NewTransactionVariables()
+
+ err := jsp.ProcessRequest(strings.NewReader(input), v, plugintypes.BodyProcessorOptions{})
+
+ // Should get a scanner error about token too long
+ if err == nil {
+ t.Errorf("expected error for token too large, got none")
+ }
+
+ if !strings.Contains(err.Error(), "error reading stream") && !strings.Contains(err.Error(), "token too long") {
+ t.Logf("Got error (this is expected): %v", err)
+ }
+}
+
+func TestProcessResponseWithoutResponseBody(t *testing.T) {
+ input := `{"status": "ok"}
+`
+
+ jsp := jsonstreamProcessor(t)
+ v := corazawaf.NewTransactionVariables()
+
+ // Process response without setting up ResponseBody
+ err := jsp.ProcessResponse(strings.NewReader(input), v, plugintypes.BodyProcessorOptions{})
+
+ if err != nil {
+ t.Errorf("unexpected error: %v", err)
+ return
+ }
+
+ // Check response args were still populated
+ responseArgs := v.ResponseArgs()
+ if status := responseArgs.Get("json.0.status"); len(status) == 0 || status[0] != "ok" {
+ t.Errorf("json.0.status should be 'ok', got: %v", status)
+ }
+
+ // TX variables should be set (but response body related ones may not be if ResponseBody() is nil)
+ txVars := v.TX()
+ lineCount := txVars.Get("jsonstream_response_line_count")
+ if len(lineCount) == 0 || lineCount[0] != "1" {
+ t.Logf("jsonstream_response_line_count: %v (may be empty if ResponseBody() is nil)", lineCount)
+ }
+}
+
+// errorReader is a reader that always returns an error
+type errorReader struct{}
+
+func (e errorReader) Read(p []byte) (n int, err error) {
+ return 0, errors.New("simulated read error")
+}
+
+func TestJSONStreamPeekError(t *testing.T) {
+ jsp := jsonstreamProcessor(t)
+ v := corazawaf.NewTransactionVariables()
+
+ // Use an error reader to trigger peek error
+ err := jsp.ProcessRequest(errorReader{}, v, plugintypes.BodyProcessorOptions{})
+
+ if err == nil {
+ t.Errorf("expected error from peek, got none")
+ }
+
+ if !strings.Contains(err.Error(), "error peeking stream") {
+ t.Errorf("expected 'error peeking stream' error, got: %v", err)
+ }
+}
+
+func TestJSONSequenceOnlyRS(t *testing.T) {
+ const RS = "\x1e"
+
+ // Only RS characters, no actual JSON
+ input := RS + RS + RS
+
+ jsp := jsonstreamProcessor(t)
+ v := corazawaf.NewTransactionVariables()
+
+ err := jsp.ProcessRequest(strings.NewReader(input), v, plugintypes.BodyProcessorOptions{})
+
+ if err == nil {
+ t.Errorf("expected error for only RS characters, got none")
+ }
+
+ if !strings.Contains(err.Error(), "no valid JSON objects") {
+ t.Errorf("expected 'no valid JSON objects' error, got: %v", err)
+ }
+}
+
+// --- Streaming callback tests ---
+
+func jsonstreamStreamingProcessor(t *testing.T) plugintypes.StreamingBodyProcessor {
+ t.Helper()
+ bp := jsonstreamProcessor(t)
+ sp, ok := bp.(plugintypes.StreamingBodyProcessor)
+ if !ok {
+ t.Fatal("jsonstream processor does not implement StreamingBodyProcessor")
+ }
+ return sp
+}
+
+func TestStreamingCallbackPerRecord(t *testing.T) {
+ input := `{"name": "Alice", "age": 30}
+{"name": "Bob", "age": 25}
+{"name": "Charlie", "age": 35}
+`
+ sp := jsonstreamStreamingProcessor(t)
+
+ var records []int
+ var rawRecords []string
+
+ err := sp.ProcessRequestRecords(strings.NewReader(input), plugintypes.BodyProcessorOptions{},
+ func(recordNum int, fields map[string]string, rawRecord string) error {
+ records = append(records, recordNum)
+ rawRecords = append(rawRecords, rawRecord)
+ return nil
+ })
+
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+
+ if len(records) != 3 {
+ t.Fatalf("expected 3 records, got %d", len(records))
+ }
+
+ for i, r := range records {
+ if r != i {
+ t.Errorf("expected recordNum %d, got %d", i, r)
+ }
+ }
+
+ if !strings.Contains(rawRecords[0], "Alice") {
+ t.Errorf("expected raw record 0 to contain Alice, got: %s", rawRecords[0])
+ }
+ if !strings.Contains(rawRecords[1], "Bob") {
+ t.Errorf("expected raw record 1 to contain Bob, got: %s", rawRecords[1])
+ }
+}
+
+func TestStreamingFieldsHaveRecordPrefix(t *testing.T) {
+ input := `{"name": "Alice"}
+{"name": "Bob"}
+`
+ sp := jsonstreamStreamingProcessor(t)
+
+ var allFields []map[string]string
+
+ err := sp.ProcessRequestRecords(strings.NewReader(input), plugintypes.BodyProcessorOptions{},
+ func(recordNum int, fields map[string]string, _ string) error {
+ // Copy fields since the map may be reused
+ copy := make(map[string]string, len(fields))
+ for k, v := range fields {
+ copy[k] = v
+ }
+ allFields = append(allFields, copy)
+ return nil
+ })
+
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+
+ if len(allFields) != 2 {
+ t.Fatalf("expected 2 records, got %d", len(allFields))
+ }
+
+ // Record 0 should have json.0.name
+ if v, ok := allFields[0]["json.0.name"]; !ok || v != "Alice" {
+ t.Errorf("expected json.0.name=Alice, got %q (ok=%v)", v, ok)
+ }
+
+ // Record 1 should have json.1.name
+ if v, ok := allFields[1]["json.1.name"]; !ok || v != "Bob" {
+ t.Errorf("expected json.1.name=Bob, got %q (ok=%v)", v, ok)
+ }
+
+ // Record 0 should NOT have json.1.* keys
+ for k := range allFields[0] {
+ if strings.HasPrefix(k, "json.1.") {
+ t.Errorf("record 0 should not have key %q", k)
+ }
+ }
+}
+
+func TestStreamingInterruptionStops(t *testing.T) {
+ input := `{"name": "Alice"}
+{"name": "Bob"}
+{"name": "Charlie"}
+{"name": "Dave"}
+`
+ sp := jsonstreamStreamingProcessor(t)
+ errBlocked := errors.New("blocked")
+ processedRecords := 0
+
+ err := sp.ProcessRequestRecords(strings.NewReader(input), plugintypes.BodyProcessorOptions{},
+ func(recordNum int, fields map[string]string, _ string) error {
+ processedRecords++
+ if recordNum == 1 {
+ return errBlocked
+ }
+ return nil
+ })
+
+ if err != errBlocked {
+ t.Fatalf("expected errBlocked, got: %v", err)
+ }
+
+ if processedRecords != 2 {
+ t.Errorf("expected 2 records processed (0 and 1), got %d", processedRecords)
+ }
+}
+
+func TestStreamingEmptyStream(t *testing.T) {
+ sp := jsonstreamStreamingProcessor(t)
+
+ err := sp.ProcessRequestRecords(strings.NewReader(""), plugintypes.BodyProcessorOptions{},
+ func(_ int, _ map[string]string, _ string) error {
+ t.Error("callback should not be called for empty stream")
+ return nil
+ })
+
+ if err == nil {
+ t.Error("expected error for empty stream")
+ }
+}
+
+func TestStreamingRFC7464WithCallback(t *testing.T) {
+ input := "\x1e{\"name\": \"Alice\"}\n\x1e{\"name\": \"Bob\"}\n"
+
+ sp := jsonstreamStreamingProcessor(t)
+
+ var records []int
+ var allFields []map[string]string
+
+ err := sp.ProcessRequestRecords(strings.NewReader(input), plugintypes.BodyProcessorOptions{},
+ func(recordNum int, fields map[string]string, _ string) error {
+ records = append(records, recordNum)
+ copy := make(map[string]string, len(fields))
+ for k, v := range fields {
+ copy[k] = v
+ }
+ allFields = append(allFields, copy)
+ return nil
+ })
+
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+
+ if len(records) != 2 {
+ t.Fatalf("expected 2 records, got %d", len(records))
+ }
+
+ if v := allFields[0]["json.0.name"]; v != "Alice" {
+ t.Errorf("expected json.0.name=Alice, got %q", v)
+ }
+ if v := allFields[1]["json.1.name"]; v != "Bob" {
+ t.Errorf("expected json.1.name=Bob, got %q", v)
+ }
+}
+
+func TestStreamingResponseRecords(t *testing.T) {
+ input := `{"status": "ok"}
+{"status": "error"}
+`
+ sp := jsonstreamStreamingProcessor(t)
+ processedRecords := 0
+
+ err := sp.ProcessResponseRecords(strings.NewReader(input), plugintypes.BodyProcessorOptions{},
+ func(recordNum int, fields map[string]string, _ string) error {
+ processedRecords++
+ return nil
+ })
+
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+
+ if processedRecords != 2 {
+ t.Errorf("expected 2 records, got %d", processedRecords)
+ }
+}
+
+func TestStreamingBackwardCompat(t *testing.T) {
+ // Verify that ProcessRequest still works unchanged after refactoring
+ input := `{"name": "Alice"}
+{"name": "Bob"}
+`
+ jsp := jsonstreamProcessor(t)
+ v := corazawaf.NewTransactionVariables()
+
+ err := jsp.ProcessRequest(strings.NewReader(input), v, plugintypes.BodyProcessorOptions{})
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+
+ argsPost := v.ArgsPost()
+ if vals := argsPost.Get("json.0.name"); len(vals) == 0 || vals[0] != "Alice" {
+ t.Errorf("expected json.0.name=Alice via ProcessRequest, got %v", vals)
+ }
+ if vals := argsPost.Get("json.1.name"); len(vals) == 0 || vals[0] != "Bob" {
+ t.Errorf("expected json.1.name=Bob via ProcessRequest, got %v", vals)
+ }
+}
diff --git a/experimental/plugins/bodyprocessors.go b/experimental/plugins/bodyprocessors.go
index 10ef7b552..2b4f1ea02 100644
--- a/experimental/plugins/bodyprocessors.go
+++ b/experimental/plugins/bodyprocessors.go
@@ -14,3 +14,9 @@ import (
func RegisterBodyProcessor(name string, fn func() plugintypes.BodyProcessor) {
bodyprocessors.RegisterBodyProcessor(name, fn)
}
+
+// GetBodyProcessor returns a body processor by name.
+// If the body processor is not found, it returns an error
+func GetBodyProcessor(name string) (plugintypes.BodyProcessor, error) {
+ return bodyprocessors.GetBodyProcessor(name)
+}
diff --git a/experimental/plugins/plugintypes/bodyprocessor.go b/experimental/plugins/plugintypes/bodyprocessor.go
index d3052a53c..ff9f3c04c 100644
--- a/experimental/plugins/plugintypes/bodyprocessor.go
+++ b/experimental/plugins/plugintypes/bodyprocessor.go
@@ -21,6 +21,8 @@ type BodyProcessorOptions struct {
FileMode fs.FileMode
// DirMode is the mode of the directory that will be created
DirMode fs.FileMode
+ // RequestBodyRecursionLimit is the maximum recursion level accepted in a body processor
+ RequestBodyRecursionLimit int
}
// BodyProcessor interface is used to create
@@ -32,3 +34,24 @@ type BodyProcessor interface {
ProcessRequest(reader io.Reader, variables TransactionVariables, options BodyProcessorOptions) error
ProcessResponse(reader io.Reader, variables TransactionVariables, options BodyProcessorOptions) error
}
+
+// StreamingBodyProcessor extends BodyProcessor with per-record streaming support.
+// Body processors that handle multi-record formats (NDJSON, JSON-Seq) can implement
+// this interface to enable per-record rule evaluation instead of evaluating rules
+// only after the entire body has been consumed.
+//
+// The callback receives pre-formatted field keys including the record number prefix
+// (e.g., "json.0.name", "json.1.age") and the raw record text. Returning a non-nil
+// error from the callback stops processing immediately.
+type StreamingBodyProcessor interface {
+ BodyProcessor
+
+ // ProcessRequestRecords reads records one at a time from the reader and calls fn
+ // for each record's parsed fields. Processing stops if fn returns a non-nil error.
+ ProcessRequestRecords(reader io.Reader, options BodyProcessorOptions,
+ fn func(recordNum int, fields map[string]string, rawRecord string) error) error
+
+ // ProcessResponseRecords is the response equivalent of ProcessRequestRecords.
+ ProcessResponseRecords(reader io.Reader, options BodyProcessorOptions,
+ fn func(recordNum int, fields map[string]string, rawRecord string) error) error
+}
diff --git a/experimental/plugins/plugintypes/operator.go b/experimental/plugins/plugintypes/operator.go
index 3d950aee3..0aa598d46 100644
--- a/experimental/plugins/plugintypes/operator.go
+++ b/experimental/plugins/plugintypes/operator.go
@@ -5,6 +5,12 @@ package plugintypes
import "io/fs"
+// Memoizer caches the result of expensive function calls by key.
+// Implementations must be safe for concurrent use.
+type Memoizer interface {
+ Do(key string, fn func() (any, error)) (any, error)
+}
+
// OperatorOptions is used to store the options for a rule operator
type OperatorOptions struct {
// Arguments is used to store the operator args
@@ -18,6 +24,9 @@ type OperatorOptions struct {
// Datasets contains input datasets or dictionaries
Datasets map[string][]string
+
+ // Memoizer caches expensive compilations (regex, aho-corasick).
+ Memoizer Memoizer
}
// Operator interface is used to define rule @operators
diff --git a/experimental/rule_observer.go b/experimental/rule_observer.go
new file mode 100644
index 000000000..dea42fea4
--- /dev/null
+++ b/experimental/rule_observer.go
@@ -0,0 +1,25 @@
+// Copyright 2026 Juan Pablo Tosso and the OWASP Coraza contributors
+// SPDX-License-Identifier: Apache-2.0
+
+package experimental
+
+import (
+ "github.com/corazawaf/coraza/v3"
+ "github.com/corazawaf/coraza/v3/types"
+)
+
+// wafConfigWithRuleObserver is the private capability interface
+type wafConfigWithRuleObserver interface {
+ WithRuleObserver(func(rule types.RuleMetadata)) coraza.WAFConfig
+}
+
+// WAFConfigWithRuleObserver applies a rule observer if supported.
+func WAFConfigWithRuleObserver(
+ cfg coraza.WAFConfig,
+ observer func(rule types.RuleMetadata),
+) coraza.WAFConfig {
+ if c, ok := cfg.(wafConfigWithRuleObserver); ok {
+ return c.WithRuleObserver(observer)
+ }
+ return cfg
+}
diff --git a/experimental/rule_observer_test.go b/experimental/rule_observer_test.go
new file mode 100644
index 000000000..ec70ab20f
--- /dev/null
+++ b/experimental/rule_observer_test.go
@@ -0,0 +1,82 @@
+// Copyright 2026 Juan Pablo Tosso and the OWASP Coraza contributors
+// SPDX-License-Identifier: Apache-2.0
+
+package experimental_test
+
+import (
+ "testing"
+
+ "github.com/corazawaf/coraza/v3"
+ "github.com/corazawaf/coraza/v3/experimental"
+ "github.com/corazawaf/coraza/v3/types"
+)
+
+func TestRuleObserver(t *testing.T) {
+ testCases := map[string]struct {
+ directives string
+ withObserver bool
+ expectRules int
+ }{
+ "no observer configured": {
+ directives: `
+ SecRule REQUEST_URI "@contains /test" "id:1000,phase:1,deny"
+ `,
+ withObserver: false,
+ expectRules: 0,
+ },
+ "single rule observed": {
+ directives: `
+ SecRule REQUEST_URI "@contains /test" "id:1001,phase:1,deny"
+ `,
+ withObserver: true,
+ expectRules: 1,
+ },
+ "multiple rules observed": {
+ directives: `
+ SecRule REQUEST_URI "@contains /a" "id:1002,phase:1,deny"
+ SecRule REQUEST_URI "@contains /b" "id:1003,phase:2,deny"
+ `,
+ withObserver: true,
+ expectRules: 2,
+ },
+ }
+
+ for name, tc := range testCases {
+ t.Run(name, func(t *testing.T) {
+ var observed []types.RuleMetadata
+
+ cfg := coraza.NewWAFConfig().
+ WithDirectives(tc.directives)
+
+ if tc.withObserver {
+ cfg = experimental.WAFConfigWithRuleObserver(cfg, func(rule types.RuleMetadata) {
+ observed = append(observed, rule)
+ })
+ }
+
+ waf, err := coraza.NewWAF(cfg)
+ if err != nil {
+ t.Fatalf("unexpected error creating WAF: %v", err)
+ }
+ if waf == nil {
+ t.Fatal("waf is nil")
+ }
+
+ if len(observed) != tc.expectRules {
+ t.Fatalf("expected %d observed rules, got %d", tc.expectRules, len(observed))
+ }
+
+ for _, rule := range observed {
+ if rule.ID() == 0 {
+ t.Fatal("expected rule ID to be set")
+ }
+ if rule.File() == "" {
+ t.Fatal("expected rule file to be set")
+ }
+ if rule.Line() == 0 {
+ t.Fatal("expected rule line to be set")
+ }
+ }
+ })
+ }
+}
diff --git a/experimental/streaming.go b/experimental/streaming.go
new file mode 100644
index 000000000..2d3028433
--- /dev/null
+++ b/experimental/streaming.go
@@ -0,0 +1,33 @@
+// Copyright 2026 OWASP Coraza contributors
+// SPDX-License-Identifier: Apache-2.0
+
+package experimental
+
+import (
+ "io"
+
+ "github.com/corazawaf/coraza/v3/types"
+)
+
+// StreamingTransaction extends Transaction with streaming body processing capabilities.
+// Transactions created by a WAF instance implement this interface when the WAF supports
+// streaming body processors (e.g., NDJSON, JSON-Seq).
+//
+// Unlike the standard ProcessRequestBody/ProcessResponseBody methods which require the
+// full body to be buffered first, streaming methods read records directly from input,
+// evaluate rules per record, and write clean records to output for relay to the backend.
+type StreamingTransaction interface {
+ types.Transaction
+
+ // ProcessRequestBodyFromStream reads records from input, evaluates Phase 2 rules
+ // per record, and writes clean records to output. If a record triggers an interruption,
+ // processing stops and the interruption is returned.
+ //
+ // For non-streaming body processors, this falls back to buffering the input,
+ // processing it normally, and copying the result to output.
+ ProcessRequestBodyFromStream(input io.Reader, output io.Writer) (*types.Interruption, error)
+
+ // ProcessResponseBodyFromStream reads records from input, evaluates Phase 4 rules
+ // per record, and writes clean records to output.
+ ProcessResponseBodyFromStream(input io.Reader, output io.Writer) (*types.Interruption, error)
+}
diff --git a/experimental/waf.go b/experimental/waf.go
index bc167a2a9..c2fd967ca 100644
--- a/experimental/waf.go
+++ b/experimental/waf.go
@@ -4,6 +4,8 @@
package experimental
import (
+ "io"
+
"github.com/corazawaf/coraza/v3/internal/corazawaf"
"github.com/corazawaf/coraza/v3/types"
)
@@ -15,3 +17,19 @@ type Options = corazawaf.Options
type WAFWithOptions interface {
NewTransactionWithOptions(Options) types.Transaction
}
+
+// WAFWithRules is an interface that allows to inspect the number of
+// rules loaded in a WAF instance. This is useful for connectors that
+// need to verify rule loading or implement configuration caching.
+type WAFWithRules interface {
+ // RulesCount returns the number of rules in this WAF.
+ RulesCount() int
+}
+
+// WAFCloser allows closing a WAF instance to release cached resources
+// such as compiled regex patterns. Transactions in-flight are unaffected
+// as they hold their own references to compiled objects.
+// This will be promoted to the public WAF interface in v4.
+type WAFCloser interface {
+ io.Closer
+}
diff --git a/go.mod b/go.mod
index d58a70941..5098dac9b 100644
--- a/go.mod
+++ b/go.mod
@@ -1,6 +1,6 @@
module github.com/corazawaf/coraza/v3
-go 1.24.0
+go 1.25.0
// Testing dependencies:
// - go-mockdns
@@ -19,7 +19,7 @@ go 1.24.0
require (
github.com/anuraaga/go-modsecurity v0.0.0-20220824035035-b9a4099778df
github.com/corazawaf/coraza-coreruleset v0.0.0-20240226094324-415b1017abdc
- github.com/corazawaf/libinjection-go v0.2.2
+ github.com/corazawaf/libinjection-go v0.3.2
github.com/foxcpp/go-mockdns v1.1.0
github.com/jcchavezs/mergefs v0.1.0
github.com/kaptinlin/jsonschema v0.4.6
@@ -28,8 +28,8 @@ require (
github.com/petar-dambovaliev/aho-corasick v0.0.0-20250424160509-463d218d4745
github.com/tidwall/gjson v1.18.0
github.com/valllabh/ocsf-schema-golang v1.0.3
- golang.org/x/net v0.43.0
- golang.org/x/sync v0.16.0
+ golang.org/x/net v0.52.0
+ golang.org/x/sync v0.20.0
rsc.io/binaryregexp v0.2.0
)
@@ -45,10 +45,10 @@ require (
github.com/stretchr/testify v1.11.1 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.1 // indirect
- golang.org/x/mod v0.26.0 // indirect
- golang.org/x/sys v0.35.0 // indirect
- golang.org/x/text v0.28.0 // indirect
- golang.org/x/tools v0.35.0 // indirect
+ golang.org/x/mod v0.33.0 // indirect
+ golang.org/x/sys v0.42.0 // indirect
+ golang.org/x/text v0.35.0 // indirect
+ golang.org/x/tools v0.42.0 // indirect
google.golang.org/protobuf v1.35.1 // indirect
)
diff --git a/go.sum b/go.sum
index 9ca14c92f..0a7340c52 100644
--- a/go.sum
+++ b/go.sum
@@ -2,8 +2,8 @@ github.com/anuraaga/go-modsecurity v0.0.0-20220824035035-b9a4099778df h1:YWiVl53
github.com/anuraaga/go-modsecurity v0.0.0-20220824035035-b9a4099778df/go.mod h1:7jguE759ADzy2EkxGRXigiC0ER1Yq2IFk2qNtwgzc7U=
github.com/corazawaf/coraza-coreruleset v0.0.0-20240226094324-415b1017abdc h1:OlJhrgI3I+FLUCTI3JJW8MoqyM78WbqJjecqMnqG+wc=
github.com/corazawaf/coraza-coreruleset v0.0.0-20240226094324-415b1017abdc/go.mod h1:7rsocqNDkTCira5T0M7buoKR2ehh7YZiPkzxRuAgvVU=
-github.com/corazawaf/libinjection-go v0.2.2 h1:Chzodvb6+NXh6wew5/yhD0Ggioif9ACrQGR4qjTCs1g=
-github.com/corazawaf/libinjection-go v0.2.2/go.mod h1:OP4TM7xdJ2skyXqNX1AN1wN5nNZEmJNuWbNPOItn7aw=
+github.com/corazawaf/libinjection-go v0.3.2 h1:9rrKt0lpg4WvUXt+lwS06GywfqRXXsa/7JcOw5cQLwI=
+github.com/corazawaf/libinjection-go v0.3.2/go.mod h1:Ik/+w3UmTWH9yn366RgS9D95K3y7Atb5m/H/gXzzPCk=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/foxcpp/go-mockdns v1.1.0 h1:jI0rD8M0wuYAxL7r/ynTrCQQq0BVqfB99Vgk7DlmewI=
@@ -57,8 +57,8 @@ golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/mod v0.14.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
-golang.org/x/mod v0.26.0 h1:EGMPT//Ezu+ylkCijjPc+f4Aih7sZvaAr+O3EHBxvZg=
-golang.org/x/mod v0.26.0/go.mod h1:/j6NAhSk8iQ723BGAUyoAcn7SlD7s15Dp9Nd/SfeaFQ=
+golang.org/x/mod v0.33.0 h1:tHFzIWbBifEmbwtGz65eaWyGiGZatSrT9prnU8DbVL8=
+golang.org/x/mod v0.33.0/go.mod h1:swjeQEj+6r7fODbD2cqrnje9PnziFuw4bmLbBZFrQ5w=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
@@ -67,16 +67,16 @@ golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk=
golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE=
golang.org/x/net v0.18.0/go.mod h1:/czyP5RqHAH4odGYxBJ1qz0+CE5WZ+2j1YgoEo8F2jQ=
-golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE=
-golang.org/x/net v0.43.0/go.mod h1:vhO1fvI4dGsIjh73sWfUVjj3N7CA9WkKJNQm2svM6Jg=
+golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0=
+golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
golang.org/x/sync v0.4.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
golang.org/x/sync v0.5.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
-golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw=
-golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
+golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
+golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
@@ -87,8 +87,8 @@ golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.14.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
-golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI=
-golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
+golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo=
+golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
@@ -103,16 +103,16 @@ golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
-golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng=
-golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU=
+golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8=
+golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58=
golang.org/x/tools v0.15.0/go.mod h1:hpksKq4dtpQWS1uQ61JkdqWM3LscIS6Slf+VVkm+wQk=
-golang.org/x/tools v0.35.0 h1:mBffYraMEf7aa0sB+NuKnuCy8qI/9Bughn8dC2Gu5r0=
-golang.org/x/tools v0.35.0/go.mod h1:NKdj5HkL/73byiZSJjqJgKn3ep7KjFkBOkR/Hps3VPw=
+golang.org/x/tools v0.42.0 h1:uNgphsn75Tdz5Ji2q36v/nsFSfR/9BRFvqhGBaJGd5k=
+golang.org/x/tools v0.42.0/go.mod h1:Ma6lCIwGZvHK6XtgbswSoWroEkhugApmsXyrUmBhfr0=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/protobuf v1.35.1 h1:m3LfL6/Ca+fqnjnlqQXNpFPABW1UD7mjh8KO2mKFytA=
google.golang.org/protobuf v1.35.1/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE=
diff --git a/go.work b/go.work
index 3e2fdc963..c03028a31 100644
--- a/go.work
+++ b/go.work
@@ -1,4 +1,4 @@
-go 1.24.0
+go 1.25.0
use (
.
diff --git a/http/interceptor_test.go b/http/interceptor_test.go
index cab89b793..dcd0c9361 100644
--- a/http/interceptor_test.go
+++ b/http/interceptor_test.go
@@ -92,10 +92,10 @@ func TestWriteWithWriteHeader(t *testing.T) {
res := httptest.NewRecorder()
rw, responseProcessor := wrap(res, req, tx)
- rw.WriteHeader(204)
+ rw.WriteHeader(201)
// although we called WriteHeader, status code should be applied until
// responseProcessor is called.
- if unwanted, have := 204, res.Code; unwanted == have {
+ if unwanted, have := 201, res.Code; unwanted == have {
t.Errorf("unexpected status code %d", have)
}
@@ -114,7 +114,7 @@ func TestWriteWithWriteHeader(t *testing.T) {
t.Errorf("unexpected error: %v", err)
}
- if want, have := 204, res.Code; want != have {
+ if want, have := 201, res.Code; want != have {
t.Errorf("unexpected status code, want %d, have %d", want, have)
}
}
@@ -203,10 +203,10 @@ func TestReadFrom(t *testing.T) {
}
rw, responseProcessor := wrap(resWithReaderFrom, req, tx)
- rw.WriteHeader(204)
+ rw.WriteHeader(201)
// although we called WriteHeader, status code should be applied until
// responseProcessor is called.
- if unwanted, have := 204, res.Code; unwanted == have {
+ if unwanted, have := 201, res.Code; unwanted == have {
t.Errorf("unexpected status code %d", have)
}
@@ -225,7 +225,7 @@ func TestReadFrom(t *testing.T) {
t.Errorf("unexpected error: %v", err)
}
- if want, have := 204, res.Code; want != have {
+ if want, have := 201, res.Code; want != have {
t.Errorf("unexpected status code, want %d, have %d", want, have)
}
}
diff --git a/http/middleware.go b/http/middleware.go
index 86e398df0..5ec73ce91 100644
--- a/http/middleware.go
+++ b/http/middleware.go
@@ -55,8 +55,10 @@ func processRequest(tx types.Transaction, req *http.Request) (*types.Interruptio
// Transfer-Encoding header is removed by go/http
// We manually add it to make rules relying on it work (E.g. CRS rule 920171)
- if req.TransferEncoding != nil {
- tx.AddRequestHeader("Transfer-Encoding", req.TransferEncoding[0])
+ // All values must be added to allow the WAF to detect HTTP request smuggling
+ // attempts (e.g. TE.TE attacks).
+ for _, te := range req.TransferEncoding {
+ tx.AddRequestHeader("Transfer-Encoding", te)
}
in = tx.ProcessRequestHeaders()
@@ -92,6 +94,9 @@ func processRequest(tx types.Transaction, req *http.Request) (*types.Interruptio
}
}
+ // ProcessRequestBody evaluates Phase 2 rules. For streaming body processors
+ // (e.g., NDJSON, JSON-Seq), rules are evaluated per record. For standard
+ // body processors, rules are evaluated once after the full body is processed.
return tx.ProcessRequestBody()
}
diff --git a/http/middleware_test.go b/http/middleware_test.go
index a1c958fdb..9d9450042 100644
--- a/http/middleware_test.go
+++ b/http/middleware_test.go
@@ -108,6 +108,32 @@ SecRule &REQUEST_HEADERS:Transfer-Encoding "!@eq 0" "id:1,phase:1,deny"
}
}
+func TestProcessRequestMultipleTransferEncodings(t *testing.T) {
+ // Multiple Transfer-Encoding values are a classic HTTP request smuggling vector (TE.TE attacks).
+ // All values should be forwarded to the WAF.
+ waf, _ := coraza.NewWAF(coraza.NewWAFConfig().
+ WithDirectives(`
+SecRule REQUEST_HEADERS:Transfer-Encoding "@contains identity" "id:1,phase:1,deny"
+`))
+ tx := waf.NewTransaction()
+
+ req, _ := http.NewRequest("GET", "https://www.coraza.io/test", nil)
+ req.TransferEncoding = []string{"chunked", "identity"}
+
+ it, err := processRequest(tx, req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if it == nil {
+ t.Fatal("Expected interruption: second Transfer-Encoding value should be processed")
+ } else if it.RuleID != 1 {
+ t.Fatalf("Expected rule 1 to be triggered, got rule %d", it.RuleID)
+ }
+ if err := tx.Close(); err != nil {
+ t.Fatal(err)
+ }
+}
+
func createMultipartRequest(t *testing.T) *http.Request {
t.Helper()
diff --git a/internal/actions/actions.go b/internal/actions/actions.go
index ed8889fb9..df1b33cac 100644
--- a/internal/actions/actions.go
+++ b/internal/actions/actions.go
@@ -1,6 +1,67 @@
// Copyright 2022 Juan Pablo Tosso and the OWASP Coraza contributors
// SPDX-License-Identifier: Apache-2.0
+// Package actions implements SecLang rule actions for processing and control flow.
+//
+// # Overview
+//
+// Actions define how the system handles HTTP requests when rule conditions match.
+// Actions are defined as part of a SecRule or as parameters for SecAction or SecDefaultAction.
+// A rule can have no or several actions which need to be separated by a comma.
+//
+// # Action Categories
+//
+// Actions are categorized into five types:
+//
+// 1. Disruptive Actions
+//
+// Trigger Coraza operations such as blocking or allowing transactions.
+// Only one disruptive action per rule applies; if multiple are specified,
+// the last one takes precedence. Disruptive actions will NOT be executed
+// if SecRuleEngine is set to DetectionOnly.
+//
+// Examples: deny, drop, redirect, allow, block, pass
+//
+// 2. Non-disruptive Actions
+//
+// Perform operations without affecting rule flow, such as variable modifications,
+// logging, or setting metadata. These actions execute regardless of SecRuleEngine mode.
+//
+// Examples: log, nolog, setvar, msg, logdata, severity, tag
+//
+// 3. Flow Actions
+//
+// Control rule processing and execution flow. These actions determine which rules
+// are evaluated and in what order.
+//
+// Examples: chain, skip, skipAfter
+//
+// 4. Meta-data Actions
+//
+// Provide information about rules, such as identification, versioning, and classification.
+// These actions do not affect transaction processing.
+//
+// Examples: id, rev, msg, tag, severity, maturity, ver
+//
+// 5. Data Actions
+//
+// Containers that hold data for use by other actions, such as status codes
+// for blocking responses.
+//
+// Examples: status (used with deny/redirect)
+//
+// # Usage
+//
+// Actions are specified in SecRule directives as comma-separated values:
+//
+// SecRule ARGS "@rx attack" "id:100,deny,log,msg:'Attack detected'"
+//
+// # Important Notes
+//
+// When using the allow action for allowlisting, it's recommended to add
+// ctl:ruleEngine=On to ensure the rule executes even in DetectionOnly mode.
+//
+// For the complete list of available actions, see: https://coraza.io/docs/seclang/actions/
package actions
import (
diff --git a/internal/actions/allow.go b/internal/actions/allow.go
index 139c2c827..1d3b7713b 100644
--- a/internal/actions/allow.go
+++ b/internal/actions/allow.go
@@ -56,7 +56,7 @@ func (a *allowFn) Init(_ plugintypes.RuleMetadata, data string) error {
func (a *allowFn) Evaluate(r plugintypes.RuleMetadata, txS plugintypes.TransactionState) {
tx := txS.(*corazawaf.Transaction)
- tx.AllowType = a.allow
+ tx.Allow(a.allow)
}
func (a *allowFn) Type() plugintypes.ActionType {
diff --git a/internal/actions/capture.go b/internal/actions/capture.go
index bab58bd00..5a846aa7e 100644
--- a/internal/actions/capture.go
+++ b/internal/actions/capture.go
@@ -21,8 +21,10 @@ import (
//
// Example:
// ```
-// SecRule REQUEST_BODY "^username=(\w{25,})" phase:2,capture,t:none,chain,id:105
-// SecRule TX:1 "(?:(?:a(dmin|nonymous)))"
+//
+// SecRule REQUEST_BODY "^username=(\w{25,})" "phase:2,capture,t:none,chain,id:105"
+// SecRule TX:1 "(?:(?:a(dmin|nonymous)))"
+//
// ```
type captureFn struct{}
diff --git a/internal/actions/ctl.go b/internal/actions/ctl.go
index b8a04e008..a091d7e6a 100644
--- a/internal/actions/ctl.go
+++ b/internal/actions/ctl.go
@@ -6,6 +6,7 @@ package actions
import (
"errors"
"fmt"
+ "regexp"
"strconv"
"strings"
@@ -73,7 +74,13 @@ const (
//
// Here are some notes about the options:
//
-// 1. Option `ruleRemoveTargetById`, `ruleRemoveTargetByMsg`, and `ruleRemoveTargetByTag`, users don't need to use the char ! before the target list.
+// 1. Option `ruleRemoveTargetById`, `ruleRemoveTargetByMsg`, and `ruleRemoveTargetByTag` accept a collection key in two forms:
+// - **Exact string**: `ARGS:user` — removes only the variable whose name is exactly `user`.
+// - **Regular expression** (delimited by `/`): `ARGS:/^json\.\d+\.field$/` — removes all variables whose
+// names match the pattern. The closing `/` must not be preceded by an odd number of backslashes
+// (e.g. `/foo\/` is treated as the literal string `/foo\/`, not a regex). An empty pattern (`//`) is rejected.
+// Pattern matching is always case-insensitive because variable names are lowercased before comparison.
+// Users do not need to use the `!` character before the target list.
//
// 2. Option `ruleRemoveById` is triggered at run time and should be specified before the rule in which it is disabling.
//
@@ -99,17 +106,30 @@ const (
// SecRule REQUEST_URI "@beginsWith /index.php" "phase:1,t:none,pass,\
// nolog,ctl:ruleRemoveTargetById=981260;ARGS:user"
//
+// # white-list all JSON array fields matching a pattern for rule #932125 when the REQUEST_URI begins with /api/jobs
+//
+// SecRule REQUEST_URI "@beginsWith /api/jobs" "phase:1,t:none,pass,\
+// nolog,ctl:ruleRemoveTargetById=932125;ARGS:/^json\.\d+\.jobdescription$/"
+//
// ```
type ctlFn struct {
action ctlFunctionType
value string
collection variables.RuleVariable
colKey string
+ colKeyRx *regexp.Regexp
}
-func (a *ctlFn) Init(_ plugintypes.RuleMetadata, data string) error {
+func (a *ctlFn) Init(m plugintypes.RuleMetadata, data string) error {
+ // Type-assert RuleMetadata to *corazawaf.Rule to access the rule's memoizer.
+ // When the assertion fails (e.g., in tests using a stub RuleMetadata), the
+ // memoizer remains nil and regex compilation proceeds without caching.
+ var memoizer plugintypes.Memoizer
+ if r, ok := m.(*corazawaf.Rule); ok {
+ memoizer = r.Memoizer()
+ }
var err error
- a.action, a.value, a.collection, a.colKey, err = parseCtl(data)
+ a.action, a.value, a.collection, a.colKey, a.colKeyRx, err = parseCtl(data, memoizer)
return err
}
@@ -130,7 +150,7 @@ func (a *ctlFn) Evaluate(_ plugintypes.RuleMetadata, txS plugintypes.Transaction
tx := txS.(*corazawaf.Transaction)
switch a.action {
case ctlRuleRemoveTargetByID:
- ran, err := rangeToInts(tx.WAF.Rules.GetRules(), a.value)
+ start, end, err := parseIDOrRange(a.value)
if err != nil {
tx.DebugLogger().Error().
Str("ctl", "RuleRemoveTargetByID").
@@ -138,21 +158,23 @@ func (a *ctlFn) Evaluate(_ plugintypes.RuleMetadata, txS plugintypes.Transaction
Msg("Invalid range")
return
}
- for _, id := range ran {
- tx.RemoveRuleTargetByID(id, a.collection, a.colKey)
+ for _, r := range tx.WAF.Rules.GetRules() {
+ if r.ID_ >= start && r.ID_ <= end {
+ tx.RemoveRuleTargetByID(r.ID_, a.collection, a.colKey, a.colKeyRx)
+ }
}
case ctlRuleRemoveTargetByTag:
rules := tx.WAF.Rules.GetRules()
for _, r := range rules {
if utils.InSlice(a.value, r.Tags_) {
- tx.RemoveRuleTargetByID(r.ID(), a.collection, a.colKey)
+ tx.RemoveRuleTargetByID(r.ID(), a.collection, a.colKey, a.colKeyRx)
}
}
case ctlRuleRemoveTargetByMsg:
rules := tx.WAF.Rules.GetRules()
for _, r := range rules {
if r.Msg != nil && r.Msg.String() == a.value {
- tx.RemoveRuleTargetByID(r.ID(), a.collection, a.colKey)
+ tx.RemoveRuleTargetByID(r.ID(), a.collection, a.colKey, a.colKeyRx)
}
}
case ctlAuditEngine:
@@ -260,7 +282,7 @@ func (a *ctlFn) Evaluate(_ plugintypes.RuleMetadata, txS plugintypes.Transaction
tx.RemoveRuleByID(id)
} else {
- ran, err := rangeToInts(tx.WAF.Rules.GetRules(), a.value)
+ start, end, err := parseRange(a.value)
if err != nil {
tx.DebugLogger().Error().
Str("ctl", "RuleRemoveByID").
@@ -268,9 +290,7 @@ func (a *ctlFn) Evaluate(_ plugintypes.RuleMetadata, txS plugintypes.Transaction
Msg("Invalid range")
return
}
- for _, id := range ran {
- tx.RemoveRuleByID(id)
- }
+ tx.RemoveRuleByIDRange(start, end)
}
case ctlRuleRemoveByMsg:
rules := tx.WAF.Rules.GetRules()
@@ -375,18 +395,40 @@ func (a *ctlFn) Type() plugintypes.ActionType {
return plugintypes.ActionTypeNondisruptive
}
-func parseCtl(data string) (ctlFunctionType, string, variables.RuleVariable, string, error) {
+func parseCtl(data string, memoizer plugintypes.Memoizer) (ctlFunctionType, string, variables.RuleVariable, string, *regexp.Regexp, error) {
action, ctlVal, ok := strings.Cut(data, "=")
if !ok {
- return ctlUnknown, "", 0, "", errors.New("invalid syntax")
+ return ctlUnknown, "", 0, "", nil, errors.New("invalid syntax")
}
value, col, ok := strings.Cut(ctlVal, ";")
var colkey, colname string
if ok {
colname, colkey, _ = strings.Cut(col, ":")
+ colkey = strings.TrimSpace(colkey)
}
collection, _ := variables.Parse(strings.TrimSpace(colname))
- colkey = strings.ToLower(colkey)
+ var keyRx *regexp.Regexp
+ if isRegex, rxPattern := utils.HasRegex(colkey); isRegex {
+ if len(rxPattern) == 0 {
+ return ctlUnknown, "", 0, "", nil, errors.New("empty regex pattern in ctl collection key")
+ }
+ var err error
+ if memoizer != nil {
+ re, compileErr := memoizer.Do(rxPattern, func() (any, error) { return regexp.Compile(rxPattern) })
+ if compileErr != nil {
+ return ctlUnknown, "", 0, "", nil, fmt.Errorf("invalid regex in ctl collection key: %w", compileErr)
+ }
+ keyRx = re.(*regexp.Regexp)
+ } else {
+ keyRx, err = regexp.Compile(rxPattern)
+ if err != nil {
+ return ctlUnknown, "", 0, "", nil, fmt.Errorf("invalid regex in ctl collection key: %w", err)
+ }
+ }
+ colkey = ""
+ } else {
+ colkey = strings.ToLower(colkey)
+ }
var act ctlFunctionType
switch action {
case "auditEngine":
@@ -430,49 +472,46 @@ func parseCtl(data string) (ctlFunctionType, string, variables.RuleVariable, str
case "debugLogLevel":
act = ctlDebugLogLevel
default:
- return ctlUnknown, "", 0x00, "", fmt.Errorf("unknown ctl action %q", action)
+ return ctlUnknown, "", 0x00, "", nil, fmt.Errorf("unknown ctl action %q", action)
}
- return act, value, collection, strings.TrimSpace(colkey), nil
+ return act, value, collection, colkey, keyRx, nil
}
-func rangeToInts(rules []corazawaf.Rule, input string) ([]int, error) {
- if len(input) == 0 {
- return nil, errors.New("empty input")
+// parseRange parses a range string of the form "start-end" and returns the start and end
+// values as integers. It returns an error if the input is not a valid range.
+func parseRange(input string) (start, end int, err error) {
+ in0, in1, ok := strings.Cut(input, "-")
+ if !ok {
+ return 0, 0, errors.New("no range separator found")
}
-
- var (
- ids []int
- start, end int
- err error
- )
-
- if in0, in1, ok := strings.Cut(input, "-"); ok {
- start, err = strconv.Atoi(in0)
- if err != nil {
- return nil, err
- }
- end, err = strconv.Atoi(in1)
- if err != nil {
- return nil, err
- }
-
- if start > end {
- return nil, errors.New("invalid range, start > end")
- }
- } else {
- id, err := strconv.Atoi(input)
- if err != nil {
- return nil, err
- }
- start, end = id, id
+ start, err = strconv.Atoi(in0)
+ if err != nil {
+ return 0, 0, err
+ }
+ end, err = strconv.Atoi(in1)
+ if err != nil {
+ return 0, 0, err
}
+ if start > end {
+ return 0, 0, errors.New("invalid range, start > end")
+ }
+ return start, end, nil
+}
- for _, r := range rules {
- if r.ID_ >= start && r.ID_ <= end {
- ids = append(ids, r.ID_)
- }
+// parseIDOrRange parses either a single integer ID or a range string of the form "start-end".
+// For a single ID, start and end are equal.
+func parseIDOrRange(input string) (start, end int, err error) {
+ if len(input) == 0 {
+ return 0, 0, errors.New("empty input")
+ }
+ if _, _, ok := strings.Cut(input, "-"); ok {
+ return parseRange(input)
+ }
+ id, err := strconv.Atoi(input)
+ if err != nil {
+ return 0, 0, err
}
- return ids, nil
+ return id, id, nil
}
func ctl() plugintypes.Action {
diff --git a/internal/actions/ctl_test.go b/internal/actions/ctl_test.go
index 252d3879b..38d27be1a 100644
--- a/internal/actions/ctl_test.go
+++ b/internal/actions/ctl_test.go
@@ -9,8 +9,8 @@ import (
"testing"
"github.com/corazawaf/coraza/v3/debuglog"
- "github.com/corazawaf/coraza/v3/internal/corazarules"
"github.com/corazawaf/coraza/v3/internal/corazawaf"
+ "github.com/corazawaf/coraza/v3/internal/memoize"
"github.com/corazawaf/coraza/v3/types"
"github.com/corazawaf/coraza/v3/types/variables"
)
@@ -24,6 +24,24 @@ func TestCtl(t *testing.T) {
"ruleRemoveTargetById": {
input: "ruleRemoveTargetById=123",
},
+ "ruleRemoveTargetById range": {
+ // Rule 1 is in WAF; range 1-5 should match it without error
+ input: "ruleRemoveTargetById=1-5;ARGS:test",
+ checkTX: func(t *testing.T, tx *corazawaf.Transaction, logEntry string) {
+ if wantToNotContain := "Invalid range"; strings.Contains(logEntry, wantToNotContain) {
+ t.Errorf("unexpected error in log: %q", logEntry)
+ }
+ },
+ },
+ "ruleRemoveTargetById regex key": {
+ // Rule 1 is in WAF; the regex /^test.*/ should remove matching ARGS targets
+ input: "ruleRemoveTargetById=1;ARGS:/^test.*/",
+ checkTX: func(t *testing.T, tx *corazawaf.Transaction, logEntry string) {
+ if strings.Contains(logEntry, "Invalid") || strings.Contains(logEntry, "invalid") {
+ t.Errorf("unexpected error in log: %q", logEntry)
+ }
+ },
+ },
"ruleRemoveTargetByTag": {
input: "ruleRemoveTargetByTag=tag1",
},
@@ -49,7 +67,7 @@ func TestCtl(t *testing.T) {
"auditLogParts": {
input: "auditLogParts=ABZ",
checkTX: func(t *testing.T, tx *corazawaf.Transaction, logEntry string) {
- if want, have := types.AuditLogPartRequestHeaders, tx.AuditLogParts[0]; want != have {
+ if want, have := types.AuditLogPartRequestHeaders, tx.AuditLogParts[1]; want != have {
t.Errorf("Failed to set audit log parts, want %s, have %s", string(want), string(have))
}
},
@@ -173,6 +191,16 @@ func TestCtl(t *testing.T) {
},
"ruleRemoveById range": {
input: "ruleRemoveById=1-3",
+ checkTX: func(t *testing.T, tx *corazawaf.Transaction, logEntry string) {
+ if len(tx.GetRuleRemoveByIDRanges()) != 1 {
+ t.Errorf("expected 1 range entry, got %d", len(tx.GetRuleRemoveByIDRanges()))
+ return
+ }
+ rng := tx.GetRuleRemoveByIDRanges()[0]
+ if rng[0] != 1 || rng[1] != 3 {
+ t.Errorf("unexpected range [%d, %d], want [1, 3]", rng[0], rng[1])
+ }
+ },
},
"ruleRemoveById incorrect": {
input: "ruleRemoveById=W",
@@ -368,7 +396,7 @@ func TestCtl(t *testing.T) {
func TestParseCtl(t *testing.T) {
t.Run("invalid ctl", func(t *testing.T) {
- ctl, _, _, _, err := parseCtl("invalid")
+ ctl, _, _, _, _, err := parseCtl("invalid", nil)
if err == nil {
t.Errorf("expected error, got nil")
}
@@ -379,7 +407,7 @@ func TestParseCtl(t *testing.T) {
})
t.Run("malformed ctl", func(t *testing.T) {
- ctl, _, _, _, err := parseCtl("unknown=")
+ ctl, _, _, _, _, err := parseCtl("unknown=", nil)
if err == nil {
t.Errorf("expected error, got nil")
}
@@ -389,35 +417,83 @@ func TestParseCtl(t *testing.T) {
}
})
+ t.Run("invalid regex in colKey", func(t *testing.T) {
+ _, _, _, _, _, err := parseCtl("ruleRemoveTargetById=1;ARGS:/[invalid/", nil)
+ if err == nil {
+ t.Errorf("expected error for invalid regex, got nil")
+ }
+ })
+
+ t.Run("empty regex pattern in colKey", func(t *testing.T) {
+ _, _, _, _, _, err := parseCtl("ruleRemoveTargetById=1;ARGS://", nil)
+ if err == nil {
+ t.Errorf("expected error for empty regex pattern, got nil")
+ }
+ })
+
+ t.Run("escaped slash not treated as regex", func(t *testing.T) {
+ _, _, _, key, rx, err := parseCtl(`ruleRemoveTargetById=1;ARGS:/user\/`, nil)
+ if err != nil {
+ t.Fatalf("unexpected error: %s", err.Error())
+ }
+ if rx != nil {
+ t.Errorf("expected nil regex for escaped-slash key, got: %s", rx.String())
+ }
+ if key != `/user\/` {
+ t.Errorf("unexpected key, want %q, have %q", `/user\/`, key)
+ }
+ })
+
+ t.Run("memoizer with valid regex", func(t *testing.T) {
+ m := memoize.NewMemoizer(99)
+ _, _, _, _, keyRx, err := parseCtl("ruleRemoveTargetById=1;ARGS:/^test.*/", m)
+ if err != nil {
+ t.Fatalf("unexpected error: %s", err.Error())
+ }
+ if keyRx == nil {
+ t.Error("expected non-nil compiled regex, got nil")
+ }
+ })
+
+ t.Run("memoizer with invalid regex", func(t *testing.T) {
+ m := memoize.NewMemoizer(100)
+ _, _, _, _, _, err := parseCtl("ruleRemoveTargetById=1;ARGS:/[invalid/", m)
+ if err == nil {
+ t.Error("expected error for invalid regex with memoizer, got nil")
+ }
+ })
+
tCases := []struct {
input string
expectAction ctlFunctionType
expectValue string
expectCollection variables.RuleVariable
expectKey string
+ expectKeyRx string
}{
- {"auditEngine=On", ctlAuditEngine, "On", variables.Unknown, ""},
- {"auditLogParts=A", ctlAuditLogParts, "A", variables.Unknown, ""},
- {"requestBodyAccess=On", ctlRequestBodyAccess, "On", variables.Unknown, ""},
- {"requestBodyLimit=100", ctlRequestBodyLimit, "100", variables.Unknown, ""},
- {"requestBodyProcessor=JSON", ctlRequestBodyProcessor, "JSON", variables.Unknown, ""},
- {"forceRequestBodyVariable=On", ctlForceRequestBodyVariable, "On", variables.Unknown, ""},
- {"responseBodyAccess=On", ctlResponseBodyAccess, "On", variables.Unknown, ""},
- {"responseBodyLimit=100", ctlResponseBodyLimit, "100", variables.Unknown, ""},
- {"responseBodyProcessor=JSON", ctlResponseBodyProcessor, "JSON", variables.Unknown, ""},
- {"forceResponseBodyVariable=On", ctlForceResponseBodyVariable, "On", variables.Unknown, ""},
- {"ruleEngine=On", ctlRuleEngine, "On", variables.Unknown, ""},
- {"ruleRemoveById=1", ctlRuleRemoveByID, "1", variables.Unknown, ""},
- {"ruleRemoveById=1-9", ctlRuleRemoveByID, "1-9", variables.Unknown, ""},
- {"ruleRemoveByMsg=MY_MSG", ctlRuleRemoveByMsg, "MY_MSG", variables.Unknown, ""},
- {"ruleRemoveByTag=MY_TAG", ctlRuleRemoveByTag, "MY_TAG", variables.Unknown, ""},
- {"ruleRemoveTargetByMsg=MY_MSG;ARGS:user", ctlRuleRemoveTargetByMsg, "MY_MSG", variables.Args, "user"},
- {"ruleRemoveTargetById=2;REQUEST_FILENAME:", ctlRuleRemoveTargetByID, "2", variables.RequestFilename, ""},
+ {"auditEngine=On", ctlAuditEngine, "On", variables.Unknown, "", ""},
+ {"auditLogParts=A", ctlAuditLogParts, "A", variables.Unknown, "", ""},
+ {"requestBodyAccess=On", ctlRequestBodyAccess, "On", variables.Unknown, "", ""},
+ {"requestBodyLimit=100", ctlRequestBodyLimit, "100", variables.Unknown, "", ""},
+ {"requestBodyProcessor=JSON", ctlRequestBodyProcessor, "JSON", variables.Unknown, "", ""},
+ {"forceRequestBodyVariable=On", ctlForceRequestBodyVariable, "On", variables.Unknown, "", ""},
+ {"responseBodyAccess=On", ctlResponseBodyAccess, "On", variables.Unknown, "", ""},
+ {"responseBodyLimit=100", ctlResponseBodyLimit, "100", variables.Unknown, "", ""},
+ {"responseBodyProcessor=JSON", ctlResponseBodyProcessor, "JSON", variables.Unknown, "", ""},
+ {"forceResponseBodyVariable=On", ctlForceResponseBodyVariable, "On", variables.Unknown, "", ""},
+ {"ruleEngine=On", ctlRuleEngine, "On", variables.Unknown, "", ""},
+ {"ruleRemoveById=1", ctlRuleRemoveByID, "1", variables.Unknown, "", ""},
+ {"ruleRemoveById=1-9", ctlRuleRemoveByID, "1-9", variables.Unknown, "", ""},
+ {"ruleRemoveByMsg=MY_MSG", ctlRuleRemoveByMsg, "MY_MSG", variables.Unknown, "", ""},
+ {"ruleRemoveByTag=MY_TAG", ctlRuleRemoveByTag, "MY_TAG", variables.Unknown, "", ""},
+ {"ruleRemoveTargetByMsg=MY_MSG;ARGS:user", ctlRuleRemoveTargetByMsg, "MY_MSG", variables.Args, "user", ""},
+ {"ruleRemoveTargetById=2;REQUEST_FILENAME:", ctlRuleRemoveTargetByID, "2", variables.RequestFilename, "", ""},
+ {"ruleRemoveTargetById=2;ARGS:/^json\\.\\d+\\.description$/", ctlRuleRemoveTargetByID, "2", variables.Args, "", `^json\.\d+\.description$`},
}
for _, tCase := range tCases {
testName, _, _ := strings.Cut(tCase.input, "=")
t.Run(testName, func(t *testing.T) {
- action, value, collection, colKey, err := parseCtl(tCase.input)
+ action, value, collection, colKey, colKeyRx, err := parseCtl(tCase.input, nil)
if err != nil {
t.Fatalf("unexpected error: %s", err.Error())
}
@@ -433,53 +509,96 @@ func TestParseCtl(t *testing.T) {
if colKey != tCase.expectKey {
t.Errorf("unexpected key, want: %s, have: %s", tCase.expectKey, colKey)
}
+ if tCase.expectKeyRx == "" {
+ if colKeyRx != nil {
+ t.Errorf("unexpected non-nil regex, have: %s", colKeyRx.String())
+ }
+ } else {
+ if colKeyRx == nil {
+ t.Errorf("expected non-nil regex matching %q, got nil", tCase.expectKeyRx)
+ } else if colKeyRx.String() != tCase.expectKeyRx {
+ t.Errorf("unexpected regex, want: %s, have: %s", tCase.expectKeyRx, colKeyRx.String())
+ }
+ }
})
}
}
-func TestCtlParseRange(t *testing.T) {
- rules := []corazawaf.Rule{
- {
- RuleMetadata: corazarules.RuleMetadata{
- ID_: 5,
- },
- },
- {
- RuleMetadata: corazarules.RuleMetadata{
- ID_: 15,
- },
- },
+func TestCtlParseIDOrRange(t *testing.T) {
+ tCases := []struct {
+ input string
+ expectStart int
+ expectEnd int
+ expectErr bool
+ }{
+ {"1-2", 1, 2, false},
+ {"4-5", 4, 5, false},
+ {"4-15", 4, 15, false},
+ {"5", 5, 5, false},
+ {"", 0, 0, true},
+ {"test", 0, 0, true},
+ {"test-2", 0, 0, true},
+ {"2-test", 0, 0, true},
+ {"-", 0, 0, true},
+ {"4-5-15", 0, 0, true},
+ }
+ for _, tCase := range tCases {
+ t.Run(tCase.input, func(t *testing.T) {
+ start, end, err := parseIDOrRange(tCase.input)
+ if tCase.expectErr && err == nil {
+ t.Error("expected error for input")
+ }
+
+ if !tCase.expectErr && err != nil {
+ t.Errorf("unexpected error for input: %s", err.Error())
+ }
+
+ if !tCase.expectErr {
+ if start != tCase.expectStart {
+ t.Errorf("unexpected start, want %d, have %d", tCase.expectStart, start)
+ }
+ if end != tCase.expectEnd {
+ t.Errorf("unexpected end, want %d, have %d", tCase.expectEnd, end)
+ }
+ }
+ })
}
+}
+func TestCtlParseRange(t *testing.T) {
tCases := []struct {
- _range string
- expectedNumberOfIds int
- expectErr bool
+ input string
+ expectStart int
+ expectEnd int
+ expectErr bool
}{
- {"1-2", 0, false},
- {"4-5", 1, false},
- {"4-15", 2, false},
- {"5", 1, false},
- {"", 0, true},
- {"test", 0, true},
- {"test-2", 0, true},
- {"2-test", 0, true},
- {"-", 0, true},
- {"4-5-15", 0, true},
+ {"1-2", 1, 2, false},
+ {"4-15", 4, 15, false},
+ {"5-5", 5, 5, false},
+ {"test-2", 0, 0, true},
+ {"2-test", 0, 0, true},
+ {"5-4", 0, 0, true}, // start > end
+ {"-", 0, 0, true},
+ {"nodash", 0, 0, true}, // no range separator
}
for _, tCase := range tCases {
- t.Run(tCase._range, func(t *testing.T) {
- ints, err := rangeToInts(rules, tCase._range)
+ t.Run(tCase.input, func(t *testing.T) {
+ start, end, err := parseRange(tCase.input)
if tCase.expectErr && err == nil {
- t.Error("expected error for range")
+ t.Error("expected error for input")
}
if !tCase.expectErr && err != nil {
- t.Errorf("unexpected error for range: %s", err.Error())
+ t.Errorf("unexpected error for input: %s", err.Error())
}
- if !tCase.expectErr && len(ints) != tCase.expectedNumberOfIds {
- t.Error("unexpected number of ids")
+ if !tCase.expectErr {
+ if start != tCase.expectStart {
+ t.Errorf("unexpected start, want %d, have %d", tCase.expectStart, start)
+ }
+ if end != tCase.expectEnd {
+ t.Errorf("unexpected end, want %d, have %d", tCase.expectEnd, end)
+ }
}
})
}
diff --git a/internal/auditlog/formats.go b/internal/auditlog/formats.go
index 316566927..dcaf142cc 100644
--- a/internal/auditlog/formats.go
+++ b/internal/auditlog/formats.go
@@ -21,6 +21,7 @@ package auditlog
import (
"fmt"
+ "net/http"
"strings"
"github.com/corazawaf/coraza/v3/experimental/plugins/plugintypes"
@@ -45,81 +46,102 @@ func (nativeFormatter) Format(al plugintypes.AuditLog) ([]byte, error) {
res.WriteString(boundaryPrefix)
res.WriteByte(byte(part))
res.WriteString("--\n")
- // [27/Jul/2016:05:46:16 +0200] V5guiH8AAQEAADTeJ2wAAAAK 192.168.3.1 50084 192.168.3.111 80
- _, _ = fmt.Fprintf(&res, "[%s] %s %s %d %s %d", al.Transaction().Timestamp(), al.Transaction().ID(),
- al.Transaction().ClientIP(), al.Transaction().ClientPort(), al.Transaction().HostIP(), al.Transaction().HostPort())
+
+ addSeparator := true
+
switch part {
+ case types.AuditLogPartHeader:
+ // Part A: Audit log header containing only the timestamp and transaction info line
+ // Note: Part A does not have an empty line separator after it
+ _, _ = fmt.Fprintf(&res, "[%s] %s %s %d %s %d\n",
+ al.Transaction().Timestamp(), al.Transaction().ID(),
+ al.Transaction().ClientIP(), al.Transaction().ClientPort(),
+ al.Transaction().HostIP(), al.Transaction().HostPort())
+ addSeparator = false
case types.AuditLogPartRequestHeaders:
- // GET /url HTTP/1.1
- // Host: example.com
- // User-Agent: Mozilla/5.0
- // Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8
- // Accept-Language: en-US,en;q=0.5
- // Accept-Encoding: gzip, deflate
- // Referer: http://example.com/index.html
- // Connection: keep-alive
- // Content-Type: application/x-www-form-urlencoded
- // Content-Length: 6
- _, _ = fmt.Fprintf(
- &res,
- "\n%s %s %s",
- al.Transaction().Request().Method(),
- al.Transaction().Request().URI(),
- al.Transaction().Request().Protocol(),
- )
- for k, vv := range al.Transaction().Request().Headers() {
- for _, v := range vv {
- res.WriteByte('\n')
- res.WriteString(k)
- res.WriteString(": ")
- res.WriteString(v)
+ // Part B: Request headers
+ if al.Transaction().HasRequest() {
+ _, _ = fmt.Fprintf(
+ &res,
+ "%s %s %s",
+ al.Transaction().Request().Method(),
+ al.Transaction().Request().URI(),
+ al.Transaction().Request().Protocol(),
+ )
+ for k, vv := range al.Transaction().Request().Headers() {
+ for _, v := range vv {
+ res.WriteByte('\n')
+ res.WriteString(k)
+ res.WriteString(": ")
+ res.WriteString(v)
+ }
}
+ res.WriteByte('\n')
}
case types.AuditLogPartRequestBody:
- if body := al.Transaction().Request().Body(); body != "" {
- res.WriteByte('\n')
- res.WriteString(body)
+ // Part C: Request body
+ if al.Transaction().HasRequest() {
+ if body := al.Transaction().Request().Body(); body != "" {
+ res.WriteString(body)
+ res.WriteByte('\n')
+ }
}
case types.AuditLogPartIntermediaryResponseBody:
- if body := al.Transaction().Response().Body(); body != "" {
- res.WriteByte('\n')
- res.WriteString(al.Transaction().Response().Body())
+ // Part E: Intermediary response body
+ if al.Transaction().HasResponse() {
+ if body := al.Transaction().Response().Body(); body != "" {
+ res.WriteString(body)
+ res.WriteByte('\n')
+ }
}
case types.AuditLogPartResponseHeaders:
- for k, vv := range al.Transaction().Response().Headers() {
- for _, v := range vv {
- res.WriteByte('\n')
- res.WriteString(k)
- res.WriteString(": ")
- res.WriteString(v)
+ // Part F: Response headers
+ if al.Transaction().HasResponse() {
+ // Write status line: HTTP/1.1 200 OK
+ protocol := al.Transaction().Response().Protocol()
+ if protocol == "" {
+ protocol = "HTTP/1.1"
+ }
+ status := al.Transaction().Response().Status()
+ statusText := http.StatusText(status)
+ _, _ = fmt.Fprintf(&res, "%s %d %s\n", protocol, status, statusText)
+
+ // Write headers
+ for k, vv := range al.Transaction().Response().Headers() {
+ for _, v := range vv {
+ res.WriteString(k)
+ res.WriteString(": ")
+ res.WriteString(v)
+ res.WriteByte('\n')
+ }
}
}
case types.AuditLogPartAuditLogTrailer:
- // Stopwatch: 1470025005945403 1715 (- - -)
- // Stopwatch2: 1470025005945403 1715; combined=26, p1=0, p2=0, p3=0, p4=0, p5=26, ↩
- // sr=0, sw=0, l=0, gc=0
- // Response-Body-Transformed: Dechunked
- // Producer: ModSecurity for Apache/2.9.1 (http://www.modsecurity.org/).
- // Server: Apache
- // Engine-Mode: "ENABLED"
-
- // AuditLogTrailer is also expected to contain the error message generated by the rule, if any
+ // Part H: Audit log trailer
for _, alEntry := range al.Messages() {
alWithErrMsg, ok := alEntry.(auditLogWithErrMesg)
if ok && alWithErrMsg.ErrorMessage() != "" {
- res.WriteByte('\n')
res.WriteString(alWithErrMsg.ErrorMessage())
+ res.WriteByte('\n')
}
}
-
- _, _ = fmt.Fprintf(&res, "\nStopwatch: %s\nResponse-Body-Transformed: %s\nProducer: %s\nServer: %s", "", "", "", "")
case types.AuditLogPartRulesMatched:
+ // Part K: Matched rules
for _, alEntry := range al.Messages() {
- res.WriteByte('\n')
res.WriteString(alEntry.Data().Raw())
+ res.WriteByte('\n')
}
+ case types.AuditLogPartEndMarker:
+ // Part Z: Final boundary marker with no content
+ default:
+ // For any other parts (D, G, I, J) that aren't explicitly handled,
+ // they remain empty
+ }
+
+ // Add separator newline for all parts except A
+ if addSeparator {
+ res.WriteByte('\n')
}
- res.WriteByte('\n')
}
return []byte(res.String()), nil
diff --git a/internal/auditlog/formats_test.go b/internal/auditlog/formats_test.go
index 79edba98b..a64b7d728 100644
--- a/internal/auditlog/formats_test.go
+++ b/internal/auditlog/formats_test.go
@@ -56,9 +56,9 @@ func TestNativeFormatter(t *testing.T) {
if !strings.Contains(f.MIME(), "x-coraza-auditlog-native") {
t.Errorf("failed to match MIME, expected json and got %s", f.MIME())
}
- // Log contains random strings, do a simple sanity check
- if !bytes.Contains(data, []byte("[02/Jan/2006:15:04:20 -0700] 123 0 0")) {
- t.Errorf("failed to match log, \ngot: %s\n", string(data))
+ // Log contains random boundary strings
+ if len(data) == 0 {
+ t.Errorf("expected non-empty log output")
}
scanner := bufio.NewScanner(bytes.NewReader(data))
@@ -69,22 +69,387 @@ func TestNativeFormatter(t *testing.T) {
}
separator := lines[0]
- checkLine(t, lines, 2, "GET /test.php HTTP/1.1")
- checkLine(t, lines, 3, "some: request header")
+ // Part B
+ checkLine(t, lines, 0, mutateSeparator(separator, 'B'))
+ checkLine(t, lines, 1, "GET /test.php HTTP/1.1")
+ checkLine(t, lines, 2, "some: request header")
+ checkLine(t, lines, 3, "")
+ // Part C
checkLine(t, lines, 4, mutateSeparator(separator, 'C'))
- checkLine(t, lines, 6, "some request body")
+ checkLine(t, lines, 5, "some request body")
+ checkLine(t, lines, 6, "")
+ // Part E
checkLine(t, lines, 7, mutateSeparator(separator, 'E'))
- checkLine(t, lines, 9, "some response body")
+ checkLine(t, lines, 8, "some response body")
+ checkLine(t, lines, 9, "")
+ // Part F
checkLine(t, lines, 10, mutateSeparator(separator, 'F'))
+ checkLine(t, lines, 11, "HTTP/1.1 200 OK")
checkLine(t, lines, 12, "some: response header")
- checkLine(t, lines, 13, mutateSeparator(separator, 'H'))
+ checkLine(t, lines, 13, "")
+ // Part H
+ checkLine(t, lines, 14, mutateSeparator(separator, 'H'))
checkLine(t, lines, 15, "error message")
- checkLine(t, lines, 16, "Stopwatch: ")
- checkLine(t, lines, 17, "Response-Body-Transformed: ")
- checkLine(t, lines, 18, "Producer: ")
- checkLine(t, lines, 19, "Server: ")
- checkLine(t, lines, 20, mutateSeparator(separator, 'K'))
- checkLine(t, lines, 22, `SecAction "id:100"`)
+ checkLine(t, lines, 16, "")
+ // Part K
+ checkLine(t, lines, 17, mutateSeparator(separator, 'K'))
+ checkLine(t, lines, 18, `SecAction "id:100"`)
+ checkLine(t, lines, 19, "")
+ })
+
+ t.Run("with parts A and Z", func(t *testing.T) {
+ al := &Log{
+ Parts_: []types.AuditLogPart{
+ types.AuditLogPartHeader,
+ types.AuditLogPartRequestHeaders,
+ types.AuditLogPartEndMarker,
+ },
+ Transaction_: Transaction{
+ Timestamp_: "02/Jan/2006:15:04:20 -0700",
+ UnixTimestamp_: 0,
+ ID_: "123",
+ Request_: &TransactionRequest{
+ URI_: "/test.php",
+ Method_: "GET",
+ Headers_: map[string][]string{
+ "some": {
+ "request header",
+ },
+ },
+ Protocol_: "HTTP/1.1",
+ },
+ },
+ }
+ data, err := f.Format(al)
+ if err != nil {
+ t.Error(err)
+ }
+
+ scanner := bufio.NewScanner(bytes.NewReader(data))
+ var lines []string
+ for scanner.Scan() {
+ lines = append(lines, scanner.Text())
+ }
+ separator := lines[0]
+
+ // Part A (header) - no empty line after it
+ checkLine(t, lines, 0, mutateSeparator(separator, 'A'))
+ checkLine(t, lines, 1, "[02/Jan/2006:15:04:20 -0700] 123 0 0")
+
+ // Part B (request headers) - has empty line after it
+ checkLine(t, lines, 2, mutateSeparator(separator, 'B'))
+ checkLine(t, lines, 3, "GET /test.php HTTP/1.1")
+ checkLine(t, lines, 4, "some: request header")
+ checkLine(t, lines, 5, "")
+
+ // Part Z (end marker) - has empty line after it
+ checkLine(t, lines, 6, mutateSeparator(separator, 'Z'))
+ checkLine(t, lines, 7, "")
+ })
+
+ t.Run("complete example matching ModSecurity format", func(t *testing.T) {
+ al := &Log{
+ Parts_: []types.AuditLogPart{
+ types.AuditLogPartHeader,
+ types.AuditLogPartRequestHeaders,
+ types.AuditLogPartIntermediaryResponseHeaders, // Part D
+ types.AuditLogPartResponseHeaders, // Part F
+ types.AuditLogPartAuditLogTrailer, // Part H
+ types.AuditLogPartEndMarker,
+ },
+ Transaction_: Transaction{
+ Timestamp_: "20/Feb/2025:13:20:33 +0000",
+ UnixTimestamp_: 1740576033,
+ ID_: "174005763366.604533",
+ ClientIP_: "192.168.65.1",
+ ClientPort_: 38532,
+ HostIP_: "172.21.0.3",
+ HostPort_: 8080,
+ Request_: &TransactionRequest{
+ URI_: "/status/200",
+ Method_: "GET",
+ Protocol_: "HTTP/1.1",
+ HTTPVersion_: "1.1",
+ Headers_: map[string][]string{
+ "Accept": {"*/*"},
+ "Connection": {"close"},
+ "Host": {"localhost"},
+ },
+ },
+ Response_: &TransactionResponse{
+ Status_: 200,
+ Protocol_: "HTTP/1.1",
+ Headers_: map[string][]string{
+ "Server": {"nginx"},
+ "Connection": {"close"},
+ },
+ },
+ },
+ Messages_: []plugintypes.AuditLogMessage{
+ &Message{
+ ErrorMessage_: "ModSecurity: Warning. Test message",
+ },
+ },
+ }
+
+ data, err := f.Format(al)
+ if err != nil {
+ t.Error(err)
+ }
+
+ output := string(data)
+ // Verify structure matches ModSecurity format
+ if !bytes.Contains(data, []byte("174005763366.604533")) {
+ t.Errorf("Missing transaction ID in output:\n%s", output)
+ }
+ if !bytes.Contains(data, []byte("GET /status/200 HTTP/1.1")) {
+ t.Errorf("Missing request line in output:\n%s", output)
+ }
+ if !bytes.Contains(data, []byte("ModSecurity: Warning. Test message")) {
+ t.Errorf("Missing error message in output:\n%s", output)
+ }
+
+ scanner := bufio.NewScanner(bytes.NewReader(data))
+ var lines []string
+ for scanner.Scan() {
+ lines = append(lines, scanner.Text())
+ }
+
+ // Verify part A exists
+ if !bytes.Contains(data, []byte("-A--")) {
+ t.Error("Missing part A boundary")
+ }
+ // Verify part Z exists
+ if !bytes.Contains(data, []byte("-Z--")) {
+ t.Error("Missing part Z boundary")
+ }
+ // Verify part A has transaction info
+ checkLine(t, lines, 1, "[20/Feb/2025:13:20:33 +0000] 174005763366.604533 192.168.65.1 38532 172.21.0.3 8080")
+ })
+
+ t.Run("apache example 1 - cookies and multiple messages", func(t *testing.T) {
+ al := &Log{
+ Parts_: []types.AuditLogPart{
+ types.AuditLogPartHeader,
+ types.AuditLogPartRequestHeaders,
+ types.AuditLogPartResponseHeaders,
+ types.AuditLogPartIntermediaryResponseBody,
+ types.AuditLogPartAuditLogTrailer,
+ types.AuditLogPartEndMarker,
+ },
+ Transaction_: Transaction{
+ Timestamp_: "20/Feb/2025:15:15:26.453565 +0000",
+ UnixTimestamp_: 1740064526,
+ ID_: "Z7dHDrSPGgnIk-ru4hvJcQAAAIA",
+ ClientIP_: "192.168.65.1",
+ ClientPort_: 42378,
+ HostIP_: "172.22.0.3",
+ HostPort_: 8080,
+ Request_: &TransactionRequest{
+ URI_: "/",
+ Method_: "GET",
+ Protocol_: "HTTP/1.1",
+ Headers_: map[string][]string{
+ "Accept": {"*/*"},
+ "Connection": {"close"},
+ "Cookie": {"$Version=1; session=\"deadbeef; PHPSESSID=secret; dummy=qaz\""},
+ "Host": {"localhost"},
+ "Origin": {"https://www.example.com"},
+ "Referer": {"https://www.example.com/"},
+ "User-Agent": {"OWASP CRS test agent"},
+ },
+ },
+ Response_: &TransactionResponse{
+ Status_: 200,
+ Protocol_: "HTTP/1.1",
+ Headers_: map[string][]string{
+ "Content-Length": {"0"},
+ "Connection": {"close"},
+ },
+ },
+ },
+ Messages_: []plugintypes.AuditLogMessage{
+ &Message{
+ ErrorMessage_: `Message: Warning. String match "1" at REQUEST_COOKIES:$Version. [file "/etc/modsecurity.d/owasp-crs/rules/REQUEST-921-PROTOCOL-ATTACK.conf"] [line "332"] [id "921250"]`,
+ },
+ &Message{
+ ErrorMessage_: `Message: Warning. Operator GE matched 5 at TX:blocking_inbound_anomaly_score. [file "/etc/modsecurity.d/owasp-crs/rules/REQUEST-949-BLOCKING-EVALUATION.conf"] [line "233"] [id "949110"]`,
+ },
+ &Message{
+ ErrorMessage_: `Apache-Error: [file "apache2_util.c"] [line 275] [level 3] [client 192.168.65.1] ModSecurity: Warning.`,
+ },
+ },
+ }
+
+ data, err := f.Format(al)
+ if err != nil {
+ t.Error(err)
+ }
+
+ output := string(data)
+ // Verify key elements exist
+ if !bytes.Contains(data, []byte("Z7dHDrSPGgnIk-ru4hvJcQAAAIA")) {
+ t.Errorf("Missing transaction ID in output")
+ }
+ if !bytes.Contains(data, []byte("GET / HTTP/1.1")) {
+ t.Errorf("Missing request line in output")
+ }
+ if !bytes.Contains(data, []byte("Cookie: $Version=1")) {
+ t.Errorf("Missing cookie header in output")
+ }
+ if !bytes.Contains(data, []byte("HTTP/1.1 200")) {
+ t.Errorf("Missing response status in output:\n%s", output)
+ }
+ // Verify multiple messages are present
+ if !bytes.Contains(data, []byte("REQUEST-921-PROTOCOL-ATTACK.conf")) {
+ t.Errorf("Missing first message in output")
+ }
+ if !bytes.Contains(data, []byte("REQUEST-949-BLOCKING-EVALUATION.conf")) {
+ t.Errorf("Missing second message in output")
+ }
+ if !bytes.Contains(data, []byte("Apache-Error:")) {
+ t.Errorf("Missing Apache-Error message in output")
+ }
+
+ // Verify parts A, B, F, E, H, Z exist
+ if !bytes.Contains(data, []byte("-A--")) {
+ t.Error("Missing part A")
+ }
+ if !bytes.Contains(data, []byte("-B--")) {
+ t.Error("Missing part B")
+ }
+ if !bytes.Contains(data, []byte("-F--")) {
+ t.Error("Missing part F")
+ }
+ if !bytes.Contains(data, []byte("-E--")) {
+ t.Error("Missing part E")
+ }
+ if !bytes.Contains(data, []byte("-H--")) {
+ t.Error("Missing part H")
+ }
+ if !bytes.Contains(data, []byte("-Z--")) {
+ t.Error("Missing part Z")
+ }
+ })
+
+ t.Run("apache example 2 - SQL injection with multiple rules", func(t *testing.T) {
+ al := &Log{
+ Parts_: []types.AuditLogPart{
+ types.AuditLogPartHeader,
+ types.AuditLogPartRequestHeaders,
+ types.AuditLogPartResponseHeaders,
+ types.AuditLogPartIntermediaryResponseBody,
+ types.AuditLogPartAuditLogTrailer,
+ types.AuditLogPartEndMarker,
+ },
+ Transaction_: Transaction{
+ Timestamp_: "23/Feb/2025:22:40:32.479855 +0000",
+ UnixTimestamp_: 1740350432,
+ ID_: "Z7uj4AkDMIUwf_JHM4k9hAAAAAY",
+ ClientIP_: "192.168.65.1",
+ ClientPort_: 53953,
+ HostIP_: "172.21.0.3",
+ HostPort_: 8080,
+ Request_: &TransactionRequest{
+ URI_: "/get?var=sdfsd%27or%201%20%3e%201",
+ Method_: "GET",
+ Protocol_: "HTTP/1.0",
+ Headers_: map[string][]string{
+ "Accept": {"text/xml,application/xml,application/xhtml+xml,text/html;q=0.9,text/plain;q=0.8,image/png,*/*;q=0.5"},
+ "Connection": {"close"},
+ "Host": {"localhost"},
+ "User-Agent": {"OWASP CRS test agent"},
+ },
+ },
+ Response_: &TransactionResponse{
+ Status_: 200,
+ Protocol_: "HTTP/1.1",
+ Headers_: map[string][]string{
+ "Content-Length": {"0"},
+ "Connection": {"close"},
+ },
+ },
+ },
+ Messages_: []plugintypes.AuditLogMessage{
+ &Message{
+ ErrorMessage_: `Message: Warning. Found 5 byte(s) in ARGS:var outside range: 38,44-46,48-58,61,65-90,95,97-122. [file "/etc/modsecurity.d/owasp-crs/rules/REQUEST-920-PROTOCOL-ENFORCEMENT.conf"] [line "1739"] [id "920273"]`,
+ },
+ &Message{
+ ErrorMessage_: `Message: Warning. detected SQLi using libinjection with fingerprint 's&1' [file "/etc/modsecurity.d/owasp-crs/rules/REQUEST-942-APPLICATION-ATTACK-SQLI.conf"] [line "66"] [id "942100"]`,
+ },
+ &Message{
+ ErrorMessage_: `Message: Warning. Pattern match "(?i)(?:/\\*)+[\"'` + "`" + `]+[\\s\\x0b]?(?:--|[#\\{]|/\\*)?|[\"'` + "`" + `](?:[\\s\\x0b]*(?:(?:x?or|and|div|like|between)" at ARGS:var. [file "/etc/modsecurity.d/owasp-crs/rules/REQUEST-942-APPLICATION-ATTACK-SQLI.conf"] [line "822"] [id "942180"]`,
+ },
+ &Message{
+ ErrorMessage_: `Message: Warning. Operator GE matched 5 at TX:blocking_inbound_anomaly_score. [file "/etc/modsecurity.d/owasp-crs/rules/REQUEST-949-BLOCKING-EVALUATION.conf"] [line "233"] [id "949110"] [msg "Inbound Anomaly Score Exceeded (Total Score: 33)"]`,
+ },
+ },
+ }
+
+ data, err := f.Format(al)
+ if err != nil {
+ t.Error(err)
+ }
+
+ output := string(data)
+ // Verify SQL injection detection messages
+ if !bytes.Contains(data, []byte("Z7uj4AkDMIUwf_JHM4k9hAAAAAY")) {
+ t.Errorf("Missing transaction ID")
+ }
+ if !bytes.Contains(data, []byte("/get?var=sdfsd%27or%201%20%3e%201")) {
+ t.Errorf("Missing URI with SQL injection attempt")
+ }
+ if !bytes.Contains(data, []byte("libinjection")) {
+ t.Errorf("Missing libinjection message")
+ }
+ if !bytes.Contains(data, []byte("920273")) {
+ t.Errorf("Missing rule ID 920273")
+ }
+ if !bytes.Contains(data, []byte("942100")) {
+ t.Errorf("Missing rule ID 942100")
+ }
+ if !bytes.Contains(data, []byte("942180")) {
+ t.Errorf("Missing rule ID 942180")
+ }
+ if !bytes.Contains(data, []byte("Inbound Anomaly Score Exceeded (Total Score: 33)")) {
+ t.Errorf("Missing anomaly score message")
+ }
+
+ scanner := bufio.NewScanner(bytes.NewReader(data))
+ var lines []string
+ for scanner.Scan() {
+ lines = append(lines, scanner.Text())
+ }
+
+ // Verify the timestamp with microseconds
+ if !bytes.Contains(data, []byte("[23/Feb/2025:22:40:32.479855 +0000]")) {
+ t.Errorf("Missing timestamp with microseconds in output:\n%s", output)
+ }
+
+ // Verify part H contains all messages (they should be concatenated)
+ partHFound := false
+ for i, line := range lines {
+ if strings.Contains(line, "-H--") {
+ partHFound = true
+ // Check that messages follow part H marker
+ if i+1 < len(lines) {
+ // At least one message should be present
+ messagesFound := 0
+ for j := i + 1; j < len(lines) && !strings.Contains(lines[j], "--"); j++ {
+ if strings.Contains(lines[j], "Message:") {
+ messagesFound++
+ }
+ }
+ if messagesFound == 0 {
+ t.Errorf("No messages found after part H marker")
+ }
+ }
+ break
+ }
+ }
+ if !partHFound {
+ t.Error("Part H marker not found")
+ }
})
}
diff --git a/internal/auditlog/https_writer_test.go b/internal/auditlog/https_writer_test.go
index 65d883a8d..5fc4e193f 100644
--- a/internal/auditlog/https_writer_test.go
+++ b/internal/auditlog/https_writer_test.go
@@ -57,7 +57,7 @@ func TestHTTPAuditLog(t *testing.T) {
t.Fatal("Body is empty")
}
if !bytes.Contains(body, []byte("test123")) {
- t.Fatal("Body does not match")
+ t.Fatalf("Body does not match, got:\n%s", string(body))
}
}))
defer server.Close()
diff --git a/internal/bodyprocessors/json.go b/internal/bodyprocessors/json.go
index 101851f7e..5176d9369 100644
--- a/internal/bodyprocessors/json.go
+++ b/internal/bodyprocessors/json.go
@@ -4,6 +4,7 @@
package bodyprocessors
import (
+ "errors"
"io"
"strconv"
"strings"
@@ -17,17 +18,18 @@ type jsonBodyProcessor struct{}
var _ plugintypes.BodyProcessor = &jsonBodyProcessor{}
-func (js *jsonBodyProcessor) ProcessRequest(reader io.Reader, v plugintypes.TransactionVariables, _ plugintypes.BodyProcessorOptions) error {
- // Read the entire body to store it and process it
+func (js *jsonBodyProcessor) ProcessRequest(reader io.Reader, v plugintypes.TransactionVariables, bpo plugintypes.BodyProcessorOptions) error {
+ // Read the entire body into memory for two purposes:
+ // 1. Store raw JSON in TX variables for operators like @validateSchema
+ // 2. Parse and flatten for ARGS_POST collection
s := strings.Builder{}
if _, err := io.Copy(&s, reader); err != nil {
return err
}
ss := s.String()
-
- // Process as normal
+ // Process with recursion limit
col := v.ArgsPost()
- data, err := readJSON(ss)
+ data, err := readJSON(ss, bpo.RequestBodyRecursionLimit)
if err != nil {
return err
}
@@ -45,6 +47,8 @@ func (js *jsonBodyProcessor) ProcessRequest(reader io.Reader, v plugintypes.Tran
return nil
}
+const ignoreJSONRecursionLimit = -1
+
func (js *jsonBodyProcessor) ProcessResponse(reader io.Reader, v plugintypes.TransactionVariables, _ plugintypes.BodyProcessorOptions) error {
// Read the entire body to store it and process it
s := strings.Builder{}
@@ -52,10 +56,9 @@ func (js *jsonBodyProcessor) ProcessResponse(reader io.Reader, v plugintypes.Tra
return err
}
ss := s.String()
-
- // Process as normal
+ // Process with no recursion limit as we don't have a directive for response body
col := v.ResponseArgs()
- data, err := readJSON(ss)
+ data, err := readJSON(ss, ignoreJSONRecursionLimit)
if err != nil {
return err
}
@@ -73,22 +76,33 @@ func (js *jsonBodyProcessor) ProcessResponse(reader io.Reader, v plugintypes.Tra
return nil
}
-func readJSON(s string) (map[string]string, error) {
- json := gjson.Parse(s)
+func readJSON(s string, maxRecursion int) (map[string]string, error) {
res := make(map[string]string)
key := []byte("json")
- readItems(json, key, res)
- return res, nil
+
+ if !gjson.Valid(s) {
+ return res, errors.New("invalid JSON")
+ }
+ json := gjson.Parse(s)
+ err := readItems(json, key, maxRecursion, res)
+ return res, err
}
// Transform JSON to a map[string]string
+// This function is recursive and will call itself for nested objects.
+// The limit in recursion is defined by maxItems.
// Example input: {"data": {"name": "John", "age": 30}, "items": [1,2,3]}
// Example output: map[string]string{"json.data.name": "John", "json.data.age": "30", "json.items.0": "1", "json.items.1": "2", "json.items.2": "3"}
// Example input: [{"data": {"name": "John", "age": 30}, "items": [1,2,3]}]
// Example output: map[string]string{"json.0.data.name": "John", "json.0.data.age": "30", "json.0.items.0": "1", "json.0.items.1": "2", "json.0.items.2": "3"}
-// TODO add some anti DOS protection
-func readItems(json gjson.Result, objKey []byte, res map[string]string) {
+func readItems(json gjson.Result, objKey []byte, maxRecursion int, res map[string]string) error {
arrayLen := 0
+ var iterationError error
+ if maxRecursion == 0 {
+ // We reached the limit of nesting we want to handle. This protects against
+ // DoS attacks using deeply nested JSON structures (e.g., {"a":{"a":{"a":...}}}).
+ return errors.New("max recursion reached while reading json object")
+ }
json.ForEach(func(key, value gjson.Result) bool {
// Avoid string concatenation to maintain a single buffer for key aggregation.
prevParentLength := len(objKey)
@@ -103,7 +117,11 @@ func readItems(json gjson.Result, objKey []byte, res map[string]string) {
var val string
switch value.Type {
case gjson.JSON:
- readItems(value, objKey, res)
+ // call recursively with one less item to avoid doing infinite recursion
+ iterationError = readItems(value, objKey, maxRecursion-1, res)
+ if iterationError != nil {
+ return false
+ }
objKey = objKey[:prevParentLength]
return true
case gjson.String:
@@ -123,6 +141,7 @@ func readItems(json gjson.Result, objKey []byte, res map[string]string) {
if arrayLen > 0 {
res[string(objKey)] = strconv.Itoa(arrayLen)
}
+ return iterationError
}
func init() {
diff --git a/internal/bodyprocessors/json_test.go b/internal/bodyprocessors/json_test.go
index 38e190e03..c4042342f 100644
--- a/internal/bodyprocessors/json_test.go
+++ b/internal/bodyprocessors/json_test.go
@@ -4,13 +4,23 @@
package bodyprocessors
import (
+ "errors"
+ "strings"
"testing"
+
+ "github.com/tidwall/gjson"
+)
+
+const (
+ deeplyNestedJSONObject = 15000
+ maxRecursion = 10000
)
var jsonTests = []struct {
name string
json string
want map[string]string
+ err error
}{
{
name: "map",
@@ -55,6 +65,7 @@ var jsonTests = []struct {
"json.f.0.0": "1",
"json.f.0.0.0.z": "abc",
},
+ err: nil,
},
{
name: "array",
@@ -115,6 +126,35 @@ var jsonTests = []struct {
"json.1.f.0.0": "1",
"json.1.f.0.0.0.z": "abc",
},
+ err: nil,
+ },
+ {
+ name: "unbalanced_brackets",
+ json: `{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"a": 1 }}}}}}}}}}}}}}}}}}}}}}`,
+ want: map[string]string{},
+ err: errors.New("invalid JSON"),
+ },
+ {
+ name: "broken2",
+ json: `{"test": 123, "test2": 456, "test3": [22, 44, 55], "test4": 3}`,
+ want: map[string]string{
+ "json.test3.0": "22",
+ "json.test3.1": "44",
+ "json.test3.2": "55",
+ "json.test4": "3",
+ "json.test": "123",
+ "json.test2": "456",
+ "json.test3": "3",
+ },
+ err: nil,
+ },
+ {
+ name: "bomb",
+ json: strings.Repeat(`{"a":`, deeplyNestedJSONObject) + "1" + strings.Repeat(`}`, deeplyNestedJSONObject),
+ want: map[string]string{
+ "json." + strings.Repeat(`a.`, deeplyNestedJSONObject-1) + "a": "1",
+ },
+ err: errors.New("max recursion reached while reading json object"),
},
{
name: "empty_object",
@@ -143,18 +183,25 @@ func TestReadJSON(t *testing.T) {
for _, tc := range jsonTests {
tt := tc
t.Run(tt.name, func(t *testing.T) {
- jsonMap, err := readJSON(tt.json)
- if err != nil {
- t.Error(err)
- }
+ jsonMap, err := readJSON(tt.json, maxRecursion)
// Special case for nested_empty - just check that the function doesn't error
if tt.name == "nested_empty" {
+ if err != nil {
+ t.Error(err)
+ }
// Print the keys for debugging
t.Logf("Actual keys for nested_empty: %v", mapKeys(jsonMap))
return
}
+ if err != nil {
+ if tt.err == nil || err.Error() != tt.err.Error() {
+ t.Error(err)
+ }
+ return
+ }
+
for k, want := range tt.want {
if have, ok := jsonMap[k]; ok {
if want != have {
@@ -183,11 +230,10 @@ func mapKeys(m map[string]string) []string {
}
func TestInvalidJSON(t *testing.T) {
- _, err := readJSON(`{invalid json`)
- if err != nil {
- // We expect no error since gjson.Parse doesn't return errors for invalid JSON
- // Instead, it returns a Result with Type == Null
- t.Error("Expected no error for invalid JSON, got:", err)
+ _, err := readJSON(`{invalid json`, maxRecursion)
+ if err == nil {
+ // We expect an error for invalid JSON since we now validate
+ t.Error("Expected error for invalid JSON, got nil")
}
}
@@ -196,7 +242,7 @@ func BenchmarkReadJSON(b *testing.B) {
tt := tc
b.Run(tt.name, func(b *testing.B) {
for i := 0; i < b.N; i++ {
- _, err := readJSON(tt.json)
+ _, err := readJSON(tt.json, maxRecursion)
if err != nil {
b.Error(err)
}
@@ -204,3 +250,71 @@ func BenchmarkReadJSON(b *testing.B) {
})
}
}
+
+// readJSONNoValidation is readJSON without the gjson.Valid pre-check.
+// Used only in benchmarks to measure the overhead of validation.
+func readJSONNoValidation(s string, maxRecursion int) (map[string]string, error) {
+ json := gjson.Parse(s)
+ res := make(map[string]string)
+ key := []byte("json")
+ err := readItems(json, key, maxRecursion, res)
+ return res, err
+}
+
+// BenchmarkValidationOverhead measures the cost of pre-validating JSON with gjson.Valid
+// in the context of the full readJSON pipeline (Valid + Parse + readItems).
+// gjson.Parse is lazy (~9ns regardless of input size), so the real overhead is
+// gjson.Valid vs the readItems traversal that does the actual parsing work.
+func BenchmarkValidationOverhead(b *testing.B) {
+ benchCases := []struct {
+ name string
+ json string
+ }{
+ {
+ name: "small_object",
+ json: `{"name":"John","age":30}`,
+ },
+ {
+ name: "medium_object",
+ json: `{"user":{"name":"John","email":"john@example.com","roles":["admin","user"]},"settings":{"theme":"dark","notifications":true},"metadata":{"created":"2026-01-01","updated":"2026-02-15"}}`,
+ },
+ {
+ name: "large_array",
+ json: func() string {
+ var sb strings.Builder
+ sb.WriteString("[")
+ for i := 0; i < 100; i++ {
+ if i > 0 {
+ sb.WriteString(",")
+ }
+ sb.WriteString(`{"id":` + strings.Repeat("1", 5) + `,"name":"user","active":true}`)
+ }
+ sb.WriteString("]")
+ return sb.String()
+ }(),
+ },
+ {
+ name: "nested_10_levels",
+ json: strings.Repeat(`{"a":`, 10) + "1" + strings.Repeat(`}`, 10),
+ },
+ }
+
+ for _, bc := range benchCases {
+ b.Run("WithValidation/"+bc.name, func(b *testing.B) {
+ b.SetBytes(int64(len(bc.json)))
+ for i := 0; i < b.N; i++ {
+ if _, err := readJSON(bc.json, maxRecursion); err != nil {
+ b.Fatal(err)
+ }
+ }
+ })
+ b.Run("WithoutValidation/"+bc.name, func(b *testing.B) {
+ b.SetBytes(int64(len(bc.json)))
+ for i := 0; i < b.N; i++ {
+ if _, err := readJSONNoValidation(bc.json, maxRecursion); err != nil {
+ b.Fatal(err)
+ }
+ }
+ })
+ }
+}
diff --git a/internal/bodyprocessors/multipart.go b/internal/bodyprocessors/multipart.go
index ded38b2b1..c3ec5e733 100644
--- a/internal/bodyprocessors/multipart.go
+++ b/internal/bodyprocessors/multipart.go
@@ -58,6 +58,7 @@ func (mbp *multipartBodyProcessor) ProcessRequest(reader io.Reader, v plugintype
filename := originFileName(p)
if filename != "" {
var size int64
+ seenUnexpectedEOF := false
if environment.HasAccessToFS {
// Only copy file to temp when not running in TinyGo
temp, err := os.CreateTemp(storagePath, "crzmp*")
@@ -68,16 +69,22 @@ func (mbp *multipartBodyProcessor) ProcessRequest(reader io.Reader, v plugintype
defer temp.Close()
sz, err := io.Copy(temp, p)
if err != nil {
- v.MultipartStrictError().(*collections.Single).Set("1")
- return err
+ if !errors.Is(err, io.ErrUnexpectedEOF) {
+ v.MultipartStrictError().(*collections.Single).Set("1")
+ return err
+ }
+ seenUnexpectedEOF = true
}
size = sz
filesTmpNamesCol.Add("", temp.Name())
} else {
sz, err := io.Copy(io.Discard, p)
if err != nil {
- v.MultipartStrictError().(*collections.Single).Set("1")
- return err
+ if !errors.Is(err, io.ErrUnexpectedEOF) {
+ v.MultipartStrictError().(*collections.Single).Set("1")
+ return err
+ }
+ seenUnexpectedEOF = true
}
size = sz
}
@@ -85,17 +92,26 @@ func (mbp *multipartBodyProcessor) ProcessRequest(reader io.Reader, v plugintype
filesCol.Add("", filename)
fileSizesCol.SetIndex(filename, 0, fmt.Sprintf("%d", size))
filesNamesCol.Add("", p.FormName())
+ filesCombinedSizeCol.(*collections.Single).Set(fmt.Sprintf("%d", totalSize))
+ if seenUnexpectedEOF {
+ break
+ }
} else {
// if is a field
data, err := io.ReadAll(p)
if err != nil {
- v.MultipartStrictError().(*collections.Single).Set("1")
- return err
+ if !errors.Is(err, io.ErrUnexpectedEOF) {
+ v.MultipartStrictError().(*collections.Single).Set("1")
+ return err
+ }
}
totalSize += int64(len(data))
postCol.Add(p.FormName(), string(data))
+ filesCombinedSizeCol.(*collections.Single).Set(fmt.Sprintf("%d", totalSize))
+ if errors.Is(err, io.ErrUnexpectedEOF) {
+ break
+ }
}
- filesCombinedSizeCol.(*collections.Single).Set(fmt.Sprintf("%d", totalSize))
}
return nil
}
diff --git a/internal/bodyprocessors/multipart_test.go b/internal/bodyprocessors/multipart_test.go
index 9b6b36fc8..97b9b63b1 100644
--- a/internal/bodyprocessors/multipart_test.go
+++ b/internal/bodyprocessors/multipart_test.go
@@ -210,3 +210,123 @@ func TestMultipartUnmatchedBoundary(t *testing.T) {
}
}
}
+
+func TestIncompleteMultipartPayload(t *testing.T) {
+ testCases := []struct {
+ name string
+ input string
+ }{
+ {
+ name: "inMiddleOfBoundary",
+ input: `
+-----------------------------9051914041544843365972754266
+Content-Disposition: form-data; name="text"
+
+text default
+-----------------------------9051914041544843365972754266
+Content-Disposition: form-data; name="file1"; filename="a.txt"
+Content-Type: text/plain
+
+Content of a.txt.
+
+-----------------------------905191404154484336
+`,
+ },
+ {
+ name: "inMiddleOfHeader",
+ input: `
+-----------------------------9051914041544843365972754266
+Content-Disposition: form-data; name="text"
+
+text default
+-----------------------------9051914041544843365972754266
+Content-Disposition: form-data; name="file1"; filename="a.txt"
+Content-Type: text/plain
+
+Content of a.txt.
+
+-----------------------------9051914041544843365972754266
+Content-Disposition: form-data; name="fil`,
+ },
+ {
+ name: "inMiddleOfContent",
+ input: `
+-----------------------------9051914041544843365972754266
+Content-Disposition: form-data; name="text"
+
+text default
+-----------------------------9051914041544843365972754266
+Content-Disposition: form-data; name="file1"; filename="a.txt"
+Content-Type: text/plain
+
+Content of a.txt.
+
+-----------------------------9051914041544843365972754266
+Content-Disposition: form-data; name="file2"; filename="a.html"
+Content-Type: text/html
+
+
Content of `,
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ payload := strings.TrimSpace(tc.input)
+
+ mp := multipartProcessor(t)
+
+ v := corazawaf.NewTransactionVariables()
+ if err := mp.ProcessRequest(strings.NewReader(payload), v, plugintypes.BodyProcessorOptions{
+ Mime: "multipart/form-data; boundary=---------------------------9051914041544843365972754266",
+ }); err != nil {
+ t.Fatal(err)
+ }
+ // first we validate we got the headers
+ headers := v.MultipartPartHeaders()
+ header1 := "Content-Disposition: form-data; name=\"file1\"; filename=\"a.txt\""
+ header2 := "Content-Type: text/plain"
+ if h := headers.Get("file1"); len(h) == 0 {
+ t.Fatal("expected headers for file1")
+ } else {
+ if len(h) != 2 {
+ t.Fatal("expected 2 headers for file1")
+ }
+ if (h[0] != header1 && h[0] != header2) || (h[1] != header1 && h[1] != header2) {
+ t.Fatalf("Got invalid multipart headers")
+ }
+ }
+
+ // Verify form field data was correctly processed before the incomplete part
+ argsPost := v.ArgsPost()
+ if textValues := argsPost.Get("text"); len(textValues) == 0 {
+ t.Fatal("expected ArgsPost to contain 'text' field")
+ } else if textValues[0] != "text default" {
+ t.Fatalf("expected ArgsPost 'text' to be 'text default', got %q", textValues[0])
+ }
+ })
+ }
+}
+
+func TestIncompleteMultipartPayloadInFormField(t *testing.T) {
+ payload := strings.TrimSpace(`
+-----------------------------9051914041544843365972754266
+Content-Disposition: form-data; name="text"
+
+text defa`)
+
+ mp := multipartProcessor(t)
+
+ v := corazawaf.NewTransactionVariables()
+ if err := mp.ProcessRequest(strings.NewReader(payload), v, plugintypes.BodyProcessorOptions{
+ Mime: "multipart/form-data; boundary=---------------------------9051914041544843365972754266",
+ }); err != nil {
+ t.Fatal(err)
+ }
+
+ // Verify the partial form field data was processed
+ argsPost := v.ArgsPost()
+ if textValues := argsPost.Get("text"); len(textValues) == 0 {
+ t.Fatal("expected ArgsPost to contain 'text' field")
+ } else if textValues[0] != "text defa" {
+ t.Fatalf("expected ArgsPost 'text' to be 'text defa', got %q", textValues[0])
+ }
+}
diff --git a/internal/collections/map.go b/internal/collections/map.go
index ba6adf35e..7ca642809 100644
--- a/internal/collections/map.go
+++ b/internal/collections/map.go
@@ -60,16 +60,30 @@ func (c *Map) Get(key string) []string {
// FindRegex returns all map elements whose key matches the regular expression.
func (c *Map) FindRegex(key *regexp.Regexp) []types.MatchData {
- var result []types.MatchData
+ n := 0
+ // Collect matching data slices in a single pass to avoid evaluating the regex twice per key.
+ var matched [][]keyValue
for k, data := range c.data {
if key.MatchString(k) {
- for _, d := range data {
- result = append(result, &corazarules.MatchData{
- Variable_: c.variable,
- Key_: d.key,
- Value_: d.value,
- })
+ n += len(data)
+ matched = append(matched, data)
+ }
+ }
+ if n == 0 {
+ return nil
+ }
+ buf := make([]corazarules.MatchData, n)
+ result := make([]types.MatchData, n)
+ i := 0
+ for _, data := range matched {
+ for _, d := range data {
+ buf[i] = corazarules.MatchData{
+ Variable_: c.variable,
+ Key_: d.key,
+ Value_: d.value,
}
+ result[i] = &buf[i]
+ i++
}
}
return result
@@ -77,7 +91,6 @@ func (c *Map) FindRegex(key *regexp.Regexp) []types.MatchData {
// FindString returns all map elements whose key matches the string.
func (c *Map) FindString(key string) []types.MatchData {
- var result []types.MatchData
if key == "" {
return c.FindAll()
}
@@ -87,29 +100,44 @@ func (c *Map) FindString(key string) []types.MatchData {
if !c.isCaseSensitive {
key = strings.ToLower(key)
}
- // if key is not empty
- if e, ok := c.data[key]; ok {
- for _, aVar := range e {
- result = append(result, &corazarules.MatchData{
- Variable_: c.variable,
- Key_: aVar.key,
- Value_: aVar.value,
- })
+ e, ok := c.data[key]
+ if !ok || len(e) == 0 {
+ return nil
+ }
+ buf := make([]corazarules.MatchData, len(e))
+ result := make([]types.MatchData, len(e))
+ for i, aVar := range e {
+ buf[i] = corazarules.MatchData{
+ Variable_: c.variable,
+ Key_: aVar.key,
+ Value_: aVar.value,
}
+ result[i] = &buf[i]
}
return result
}
// FindAll returns all map elements.
func (c *Map) FindAll() []types.MatchData {
- var result []types.MatchData
+ n := 0
+ for _, data := range c.data {
+ n += len(data)
+ }
+ if n == 0 {
+ return nil
+ }
+ buf := make([]corazarules.MatchData, n)
+ result := make([]types.MatchData, n)
+ i := 0
for _, data := range c.data {
for _, d := range data {
- result = append(result, &corazarules.MatchData{
+ buf[i] = corazarules.MatchData{
Variable_: c.variable,
Key_: d.key,
Value_: d.value,
- })
+ }
+ result[i] = &buf[i]
+ i++
}
}
return result
diff --git a/internal/collections/map_test.go b/internal/collections/map_test.go
index 7c47d29c5..60dc0c68f 100644
--- a/internal/collections/map_test.go
+++ b/internal/collections/map_test.go
@@ -107,6 +107,120 @@ func TestNewCaseSensitiveKeyMap(t *testing.T) {
}
+func TestFindAllBulkAllocIndependence(t *testing.T) {
+ m := NewMap(variables.ArgsGet)
+ m.Add("key1", "value1")
+ m.Add("key2", "value2")
+ m.Add("key3", "value3")
+
+ results := m.FindAll()
+ if len(results) != 3 {
+ t.Fatalf("expected 3 results, got %d", len(results))
+ }
+
+ // Mutate first result's value through the MatchData interface
+ // and verify others are not affected
+ values := make([]string, len(results))
+ for i, r := range results {
+ values[i] = r.Value()
+ }
+
+ // Verify all values are distinct and correct
+ seen := map[string]bool{}
+ for _, v := range values {
+ if seen[v] {
+ t.Errorf("duplicate value found: %s", v)
+ }
+ seen[v] = true
+ }
+ if !seen["value1"] || !seen["value2"] || !seen["value3"] {
+ t.Errorf("expected value1, value2, value3 but got %v", values)
+ }
+}
+
+func TestFindStringBulkAlloc(t *testing.T) {
+ m := NewMap(variables.ArgsGet)
+ m.Add("key", "val1")
+ m.Add("key", "val2")
+
+ results := m.FindString("key")
+ if len(results) != 2 {
+ t.Fatalf("expected 2 results, got %d", len(results))
+ }
+
+ // Each result should have distinct values
+ if results[0].Value() == results[1].Value() {
+ t.Errorf("expected distinct values, got %q and %q", results[0].Value(), results[1].Value())
+ }
+}
+
+func TestFindRegexBulkAlloc(t *testing.T) {
+ m := NewMap(variables.ArgsGet)
+ m.Add("abc", "val1")
+ m.Add("abd", "val2")
+ m.Add("xyz", "val3")
+
+ re := regexp.MustCompile("^ab")
+ results := m.FindRegex(re)
+ if len(results) != 2 {
+ t.Fatalf("expected 2 results, got %d", len(results))
+ }
+
+ // Verify keys match regex
+ for _, r := range results {
+ if r.Key() != "abc" && r.Key() != "abd" {
+ t.Errorf("unexpected key: %s", r.Key())
+ }
+ }
+}
+
+func TestFindAllEmptyMap(t *testing.T) {
+ m := NewMap(variables.ArgsGet)
+ results := m.FindAll()
+ if results != nil {
+ t.Errorf("expected nil for empty map, got %v", results)
+ }
+}
+
+func BenchmarkFindAll(b *testing.B) {
+ b.ReportAllocs()
+ m := NewMap(variables.RequestHeaders)
+ for i := 0; i < 20; i++ {
+ m.Add(fmt.Sprintf("x-header-%d", i), fmt.Sprintf("value-%d", i))
+ }
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ _ = m.FindAll()
+ }
+}
+
+func BenchmarkFindRegex(b *testing.B) {
+ b.ReportAllocs()
+ m := NewMap(variables.RequestHeaders)
+ for i := 0; i < 20; i++ {
+ m.Add(fmt.Sprintf("x-header-%d", i), fmt.Sprintf("value-%d", i))
+ }
+ // Matches keys ending in 0-9 (x-header-0 .. x-header-9), roughly half.
+ re := regexp.MustCompile(`^x-header-\d$`)
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ _ = m.FindRegex(re)
+ }
+}
+
+func BenchmarkFindString(b *testing.B) {
+ b.ReportAllocs()
+ m := NewMap(variables.RequestHeaders)
+ // Single key with multiple values
+ for i := 0; i < 20; i++ {
+ m.Add("x-forwarded-for", fmt.Sprintf("10.0.0.%d", i))
+ }
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ _ = m.FindString("x-forwarded-for")
+ }
+}
+
func BenchmarkTxSetGet(b *testing.B) {
keys := make(map[int]string, b.N)
for i := 0; i < b.N; i++ {
diff --git a/internal/collections/named.go b/internal/collections/named.go
index c3b1a0033..6d6c922f4 100644
--- a/internal/collections/named.go
+++ b/internal/collections/named.go
@@ -101,37 +101,49 @@ type NamedCollectionNames struct {
}
func (c *NamedCollectionNames) FindRegex(key *regexp.Regexp) []types.MatchData {
- var res []types.MatchData
-
+ n := 0
+ // Collect matching data slices in a single pass to avoid evaluating the regex twice per key.
+ var matched [][]keyValue
for k, data := range c.collection.data {
- if !key.MatchString(k) {
- continue
+ if key.MatchString(k) {
+ n += len(data)
+ matched = append(matched, data)
}
+ }
+ if n == 0 {
+ return nil
+ }
+ buf := make([]corazarules.MatchData, n)
+ res := make([]types.MatchData, n)
+ i := 0
+ for _, data := range matched {
for _, d := range data {
- res = append(res, &corazarules.MatchData{
+ buf[i] = corazarules.MatchData{
Variable_: c.variable,
Key_: d.key,
Value_: d.key,
- })
+ }
+ res[i] = &buf[i]
+ i++
}
}
return res
}
func (c *NamedCollectionNames) FindString(key string) []types.MatchData {
- var res []types.MatchData
-
- for k, data := range c.collection.data {
- if k != key {
- continue
- }
- for _, d := range data {
- res = append(res, &corazarules.MatchData{
- Variable_: c.variable,
- Key_: d.key,
- Value_: d.key,
- })
+ data, ok := c.collection.data[key]
+ if !ok || len(data) == 0 {
+ return nil
+ }
+ buf := make([]corazarules.MatchData, len(data))
+ res := make([]types.MatchData, len(data))
+ for i, d := range data {
+ buf[i] = corazarules.MatchData{
+ Variable_: c.variable,
+ Key_: d.key,
+ Value_: d.key,
}
+ res[i] = &buf[i]
}
return res
}
@@ -141,16 +153,25 @@ func (c *NamedCollectionNames) Get(key string) []string {
}
func (c *NamedCollectionNames) FindAll() []types.MatchData {
- var res []types.MatchData
- // Iterates over all the data in the map and adds the key element also to the Key field (The key value may be the value
- // that is matched, but it is still also the key of the pair and it is needed to print the matched var name)
+ n := 0
+ for _, data := range c.collection.data {
+ n += len(data)
+ }
+ if n == 0 {
+ return nil
+ }
+ buf := make([]corazarules.MatchData, n)
+ res := make([]types.MatchData, n)
+ i := 0
for _, data := range c.collection.data {
for _, d := range data {
- res = append(res, &corazarules.MatchData{
+ buf[i] = corazarules.MatchData{
Variable_: c.variable,
Key_: d.key,
Value_: d.key,
- })
+ }
+ res[i] = &buf[i]
+ i++
}
}
return res
diff --git a/internal/corazawaf/rule.go b/internal/corazawaf/rule.go
index c7b047707..d2e3b5a12 100644
--- a/internal/corazawaf/rule.go
+++ b/internal/corazawaf/rule.go
@@ -14,7 +14,7 @@ import (
"github.com/corazawaf/coraza/v3/experimental/plugins/macro"
"github.com/corazawaf/coraza/v3/experimental/plugins/plugintypes"
"github.com/corazawaf/coraza/v3/internal/corazarules"
- "github.com/corazawaf/coraza/v3/internal/memoize"
+ utils "github.com/corazawaf/coraza/v3/internal/strings"
"github.com/corazawaf/coraza/v3/types"
"github.com/corazawaf/coraza/v3/types/variables"
)
@@ -97,6 +97,11 @@ type Rule struct {
transformations []ruleTransformationParams
transformationsID int
+ // transformationPrefixIDs holds the chain ID at each step of the transformation chain.
+ // transformationPrefixIDs[i] is the ID representing transformations [0..i].
+ // This enables prefix-based caching: rules sharing a common transformation prefix
+ // can reuse cached intermediate results instead of recomputing from scratch.
+ transformationPrefixIDs []int
// Slice of initialized actions to be evaluated during
// the rule evaluation process
@@ -146,6 +151,8 @@ type Rule struct {
// chainedRules containing rules with just PhaseUnknown variables, may potentially
// be anticipated. This boolean ensures that it happens
withPhaseUnknownVariable bool
+
+ memoizer plugintypes.Memoizer
}
func (r *Rule) ParentID() int {
@@ -161,7 +168,7 @@ const chainLevelZero = 0
// Evaluate will evaluate the current rule for the indicated transaction
// If the operator matches, actions will be evaluated, and it will return
// the matched variables, keys and values (MatchData)
-func (r *Rule) Evaluate(phase types.RulePhase, tx plugintypes.TransactionState, cache map[transformationKey]*transformationValue) {
+func (r *Rule) Evaluate(phase types.RulePhase, tx plugintypes.TransactionState, cache map[transformationKey]transformationValue) {
// collectiveMatchedValues lives across recursive calls of doEvaluate
var collectiveMatchedValues []types.MatchData
@@ -180,7 +187,7 @@ func (r *Rule) Evaluate(phase types.RulePhase, tx plugintypes.TransactionState,
const noID = 0
-func (r *Rule) doEvaluate(logger debuglog.Logger, phase types.RulePhase, tx *Transaction, collectiveMatchedValues *[]types.MatchData, chainLevel int, cache map[transformationKey]*transformationValue) []types.MatchData {
+func (r *Rule) doEvaluate(logger debuglog.Logger, phase types.RulePhase, tx *Transaction, collectiveMatchedValues *[]types.MatchData, chainLevel int, cache map[transformationKey]transformationValue) []types.MatchData {
tx.Capture = r.Capture
if multiphaseEvaluation {
@@ -230,7 +237,7 @@ func (r *Rule) doEvaluate(logger debuglog.Logger, phase types.RulePhase, tx *Tra
for _, c := range ecol {
if c.Variable == v.Variable {
// TODO shall we check the pointer?
- v.Exceptions = append(v.Exceptions, ruleVariableException{c.KeyStr, nil})
+ v.Exceptions = append(v.Exceptions, ruleVariableException{c.KeyStr, c.KeyRx})
}
}
@@ -372,15 +379,22 @@ func (r *Rule) doEvaluate(logger debuglog.Logger, phase types.RulePhase, tx *Tra
}
for _, a := range r.actions {
- if a.Function.Type() == plugintypes.ActionTypeFlow {
- // Flow actions are evaluated also if the rule engine is set to DetectionOnly
+ // All actions are evaluated independently from the engine being On or in DetectionOnly.
+ // The action evaluation is responsible of checking the engine mode and decide if the disruptive action
+ // has to be enforced or not. This allows finer control to the actions, such us creating the detectionOnlyInterruption and
+ // allowing RelevantOnly audit logs in detection only mode.
+ switch a.Function.Type() {
+ case plugintypes.ActionTypeFlow:
logger.Debug().Str("action", a.Name).Int("phase", int(phase)).Msg("Evaluating flow action for rule")
- a.Function.Evaluate(r, tx)
- } else if a.Function.Type() == plugintypes.ActionTypeDisruptive && tx.RuleEngine == types.RuleEngineOn {
+ case plugintypes.ActionTypeDisruptive:
// The parser enforces that the disruptive action is just one per rule (if more than one, only the last one is kept)
- logger.Debug().Str("action", a.Name).Msg("Executing disruptive action for rule")
- a.Function.Evaluate(r, tx)
+ logger.Debug().Str("action", a.Name).Int("phase", int(phase)).Msg("Executing disruptive action for rule")
+ default:
+ // Only flow and disruptive actions are supposed to be evaluated here, non disruptive actions
+ // are evaluated previously, during the variable matching.
+ continue
}
+ a.Function.Evaluate(r, tx)
}
if r.ID_ != noID {
// we avoid matching chains and secmarkers
@@ -397,7 +411,7 @@ func (r *Rule) transformMultiMatchArg(arg types.MatchData) ([]string, []error) {
return r.executeTransformationsMultimatch(arg.Value())
}
-func (r *Rule) transformArg(arg types.MatchData, argIdx int, cache map[transformationKey]*transformationValue) (string, []error) {
+func (r *Rule) transformArg(arg types.MatchData, argIdx int, cache map[transformationKey]transformationValue) (string, []error) {
switch {
case len(r.transformations) == 0:
return arg.Value(), nil
@@ -409,23 +423,53 @@ func (r *Rule) transformArg(arg types.MatchData, argIdx int, cache map[transform
// NOTE: See comment on transformationKey struct to understand this hacky code
argKey := arg.Key()
argKeyPtr := unsafe.StringData(argKey)
- key := transformationKey{
- argKey: argKeyPtr,
- argIndex: argIdx,
- argVariable: arg.Variable(),
- transformationsID: r.transformationsID,
+
+ // Search from longest prefix (full chain) backwards for a cache hit.
+ // Best case: full chain cached → single map lookup, done.
+ // Typical case: shared prefix cached → start computing from there.
+ startIdx := 0
+ value := arg.Value()
+ var errs []error
+
+ for i := len(r.transformationPrefixIDs) - 1; i >= 0; i-- {
+ key := transformationKey{
+ argKey: argKeyPtr,
+ argIndex: argIdx,
+ argVariable: arg.Variable(),
+ transformationsID: r.transformationPrefixIDs[i],
+ }
+ if cached, ok := cache[key]; ok {
+ if i == len(r.transformationPrefixIDs)-1 {
+ // Full chain cached — nothing more to compute
+ return cached.arg, cached.errs
+ }
+ value = cached.arg
+ errs = cached.errs
+ startIdx = i + 1
+ break
+ }
}
- if cached, ok := cache[key]; ok {
- return cached.arg, cached.errs
- } else {
- ars, es := r.executeTransformations(arg.Value())
- errs := es
- cache[key] = &transformationValue{
- arg: ars,
- errs: es,
+
+ // Execute remaining transformations, caching each intermediate step
+ // so later rules sharing a prefix can reuse our work.
+ for i := startIdx; i < len(r.transformations); i++ {
+ v, _, err := r.transformations[i].Function(value)
+ if err != nil {
+ errs = append(errs, err)
+ } else {
+ value = v
}
- return ars, errs
+
+ key := transformationKey{
+ argKey: argKeyPtr,
+ argIndex: argIdx,
+ argVariable: arg.Variable(),
+ transformationsID: r.transformationPrefixIDs[i],
+ }
+ cache[key] = transformationValue{arg: value, errs: errs}
}
+
+ return value, errs
}
}
@@ -467,14 +511,28 @@ func (r *Rule) AddAction(name string, action plugintypes.Action) error {
return nil
}
+// ClearDisruptiveActions removes all disruptive actions from the rule.
+//
+// This is used by directives like SecRuleUpdateActionById to clear existing
+// disruptive actions before applying new ones, matching ModSecurity behavior
+// where updating with a disruptive action replaces the previous one.
+func (r *Rule) ClearDisruptiveActions() {
+ actionType := plugintypes.ActionTypeDisruptive
+ filtered := make([]ruleActionParams, 0, len(r.actions))
+ for _, action := range r.actions {
+ if action.Function.Type() != actionType {
+ filtered = append(filtered, action)
+ }
+ }
+ r.actions = filtered
+}
+
// hasRegex checks the received key to see if it is between forward slashes.
// if it is, it will return true and the content of the regular expression inside the slashes.
// otherwise it will return false and the same key.
+// Delegates to utils.HasRegex which properly handles escaped slashes.
func hasRegex(key string) (bool, string) {
- if len(key) > 2 && key[0] == '/' && key[len(key)-1] == '/' {
- return true, key[1 : len(key)-1]
- }
- return false, key
+ return utils.HasRegex(key)
}
// caseSensitiveVariable returns true if the variable is case sensitive
@@ -515,7 +573,10 @@ func (r *Rule) AddVariable(v variables.RuleVariable, key string, iscount bool) e
}
var re *regexp.Regexp
if isRegex, rx := hasRegex(key); isRegex {
- if vare, err := memoize.Do(rx, func() (any, error) { return regexp.Compile(rx) }); err != nil {
+ if !caseSensitiveVariable(v) {
+ rx = strings.ToLower(rx)
+ }
+ if vare, err := r.memoizeDo(rx, func() (any, error) { return regexp.Compile(rx) }); err != nil {
return err
} else {
re = vare.(*regexp.Regexp)
@@ -559,7 +620,10 @@ func needToSplitConcatenatedVariable(v variables.RuleVariable, ve variables.Rule
func (r *Rule) AddVariableNegation(v variables.RuleVariable, key string) error {
var re *regexp.Regexp
if isRegex, rx := hasRegex(key); isRegex {
- if vare, err := memoize.Do(rx, func() (any, error) { return regexp.Compile(rx) }); err != nil {
+ if !caseSensitiveVariable(v) {
+ rx = strings.ToLower(rx)
+ }
+ if vare, err := r.memoizeDo(rx, func() (any, error) { return regexp.Compile(rx) }); err != nil {
return err
} else {
re = vare.(*regexp.Regexp)
@@ -613,6 +677,7 @@ func (r *Rule) AddTransformation(name string, t plugintypes.Transformation) erro
}
r.transformations = append(r.transformations, ruleTransformationParams{Function: t})
r.transformationsID = transformationID(r.transformationsID, name)
+ r.transformationPrefixIDs = append(r.transformationPrefixIDs, r.transformationsID)
return nil
}
@@ -620,6 +685,8 @@ func (r *Rule) AddTransformation(name string, t plugintypes.Transformation) erro
// it is mostly used by the "none" transformation
func (r *Rule) ClearTransformations() {
r.transformations = []ruleTransformationParams{}
+ r.transformationsID = 0
+ r.transformationPrefixIDs = nil
}
// SetOperator sets the operator of the rule
@@ -674,6 +741,23 @@ func (r *Rule) executeTransformations(value string) (string, []error) {
return value, errs
}
+// SetMemoizer sets the memoizer used for caching compiled regexes in variable selectors.
+func (r *Rule) SetMemoizer(m plugintypes.Memoizer) {
+ r.memoizer = m
+}
+
+// Memoizer returns the memoizer used for caching compiled regexes in variable selectors.
+func (r *Rule) Memoizer() plugintypes.Memoizer {
+ return r.memoizer
+}
+
+func (r *Rule) memoizeDo(key string, fn func() (any, error)) (any, error) {
+ if r.memoizer != nil {
+ return r.memoizer.Do(key, fn)
+ }
+ return fn()
+}
+
// NewRule returns a new initialized rule
// By default, the rule is set to phase 2
func NewRule() *Rule {
diff --git a/internal/corazawaf/rule_test.go b/internal/corazawaf/rule_test.go
index d999bb0c8..b397854f3 100644
--- a/internal/corazawaf/rule_test.go
+++ b/internal/corazawaf/rule_test.go
@@ -101,7 +101,7 @@ func TestNoMatchEvaluateBecauseOfException(t *testing.T) {
_ = r.AddAction("dummyDeny", action)
tx := NewWAF().NewTransaction()
tx.AddGetRequestArgument("test", "0")
- tx.RemoveRuleTargetByID(1, tc.variable, "test")
+ tx.RemoveRuleTargetByID(1, tc.variable, "test", nil)
var matchedValues []types.MatchData
matchdata := r.doEvaluate(debuglog.Noop(), types.PhaseRequestHeaders, tx, &matchedValues, 0, tx.transformationCache)
if len(matchdata) != 0 {
@@ -114,6 +114,52 @@ func TestNoMatchEvaluateBecauseOfException(t *testing.T) {
}
}
+func TestNoMatchEvaluateBecauseOfWholeCollectionException(t *testing.T) {
+ testCases := []struct {
+ name string
+ variable variables.RuleVariable
+ }{
+ {
+ name: "Test ArgsGet whole collection exception",
+ variable: variables.ArgsGet,
+ },
+ {
+ name: "Test Args whole collection exception",
+ variable: variables.Args,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ r := NewRule()
+ r.Msg, _ = macro.NewMacro("Message")
+ r.LogData, _ = macro.NewMacro("Data Message")
+ r.ID_ = 1
+ r.LogID_ = "1"
+ if err := r.AddVariable(tc.variable, "", false); err != nil {
+ t.Error(err)
+ }
+ dummyEqOp := &dummyEqOperator{}
+ r.SetOperator(dummyEqOp, "@eq", "0")
+ action := &dummyDenyAction{}
+ _ = r.AddAction("dummyDeny", action)
+ tx := NewWAF().NewTransaction()
+ tx.AddGetRequestArgument("test", "0")
+ tx.AddGetRequestArgument("other", "0")
+ // Remove with empty key should exclude the entire collection
+ tx.RemoveRuleTargetByID(1, tc.variable, "", nil)
+ var matchedValues []types.MatchData
+ matchdata := r.doEvaluate(debuglog.Noop(), types.PhaseRequestHeaders, tx, &matchedValues, 0, tx.transformationCache)
+ if len(matchdata) != 0 {
+ t.Errorf("Expected 0 matchdata when whole collection is excluded, got %d", len(matchdata))
+ }
+ if tx.interruption != nil {
+ t.Errorf("Expected no interruption because whole collection is excluded")
+ }
+ })
+ }
+}
+
type dummyFlowAction struct{}
func (*dummyFlowAction) Init(_ plugintypes.RuleMetadata, _ string) error {
@@ -297,6 +343,39 @@ func TestVariablesRxAreCaseSensitive(t *testing.T) {
}
}
+func TestVariablesRxAreLowercasedForCaseInsensitiveCollections(t *testing.T) {
+ rule := NewRule()
+ if err := rule.AddVariable(variables.TX, "/MULTIPART_HEADERS_CONTENT_TYPES_.*/", false); err != nil {
+ t.Error(err)
+ }
+ if rule.variables[0].KeyRx == nil {
+ t.Fatal("expected regex to be set")
+ }
+ if rule.variables[0].KeyRx.String() != "multipart_headers_content_types_.*" {
+ t.Errorf("expected lowercased regex for case-insensitive variable TX, got %q", rule.variables[0].KeyRx.String())
+ }
+}
+
+func TestVariableNegationRxLowercasedForCaseInsensitiveCollections(t *testing.T) {
+ rule := NewRule()
+ if err := rule.AddVariable(variables.TX, "", false); err != nil {
+ t.Error(err)
+ }
+ if err := rule.AddVariableNegation(variables.TX, "/SOME_PATTERN.*/"); err != nil {
+ t.Error(err)
+ }
+ if len(rule.variables[0].Exceptions) != 1 {
+ t.Fatal("expected 1 exception")
+ }
+ ex := rule.variables[0].Exceptions[0]
+ if ex.KeyRx == nil {
+ t.Fatal("expected regex exception to be set")
+ }
+ if ex.KeyRx.String() != "some_pattern.*" {
+ t.Errorf("expected lowercased regex for case-insensitive variable TX exception, got %q", ex.KeyRx.String())
+ }
+}
+
func TestInferredPhase(t *testing.T) {
var b inferredPhases
@@ -495,7 +574,7 @@ func TestExecuteTransformationsMultiMatchReturnsMultipleErrors(t *testing.T) {
}
func TestTransformArgSimple(t *testing.T) {
- transformationCache := map[transformationKey]*transformationValue{}
+ transformationCache := map[transformationKey]transformationValue{}
md := &corazarules.MatchData{
Variable_: variables.RequestURI,
Key_: "REQUEST_URI",
@@ -513,10 +592,11 @@ func TestTransformArgSimple(t *testing.T) {
if arg != "/testAB" {
t.Errorf("Expected \"/testAB\", got \"%s\"", arg)
}
- if len(transformationCache) != 1 {
- t.Errorf("Expected 1 transformations in cache, got %d", len(transformationCache))
+ // Prefix caching stores an entry for each transformation step
+ if len(transformationCache) != 2 {
+ t.Errorf("Expected 2 transformations in cache (one per step), got %d", len(transformationCache))
}
- // Repeating the same transformation, expecting still one element in the cache (that means it is a cache hit)
+ // Repeating the same transformation, expecting still two elements in the cache (cache hit)
arg, errs = rule.transformArg(md, 0, transformationCache)
if errs != nil {
t.Fatalf("Unexpected errors executing transformations: %v", errs)
@@ -524,13 +604,13 @@ func TestTransformArgSimple(t *testing.T) {
if arg != "/testAB" {
t.Errorf("Expected \"/testAB\", got \"%s\"", arg)
}
- if len(transformationCache) != 1 {
- t.Errorf("Expected 1 transformations in cache, got %d", len(transformationCache))
+ if len(transformationCache) != 2 {
+ t.Errorf("Expected 2 transformations in cache, got %d", len(transformationCache))
}
}
func TestTransformArgNoCacheForTXVariable(t *testing.T) {
- transformationCache := map[transformationKey]*transformationValue{}
+ transformationCache := map[transformationKey]transformationValue{}
md := &corazarules.MatchData{
Variable_: variables.TX,
Key_: "Custom_TX_Variable",
@@ -550,6 +630,84 @@ func TestTransformArgNoCacheForTXVariable(t *testing.T) {
}
}
+func TestTransformArgPrefixSharing(t *testing.T) {
+ transformationCache := map[transformationKey]transformationValue{}
+ md := &corazarules.MatchData{
+ Variable_: variables.RequestURI,
+ Key_: "REQUEST_URI",
+ Value_: "/test",
+ }
+
+ // Rule 1 has chain: AppendA
+ rule1 := NewRule()
+ _ = rule1.AddTransformation("AppendA", transformationAppendA)
+
+ // Rule 2 has chain: AppendA, AppendB (shares prefix with rule1)
+ rule2 := NewRule()
+ _ = rule2.AddTransformation("AppendA", transformationAppendA)
+ _ = rule2.AddTransformation("AppendB", transformationAppendB)
+
+ // Evaluate rule1 first — caches the AppendA intermediate
+ arg1, errs := rule1.transformArg(md, 0, transformationCache)
+ if errs != nil {
+ t.Fatalf("Unexpected errors: %v", errs)
+ }
+ if arg1 != "/testA" {
+ t.Errorf("Expected \"/testA\", got %q", arg1)
+ }
+ if len(transformationCache) != 1 {
+ t.Errorf("Expected 1 cache entry after rule1, got %d", len(transformationCache))
+ }
+
+ // Evaluate rule2 — should reuse the cached AppendA result and only compute AppendB
+ arg2, errs := rule2.transformArg(md, 0, transformationCache)
+ if errs != nil {
+ t.Fatalf("Unexpected errors: %v", errs)
+ }
+ if arg2 != "/testAB" {
+ t.Errorf("Expected \"/testAB\", got %q", arg2)
+ }
+ // Should now have 2 entries: AppendA (shared) and AppendA+AppendB
+ if len(transformationCache) != 2 {
+ t.Errorf("Expected 2 cache entries after rule2 (prefix reuse), got %d", len(transformationCache))
+ }
+}
+
+func TestClearTransformationsResetsID(t *testing.T) {
+ transformationCache := map[transformationKey]transformationValue{}
+ md := &corazarules.MatchData{
+ Variable_: variables.RequestURI,
+ Key_: "REQUEST_URI",
+ Value_: "test",
+ }
+
+ // Rule A: t:AppendA, t:none, t:AppendB — effective chain is only AppendB
+ ruleA := NewRule()
+ _ = ruleA.AddTransformation("AppendA", transformationAppendA)
+ ruleA.ClearTransformations()
+ _ = ruleA.AddTransformation("AppendB", transformationAppendB)
+
+ // Rule B: t:AppendA, t:AppendB — effective chain is AppendA then AppendB
+ ruleB := NewRule()
+ _ = ruleB.AddTransformation("AppendA", transformationAppendA)
+ _ = ruleB.AddTransformation("AppendB", transformationAppendB)
+
+ argA, _ := ruleA.transformArg(md, 0, transformationCache)
+ argB, _ := ruleB.transformArg(md, 0, transformationCache)
+
+ if argA != "testB" {
+ t.Errorf("Rule A (t:none resets): expected \"testB\", got %q", argA)
+ }
+ if argB != "testAB" {
+ t.Errorf("Rule B (t:AppendA,t:AppendB): expected \"testAB\", got %q", argB)
+ }
+ // They must produce different results — if ClearTransformations didn't reset the ID,
+ // they'd collide in the cache and one would get the wrong result.
+ if argA == argB {
+ t.Error("Rule A and Rule B produced the same result — ClearTransformations likely didn't reset transformationsID")
+ }
+}
+
func TestCaptureNotPropagatedToInnerChainRule(t *testing.T) {
r := NewRule()
r.ID_ = 1
diff --git a/internal/corazawaf/rulegroup.go b/internal/corazawaf/rulegroup.go
index 0682dc5d1..0924aa2f5 100644
--- a/internal/corazawaf/rulegroup.go
+++ b/internal/corazawaf/rulegroup.go
@@ -18,7 +18,8 @@ import (
// It is not concurrent safe, so it's not recommended to use it
// after compilation
type RuleGroup struct {
- rules []Rule
+ rules []Rule
+ observer func(rule types.RuleMetadata)
}
// Add a rule to the collection
@@ -61,9 +62,19 @@ func (rg *RuleGroup) Add(rule *Rule) error {
}
rg.rules = append(rg.rules, *rule)
+
+ if rg.observer != nil {
+ rg.observer(rule)
+ }
+
return nil
}
+// SetObserver assigns the observer function to the group.
+func (rg *RuleGroup) SetObserver(observer func(rule types.RuleMetadata)) {
+ rg.observer = observer
+}
+
// GetRules returns the slice of rules,
func (rg *RuleGroup) GetRules() []Rule {
return rg.rules
@@ -148,7 +159,7 @@ RulesLoop:
r := &rg.rules[i]
// if there is already an interruption and the phase isn't logging
// we break the loop
- if tx.interruption != nil && phase != types.PhaseLogging {
+ if tx.IsInterrupted() && phase != types.PhaseLogging {
break RulesLoop
}
// Rules with phase 0 will always run
@@ -168,12 +179,17 @@ RulesLoop:
}
// we skip the rule in case it's in the excluded list
- for _, trb := range tx.ruleRemoveByID {
- if trb == r.ID_ {
+ if _, skip := tx.ruleRemoveByID[r.ID_]; skip {
+ tx.DebugLogger().Debug().
+ Int("rule_id", r.ID_).
+ Msg("Skipping rule")
+ continue RulesLoop
+ }
+ for _, rng := range tx.ruleRemoveByIDRanges {
+ if r.ID_ >= rng[0] && r.ID_ <= rng[1] {
tx.DebugLogger().Debug().
Int("rule_id", r.ID_).
Msg("Skipping rule")
-
continue RulesLoop
}
}
@@ -247,7 +263,7 @@ RulesLoop:
tx.Skip = 0
tx.stopWatches[phase] = time.Now().UnixNano() - ts
- return tx.interruption != nil
+ return tx.IsInterrupted()
}
// NewRuleGroup creates an empty RuleGroup that
diff --git a/internal/corazawaf/transaction.go b/internal/corazawaf/transaction.go
index 730a850b1..f51052176 100644
--- a/internal/corazawaf/transaction.go
+++ b/internal/corazawaf/transaction.go
@@ -14,6 +14,7 @@ import (
"net/url"
"os"
"path/filepath"
+ "regexp"
"strconv"
"strings"
"time"
@@ -54,6 +55,12 @@ type Transaction struct {
// True if the transaction has been disrupted by any rule
interruption *types.Interruption
+ // detectionOnlyInterruption keeps track of the interruption that would have been performed if the engine was On.
+ // It provides visibility of what would have happened in On mode when the engine is set to "DetectionOnly"
+ // and is used to correctly emit relevant only audit logs in DetectionOnly mode (When the rules would have
+ // caused an interruption if the engine was On).
+ detectionOnlyInterruption *types.Interruption
+
// This is used to store log messages
// Deprecated since Coraza 3.0.5: this variable is not used, logdata values are stored in the matched rules
Logdata string
@@ -62,6 +69,9 @@ type Transaction struct {
SkipAfter string
// AllowType is used by the allow disruptive action to skip evaluating rules after being allowed
+ // Note: Rely on tx.Allow(allowType) for tweaking this field. This field is exposed for backwards
+ // compatibility, but it is not recommended to be used directly.
+ // TODO(4.x): Evaluate to make it private
AllowType corazatypes.AllowType
// Copies from the WAF instance that may be overwritten by the ctl action
@@ -89,7 +99,11 @@ type Transaction struct {
responseBodyBuffer *BodyBuffer
// Rules with this id are going to be skipped while processing a phase
- ruleRemoveByID []int
+ ruleRemoveByID map[int]struct{}
+
+ // ruleRemoveByIDRanges stores ranges of rule IDs to be skipped during a phase.
+ // Ranges avoid expanding all IDs into the ruleRemoveByID map.
+ ruleRemoveByIDRanges [][2]int
// ruleRemoveTargetByID is used by ctl to remove rule targets by id during the
// transaction. All other "target removers" like "ByTag" are an abstraction of "ById"
@@ -122,7 +136,7 @@ type Transaction struct {
variables TransactionVariables
- transformationCache map[transformationKey]*transformationValue
+ transformationCache map[transformationKey]transformationValue
}
func (tx *Transaction) ID() string {
@@ -305,9 +319,36 @@ func (tx *Transaction) Collection(idx variables.RuleVariable) collection.Collect
return collections.Noop
}
+// Interrupt sets the interruption for the transaction.
+// It complies with DetectionOnly definition which requires that disruptive actions are not executed.
+// Depending on the RuleEngine mode:
+// If On: it immediately interrupts the transaction and generates a response.
+// If DetectionOnly: it keeps track of what the interruption would have been if the engine was "On",
+// allowing consistent logging and visibility of potential disruptions without actually interrupting the transaction.
func (tx *Transaction) Interrupt(interruption *types.Interruption) {
- if tx.RuleEngine == types.RuleEngineOn {
+ switch tx.RuleEngine {
+ case types.RuleEngineOn:
tx.interruption = interruption
+ case types.RuleEngineDetectionOnly:
+ // In DetectionOnly mode, the interruption is not actually triggered, which means that
+ // further rules will continue to be evaluated and more actions can be executed.
+ // Let's keep only the first interruption here, matching the one that would have been triggered
+ // if the engine was on.
+ if tx.detectionOnlyInterruption == nil {
+ tx.detectionOnlyInterruption = interruption
+ }
+ }
+}
+
+// Allow sets the allow type for the transaction.
+// It complies with DetectionOnly definition which requires not executing disruptive actions.
+// Depending on the RuleEngine mode:
+// If On: it will cause the transaction to skip rules according to the allow type (phase, request, all).
+// If DetectionOnly: allow is not enforced.
+// TODO(4.x): evaluate to expose it in the interface.
+func (tx *Transaction) Allow(allowType corazatypes.AllowType) {
+ if tx.RuleEngine == types.RuleEngineOn {
+ tx.AllowType = allowType
}
}
@@ -530,19 +571,19 @@ func (tx *Transaction) MatchRule(r *Rule, mds []types.MatchData) {
MatchedDatas_: mds,
Context_: tx.context,
}
- // Populate MatchedRule disruption related fields only if the Engine is capable of performing disruptive actions
- if tx.RuleEngine == types.RuleEngineOn {
- var exists bool
- for _, a := range r.actions {
- // There can be only at most one disruptive action per rule
- if a.Function.Type() == plugintypes.ActionTypeDisruptive {
- mr.DisruptiveAction_, exists = corazarules.DisruptiveActionMap[a.Name]
- if !exists {
- mr.DisruptiveAction_ = corazarules.DisruptiveActionUnknown
- }
- mr.Disruptive_ = true
- break
- }
+
+ // Starting from Coraza 3.4, MatchedRule are including the disruptive action (DisruptiveAction_)
+ // also in DetectionOnly mode. This improves visibility of what would have happened if the engine was on.
+ // The Disruptive_ boolean still allows to identify actual disruptions from "potential" disruptions.
+ // Disruptive_ field is also used during logging to print different messages if the disruption has been real or not
+ // so it is important to set it according to the RuleEngine mode.
+ for _, a := range r.actions {
+ // There can be only one disruptive action per rule
+ if a.Function.Type() == plugintypes.ActionTypeDisruptive {
+ // if not found it will default to DisruptiveActionUnknown.
+ mr.DisruptiveAction_ = corazarules.DisruptiveActionMap[a.Name]
+ mr.Disruptive_ = tx.RuleEngine == types.RuleEngineOn
+ break
}
}
@@ -614,7 +655,7 @@ func (tx *Transaction) GetField(rv ruleVariableParams) []types.MatchData {
isException := false
lkey := strings.ToLower(c.Key())
for _, ex := range rv.Exceptions {
- if (ex.KeyRx != nil && ex.KeyRx.MatchString(lkey)) || strings.ToLower(ex.KeyStr) == lkey {
+ if (ex.KeyRx != nil && ex.KeyRx.MatchString(lkey)) || strings.ToLower(ex.KeyStr) == lkey || (ex.KeyStr == "" && ex.KeyRx == nil) {
isException = true
break
}
@@ -639,12 +680,17 @@ func (tx *Transaction) GetField(rv ruleVariableParams) []types.MatchData {
return matches
}
-// RemoveRuleTargetByID Removes the VARIABLE:KEY from the rule ID
-// It's mostly used by CTL to dynamically remove targets from rules
-func (tx *Transaction) RemoveRuleTargetByID(id int, variable variables.RuleVariable, key string) {
+// RemoveRuleTargetByID removes the VARIABLE:KEY from the rule ID.
+// It is mostly used by CTL to dynamically remove targets from rules.
+// key is an exact string to match against the variable name; keyRx is an
+// optional compiled regular expression that, when non-nil, is used instead of
+// key for pattern-based matching (e.g. removing all ARGS matching
+// /^json\.\d+\.field$/ from a given rule).
+func (tx *Transaction) RemoveRuleTargetByID(id int, variable variables.RuleVariable, key string, keyRx *regexp.Regexp) {
c := ruleVariableParams{
Variable: variable,
KeyStr: key,
+ KeyRx: keyRx,
}
if multiphaseEvaluation && (variable == variables.Args || variable == variables.ArgsNames) {
@@ -669,7 +715,22 @@ func (tx *Transaction) RemoveRuleTargetByID(id int, variable variables.RuleVaria
// RemoveRuleByID Removes a rule from the transaction
// It does not affect the WAF rules
func (tx *Transaction) RemoveRuleByID(id int) {
- tx.ruleRemoveByID = append(tx.ruleRemoveByID, id)
+ if tx.ruleRemoveByID == nil {
+ tx.ruleRemoveByID = map[int]struct{}{}
+ }
+ tx.ruleRemoveByID[id] = struct{}{}
+}
+
+// RemoveRuleByIDRange marks rules in the ID range [start, end] (inclusive) to be
+// skipped during transaction processing. It does not affect the WAF rules.
+func (tx *Transaction) RemoveRuleByIDRange(start, end int) {
+ tx.ruleRemoveByIDRanges = append(tx.ruleRemoveByIDRanges, [2]int{start, end})
+}
+
+// GetRuleRemoveByIDRanges returns the list of rule ID ranges that will be skipped
+// during transaction processing.
+func (tx *Transaction) GetRuleRemoveByIDRanges() [][2]int {
+ return tx.ruleRemoveByIDRanges
}
// ProcessConnection should be called at very beginning of a request process, it is
@@ -820,7 +881,7 @@ func (tx *Transaction) SetServerName(serverName string) {
//
// note: Remember to check for a possible intervention.
func (tx *Transaction) ProcessRequestHeaders() *types.Interruption {
- if tx.RuleEngine == types.RuleEngineOff {
+ if tx.IsRuleEngineOff() {
// Rule engine is disabled
return nil
}
@@ -830,7 +891,7 @@ func (tx *Transaction) ProcessRequestHeaders() *types.Interruption {
return tx.interruption
}
- if tx.interruption != nil {
+ if tx.IsInterrupted() {
tx.debugLogger.Error().Msg("Calling ProcessRequestHeaders but there is a preexisting interruption")
return tx.interruption
}
@@ -852,7 +913,7 @@ func setAndReturnBodyLimitInterruption(tx *Transaction, status int) (*types.Inte
// it returns an interruption if the writing bytes go beyond the request body limit.
// It won't copy the bytes if the body access isn't accessible.
func (tx *Transaction) WriteRequestBody(b []byte) (*types.Interruption, int, error) {
- if tx.RuleEngine == types.RuleEngineOff {
+ if tx.IsRuleEngineOff() {
return nil, 0, nil
}
@@ -917,7 +978,7 @@ type ByteLenger interface {
// it returns an interruption if the writing bytes go beyond the request body limit.
// It won't read the reader if the body access isn't accessible.
func (tx *Transaction) ReadRequestBodyFrom(r io.Reader) (*types.Interruption, int, error) {
- if tx.RuleEngine == types.RuleEngineOff {
+ if tx.IsRuleEngineOff() {
return nil, 0, nil
}
@@ -996,11 +1057,10 @@ func (tx *Transaction) ReadRequestBodyFrom(r io.Reader) (*types.Interruption, in
//
// Remember to check for a possible intervention.
func (tx *Transaction) ProcessRequestBody() (*types.Interruption, error) {
- if tx.RuleEngine == types.RuleEngineOff {
+ if tx.IsRuleEngineOff() {
return nil, nil
}
-
- if tx.interruption != nil {
+ if tx.IsInterrupted() {
tx.debugLogger.Error().Msg("Calling ProcessRequestBody but there is a preexisting interruption")
return tx.interruption, nil
}
@@ -1024,9 +1084,9 @@ func (tx *Transaction) ProcessRequestBody() (*types.Interruption, error) {
tx.WAF.Rules.Eval(types.PhaseRequestBody, tx)
return tx.interruption, nil
}
- mime := ""
+ mimeType := ""
if m := tx.variables.requestHeaders.Get("content-type"); len(m) > 0 {
- mime = m[0]
+ mimeType = m[0]
}
reader, err := tx.requestBodyBuffer.Reader()
@@ -1039,7 +1099,7 @@ func (tx *Transaction) ProcessRequestBody() (*types.Interruption, error) {
// Default variables.ReqbodyProcessor values
// XML and JSON must be forced with ctl:requestBodyProcessor=JSON
if tx.ForceRequestBodyVariable {
- // We force URLENCODED if mime is x-www... or we have an empty RBP and ForceRequestBodyVariable
+ // We force URLENCODED if mimeType is x-www... or we have an empty RBP and ForceRequestBodyVariable
if rbp == "" {
rbp = "URLENCODED"
}
@@ -1062,10 +1122,18 @@ func (tx *Transaction) ProcessRequestBody() (*types.Interruption, error) {
Str("body_processor", rbp).
Msg("Attempting to process request body")
- if err := bodyprocessor.ProcessRequest(reader, tx.Variables(), plugintypes.BodyProcessorOptions{
- Mime: mime,
- StoragePath: tx.WAF.UploadDir,
- }); err != nil {
+ bpOpts := plugintypes.BodyProcessorOptions{
+ Mime: mimeType,
+ StoragePath: tx.WAF.UploadDir,
+ RequestBodyRecursionLimit: tx.WAF.RequestBodyJsonDepthLimit,
+ }
+
+ // If the body processor supports streaming, evaluate rules per record
+ if sp, ok := bodyprocessor.(plugintypes.StreamingBodyProcessor); ok {
+ return tx.processRequestBodyStreaming(sp, reader, bpOpts)
+ }
+
+ if err := bodyprocessor.ProcessRequest(reader, tx.Variables(), bpOpts); err != nil {
tx.debugLogger.Error().Err(err).Msg("Failed to process request body")
tx.generateRequestBodyError(err)
tx.WAF.Rules.Eval(types.PhaseRequestBody, tx)
@@ -1076,6 +1144,298 @@ func (tx *Transaction) ProcessRequestBody() (*types.Interruption, error) {
return tx.interruption, nil
}
+// errStreamInterrupted is a sentinel error used to stop streaming body processing
+// when a rule triggers an interruption.
+var errStreamInterrupted = errors.New("stream processing interrupted")
+
+// processRequestBodyStreaming evaluates Phase 2 rules after each record in a streaming body.
+// ArgsPost is cleared and repopulated for each record, while TX variables persist across records
+// for cross-record correlation (e.g., anomaly scoring).
+func (tx *Transaction) processRequestBodyStreaming(sp plugintypes.StreamingBodyProcessor, reader io.Reader, opts plugintypes.BodyProcessorOptions) (*types.Interruption, error) {
+ err := sp.ProcessRequestRecords(reader, opts, func(recordNum int, fields map[string]string, _ string) error {
+ // Clear ArgsPost and repopulate with this record's fields only
+ tx.variables.argsPost.Reset()
+ for key, value := range fields {
+ tx.variables.argsPost.SetIndex(key, 0, value)
+ }
+
+ // Evaluate Phase 2 rules for this record
+ tx.WAF.Rules.Eval(types.PhaseRequestBody, tx)
+
+ if tx.interruption != nil {
+ tx.debugLogger.Debug().
+ Int("record_num", recordNum).
+ Msg("Stream processing interrupted by rule")
+ return errStreamInterrupted
+ }
+ return nil
+ })
+
+ if err != nil && err != errStreamInterrupted {
+ tx.debugLogger.Error().Err(err).Msg("Failed to process streaming request body")
+ tx.generateRequestBodyError(err)
+ if tx.interruption == nil {
+ tx.WAF.Rules.Eval(types.PhaseRequestBody, tx)
+ }
+ }
+
+ return tx.interruption, nil
+}
+
+// processResponseBodyStreaming evaluates Phase 4 rules after each record in a streaming response body.
+// ArgsResponse is cleared and repopulated for each record, while TX variables persist across records.
+func (tx *Transaction) processResponseBodyStreaming(sp plugintypes.StreamingBodyProcessor, reader io.Reader, opts plugintypes.BodyProcessorOptions) (*types.Interruption, error) {
+ err := sp.ProcessResponseRecords(reader, opts, func(recordNum int, fields map[string]string, _ string) error {
+ // Clear ResponseArgs and repopulate with this record's fields only
+ tx.variables.responseArgs.Reset()
+ for key, value := range fields {
+ tx.variables.responseArgs.SetIndex(key, 0, value)
+ }
+
+ // Evaluate Phase 4 rules for this record
+ tx.WAF.Rules.Eval(types.PhaseResponseBody, tx)
+
+ if tx.interruption != nil {
+ tx.debugLogger.Debug().
+ Int("record_num", recordNum).
+ Msg("Response stream processing interrupted by rule")
+ return errStreamInterrupted
+ }
+ return nil
+ })
+
+ if err != nil && err != errStreamInterrupted {
+ tx.debugLogger.Error().Err(err).Msg("Failed to process streaming response body")
+ tx.generateResponseBodyError(err)
+ if tx.interruption == nil {
+ tx.WAF.Rules.Eval(types.PhaseResponseBody, tx)
+ }
+ }
+
+ return tx.interruption, nil
+}
+
+// ProcessRequestBodyFromStream processes a streaming request body by reading records
+// from input, evaluating Phase 2 rules per record, and writing clean records to output.
+// This enables true streaming without buffering the entire body.
+//
+// Unlike ProcessRequestBody(), this method reads directly from the provided input
+// rather than from the transaction's body buffer. Records that pass rule evaluation
+// are written to output for relay to the backend.
+//
+// This method handles Phase 2 rule evaluation internally.
+func (tx *Transaction) ProcessRequestBodyFromStream(input io.Reader, output io.Writer) (*types.Interruption, error) {
+ if tx.RuleEngine == types.RuleEngineOff {
+ // Pass through: copy input to output without evaluation
+ if _, err := io.Copy(output, input); err != nil {
+ return nil, err
+ }
+ return nil, nil
+ }
+
+ if tx.interruption != nil {
+ return tx.interruption, nil
+ }
+
+ if tx.lastPhase != types.PhaseRequestHeaders {
+ tx.debugLogger.Debug().Msg("Skipping ProcessRequestBodyFromStream: wrong phase")
+ return nil, nil
+ }
+
+ rbp := tx.variables.reqbodyProcessor.Get()
+ if tx.ForceRequestBodyVariable && rbp == "" {
+ rbp = "URLENCODED"
+ tx.variables.reqbodyProcessor.Set(rbp)
+ }
+ rbp = strings.ToLower(rbp)
+ if rbp == "" {
+ // No body processor, pass through
+ if _, err := io.Copy(output, input); err != nil {
+ return nil, err
+ }
+ tx.WAF.Rules.Eval(types.PhaseRequestBody, tx)
+ return tx.interruption, nil
+ }
+
+ bodyprocessor, err := bodyprocessors.GetBodyProcessor(rbp)
+ if err != nil {
+ tx.generateRequestBodyError(errors.New("invalid body processor"))
+ tx.WAF.Rules.Eval(types.PhaseRequestBody, tx)
+ return tx.interruption, nil
+ }
+
+ sp, ok := bodyprocessor.(plugintypes.StreamingBodyProcessor)
+ if !ok {
+ // Non-streaming body processor: buffer input, process normally, then copy to output
+ it, _, err := tx.ReadRequestBodyFrom(input)
+ if err != nil {
+ return nil, err
+ }
+ if it != nil {
+ return it, nil
+ }
+
+ it, err = tx.ProcessRequestBody()
+ if err != nil || it != nil {
+ return it, err
+ }
+
+ // Copy buffered body to output
+ rbr, err := tx.RequestBodyReader()
+ if err != nil {
+ return nil, err
+ }
+ if _, err := io.Copy(output, rbr); err != nil {
+ return nil, err
+ }
+ return nil, nil
+ }
+
+ mime := ""
+ if m := tx.variables.requestHeaders.Get("content-type"); len(m) > 0 {
+ mime = m[0]
+ }
+
+ tx.debugLogger.Debug().
+ Str("body_processor", rbp).
+ Msg("Attempting to process streaming request body with relay")
+
+ streamErr := sp.ProcessRequestRecords(input, plugintypes.BodyProcessorOptions{
+ Mime: mime,
+ StoragePath: tx.WAF.UploadDir,
+ }, func(recordNum int, fields map[string]string, rawRecord string) error {
+ // Clear ArgsPost and repopulate with this record's fields only
+ tx.variables.argsPost.Reset()
+ for key, value := range fields {
+ tx.variables.argsPost.SetIndex(key, 0, value)
+ }
+
+ // Evaluate Phase 2 rules for this record
+ tx.WAF.Rules.Eval(types.PhaseRequestBody, tx)
+
+ if tx.interruption != nil {
+ tx.debugLogger.Debug().
+ Int("record_num", recordNum).
+ Msg("Stream relay interrupted by rule")
+ return errStreamInterrupted
+ }
+
+ // Record passed evaluation, write to output for relay.
+ // rawRecord includes format-specific delimiters (e.g., \n for NDJSON,
+ // RS prefix + \n for RFC 7464), preserving the original stream format.
+ if _, err := io.WriteString(output, rawRecord); err != nil {
+ return fmt.Errorf("failed to write record to output: %w", err)
+ }
+
+ return nil
+ })
+
+ if streamErr != nil && streamErr != errStreamInterrupted {
+ tx.debugLogger.Error().Err(streamErr).Msg("Failed to process streaming request body relay")
+ tx.generateRequestBodyError(streamErr)
+ if tx.interruption == nil {
+ tx.WAF.Rules.Eval(types.PhaseRequestBody, tx)
+ }
+ }
+
+ return tx.interruption, nil
+}
+
+// ProcessResponseBodyFromStream processes a streaming response body by reading records
+// from input, evaluating Phase 4 rules per record, and writing clean records to output.
+// This enables true streaming of response bodies without full buffering.
+func (tx *Transaction) ProcessResponseBodyFromStream(input io.Reader, output io.Writer) (*types.Interruption, error) {
+ if tx.RuleEngine == types.RuleEngineOff {
+ if _, err := io.Copy(output, input); err != nil {
+ return nil, err
+ }
+ return nil, nil
+ }
+
+ if tx.interruption != nil {
+ return tx.interruption, nil
+ }
+
+ if tx.lastPhase != types.PhaseResponseHeaders {
+ tx.debugLogger.Debug().Msg("Skipping ProcessResponseBodyFromStream: wrong phase")
+ return nil, nil
+ }
+
+ bp := tx.variables.resBodyProcessor.Get()
+ if bp == "" {
+ // No body processor, pass through
+ if _, err := io.Copy(output, input); err != nil {
+ return nil, err
+ }
+ tx.WAF.Rules.Eval(types.PhaseResponseBody, tx)
+ return tx.interruption, nil
+ }
+
+ bodyprocessor, err := bodyprocessors.GetBodyProcessor(bp)
+ if err != nil {
+ tx.generateResponseBodyError(errors.New("invalid body processor"))
+ tx.WAF.Rules.Eval(types.PhaseResponseBody, tx)
+ return tx.interruption, nil
+ }
+
+ sp, ok := bodyprocessor.(plugintypes.StreamingBodyProcessor)
+ if !ok {
+ // Non-streaming: buffer, process, copy
+ if _, _, err := tx.ReadResponseBodyFrom(input); err != nil {
+ return nil, err
+ }
+ it, err := tx.ProcessResponseBody()
+ if err != nil || it != nil {
+ return it, err
+ }
+ rbr, err := tx.ResponseBodyReader()
+ if err != nil {
+ return nil, err
+ }
+ if _, err := io.Copy(output, rbr); err != nil {
+ return nil, err
+ }
+ return nil, nil
+ }
+
+ tx.debugLogger.Debug().
+ Str("body_processor", bp).
+ Msg("Attempting to process streaming response body with relay")
+
+ streamErr := sp.ProcessResponseRecords(input, plugintypes.BodyProcessorOptions{}, func(recordNum int, fields map[string]string, rawRecord string) error {
+ tx.variables.responseArgs.Reset()
+ for key, value := range fields {
+ tx.variables.responseArgs.SetIndex(key, 0, value)
+ }
+
+ tx.WAF.Rules.Eval(types.PhaseResponseBody, tx)
+
+ if tx.interruption != nil {
+ tx.debugLogger.Debug().
+ Int("record_num", recordNum).
+ Msg("Response stream relay interrupted by rule")
+ return errStreamInterrupted
+ }
+
+ // rawRecord includes format-specific delimiters, preserving the original stream format.
+ if _, err := io.WriteString(output, rawRecord); err != nil {
+ return fmt.Errorf("failed to write response record to output: %w", err)
+ }
+
+ return nil
+ })
+
+ if streamErr != nil && streamErr != errStreamInterrupted {
+ tx.debugLogger.Error().Err(streamErr).Msg("Failed to process streaming response body relay")
+ tx.generateResponseBodyError(streamErr)
+ if tx.interruption == nil {
+ tx.WAF.Rules.Eval(types.PhaseResponseBody, tx)
+ }
+ }
+
+ return tx.interruption, nil
+}
+
// ProcessResponseHeaders performs the analysis on the response headers.
//
// This method performs the analysis on the response headers. Note, however,
@@ -1083,7 +1443,7 @@ func (tx *Transaction) ProcessRequestBody() (*types.Interruption, error) {
//
// Note: Remember to check for a possible intervention.
func (tx *Transaction) ProcessResponseHeaders(code int, proto string) *types.Interruption {
- if tx.RuleEngine == types.RuleEngineOff {
+ if tx.IsRuleEngineOff() {
return nil
}
@@ -1093,7 +1453,7 @@ func (tx *Transaction) ProcessResponseHeaders(code int, proto string) *types.Int
return tx.interruption
}
- if tx.interruption != nil {
+ if tx.IsInterrupted() {
tx.debugLogger.Error().Msg("Calling ProcessResponseHeaders but there is a preexisting interruption")
return tx.interruption
}
@@ -1125,7 +1485,7 @@ func (tx *Transaction) IsResponseBodyProcessable() bool {
// it returns an interruption if the writing bytes go beyond the response body limit.
// It won't copy the bytes if the body access isn't accessible.
func (tx *Transaction) WriteResponseBody(b []byte) (*types.Interruption, int, error) {
- if tx.RuleEngine == types.RuleEngineOff {
+ if tx.IsRuleEngineOff() {
return nil, 0, nil
}
@@ -1176,7 +1536,7 @@ func (tx *Transaction) WriteResponseBody(b []byte) (*types.Interruption, int, er
// it returns an interruption if the writing bytes go beyond the response body limit.
// It won't read the reader if the body access isn't accessible.
func (tx *Transaction) ReadResponseBodyFrom(r io.Reader) (*types.Interruption, int, error) {
- if tx.RuleEngine == types.RuleEngineOff {
+ if tx.IsRuleEngineOff() {
return nil, 0, nil
}
@@ -1246,11 +1606,11 @@ func (tx *Transaction) ReadResponseBodyFrom(r io.Reader) (*types.Interruption, i
//
// note Remember to check for a possible intervention.
func (tx *Transaction) ProcessResponseBody() (*types.Interruption, error) {
- if tx.RuleEngine == types.RuleEngineOff {
+ if tx.IsRuleEngineOff() {
return nil, nil
}
- if tx.interruption != nil {
+ if tx.IsInterrupted() {
tx.debugLogger.Error().Msg("Calling ProcessResponseBody but there is a preexisting interruption")
return tx.interruption, nil
}
@@ -1295,6 +1655,11 @@ func (tx *Transaction) ProcessResponseBody() (*types.Interruption, error) {
tx.debugLogger.Debug().Str("body_processor", bp).Msg("Attempting to process response body")
+ // If the body processor supports streaming, evaluate rules per record
+ if sp, ok := b.(plugintypes.StreamingBodyProcessor); ok {
+ return tx.processResponseBodyStreaming(sp, reader, plugintypes.BodyProcessorOptions{})
+ }
+
if err := b.ProcessResponse(reader, tx.Variables(), plugintypes.BodyProcessorOptions{}); err != nil {
tx.debugLogger.Error().Err(err).Msg("Failed to process response body")
tx.generateResponseBodyError(err)
@@ -1319,12 +1684,11 @@ func (tx *Transaction) ProcessLogging() {
// If Rule engine is disabled, Log phase rules are not going to be evaluated.
// This avoids trying to rely on variables not set by previous rules that
// have not been executed
- if tx.RuleEngine != types.RuleEngineOff {
+ if !tx.IsRuleEngineOff() {
tx.WAF.Rules.Eval(types.PhaseLogging, tx)
}
if tx.AuditEngine == types.AuditEngineOff {
- // Audit engine disabled
tx.debugLogger.Debug().
Msg("Transaction not marked for audit logging, AuditEngine is disabled")
return
@@ -1342,11 +1706,14 @@ func (tx *Transaction) ProcessLogging() {
status := tx.variables.responseStatus.Get()
if tx.IsInterrupted() {
status = strconv.Itoa(tx.interruption.Status)
+ } else if tx.IsDetectionOnlyInterrupted() {
+ // This allows to check for relevant status even in detection only mode.
+ // Fixes https://github.com/corazawaf/coraza/issues/1333
+ status = strconv.Itoa(tx.detectionOnlyInterruption.Status)
}
if re != nil && !re.Match([]byte(status)) {
// Not relevant status
- tx.debugLogger.Debug().
- Msg("Transaction status not marked for audit logging")
+ tx.debugLogger.Debug().Msg("Transaction status not marked for audit logging")
return
}
}
@@ -1382,14 +1749,35 @@ func (tx *Transaction) IsInterrupted() bool {
return tx.interruption != nil
}
+// TODO(4.x): evaluate to expose it in the interface.
+func (tx *Transaction) IsDetectionOnlyInterrupted() bool {
+ return tx.detectionOnlyInterruption != nil
+}
+
func (tx *Transaction) Interruption() *types.Interruption {
return tx.interruption
}
+func (tx *Transaction) DetectionOnlyInterruption() *types.Interruption {
+ return tx.detectionOnlyInterruption
+}
+
func (tx *Transaction) MatchedRules() []types.MatchedRule {
return tx.matchedRules
}
+// hasLogRelevantMatchedRules returns true if any matched rule has Log enabled.
+// Rules with nolog (e.g. CRS initialization rules) are excluded, matching
+// the same filtering used for audit log part K.
+func (tx *Transaction) hasLogRelevantMatchedRules() bool {
+ for _, mr := range tx.matchedRules {
+ if mrWithLog, ok := mr.(*corazarules.MatchedRule); ok && mrWithLog.Log() {
+ return true
+ }
+ }
+ return false
+}
+
func (tx *Transaction) LastPhase() types.RulePhase {
return tx.lastPhase
}
@@ -1563,12 +1951,20 @@ func (tx *Transaction) Close() error {
var errs []error
if environment.HasAccessToFS {
- // TODO(jcchavezs): filesTmpNames should probably be a new kind of collection that
- // is aware of the files and then attempt to delete them when the collection
- // is resetted or an item is removed.
- for _, file := range tx.variables.filesTmpNames.Get("") {
- if err := os.Remove(file); err != nil {
- errs = append(errs, fmt.Errorf("removing temporary file: %v", err))
+ // UploadKeepFilesRelevantOnly keeps temporary files only when there are
+ // log-relevant matched rules (i.e., rules that would be logged; rules
+ // with actions such as "nolog" are intentionally excluded here).
+ keepFiles := tx.WAF.UploadKeepFiles == types.UploadKeepFilesOn ||
+ (tx.WAF.UploadKeepFiles == types.UploadKeepFilesRelevantOnly && tx.hasLogRelevantMatchedRules())
+
+ if !keepFiles {
+ // TODO(jcchavezs): filesTmpNames should probably be a new kind of collection that
+ // is aware of the files and then attempt to delete them when the collection
+ // is resetted or an item is removed.
+ for _, file := range tx.variables.filesTmpNames.Get("") {
+ if err := os.Remove(file); err != nil {
+ errs = append(errs, fmt.Errorf("removing temporary file: %v", err))
+ }
}
}
}
diff --git a/internal/corazawaf/transaction_test.go b/internal/corazawaf/transaction_test.go
index b5eb2905c..62f1935f0 100644
--- a/internal/corazawaf/transaction_test.go
+++ b/internal/corazawaf/transaction_test.go
@@ -8,6 +8,7 @@ import (
"fmt"
"io"
"net/http"
+ "os"
"regexp"
"runtime/debug"
"strconv"
@@ -22,6 +23,7 @@ import (
"github.com/corazawaf/coraza/v3/internal/collections"
"github.com/corazawaf/coraza/v3/internal/corazarules"
"github.com/corazawaf/coraza/v3/internal/environment"
+ "github.com/corazawaf/coraza/v3/internal/operators"
utils "github.com/corazawaf/coraza/v3/internal/strings"
"github.com/corazawaf/coraza/v3/types"
"github.com/corazawaf/coraza/v3/types/variables"
@@ -752,6 +754,72 @@ func TestAuditLogFields(t *testing.T) {
}
}
+func TestMatchRuleDisruptiveActionPopulated(t *testing.T) {
+ tests := []struct {
+ name string
+ engine types.RuleEngineStatus
+ wantDisruptive bool
+ wantAction corazarules.DisruptiveAction
+ wantInterrupted bool
+ wantDetectionOnlyInterrupted bool
+ }{
+ {
+ name: "engine on",
+ engine: types.RuleEngineOn,
+ wantDisruptive: true,
+ wantAction: corazarules.DisruptiveActionDeny,
+ wantInterrupted: true,
+ wantDetectionOnlyInterrupted: false,
+ },
+ {
+ name: "engine detection only",
+ engine: types.RuleEngineDetectionOnly,
+ wantDisruptive: false,
+ wantAction: corazarules.DisruptiveActionDeny,
+ wantInterrupted: false,
+ wantDetectionOnlyInterrupted: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ waf := NewWAF()
+ tx := waf.NewTransaction()
+ tx.RuleEngine = tt.engine
+
+ rule := NewRule()
+ rule.ID_ = 1
+ if err := rule.AddVariable(variables.ArgsGet, "", false); err != nil {
+ t.Fatal(err)
+ }
+ rule.SetOperator(&dummyEqOperator{}, "@eq", "0")
+ _ = rule.AddAction("deny", &dummyDenyAction{})
+
+ tx.AddGetRequestArgument("test", "0")
+
+ var matchedValues []types.MatchData
+ rule.doEvaluate(debuglog.Noop(), types.PhaseRequestHeaders, tx, &matchedValues, 0, tx.transformationCache)
+
+ if len(tx.matchedRules) != 1 {
+ t.Fatalf("expected 1 matched rule, got %d", len(tx.matchedRules))
+ }
+ mr := tx.matchedRules[0].(*corazarules.MatchedRule)
+ if mr.Disruptive_ != tt.wantDisruptive {
+ t.Errorf("Disruptive_: got %t, want %t", mr.Disruptive_, tt.wantDisruptive)
+ }
+ if mr.DisruptiveAction_ != tt.wantAction {
+ t.Errorf("DisruptiveAction_: got %d, want %d", mr.DisruptiveAction_, tt.wantAction)
+ }
+ if tx.IsInterrupted() != tt.wantInterrupted {
+ t.Errorf("IsInterrupted: got %t, want %t", tx.IsInterrupted(), tt.wantInterrupted)
+ }
+ if tx.IsDetectionOnlyInterrupted() != tt.wantDetectionOnlyInterrupted {
+ t.Errorf("IsDetectionOnlyInterrupted: got %t, want %t", tx.IsDetectionOnlyInterrupted(), tt.wantDetectionOnlyInterrupted)
+ }
+ })
+ }
+}
+
func TestResetCapture(t *testing.T) {
tx := makeTransaction(t)
tx.Capture = true
@@ -1326,6 +1394,61 @@ func BenchmarkTxGetField(b *testing.B) {
b.ReportAllocs()
}
+// makeTransactionWithJSONArgs creates a transaction that includes JSON-array-style
+// GET arguments (json.0.field … json.9.field) on top of the standard args.
+// This simulates the real-world pattern that motivates regex key exceptions.
+func makeTransactionWithJSONArgs(t testing.TB) *Transaction {
+ t.Helper()
+ tx := makeTransaction(t)
+ for i := 0; i < 10; i++ {
+ tx.AddGetRequestArgument(fmt.Sprintf("json.%d.jobdescription", i), "value")
+ }
+ return tx
+}
+
+// BenchmarkTxGetFieldWithShortRegexException measures the overhead of GetField
+// when a short regex exception (e.g. ^id$) is applied against the args collection.
+func BenchmarkTxGetFieldWithShortRegexException(b *testing.B) {
+ tx := makeTransactionWithJSONArgs(b)
+ rvp := ruleVariableParams{
+ Variable: variables.Args,
+ Exceptions: []ruleVariableException{
+ {KeyRx: regexp.MustCompile(`^id$`)},
+ },
+ }
+ b.ReportAllocs()
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ tx.GetField(rvp)
+ }
+ b.StopTimer()
+ if err := tx.Close(); err != nil {
+ b.Fatalf("Failed to close transaction: %s", err.Error())
+ }
+}
+
+// BenchmarkTxGetFieldWithMediumRegexException measures the overhead of GetField
+// when a medium-complexity regex exception (e.g. ^json\.\d+\.jobdescription$) is
+// applied — the typical pattern used in URI-scoped CRS exclusions.
+func BenchmarkTxGetFieldWithMediumRegexException(b *testing.B) {
+ tx := makeTransactionWithJSONArgs(b)
+ rvp := ruleVariableParams{
+ Variable: variables.Args,
+ Exceptions: []ruleVariableException{
+ {KeyRx: regexp.MustCompile(`^json\.\d+\.jobdescription$`)},
+ },
+ }
+ b.ReportAllocs()
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ tx.GetField(rvp)
+ }
+ b.StopTimer()
+ if err := tx.Close(); err != nil {
+ b.Fatalf("Failed to close transaction: %s", err.Error())
+ }
+}
+
func TestTxProcessURI(t *testing.T) {
waf := NewWAF()
tx := waf.NewTransaction()
@@ -1721,7 +1844,7 @@ func TestResponseBodyForceProcessing(t *testing.T) {
if _, err := tx.ProcessRequestBody(); err != nil {
t.Fatal(err)
}
- tx.ProcessResponseHeaders(200, "HTTP/1")
+ tx.ProcessResponseHeaders(200, "HTTP/1.1")
if _, _, err := tx.WriteResponseBody([]byte(`{"key":"value"}`)); err != nil {
t.Fatal(err)
}
@@ -1790,6 +1913,121 @@ func TestCloseFails(t *testing.T) {
}
}
+func TestUploadKeepFiles(t *testing.T) {
+ if !environment.HasAccessToFS {
+ t.Skip("skipping test as it requires access to filesystem")
+ }
+
+ createTmpFile := func(t *testing.T) string {
+ t.Helper()
+ f, err := os.CreateTemp(t.TempDir(), "crztest*")
+ if err != nil {
+ t.Fatal(err)
+ }
+ name := f.Name()
+ if err := f.Close(); err != nil {
+ t.Fatalf("failed to close temp file: %v", err)
+ }
+ return name
+ }
+
+ t.Run("Off deletes files", func(t *testing.T) {
+ waf := NewWAF()
+ waf.UploadKeepFiles = types.UploadKeepFilesOff
+ tx := waf.NewTransaction()
+ tmpFile := createTmpFile(t)
+
+ col := tx.Variables().FilesTmpNames().(*collections.Map)
+ col.Add("", tmpFile)
+
+ if err := tx.Close(); err != nil {
+ t.Fatal(err)
+ }
+
+ if _, err := os.Stat(tmpFile); !os.IsNotExist(err) {
+ t.Fatal("expected temp file to be deleted when UploadKeepFiles is Off")
+ }
+ })
+
+ t.Run("On keeps files", func(t *testing.T) {
+ waf := NewWAF()
+ waf.UploadKeepFiles = types.UploadKeepFilesOn
+ tx := waf.NewTransaction()
+ tmpFile := createTmpFile(t)
+
+ col := tx.Variables().FilesTmpNames().(*collections.Map)
+ col.Add("", tmpFile)
+
+ if err := tx.Close(); err != nil {
+ t.Fatal(err)
+ }
+
+ if _, err := os.Stat(tmpFile); err != nil {
+ t.Fatal("expected temp file to be kept when UploadKeepFiles is On")
+ }
+ })
+
+ t.Run("RelevantOnly keeps files when log rules matched", func(t *testing.T) {
+ waf := NewWAF()
+ waf.UploadKeepFiles = types.UploadKeepFilesRelevantOnly
+ tx := waf.NewTransaction()
+ tmpFile := createTmpFile(t)
+
+ // Simulate a matched rule with Log enabled
+ tx.matchedRules = append(tx.matchedRules, &corazarules.MatchedRule{Log_: true})
+
+ col := tx.Variables().FilesTmpNames().(*collections.Map)
+ col.Add("", tmpFile)
+
+ if err := tx.Close(); err != nil {
+ t.Fatal(err)
+ }
+
+ if _, err := os.Stat(tmpFile); err != nil {
+ t.Fatal("expected temp file to be kept when UploadKeepFiles is RelevantOnly and log rules matched")
+ }
+ })
+
+ t.Run("RelevantOnly deletes files when only nolog rules matched", func(t *testing.T) {
+ waf := NewWAF()
+ waf.UploadKeepFiles = types.UploadKeepFilesRelevantOnly
+ tx := waf.NewTransaction()
+ tmpFile := createTmpFile(t)
+
+ // Simulate a matched rule with Log disabled (e.g. CRS initialization rules)
+ tx.matchedRules = append(tx.matchedRules, &corazarules.MatchedRule{Log_: false})
+
+ col := tx.Variables().FilesTmpNames().(*collections.Map)
+ col.Add("", tmpFile)
+
+ if err := tx.Close(); err != nil {
+ t.Fatal(err)
+ }
+
+ if _, err := os.Stat(tmpFile); !os.IsNotExist(err) {
+ t.Fatal("expected temp file to be deleted when UploadKeepFiles is RelevantOnly and only nolog rules matched")
+ }
+ })
+
+ t.Run("RelevantOnly deletes files when no rules matched", func(t *testing.T) {
+ waf := NewWAF()
+ waf.UploadKeepFiles = types.UploadKeepFilesRelevantOnly
+ tx := waf.NewTransaction()
+ tmpFile := createTmpFile(t)
+
+ col := tx.Variables().FilesTmpNames().(*collections.Map)
+ col.Add("", tmpFile)
+
+ if err := tx.Close(); err != nil {
+ t.Fatal(err)
+ }
+
+ if _, err := os.Stat(tmpFile); !os.IsNotExist(err) {
+ t.Fatal("expected temp file to be deleted when UploadKeepFiles is RelevantOnly and no rules matched")
+ }
+ })
+}
+
func TestRequestFilename(t *testing.T) {
tests := []struct {
name string
@@ -1864,3 +2102,189 @@ func TestRequestFilename(t *testing.T) {
})
}
}
+
+func newTestUnconditionalMatch(t testing.TB) plugintypes.Operator {
+ t.Helper()
+ op, err := operators.Get("unconditionalMatch", plugintypes.OperatorOptions{})
+ if err != nil {
+ t.Fatal(err)
+ }
+ return op
+}
+
+func TestRemoveRuleByID(t *testing.T) {
+ waf := NewWAF()
+ op := newTestUnconditionalMatch(t)
+
+ // Add two rules with different IDs
+ rule1 := NewRule()
+ rule1.ID_ = 100
+ rule1.LogID_ = "100"
+ rule1.Phase_ = types.PhaseRequestHeaders
+ rule1.operator = &ruleOperatorParams{
+ Operator: op,
+ Function: "@unconditionalMatch",
+ }
+ rule1.Log = true
+ if err := waf.Rules.Add(rule1); err != nil {
+ t.Fatal(err)
+ }
+
+ rule2 := NewRule()
+ rule2.ID_ = 200
+ rule2.LogID_ = "200"
+ rule2.Phase_ = types.PhaseRequestHeaders
+ rule2.operator = &ruleOperatorParams{
+ Operator: op,
+ Function: "@unconditionalMatch",
+ }
+ rule2.Log = true
+ if err := waf.Rules.Add(rule2); err != nil {
+ t.Fatal(err)
+ }
+
+ tx := waf.NewTransaction()
+ defer tx.Close()
+
+ // Remove rule 100
+ tx.RemoveRuleByID(100)
+
+ // Verify the map was lazily initialized
+ if tx.ruleRemoveByID == nil {
+ t.Fatal("ruleRemoveByID should not be nil after RemoveRuleByID")
+ }
+
+ // Remove another rule
+ tx.RemoveRuleByID(100) // duplicate removal should be idempotent
+ tx.RemoveRuleByID(200)
+
+ if len(tx.ruleRemoveByID) != 2 {
+ t.Errorf("expected 2 entries in ruleRemoveByID map, got %d", len(tx.ruleRemoveByID))
+ }
+}
+
+func TestRemoveRuleByIDRange(t *testing.T) {
+ waf := NewWAF()
+
+ // Use nil-operator rules (SecAction-style): they always match regardless of variables.
+ for _, id := range []int{100, 150, 200, 300} {
+ r := NewRule()
+ r.ID_ = id
+ r.LogID_ = strconv.Itoa(id)
+ r.Phase_ = types.PhaseRequestHeaders
+ // nil operator means the rule always matches
+ if err := waf.Rules.Add(r); err != nil {
+ t.Fatal(err)
+ }
+ }
+
+ t.Run("range is stored", func(t *testing.T) {
+ tx := waf.NewTransaction()
+ defer tx.Close()
+
+ tx.RemoveRuleByIDRange(100, 200)
+ if len(tx.ruleRemoveByIDRanges) != 1 {
+ t.Fatalf("expected 1 range entry, got %d", len(tx.ruleRemoveByIDRanges))
+ }
+ if tx.ruleRemoveByIDRanges[0][0] != 100 || tx.ruleRemoveByIDRanges[0][1] != 200 {
+ t.Errorf("unexpected range: %v", tx.ruleRemoveByIDRanges[0])
+ }
+ })
+
+ t.Run("rules in range are skipped during eval", func(t *testing.T) {
+ tx := waf.NewTransaction()
+ defer tx.Close()
+
+ // Remove rules with IDs 100-200; rule 300 should still be evaluated.
+ tx.RemoveRuleByIDRange(100, 200)
+ waf.Rules.Eval(types.PhaseRequestHeaders, tx)
+
+ matchedIDs := make(map[int]bool)
+ for _, mr := range tx.MatchedRules() {
+ matchedIDs[mr.Rule().ID()] = true
+ }
+
+ for _, skipped := range []int{100, 150, 200} {
+ if matchedIDs[skipped] {
+ t.Errorf("rule %d should have been skipped but was matched", skipped)
+ }
+ }
+ if !matchedIDs[300] {
+ t.Errorf("rule 300 should have been matched but was not")
+ }
+ })
+
+ t.Run("multiple ranges", func(t *testing.T) {
+ tx := waf.NewTransaction()
+ defer tx.Close()
+
+ tx.RemoveRuleByIDRange(100, 100)
+ tx.RemoveRuleByIDRange(300, 300)
+ if len(tx.ruleRemoveByIDRanges) != 2 {
+ t.Fatalf("expected 2 range entries, got %d", len(tx.ruleRemoveByIDRanges))
+ }
+
+ waf.Rules.Eval(types.PhaseRequestHeaders, tx)
+
+ matchedIDs := make(map[int]bool)
+ for _, mr := range tx.MatchedRules() {
+ matchedIDs[mr.Rule().ID()] = true
+ }
+ if matchedIDs[100] {
+ t.Error("rule 100 should have been skipped")
+ }
+ if matchedIDs[300] {
+ t.Error("rule 300 should have been skipped")
+ }
+ if !matchedIDs[150] {
+ t.Error("rule 150 should have been matched")
+ }
+ if !matchedIDs[200] {
+ t.Error("rule 200 should have been matched")
+ }
+ })
+
+ t.Run("range reset on transaction reuse", func(t *testing.T) {
+ tx := waf.NewTransaction()
+ tx.RemoveRuleByIDRange(100, 200)
+ tx.Close()
+
+ // Get a new transaction (pool reuse may return the same object)
+ tx2 := waf.NewTransaction()
+ defer tx2.Close()
+
+ if len(tx2.ruleRemoveByIDRanges) != 0 {
+ t.Errorf("expected ruleRemoveByIDRanges to be reset, got %d entries", len(tx2.ruleRemoveByIDRanges))
+ }
+ })
+}
+
+func BenchmarkRuleEvalWithRemovedRules(b *testing.B) {
+ waf := NewWAF()
+ op := newTestUnconditionalMatch(b)
+
+ rule := NewRule()
+ rule.ID_ = 1000
+ rule.LogID_ = "1000"
+ rule.Phase_ = types.PhaseRequestHeaders
+ rule.operator = &ruleOperatorParams{
+ Operator: op,
+ Function: "@unconditionalMatch",
+ }
+ if err := waf.Rules.Add(rule); err != nil {
+ b.Fatal(err)
+ }
+
+ tx := waf.NewTransaction()
+ defer tx.Close()
+
+ for i := 1; i <= 100; i++ {
+ tx.RemoveRuleByID(i)
+ }
+
+ b.ReportAllocs()
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ waf.Rules.Eval(types.PhaseRequestHeaders, tx)
+ }
+}
diff --git a/internal/corazawaf/waf.go b/internal/corazawaf/waf.go
index 32ac51fb1..54e28eaac 100644
--- a/internal/corazawaf/waf.go
+++ b/internal/corazawaf/waf.go
@@ -12,17 +12,28 @@ import (
"os"
"regexp"
"strconv"
+ gosync "sync"
+ "sync/atomic"
"time"
"github.com/corazawaf/coraza/v3/debuglog"
"github.com/corazawaf/coraza/v3/experimental/plugins/plugintypes"
"github.com/corazawaf/coraza/v3/internal/auditlog"
"github.com/corazawaf/coraza/v3/internal/environment"
+ "github.com/corazawaf/coraza/v3/internal/memoize"
stringutils "github.com/corazawaf/coraza/v3/internal/strings"
"github.com/corazawaf/coraza/v3/internal/sync"
"github.com/corazawaf/coraza/v3/types"
)
+var wafIDCounter atomic.Uint64
+
+// Default settings
+const (
+ // DefaultRequestBodyJsonDepthLimit is the default limit for the depth of JSON objects in the request body
+ DefaultRequestBodyJsonDepthLimit = 1024
+)
+
// WAF instance is used to store configurations and rules
// Every web application should have a different WAF instance,
// but you can share an instance if you are ok with sharing
@@ -44,6 +55,9 @@ type WAF struct {
// Request body page file limit
RequestBodyLimit int64
+ // Request body JSON recursive depth limit
+ RequestBodyJsonDepthLimit int
+
// Request body in memory limit
requestBodyInMemoryLimit *int64
@@ -80,9 +94,9 @@ type WAF struct {
// Path to store data files (ex. cache)
DataDir string
- // If true, the WAF will store the uploaded files in the UploadDir
- // directory
- UploadKeepFiles bool
+ // UploadKeepFiles controls whether uploaded files are kept after the transaction.
+ // On: always keep, Off: always delete (default), RelevantOnly: keep only if log-relevant rules matched (excluding nolog rules).
+ UploadKeepFiles types.UploadKeepFilesStatus
// UploadFileMode instructs the waf to set the file mode for uploaded files
UploadFileMode fs.FileMode
// UploadFileLimit is the maximum size of the uploaded file to be stored
@@ -135,6 +149,10 @@ type WAF struct {
// Configures the maximum number of ARGS that will be accepted for processing.
ArgumentLimit int
+
+ memoizerID uint64
+ memoizer *memoize.Memoizer
+ closeOnce gosync.Once
}
// Options is used to pass options to the WAF instance
@@ -180,14 +198,15 @@ func (w *WAF) newTransaction(opts Options) *Transaction {
tx.AuditLogFormat = w.AuditLogFormat
tx.ForceRequestBodyVariable = false
tx.RequestBodyAccess = w.RequestBodyAccess
- tx.RequestBodyLimit = int64(w.RequestBodyLimit)
+ tx.RequestBodyLimit = w.RequestBodyLimit
tx.ResponseBodyAccess = w.ResponseBodyAccess
- tx.ResponseBodyLimit = int64(w.ResponseBodyLimit)
+ tx.ResponseBodyLimit = w.ResponseBodyLimit
tx.RuleEngine = w.RuleEngine
tx.HashEngine = false
tx.HashEnforcement = false
tx.lastPhase = 0
tx.ruleRemoveByID = nil
+ tx.ruleRemoveByIDRanges = nil
tx.ruleRemoveTargetByID = map[int][]ruleVariableParams{}
tx.Skip = 0
tx.AllowType = 0
@@ -198,13 +217,13 @@ func (w *WAF) newTransaction(opts Options) *Transaction {
tx.Timestamp = time.Now().UnixNano()
tx.audit = false
- // Always non-nil if buffers / collections were already initialized so we don't do any of them
+ // Always non-nil if buffers / collections were already initialized, so we don't do any of them
// based on the presence of RequestBodyBuffer.
if tx.requestBodyBuffer == nil {
// if no requestBodyInMemoryLimit has been set we default to the requestBodyLimit
requestBodyInMemoryLimit := w.RequestBodyLimit
if w.requestBodyInMemoryLimit != nil {
- requestBodyInMemoryLimit = int64(*w.requestBodyInMemoryLimit)
+ requestBodyInMemoryLimit = *w.requestBodyInMemoryLimit
}
tx.requestBodyBuffer = NewBodyBuffer(types.BodyBufferOptions{
@@ -221,7 +240,7 @@ func (w *WAF) newTransaction(opts Options) *Transaction {
})
tx.variables = *NewTransactionVariables()
- tx.transformationCache = map[transformationKey]*transformationValue{}
+ tx.transformationCache = map[transformationKey]transformationValue{}
}
// set capture variables
@@ -299,6 +318,7 @@ func NewWAF() *WAF {
RequestBodyAccess: false,
RequestBodyLimit: 134217728, // Hard limit equal to _1gib
RequestBodyLimitAction: types.BodyLimitActionReject,
+ RequestBodyJsonDepthLimit: DefaultRequestBodyJsonDepthLimit,
ResponseBodyAccess: false,
ResponseBodyLimit: 524288, // Hard limit equal to _1gib
auditLogWriter: logWriter,
@@ -319,6 +339,10 @@ func NewWAF() *WAF {
waf.TmpDir = os.TempDir()
}
+ id := wafIDCounter.Add(1)
+ waf.memoizerID = id
+ waf.memoizer = memoize.NewMemoizer(id)
+
waf.Logger.Debug().Msg("A new WAF instance was created")
return waf
}
@@ -418,5 +442,34 @@ func (w *WAF) Validate() error {
return errors.New("argument limit should be bigger than 0")
}
+ if w.RequestBodyJsonDepthLimit <= 0 {
+ return errors.New("request body json depth limit should be bigger than 0")
+ }
+
+ if environment.HasAccessToFS {
+ if w.UploadKeepFiles != types.UploadKeepFilesOff && w.UploadDir == "" {
+ return errors.New("SecUploadDir is required when SecUploadKeepFiles is enabled")
+ }
+ } else {
+ if w.UploadKeepFiles != types.UploadKeepFilesOff {
+ return errors.New("SecUploadKeepFiles requires filesystem access, which is not available in this build")
+ }
+ }
+
+ return nil
+}
+
+// Memoizer returns the WAF's memoizer for caching compiled patterns.
+func (w *WAF) Memoizer() *memoize.Memoizer {
+ return w.memoizer
+}
+
+// Close releases cached resources owned by this WAF instance.
+// Cached entries shared with other WAF instances remain until all owners release them.
+// Transactions already in-flight are unaffected as they hold their own references.
+func (w *WAF) Close() error {
+ w.closeOnce.Do(func() {
+ memoize.Release(w.memoizerID)
+ })
return nil
}
diff --git a/internal/corazawaf/waf_test.go b/internal/corazawaf/waf_test.go
index 11a2c75ef..898dd4144 100644
--- a/internal/corazawaf/waf_test.go
+++ b/internal/corazawaf/waf_test.go
@@ -7,6 +7,9 @@ import (
"io"
"os"
"testing"
+
+ "github.com/corazawaf/coraza/v3/internal/environment"
+ "github.com/corazawaf/coraza/v3/types"
)
func TestNewTransaction(t *testing.T) {
@@ -107,6 +110,39 @@ func TestValidate(t *testing.T) {
},
}
+ if environment.HasAccessToFS {
+ testCases["upload keep files on without upload dir"] = struct {
+ customizer func(*WAF)
+ expectErr bool
+ }{
+ expectErr: true,
+ customizer: func(w *WAF) {
+ w.UploadKeepFiles = types.UploadKeepFilesOn
+ w.UploadDir = ""
+ },
+ }
+ testCases["upload keep files relevant only without upload dir"] = struct {
+ customizer func(*WAF)
+ expectErr bool
+ }{
+ expectErr: true,
+ customizer: func(w *WAF) {
+ w.UploadKeepFiles = types.UploadKeepFilesRelevantOnly
+ w.UploadDir = ""
+ },
+ }
+ testCases["upload keep files on with upload dir"] = struct {
+ customizer func(*WAF)
+ expectErr bool
+ }{
+ expectErr: false,
+ customizer: func(w *WAF) {
+ w.UploadKeepFiles = types.UploadKeepFilesOn
+ w.UploadDir = "/tmp"
+ },
+ }
+ }
+
for name, tCase := range testCases {
t.Run(name, func(t *testing.T) {
waf := NewWAF()
diff --git a/internal/memoize/README.md b/internal/memoize/README.md
index 5ebe09aba..d0dfd3eb4 100644
--- a/internal/memoize/README.md
+++ b/internal/memoize/README.md
@@ -1,17 +1,19 @@
# Memoize
-Memoize allows to cache certain expensive function calls and
-cache the result. The main advantage in Coraza is to memoize
-the regexes and aho-corasick dictionaries when the connects
-spins up more than one WAF in the same process and hence same
-regexes are being compiled over and over.
+Memoize caches certain expensive function calls (regex and aho-corasick
+compilation) so the same patterns are not recompiled when multiple WAF
+instances in the same process share rules.
-Currently it is opt-in under the `memoize_builders` build tag
-as under a misuse (e.g. using after build time) it could lead
-to a memory leak as currently the cache is global.
+Memoization is **enabled by default** and uses a **global cache** within
+the process. In long-lived processes that reload WAF configurations,
+use `WAF.Close()` (via `experimental.WAFCloser`) to release cached
+entries when a WAF is destroyed. Alternatively, disable memoization with
+the `coraza.no_memoize` build tag.
-**Important:** Connectors with *live reload* functionality (e.g. Caddy)
-could lead to memory leaks which might or might not be negligible in
-most of the cases as usually config changes in a WAF are about a few
-rules, this is old objects will be still alive in memory until the program
-stops.
+## Build variants
+
+| Build tag | Behavior |
+|-----------------------|--------------------------------------------------|
+| *(none)* | Full memoization with `singleflight` (default) |
+| `tinygo` | Memoization without `singleflight` (TinyGo) |
+| `coraza.no_memoize` | No-op — every call compiles fresh |
diff --git a/internal/memoize/noop.go b/internal/memoize/noop.go
index 1d1e5447f..f5b82e085 100644
--- a/internal/memoize/noop.go
+++ b/internal/memoize/noop.go
@@ -1,10 +1,21 @@
// Copyright 2023 Juan Pablo Tosso and the OWASP Coraza contributors
// SPDX-License-Identifier: Apache-2.0
-//go:build !memoize_builders
+//go:build coraza.no_memoize
package memoize
-func Do(_ string, fn func() (any, error)) (any, error) {
- return fn()
-}
+// Memoizer is a no-op implementation when memoization is disabled.
+type Memoizer struct{}
+
+// NewMemoizer returns a no-op Memoizer.
+func NewMemoizer(_ uint64) *Memoizer { return &Memoizer{} }
+
+// Do always calls fn directly without caching.
+func (m *Memoizer) Do(_ string, fn func() (any, error)) (any, error) { return fn() }
+
+// Release is a no-op when memoization is disabled.
+func Release(_ uint64) {}
+
+// Reset is a no-op when memoization is disabled.
+func Reset() {}
diff --git a/internal/memoize/noop_test.go b/internal/memoize/noop_test.go
new file mode 100644
index 000000000..bf5c0ed47
--- /dev/null
+++ b/internal/memoize/noop_test.go
@@ -0,0 +1,206 @@
+// Copyright 2023 Juan Pablo Tosso and the OWASP Coraza contributors
+// SPDX-License-Identifier: Apache-2.0
+
+//go:build coraza.no_memoize
+
+package memoize
+
+import (
+ "errors"
+ "strconv"
+ "testing"
+)
+
+func TestNoopDo(t *testing.T) {
+ m := NewMemoizer(1)
+ calls := 0
+
+ fn := func() (any, error) {
+ calls++
+ return calls, nil
+ }
+
+ result, err := m.Do("key1", fn)
+ if err != nil {
+ t.Fatalf("unexpected error: %s", err.Error())
+ }
+ if want, have := 1, result.(int); want != have {
+ t.Fatalf("unexpected value, want %d, have %d", want, have)
+ }
+
+ // No-op memoizer should call fn again (no caching).
+ result, err = m.Do("key1", fn)
+ if err != nil {
+ t.Fatalf("unexpected error: %s", err.Error())
+ }
+ if want, have := 2, result.(int); want != have {
+ t.Fatalf("expected no caching, want %d, have %d", want, have)
+ }
+}
+
+func TestNoopDoError(t *testing.T) {
+ m := NewMemoizer(1)
+
+ _, err := m.Do("key1", func() (any, error) {
+ return nil, errors.New("fail")
+ })
+ if err == nil {
+ t.Fatal("expected error")
+ }
+}
+
+// TestNoopDoMultipleKeys verifies that different keys each invoke fn independently.
+func TestNoopDoMultipleKeys(t *testing.T) {
+ m := NewMemoizer(1)
+ calls := 0
+
+ fn := func() (any, error) {
+ calls++
+ return calls, nil
+ }
+
+ for i := 1; i <= 3; i++ {
+ result, err := m.Do("key"+strconv.Itoa(i), fn)
+ if err != nil {
+ t.Fatalf("unexpected error on key %d: %s", i, err.Error())
+ }
+ if want, have := i, result.(int); want != have {
+ t.Fatalf("key%d: want %d, have %d", i, want, have)
+ }
+ }
+ if calls != 3 {
+ t.Fatalf("expected 3 fn calls, got %d", calls)
+ }
+}
+
+// TestNoopErrorNotCached verifies that errors returned by fn are not cached:
+// a subsequent call with the same key will invoke fn again.
+func TestNoopErrorNotCached(t *testing.T) {
+ m := NewMemoizer(1)
+ calls := 0
+
+ fn := func() (any, error) {
+ calls++
+ if calls == 1 {
+ return nil, errors.New("transient error")
+ }
+ return calls, nil
+ }
+
+ // First call should return error.
+ _, err := m.Do("key1", fn)
+ if err == nil {
+ t.Fatal("expected error on first call")
+ }
+
+ // Second call should succeed (no caching of error).
+ result, err := m.Do("key1", fn)
+ if err != nil {
+ t.Fatalf("unexpected error on second call: %s", err.Error())
+ }
+ if want, have := 2, result.(int); want != have {
+ t.Fatalf("want %d, have %d", want, have)
+ }
+
+ // Third call: fn invoked again (still no caching).
+ result, err = m.Do("key1", fn)
+ if err != nil {
+ t.Fatalf("unexpected error on third call: %s", err.Error())
+ }
+ if want, have := 3, result.(int); want != have {
+ t.Fatalf("want %d, have %d", want, have)
+ }
+}
+
+// TestNoopRelease verifies that Release is a no-op and does not panic or affect subsequent Do calls.
+func TestNoopRelease(t *testing.T) {
+ m := NewMemoizer(1)
+ calls := 0
+
+ fn := func() (any, error) {
+ calls++
+ return calls, nil
+ }
+
+ result, err := m.Do("key1", fn)
+ if err != nil {
+ t.Fatalf("unexpected error before Release: %s", err.Error())
+ }
+ if want, have := 1, result.(int); want != have {
+ t.Fatalf("before Release: want %d, have %d", want, have)
+ }
+ if calls != 1 {
+ t.Fatalf("expected 1 call before Release, got %d", calls)
+ }
+
+ // Release should not panic and should be a no-op.
+ Release(1)
+
+ // Subsequent calls should still work normally.
+ result, err = m.Do("key1", fn)
+ if err != nil {
+ t.Fatalf("unexpected error after Release: %s", err.Error())
+ }
+ if want, have := 2, result.(int); want != have {
+ t.Fatalf("expected fn called again after Release, want %d, have %d", want, have)
+ }
+}
+
+// TestNoopReset verifies that Reset is a no-op and does not panic or affect subsequent Do calls.
+func TestNoopReset(t *testing.T) {
+ m := NewMemoizer(1)
+ calls := 0
+
+ fn := func() (any, error) {
+ calls++
+ return calls, nil
+ }
+
+ for _, key := range []string{"k1", "k2"} {
+ result, err := m.Do(key, fn)
+ if err != nil {
+ t.Fatalf("unexpected error for %s before Reset: %s", key, err.Error())
+ }
+ if result == nil {
+ t.Fatalf("unexpected nil result for %s", key)
+ }
+ }
+ if calls != 2 {
+ t.Fatalf("expected 2 calls before Reset, got %d", calls)
+ }
+
+ // Reset should not panic.
+ Reset()
+
+ // Calls after Reset should continue working.
+ result, err := m.Do("k1", fn)
+ if err != nil {
+ t.Fatalf("unexpected error after Reset: %s", err.Error())
+ }
+ if want, have := 3, result.(int); want != have {
+ t.Fatalf("expected fn called again after Reset, want %d, have %d", want, have)
+ }
+}
+
+// TestNoopMultipleMemoizers verifies that multiple no-op memoizers are independent
+// (no shared state between different owner IDs).
+func TestNoopMultipleMemoizers(t *testing.T) {
+ m1 := NewMemoizer(1)
+ m2 := NewMemoizer(2)
+ calls := 0
+
+ fn := func() (any, error) {
+ calls++
+ return calls, nil
+ }
+
+ r1, _ := m1.Do("shared", fn)
+ r2, _ := m2.Do("shared", fn)
+
+ if r1.(int) != 1 || r2.(int) != 2 {
+ t.Fatalf("expected independent calls: m1=%d, m2=%d", r1.(int), r2.(int))
+ }
+ if calls != 2 {
+ t.Fatalf("expected 2 fn calls, got %d", calls)
+ }
+}
diff --git a/internal/memoize/nosync.go b/internal/memoize/nosync.go
index d238244d9..45e1b1dee 100644
--- a/internal/memoize/nosync.go
+++ b/internal/memoize/nosync.go
@@ -1,36 +1,90 @@
// Copyright 2023 Juan Pablo Tosso and the OWASP Coraza contributors
// SPDX-License-Identifier: Apache-2.0
-//go:build tinygo && memoize_builders
+//go:build tinygo && !coraza.no_memoize
package memoize
import "sync"
-var doer = makeDoer(new(sync.Map))
+type entry struct {
+ value any
+ mu sync.Mutex
+ owners map[uint64]struct{}
+ deleted bool
+}
+
+var cache sync.Map // key -> *entry
+
+// Memoizer caches expensive function calls with per-owner tracking.
+// TinyGo variant without singleflight.
+type Memoizer struct {
+ ownerID uint64
+}
-// Do executes and returns the results of the given function, unless there was a cached
-// value of the same key. Only one execution is in-flight for a given key at a time.
-// The boolean return value indicates whether v was previously stored.
-func Do(key string, fn func() (interface{}, error)) (interface{}, error) {
- value, err, _ := doer(key, fn)
- return value, err
+// NewMemoizer creates a Memoizer that tracks cached entries under the given owner ID.
+func NewMemoizer(ownerID uint64) *Memoizer {
+ return &Memoizer{ownerID: ownerID}
}
-// makeDoer returns a function that executes and returns the results of the given function
-func makeDoer(cache *sync.Map) func(string, func() (interface{}, error)) (interface{}, error, bool) {
- return func(key string, fn func() (interface{}, error)) (interface{}, error, bool) {
- // Check cache
- value, found := cache.Load(key)
- if found {
- return value, nil, true
+// Do returns a cached value for key, or calls fn and caches the result.
+func (m *Memoizer) Do(key string, fn func() (any, error)) (any, error) {
+ if v, ok := cache.Load(key); ok {
+ e := v.(*entry)
+ e.mu.Lock()
+ if !e.deleted {
+ e.owners[m.ownerID] = struct{}{}
+ e.mu.Unlock()
+ return e.value, nil
}
+ e.mu.Unlock()
+ }
- data, err := fn()
- if err == nil {
- cache.Store(key, data)
+ data, err := fn()
+ if err == nil {
+ e := &entry{
+ value: data,
+ owners: map[uint64]struct{}{m.ownerID: {}},
}
+ cache.Store(key, e)
+ }
+ return data, err
+}
+
+// Release removes ownerID from all cached entries, deleting entries with no remaining owners.
+//
+// Deletions are deferred until after Range completes because TinyGo's sync.Map
+// holds its internal lock for the entire Range call, so calling Delete inside
+// the callback would deadlock.
+func Release(ownerID uint64) {
+ var toDelete []any
+ cache.Range(func(key, value any) bool {
+ e := value.(*entry)
+ e.mu.Lock()
+ delete(e.owners, ownerID)
+ if len(e.owners) == 0 {
+ e.deleted = true
+ toDelete = append(toDelete, key)
+ }
+ e.mu.Unlock()
+ return true
+ })
+ for _, key := range toDelete {
+ cache.Delete(key)
+ }
+}
- return data, err, false
+// Reset clears the entire cache. Intended for testing.
+//
+// Keys are collected first and deleted after Range returns to avoid deadlocking
+// on TinyGo's mutex-based sync.Map (see Release comment).
+func Reset() {
+ var keys []any
+ cache.Range(func(key, _ any) bool {
+ keys = append(keys, key)
+ return true
+ })
+ for _, key := range keys {
+ cache.Delete(key)
}
}
diff --git a/internal/memoize/nosync_test.go b/internal/memoize/nosync_test.go
index c942a522c..85170ab5d 100644
--- a/internal/memoize/nosync_test.go
+++ b/internal/memoize/nosync_test.go
@@ -1,167 +1,335 @@
// Copyright 2023 Juan Pablo Tosso and the OWASP Coraza contributors
// SPDX-License-Identifier: Apache-2.0
-//go:build tinygo && memoize_builders
-
-// https://github.com/kofalt/go-memoize/blob/master/memoize.go
+//go:build tinygo && !coraza.no_memoize
package memoize
import (
"errors"
- "sync"
+ "fmt"
+ "regexp"
"testing"
)
func TestDo(t *testing.T) {
+ t.Cleanup(Reset)
+
+ m := NewMemoizer(1)
expensiveCalls := 0
- // Function tracks how many times its been called
- expensive := func() (interface{}, error) {
+ expensive := func() (any, error) {
expensiveCalls++
return expensiveCalls, nil
}
- // First call SHOULD NOT be cached
- result, err := Do("key1", expensive)
+ result, err := m.Do("key1", expensive)
if err != nil {
t.Fatalf("unexpected error: %s", err.Error())
}
-
if want, have := 1, result.(int); want != have {
t.Fatalf("unexpected value, want %d, have %d", want, have)
}
- // Second call on same key SHOULD be cached
- result, err = Do("key1", expensive)
+ result, err = m.Do("key1", expensive)
if err != nil {
t.Fatalf("unexpected error: %s", err.Error())
}
-
if want, have := 1, result.(int); want != have {
t.Fatalf("unexpected value, want %d, have %d", want, have)
}
- // First call on a new key SHOULD NOT be cached
- result, err = Do("key2", expensive)
+ result, err = m.Do("key2", expensive)
if err != nil {
t.Fatalf("unexpected error: %s", err.Error())
}
-
if want, have := 2, result.(int); want != have {
t.Fatalf("unexpected value, want %d, have %d", want, have)
}
}
-func TestSuccessCall(t *testing.T) {
- do := makeDoer(new(sync.Map))
+func TestFailedCall(t *testing.T) {
+ t.Cleanup(Reset)
- expensiveCalls := 0
+ m := NewMemoizer(1)
+ calls := 0
- // Function tracks how many times its been called
- expensive := func() (interface{}, error) {
- expensiveCalls++
- return expensiveCalls, nil
+ twoForTheMoney := func() (any, error) {
+ calls++
+ if calls == 1 {
+ return calls, errors.New("Try again")
+ }
+ return calls, nil
}
- // First call SHOULD NOT be cached
- result, err, cached := do("key1", expensive)
- if err != nil {
- t.Fatalf("unexpected error: %s", err.Error())
+ result, err := m.Do("key1", twoForTheMoney)
+ if err == nil {
+ t.Fatalf("expected error")
}
-
if want, have := 1, result.(int); want != have {
t.Fatalf("unexpected value, want %d, have %d", want, have)
}
- if want, have := false, cached; want != have {
- t.Fatalf("unexpected caching, want %t, have %t", want, have)
+ result, err = m.Do("key1", twoForTheMoney)
+ if err != nil {
+ t.Fatalf("unexpected error: %s", err.Error())
+ }
+ if want, have := 2, result.(int); want != have {
+ t.Fatalf("unexpected value, want %d, have %d", want, have)
}
- // Second call on same key SHOULD be cached
- result, err, cached = do("key1", expensive)
+ result, err = m.Do("key1", twoForTheMoney)
if err != nil {
t.Fatalf("unexpected error: %s", err.Error())
}
-
- if want, have := 1, result.(int); want != have {
+ if want, have := 2, result.(int); want != have {
t.Fatalf("unexpected value, want %d, have %d", want, have)
}
+}
+
+func TestRelease(t *testing.T) {
+ t.Cleanup(Reset)
- if want, have := true, cached; want != have {
- t.Fatalf("unexpected caching, want %t, have %t", want, have)
+ m1 := NewMemoizer(1)
+ m2 := NewMemoizer(2)
+
+ calls := 0
+ fn := func() (any, error) {
+ calls++
+ return calls, nil
}
- // First call on a new key SHOULD NOT be cached
- result, err, cached = do("key2", expensive)
- if err != nil {
- t.Fatalf("unexpected error: %s", err.Error())
+ _, _ = m1.Do("shared", fn)
+ _, _ = m2.Do("shared", fn)
+ _, _ = m1.Do("only-waf1", fn)
+
+ Release(1)
+
+ if _, ok := cache.Load("shared"); !ok {
+ t.Fatal("shared entry should still exist after releasing waf-1")
+ }
+ if _, ok := cache.Load("only-waf1"); ok {
+ t.Fatal("only-waf1 entry should be deleted after releasing its sole owner")
}
- if want, have := 2, result.(int); want != have {
- t.Fatalf("unexpected value, want %d, have %d", want, have)
+ Release(2)
+ if _, ok := cache.Load("shared"); ok {
+ t.Fatal("shared entry should be deleted after releasing all owners")
}
+}
- if want, have := false, cached; want != have {
- t.Fatalf("unexpected caching, want %t, have %t", want, have)
+func TestReset(t *testing.T) {
+ m := NewMemoizer(1)
+ _, _ = m.Do("k1", func() (any, error) { return 1, nil })
+ _, _ = m.Do("k2", func() (any, error) { return 2, nil })
+
+ Reset()
+
+ if _, ok := cache.Load("k1"); ok {
+ t.Fatal("cache should be empty after Reset")
+ }
+ if _, ok := cache.Load("k2"); ok {
+ t.Fatal("cache should be empty after Reset")
}
}
-func TestFailedCall(t *testing.T) {
- do := makeDoer(new(sync.Map))
+// cacheLen counts the number of entries in the global cache.
+func cacheLen() int {
+ n := 0
+ cache.Range(func(_, _ any) bool {
+ n++
+ return true
+ })
+ return n
+}
- calls := 0
+// crsLikePatterns generates n CRS-scale regex patterns.
+func crsLikePatterns(n int) []string {
+ patterns := make([]string, n)
+ for i := range patterns {
+ patterns[i] = fmt.Sprintf(`(?i)pattern_%d_[a-z]{2,8}\d+`, i)
+ }
+ return patterns
+}
- // This function will fail IFF it has not been called before.
- twoForTheMoney := func() (interface{}, error) {
- calls++
+func TestMemoizeScaleMultipleOwners(t *testing.T) {
+ if testing.Short() {
+ t.Log("skipping scale test in short mode")
+ return
+ }
+ t.Cleanup(Reset)
- if calls == 1 {
- return calls, errors.New("Try again")
- } else {
- return calls, nil
+ const (
+ numOwners = 10
+ numPatterns = 300
+ )
+
+ patterns := crsLikePatterns(numPatterns)
+ calls := 0
+ fn := func(p string) func() (any, error) {
+ return func() (any, error) {
+ calls++
+ return regexp.Compile(p)
}
}
- // First call should fail, and not be cached
- result, err, cached := do("key1", twoForTheMoney)
- if err == nil {
- t.Fatalf("expected error")
+ for i := uint64(1); i <= numOwners; i++ {
+ m := NewMemoizer(i)
+ for _, p := range patterns {
+ if _, err := m.Do(p, fn(p)); err != nil {
+ t.Fatal(err)
+ }
+ }
}
- if want, have := 1, result.(int); want != have {
- t.Fatalf("unexpected value, want %d, have %d", want, have)
+ if calls != numPatterns {
+ t.Fatalf("expected %d compilations, got %d", numPatterns, calls)
+ }
+ if n := cacheLen(); n != numPatterns {
+ t.Fatalf("expected %d cache entries, got %d", numPatterns, n)
}
- if want, have := false, cached; want != have {
- t.Fatalf("unexpected caching, want %t, have %t", want, have)
+ // Release all owners; cache should be empty.
+ for i := uint64(1); i <= numOwners; i++ {
+ Release(i)
}
+ if n := cacheLen(); n != 0 {
+ t.Fatalf("expected empty cache after releasing all owners, got %d", n)
+ }
+}
- // Second call should succeed, and not be cached
- result, err, cached = do("key1", twoForTheMoney)
- if err != nil {
- t.Fatalf("unexpected error: %s", err.Error())
+func TestCacheGrowthWithoutClose(t *testing.T) {
+ if testing.Short() {
+ t.Log("skipping scale test in short mode")
+ return
}
+ t.Cleanup(Reset)
- if want, have := 2, result.(int); want != have {
- t.Fatalf("unexpected value, want %d, have %d", want, have)
+ const (
+ numOwners = 100
+ numPatterns = 300
+ )
+
+ patterns := crsLikePatterns(numPatterns)
+ fn := func(p string) func() (any, error) {
+ return func() (any, error) {
+ return regexp.Compile(p)
+ }
}
- if want, have := false, cached; want != have {
- t.Fatalf("unexpected caching, want %t, have %t", want, have)
+ for i := uint64(1); i <= numOwners; i++ {
+ m := NewMemoizer(i)
+ for _, p := range patterns {
+ if _, err := m.Do(p, fn(p)); err != nil {
+ t.Fatal(err)
+ }
+ }
}
- // Third call should succeed, and be cached
- result, err, cached = do("key1", twoForTheMoney)
- if err != nil {
- t.Fatalf("unexpected error: %s", err.Error())
+ // Every entry should have all owners.
+ cache.Range(func(_, value any) bool {
+ e := value.(*entry)
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ if len(e.owners) != numOwners {
+ t.Fatalf("expected %d owners per entry, got %d", numOwners, len(e.owners))
+ }
+ return true
+ })
+}
+
+func TestCacheBoundedWithClose(t *testing.T) {
+ if testing.Short() {
+ t.Log("skipping scale test in short mode")
+ return
}
+ t.Cleanup(Reset)
- if want, have := 2, result.(int); want != have {
- t.Fatalf("unexpected value, want %d, have %d", want, have)
+ const (
+ numCycles = 100
+ numPatterns = 300
+ )
+
+ patterns := crsLikePatterns(numPatterns)
+ fn := func(p string) func() (any, error) {
+ return func() (any, error) {
+ return regexp.Compile(p)
+ }
+ }
+
+ for i := uint64(1); i <= numCycles; i++ {
+ m := NewMemoizer(i)
+ for _, p := range patterns {
+ if _, err := m.Do(p, fn(p)); err != nil {
+ t.Fatal(err)
+ }
+ }
+ Release(i)
+ }
+
+ if n := cacheLen(); n != 0 {
+ t.Fatalf("expected empty cache after all releases, got %d", n)
}
+}
+
+func BenchmarkCompileWithoutMemoize(b *testing.B) {
+ patterns := crsLikePatterns(300)
+ for _, numWAFs := range []int{1, 10, 100} {
+ b.Run(fmt.Sprintf("WAFs=%d", numWAFs), func(b *testing.B) {
+ for i := 0; i < b.N; i++ {
+ for w := 0; w < numWAFs; w++ {
+ for _, p := range patterns {
+ if _, err := regexp.Compile(p); err != nil {
+ b.Fatal(err)
+ }
+ }
+ }
+ }
+ })
+ }
+}
+
+func BenchmarkCompileWithMemoize(b *testing.B) {
+ patterns := crsLikePatterns(300)
+ for _, numWAFs := range []int{1, 10, 100} {
+ b.Run(fmt.Sprintf("WAFs=%d", numWAFs), func(b *testing.B) {
+ for i := 0; i < b.N; i++ {
+ Reset()
+ for w := 0; w < numWAFs; w++ {
+ m := NewMemoizer(uint64(w + 1))
+ for _, p := range patterns {
+ if _, err := m.Do(p, func() (any, error) {
+ return regexp.Compile(p)
+ }); err != nil {
+ b.Fatal(err)
+ }
+ }
+ }
+ }
+ })
+ }
+}
- if want, have := true, cached; want != have {
- t.Fatalf("unexpected caching, want %t, have %t", want, have)
+func BenchmarkRelease(b *testing.B) {
+ patterns := crsLikePatterns(300)
+ for _, numOwners := range []int{1, 10, 100} {
+ b.Run(fmt.Sprintf("Owners=%d", numOwners), func(b *testing.B) {
+ for i := 0; i < b.N; i++ {
+ b.StopTimer()
+ Reset()
+ for o := 0; o < numOwners; o++ {
+ m := NewMemoizer(uint64(o + 1))
+ for _, p := range patterns {
+ m.Do(p, func() (any, error) {
+ return regexp.Compile(p)
+ })
+ }
+ }
+ b.StartTimer()
+ for o := 0; o < numOwners; o++ {
+ Release(uint64(o + 1))
+ }
+ }
+ })
}
}
diff --git a/internal/memoize/sync.go b/internal/memoize/sync.go
index a8d5c9610..c8ebd4903 100644
--- a/internal/memoize/sync.go
+++ b/internal/memoize/sync.go
@@ -1,9 +1,7 @@
// Copyright 2023 Juan Pablo Tosso and the OWASP Coraza contributors
// SPDX-License-Identifier: Apache-2.0
-//go:build !tinygo && memoize_builders
-
-// https://github.com/kofalt/go-memoize/blob/master/memoize.go
+//go:build !tinygo && !coraza.no_memoize
package memoize
@@ -13,35 +11,104 @@ import (
"golang.org/x/sync/singleflight"
)
-var doer = makeDoer(new(sync.Map), new(singleflight.Group))
+type entry struct {
+ value any
+ mu sync.Mutex
+ owners map[uint64]struct{}
+ deleted bool
+}
+
+var (
+ cache sync.Map // key -> *entry
+ group singleflight.Group
+)
+
+// Memoizer caches expensive function calls with per-owner tracking.
+type Memoizer struct {
+ ownerID uint64
+}
-// Do executes and returns the results of the given function, unless there was a cached
-// value of the same key. Only one execution is in-flight for a given key at a time.
-// The boolean return value indicates whether v was previously stored.
-func Do(key string, fn func() (interface{}, error)) (interface{}, error) {
- value, err, _ := doer(key, fn)
- return value, err
+// NewMemoizer creates a Memoizer that tracks cached entries under the given owner ID.
+func NewMemoizer(ownerID uint64) *Memoizer {
+ return &Memoizer{ownerID: ownerID}
}
-// makeDoer returns a function that executes and returns the results of the given function
-func makeDoer(cache *sync.Map, group *singleflight.Group) func(string, func() (interface{}, error)) (interface{}, error, bool) {
- return func(key string, fn func() (interface{}, error)) (interface{}, error, bool) {
- // Check cache
- value, found := cache.Load(key)
- if found {
- return value, nil, true
+// addOwner attempts to register the ownerID on the entry.
+// Returns false if the entry has been marked as deleted.
+func (m *Memoizer) addOwner(e *entry) bool {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ if e.deleted {
+ return false
+ }
+ e.owners[m.ownerID] = struct{}{}
+ return true
+}
+
+// Do returns a cached value for key, or calls fn and caches the result.
+// Only one execution is in-flight for a given key at a time.
+func (m *Memoizer) Do(key string, fn func() (any, error)) (any, error) {
+ // Fast path: check cache
+ if v, ok := cache.Load(key); ok {
+ e := v.(*entry)
+ if m.addOwner(e) {
+ return e.value, nil
}
+ // Entry was deleted concurrently; fall through to slow path.
+ }
- // Combine memoized function with a cache store
- value, err, _ := group.Do(key, func() (interface{}, error) {
- data, innerErr := fn()
- if innerErr == nil {
- cache.Store(key, data)
+ // Slow path: singleflight ensures only one compilation per key
+ val, err, _ := group.Do(key, func() (any, error) {
+ // Double-check after acquiring singleflight
+ if v, ok := cache.Load(key); ok {
+ e := v.(*entry)
+ if m.addOwner(e) {
+ return e.value, nil
}
+ }
- return data, innerErr
- })
+ data, innerErr := fn()
+ if innerErr == nil {
+ e := &entry{
+ value: data,
+ owners: map[uint64]struct{}{m.ownerID: {}},
+ }
+ cache.Store(key, e)
+ }
+ return data, innerErr
+ })
- return value, err, false
+ // Ensure this caller is registered as an owner even if its execution
+ // was deduplicated by singleflight.
+ if err == nil {
+ if v, ok := cache.Load(key); ok {
+ e := v.(*entry)
+ m.addOwner(e)
+ }
}
+
+ return val, err
+}
+
+// Release removes ownerID from all cached entries, deleting entries with no remaining owners.
+func Release(ownerID uint64) {
+ cache.Range(func(key, value any) bool {
+ e := value.(*entry)
+ e.mu.Lock()
+ delete(e.owners, ownerID)
+ if len(e.owners) == 0 {
+ e.deleted = true
+ cache.Delete(key)
+ }
+ e.mu.Unlock()
+ return true
+ })
+}
+
+// Reset clears the entire cache. Intended for testing.
+func Reset() {
+ cache.Range(func(key, _ any) bool {
+ cache.Delete(key)
+ return true
+ })
}
diff --git a/internal/memoize/sync_test.go b/internal/memoize/sync_test.go
index d995d1d41..b1a27bc33 100644
--- a/internal/memoize/sync_test.go
+++ b/internal/memoize/sync_test.go
@@ -1,169 +1,339 @@
// Copyright 2023 Juan Pablo Tosso and the OWASP Coraza contributors
// SPDX-License-Identifier: Apache-2.0
-//go:build !tinygo && memoize_builders
-
-// https://github.com/kofalt/go-memoize/blob/master/memoize.go
+//go:build !tinygo && !coraza.no_memoize
package memoize
import (
"errors"
- "sync"
+ "fmt"
+ "os"
+ "regexp"
"testing"
-
- "golang.org/x/sync/singleflight"
)
func TestDo(t *testing.T) {
+ t.Cleanup(Reset)
+
+ m := NewMemoizer(1)
expensiveCalls := 0
- // Function tracks how many times its been called
- expensive := func() (interface{}, error) {
+ expensive := func() (any, error) {
expensiveCalls++
return expensiveCalls, nil
}
- // First call SHOULD NOT be cached
- result, err := Do("key1", expensive)
+ result, err := m.Do("key1", expensive)
if err != nil {
t.Fatalf("unexpected error: %s", err.Error())
}
-
if want, have := 1, result.(int); want != have {
t.Fatalf("unexpected value, want %d, have %d", want, have)
}
- // Second call on same key SHOULD be cached
- result, err = Do("key1", expensive)
+ result, err = m.Do("key1", expensive)
if err != nil {
t.Fatalf("unexpected error: %s", err.Error())
}
-
if want, have := 1, result.(int); want != have {
t.Fatalf("unexpected value, want %d, have %d", want, have)
}
- // First call on a new key SHOULD NOT be cached
- result, err = Do("key2", expensive)
+ result, err = m.Do("key2", expensive)
if err != nil {
t.Fatalf("unexpected error: %s", err.Error())
}
-
if want, have := 2, result.(int); want != have {
t.Fatalf("unexpected value, want %d, have %d", want, have)
}
}
-func TestSuccessCall(t *testing.T) {
- do := makeDoer(new(sync.Map), &singleflight.Group{})
+func TestFailedCall(t *testing.T) {
+ t.Cleanup(Reset)
- expensiveCalls := 0
+ m := NewMemoizer(1)
+ calls := 0
- // Function tracks how many times its been called
- expensive := func() (interface{}, error) {
- expensiveCalls++
- return expensiveCalls, nil
+ twoForTheMoney := func() (any, error) {
+ calls++
+ if calls == 1 {
+ return calls, errors.New("Try again")
+ }
+ return calls, nil
}
- // First call SHOULD NOT be cached
- result, err, cached := do("key1", expensive)
- if err != nil {
- t.Fatalf("unexpected error: %s", err.Error())
+ result, err := m.Do("key1", twoForTheMoney)
+ if err == nil {
+ t.Fatalf("expected error")
}
-
if want, have := 1, result.(int); want != have {
t.Fatalf("unexpected value, want %d, have %d", want, have)
}
- if want, have := false, cached; want != have {
- t.Fatalf("unexpected caching, want %t, have %t", want, have)
+ result, err = m.Do("key1", twoForTheMoney)
+ if err != nil {
+ t.Fatalf("unexpected error: %s", err.Error())
+ }
+ if want, have := 2, result.(int); want != have {
+ t.Fatalf("unexpected value, want %d, have %d", want, have)
}
- // Second call on same key SHOULD be cached
- result, err, cached = do("key1", expensive)
+ result, err = m.Do("key1", twoForTheMoney)
if err != nil {
t.Fatalf("unexpected error: %s", err.Error())
}
-
- if want, have := 1, result.(int); want != have {
+ if want, have := 2, result.(int); want != have {
t.Fatalf("unexpected value, want %d, have %d", want, have)
}
+}
- if want, have := true, cached; want != have {
- t.Fatalf("unexpected caching, want %t, have %t", want, have)
+func TestRelease(t *testing.T) {
+ t.Cleanup(Reset)
+
+ m1 := NewMemoizer(1)
+ m2 := NewMemoizer(2)
+
+ calls := 0
+ fn := func() (any, error) {
+ calls++
+ return calls, nil
}
- // First call on a new key SHOULD NOT be cached
- result, err, cached = do("key2", expensive)
- if err != nil {
- t.Fatalf("unexpected error: %s", err.Error())
+ _, _ = m1.Do("shared", fn)
+ _, _ = m2.Do("shared", fn)
+ _, _ = m1.Do("only-waf1", fn)
+
+ Release(1)
+
+ if _, ok := cache.Load("shared"); !ok {
+ t.Fatal("shared entry should still exist after releasing waf-1")
+ }
+ if _, ok := cache.Load("only-waf1"); ok {
+ t.Fatal("only-waf1 entry should be deleted after releasing its sole owner")
}
- if want, have := 2, result.(int); want != have {
- t.Fatalf("unexpected value, want %d, have %d", want, have)
+ Release(2)
+ if _, ok := cache.Load("shared"); ok {
+ t.Fatal("shared entry should be deleted after releasing all owners")
}
+}
- if want, have := false, cached; want != have {
- t.Fatalf("unexpected caching, want %t, have %t", want, have)
+func TestReset(t *testing.T) {
+ m := NewMemoizer(1)
+ _, _ = m.Do("k1", func() (any, error) { return 1, nil })
+ _, _ = m.Do("k2", func() (any, error) { return 2, nil })
+
+ Reset()
+
+ if _, ok := cache.Load("k1"); ok {
+ t.Fatal("cache should be empty after Reset")
+ }
+ if _, ok := cache.Load("k2"); ok {
+ t.Fatal("cache should be empty after Reset")
}
}
-func TestFailedCall(t *testing.T) {
- do := makeDoer(new(sync.Map), &singleflight.Group{})
+// cacheLen counts the number of entries in the global cache.
+func cacheLen() int {
+ n := 0
+ cache.Range(func(_, _ any) bool {
+ n++
+ return true
+ })
+ return n
+}
- calls := 0
+// crsLikePatterns generates n CRS-scale regex patterns.
+func crsLikePatterns(n int) []string {
+ patterns := make([]string, n)
+ for i := range patterns {
+ patterns[i] = fmt.Sprintf(`(?i)pattern_%d_[a-z]{2,8}\d+`, i)
+ }
+ return patterns
+}
- // This function will fail IFF it has not been called before.
- twoForTheMoney := func() (interface{}, error) {
- calls++
+func TestMemoizeScaleMultipleOwners(t *testing.T) {
+ if testing.Short() {
+ t.Skip("skipping scale test in short mode")
+ }
+ t.Cleanup(Reset)
- if calls == 1 {
- return calls, errors.New("Try again")
- } else {
- return calls, nil
+ const (
+ numOwners = 10
+ numPatterns = 300
+ )
+
+ patterns := crsLikePatterns(numPatterns)
+ calls := 0
+ fn := func(p string) func() (any, error) {
+ return func() (any, error) {
+ calls++
+ return regexp.Compile(p)
}
}
- // First call should fail, and not be cached
- result, err, cached := do("key1", twoForTheMoney)
- if err == nil {
- t.Fatalf("expected error")
+ for i := uint64(1); i <= numOwners; i++ {
+ m := NewMemoizer(i)
+ for _, p := range patterns {
+ if _, err := m.Do(p, fn(p)); err != nil {
+ t.Fatal(err)
+ }
+ }
}
- if want, have := 1, result.(int); want != have {
- t.Fatalf("unexpected value, want %d, have %d", want, have)
+ if calls != numPatterns {
+ t.Fatalf("expected %d compilations, got %d", numPatterns, calls)
+ }
+ if n := cacheLen(); n != numPatterns {
+ t.Fatalf("expected %d cache entries, got %d", numPatterns, n)
}
- if want, have := false, cached; want != have {
- t.Fatalf("unexpected caching, want %t, have %t", want, have)
+ // Release all owners; cache should be empty.
+ for i := uint64(1); i <= numOwners; i++ {
+ Release(i)
}
+ if n := cacheLen(); n != 0 {
+ t.Fatalf("expected empty cache after releasing all owners, got %d", n)
+ }
+}
- // Second call should succeed, and not be cached
- result, err, cached = do("key1", twoForTheMoney)
- if err != nil {
- t.Fatalf("unexpected error: %s", err.Error())
+func TestCacheGrowthWithoutClose(t *testing.T) {
+ if testing.Short() {
+ t.Skip("skipping scale test in short mode")
}
+ t.Cleanup(Reset)
- if want, have := 2, result.(int); want != have {
- t.Fatalf("unexpected value, want %d, have %d", want, have)
+ const (
+ numOwners = 100
+ numPatterns = 300
+ )
+
+ patterns := crsLikePatterns(numPatterns)
+ fn := func(p string) func() (any, error) {
+ return func() (any, error) {
+ return regexp.Compile(p)
+ }
}
- if want, have := false, cached; want != have {
- t.Fatalf("unexpected caching, want %t, have %t", want, have)
+ for i := uint64(1); i <= numOwners; i++ {
+ m := NewMemoizer(i)
+ for _, p := range patterns {
+ if _, err := m.Do(p, fn(p)); err != nil {
+ t.Fatal(err)
+ }
+ }
}
- // Third call should succeed, and be cached
- result, err, cached = do("key1", twoForTheMoney)
- if err != nil {
- t.Fatalf("unexpected error: %s", err.Error())
+ // Every entry should have all owners.
+ cache.Range(func(_, value any) bool {
+ e := value.(*entry)
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ if len(e.owners) != numOwners {
+ t.Fatalf("expected %d owners per entry, got %d", numOwners, len(e.owners))
+ }
+ return true
+ })
+}
+
+func TestCacheBoundedWithClose(t *testing.T) {
+ if testing.Short() {
+ t.Skip("skipping scale test in short mode")
}
+ if os.Getenv("CORAZA_MEMOIZE_SCALE") != "1" {
+ t.Skip("skipping scale test; set CORAZA_MEMOIZE_SCALE=1 to enable")
+ }
+ t.Cleanup(Reset)
- if want, have := 2, result.(int); want != have {
- t.Fatalf("unexpected value, want %d, have %d", want, have)
+ const (
+ numCycles = 100
+ numPatterns = 300
+ )
+
+ patterns := crsLikePatterns(numPatterns)
+ fn := func(p string) func() (any, error) {
+ return func() (any, error) {
+ return regexp.Compile(p)
+ }
}
- if want, have := true, cached; want != have {
- t.Fatalf("unexpected caching, want %t, have %t", want, have)
+ for i := uint64(1); i <= numCycles; i++ {
+ m := NewMemoizer(i)
+ for _, p := range patterns {
+ if _, err := m.Do(p, fn(p)); err != nil {
+ t.Fatal(err)
+ }
+ }
+ Release(i)
+ }
+
+ // After releasing each owner immediately, only the last cycle's
+ // entries remain (the last owner hasn't been released yet — but we
+ // DID release it in the loop). Cache should be empty.
+ if n := cacheLen(); n != 0 {
+ t.Fatalf("expected empty cache after all releases, got %d", n)
+ }
+}
+
+func BenchmarkCompileWithoutMemoize(b *testing.B) {
+ patterns := crsLikePatterns(300)
+ for _, numWAFs := range []int{1, 10, 100} {
+ b.Run(fmt.Sprintf("WAFs=%d", numWAFs), func(b *testing.B) {
+ for i := 0; i < b.N; i++ {
+ for w := 0; w < numWAFs; w++ {
+ for _, p := range patterns {
+ if _, err := regexp.Compile(p); err != nil {
+ b.Fatal(err)
+ }
+ }
+ }
+ }
+ })
+ }
+}
+
+func BenchmarkCompileWithMemoize(b *testing.B) {
+ patterns := crsLikePatterns(300)
+ for _, numWAFs := range []int{1, 10, 100} {
+ b.Run(fmt.Sprintf("WAFs=%d", numWAFs), func(b *testing.B) {
+ for i := 0; i < b.N; i++ {
+ Reset()
+ for w := 0; w < numWAFs; w++ {
+ m := NewMemoizer(uint64(w + 1))
+ for _, p := range patterns {
+ if _, err := m.Do(p, func() (any, error) {
+ return regexp.Compile(p)
+ }); err != nil {
+ b.Fatal(err)
+ }
+ }
+ }
+ }
+ })
+ }
+}
+
+func BenchmarkRelease(b *testing.B) {
+ patterns := crsLikePatterns(300)
+ for _, numOwners := range []int{1, 10, 100} {
+ b.Run(fmt.Sprintf("Owners=%d", numOwners), func(b *testing.B) {
+ for i := 0; i < b.N; i++ {
+ b.StopTimer()
+ Reset()
+ for o := 0; o < numOwners; o++ {
+ m := NewMemoizer(uint64(o + 1))
+ for _, p := range patterns {
+ _, _ = m.Do(p, func() (any, error) {
+ return regexp.Compile(p)
+ })
+ }
+ }
+ b.StartTimer()
+ for o := 0; o < numOwners; o++ {
+ Release(uint64(o + 1))
+ }
+ }
+ })
}
}
diff --git a/internal/operators/operators.go b/internal/operators/operators.go
index 3ddd87989..28658fd45 100644
--- a/internal/operators/operators.go
+++ b/internal/operators/operators.go
@@ -29,6 +29,13 @@ import (
var operators = map[string]plugintypes.OperatorFactory{}
+func memoizeDo(m plugintypes.Memoizer, key string, fn func() (any, error)) (any, error) {
+ if m != nil {
+ return m.Do(key, fn)
+ }
+ return fn()
+}
+
// Get returns an operator by name
func Get(name string, options plugintypes.OperatorOptions) (plugintypes.Operator, error) {
if op, ok := operators[name]; ok {
diff --git a/internal/operators/pm.go b/internal/operators/pm.go
index b66da3be5..35e5b3244 100644
--- a/internal/operators/pm.go
+++ b/internal/operators/pm.go
@@ -11,7 +11,6 @@ import (
ahocorasick "github.com/petar-dambovaliev/aho-corasick"
"github.com/corazawaf/coraza/v3/experimental/plugins/plugintypes"
- "github.com/corazawaf/coraza/v3/internal/memoize"
)
// Description:
@@ -51,7 +50,7 @@ func newPM(options plugintypes.OperatorOptions) (plugintypes.Operator, error) {
DFA: true,
})
- m, _ := memoize.Do(data, func() (any, error) { return builder.Build(dict), nil })
+ m, _ := memoizeDo(options.Memoizer, data, func() (any, error) { return builder.Build(dict), nil })
// TODO this operator is supposed to support snort data syntax: "@pm A|42|C|44|F"
return &pm{matcher: m.(ahocorasick.AhoCorasick)}, nil
}
diff --git a/internal/operators/pm_from_dataset.go b/internal/operators/pm_from_dataset.go
index e5118a9f0..1c3def538 100644
--- a/internal/operators/pm_from_dataset.go
+++ b/internal/operators/pm_from_dataset.go
@@ -11,7 +11,6 @@ import (
ahocorasick "github.com/petar-dambovaliev/aho-corasick"
"github.com/corazawaf/coraza/v3/experimental/plugins/plugintypes"
- "github.com/corazawaf/coraza/v3/internal/memoize"
)
// Description:
@@ -46,7 +45,7 @@ func newPMFromDataset(options plugintypes.OperatorOptions) (plugintypes.Operator
DFA: true,
})
- m, _ := memoize.Do(data, func() (any, error) { return builder.Build(dataset), nil })
+ m, _ := memoizeDo(options.Memoizer, data, func() (any, error) { return builder.Build(dataset), nil })
return &pm{matcher: m.(ahocorasick.AhoCorasick)}, nil
}
diff --git a/internal/operators/pm_from_file.go b/internal/operators/pm_from_file.go
index 4d4ae4089..17a32030c 100644
--- a/internal/operators/pm_from_file.go
+++ b/internal/operators/pm_from_file.go
@@ -13,7 +13,6 @@ import (
ahocorasick "github.com/petar-dambovaliev/aho-corasick"
"github.com/corazawaf/coraza/v3/experimental/plugins/plugintypes"
- "github.com/corazawaf/coraza/v3/internal/memoize"
)
// Description:
@@ -65,7 +64,7 @@ func newPMFromFile(options plugintypes.OperatorOptions) (plugintypes.Operator, e
DFA: false,
})
- m, _ := memoize.Do(strings.Join(options.Path, ",")+filepath, func() (any, error) { return builder.Build(lines), nil })
+ m, _ := memoizeDo(options.Memoizer, strings.Join(options.Path, ",")+filepath, func() (any, error) { return builder.Build(lines), nil })
return &pm{matcher: m.(ahocorasick.AhoCorasick)}, nil
}
diff --git a/internal/operators/restpath.go b/internal/operators/restpath.go
index 2af865e12..d94298d36 100644
--- a/internal/operators/restpath.go
+++ b/internal/operators/restpath.go
@@ -11,7 +11,6 @@ import (
"strings"
"github.com/corazawaf/coraza/v3/experimental/plugins/plugintypes"
- "github.com/corazawaf/coraza/v3/internal/memoize"
)
var rePathTokenRe = regexp.MustCompile(`\{([^\}]+)\}`)
@@ -48,7 +47,7 @@ func newRESTPath(options plugintypes.OperatorOptions) (plugintypes.Operator, err
data = strings.Replace(data, token[0], fmt.Sprintf("(?P<%s>[^?/]+)", token[1]), 1)
}
- re, err := memoize.Do(data, func() (any, error) { return regexp.Compile(data) })
+ re, err := memoizeDo(options.Memoizer, data, func() (any, error) { return regexp.Compile(data) })
if err != nil {
return nil, err
}
diff --git a/internal/operators/rx.go b/internal/operators/rx.go
index 207b26e51..f85c1f812 100644
--- a/internal/operators/rx.go
+++ b/internal/operators/rx.go
@@ -14,7 +14,6 @@ import (
"rsc.io/binaryregexp"
"github.com/corazawaf/coraza/v3/experimental/plugins/plugintypes"
- "github.com/corazawaf/coraza/v3/internal/memoize"
)
// Description:
@@ -68,7 +67,7 @@ func newRX(options plugintypes.OperatorOptions) (plugintypes.Operator, error) {
return newBinaryRX(options)
}
- re, err := memoize.Do(data, func() (any, error) { return regexp.Compile(data) })
+ re, err := memoizeDo(options.Memoizer, data, func() (any, error) { return regexp.Compile(data) })
if err != nil {
return nil, err
}
@@ -77,15 +76,27 @@ func newRX(options plugintypes.OperatorOptions) (plugintypes.Operator, error) {
func (o *rx) Evaluate(tx plugintypes.TransactionState, value string) bool {
if tx.Capturing() {
- match := o.re.FindStringSubmatch(value)
- if len(match) == 0 {
+ // FindStringSubmatchIndex returns a slice of index pairs [start0, end0, start1, end1, ...]
+ // instead of allocating new strings for each capture group. We then slice the original
+ // input value[start:end] to get zero-allocation substrings.
+ match := o.re.FindStringSubmatchIndex(value)
+ if match == nil {
return false
}
- for i, c := range match {
+ // match has 2 entries per group: match[2*i] is the start index,
+ // match[2*i+1] is the end index for capture group i. Group 0 is
+ // the full match, groups 1..N are the parenthesized sub-expressions.
+ for i := 0; i < len(match)/2; i++ {
if i == 9 {
return true
}
- tx.CaptureField(i, c)
+ // A negative start index means the group did not participate in the match
+ // (e.g. an optional group like (foo)? when foo is absent).
+ if match[2*i] >= 0 {
+ tx.CaptureField(i, value[match[2*i]:match[2*i+1]])
+ } else {
+ tx.CaptureField(i, "")
+ }
}
return true
} else {
@@ -104,7 +115,7 @@ var _ plugintypes.Operator = (*binaryRX)(nil)
func newBinaryRX(options plugintypes.OperatorOptions) (plugintypes.Operator, error) {
data := options.Arguments
- re, err := memoize.Do(data, func() (any, error) { return binaryregexp.Compile(data) })
+ re, err := memoizeDo(options.Memoizer, data, func() (any, error) { return binaryregexp.Compile(data) })
if err != nil {
return nil, err
}
diff --git a/internal/operators/rx_test.go b/internal/operators/rx_test.go
index ccbb0649d..c4a9d480d 100644
--- a/internal/operators/rx_test.go
+++ b/internal/operators/rx_test.go
@@ -108,6 +108,32 @@ func TestRx(t *testing.T) {
}
}
+func BenchmarkRxCapture(b *testing.B) {
+ pattern := `(?sm)^/api/v(\d+)/users/(\w+)/posts/(\d+)`
+ input := "/api/v3/users/jptosso/posts/42"
+
+ re := regexp.MustCompile(pattern)
+
+ b.Run("FindStringSubmatch", func(b *testing.B) {
+ for b.Loop() {
+ match := re.FindStringSubmatch(input)
+ if len(match) == 0 {
+ b.Fatal("expected match")
+ }
+ _ = match[1]
+ }
+ })
+ b.Run("FindStringSubmatchIndex", func(b *testing.B) {
+ for b.Loop() {
+ match := re.FindStringSubmatchIndex(input)
+ if match == nil {
+ b.Fatal("expected match")
+ }
+ _ = input[match[2]:match[3]]
+ }
+ })
+}
+
func BenchmarkRxSubstringVsMatch(b *testing.B) {
str := "hello world; heelloo Woorld; hello; heeeelloooo wooooooorld;hello world; heelloo Woorld; hello; heeeelloooo wooooooorld;hello world; heelloo Woorld; hello; heeeelloooo wooooooorld;hello world; heelloo Woorld; hello; heeeelloooo wooooooorld;hello world; heelloo Woorld; hello; heeeelloooo wooooooorld;hello world; heelloo Woorld; hello; heeeelloooo wooooooorld;hello world; heelloo Woorld; hello; heeeelloooo wooooooorld;"
rx := regexp.MustCompile(`((h.*e.*l.*l.*o.*)|\d+)`)
diff --git a/internal/operators/validate_nid.go b/internal/operators/validate_nid.go
index af2a30cb2..fab78ab49 100644
--- a/internal/operators/validate_nid.go
+++ b/internal/operators/validate_nid.go
@@ -12,7 +12,6 @@ import (
"strings"
"github.com/corazawaf/coraza/v3/experimental/plugins/plugintypes"
- "github.com/corazawaf/coraza/v3/internal/memoize"
)
type validateNidFunction = func(input string) bool
@@ -61,7 +60,7 @@ func newValidateNID(options plugintypes.OperatorOptions) (plugintypes.Operator,
return nil, fmt.Errorf("invalid @validateNid argument")
}
- re, err := memoize.Do(expr, func() (any, error) { return regexp.Compile(expr) })
+ re, err := memoizeDo(options.Memoizer, expr, func() (any, error) { return regexp.Compile(expr) })
if err != nil {
return nil, err
}
diff --git a/internal/operators/validate_schema.go b/internal/operators/validate_schema.go
index 26d43499a..b09eb7f28 100644
--- a/internal/operators/validate_schema.go
+++ b/internal/operators/validate_schema.go
@@ -18,7 +18,6 @@ import (
"github.com/kaptinlin/jsonschema"
"github.com/corazawaf/coraza/v3/experimental/plugins/plugintypes"
- "github.com/corazawaf/coraza/v3/internal/memoize"
"github.com/corazawaf/coraza/v3/types"
)
@@ -79,7 +78,7 @@ func NewValidateSchema(options plugintypes.OperatorOptions) (plugintypes.Operato
}
key := md5Hash(schemaData)
- schema, err := memoize.Do(key, func() (any, error) {
+ schema, err := memoizeDo(options.Memoizer, key, func() (any, error) {
// Preliminarily validate that the schema is valid JSON
var jsonSchema any
if err := json.Unmarshal(schemaData, &jsonSchema); err != nil {
diff --git a/internal/seclang/directives.go b/internal/seclang/directives.go
index 9ee5127e1..07dee1494 100644
--- a/internal/seclang/directives.go
+++ b/internal/seclang/directives.go
@@ -14,10 +14,10 @@ import (
"strings"
"github.com/corazawaf/coraza/v3/debuglog"
+ "github.com/corazawaf/coraza/v3/experimental/plugins/plugintypes"
"github.com/corazawaf/coraza/v3/internal/auditlog"
"github.com/corazawaf/coraza/v3/internal/corazawaf"
"github.com/corazawaf/coraza/v3/internal/environment"
- "github.com/corazawaf/coraza/v3/internal/memoize"
utils "github.com/corazawaf/coraza/v3/internal/strings"
"github.com/corazawaf/coraza/v3/types"
)
@@ -282,6 +282,29 @@ func directiveSecRequestBodyAccess(options *DirectiveOptions) error {
return nil
}
+// Description: Configures the maximum JSON recursion depth limit Coraza will accept.
+// Default: 1024
+// Syntax: SecRequestBodyJsonDepthLimit [LIMIT]
+// ---
+// Anything over the limit will generate a REQBODY_ERROR in the JSON body processor.
+func directiveSecRequestBodyJsonDepthLimit(options *DirectiveOptions) error {
+ if len(options.Opts) == 0 {
+ return errEmptyOptions
+ }
+
+ limit, err := strconv.Atoi(options.Opts)
+ if err != nil {
+ return err
+ }
+
+ if limit <= 0 {
+ return errors.New("limit must be a positive integer")
+ }
+
+ options.WAF.RequestBodyJsonDepthLimit = limit
+ return nil
+}
+
// Description: Configures the rules engine.
// Syntax: SecRuleEngine On|Off|DetectionOnly
// Default: Off
@@ -813,7 +836,7 @@ func directiveSecAuditLogRelevantStatus(options *DirectiveOptions) error {
return errEmptyOptions
}
- re, err := memoize.Do(options.Opts, func() (any, error) { return regexp.Compile(options.Opts) })
+ re, err := options.WAF.Memoizer().Do(options.Opts, func() (any, error) { return regexp.Compile(options.Opts) })
if err != nil {
return err
}
@@ -912,12 +935,34 @@ func directiveSecDataDir(options *DirectiveOptions) error {
return nil
}
+// Description: Configures whether intercepted files will be kept after the transaction is processed.
+// Syntax: SecUploadKeepFiles On|RelevantOnly|Off
+// Default: Off
+// ---
+// The `SecUploadKeepFiles` directive is used to configure whether intercepted files are
+// preserved on disk after the transaction is processed.
+// This directive requires the storage directory to be defined (using `SecUploadDir`).
+//
+// Possible values are:
+// - On: Keep all uploaded files.
+// - Off: Do not keep uploaded files.
+// - RelevantOnly: Keep only uploaded files that matched at least one rule that would be
+// logged (excluding rules with the `nolog` action).
func directiveSecUploadKeepFiles(options *DirectiveOptions) error {
- b, err := parseBoolean(options.Opts)
+ if len(options.Opts) == 0 {
+ return errEmptyOptions
+ }
+
+ status, err := types.ParseUploadKeepFilesStatus(options.Opts)
if err != nil {
return err
}
- options.WAF.UploadKeepFiles = b
+
+ if !environment.HasAccessToFS && status != types.UploadKeepFilesOff {
+ return fmt.Errorf("SecUploadKeepFiles: cannot enable keeping uploaded files: filesystem access is disabled")
+ }
+
+ options.WAF.UploadKeepFiles = status
return nil
}
@@ -944,6 +989,11 @@ func directiveSecUploadFileLimit(options *DirectiveOptions) error {
return err
}
+// Description: Configures the directory where uploaded files will be stored.
+// Syntax: SecUploadDir /path/to/dir
+// Default: ""
+// ---
+// This directive is required when enabling SecUploadKeepFiles.
func directiveSecUploadDir(options *DirectiveOptions) error {
if len(options.Opts) == 0 {
return errEmptyOptions
@@ -1102,6 +1152,17 @@ func updateTargetBySingleID(id int, variables string, options *DirectiveOptions)
return rp.ParseVariables(strings.Trim(variables, "\""))
}
+// hasDisruptiveActions checks if any of the parsed actions are disruptive.
+// Returns true if at least one action has ActionTypeDisruptive, false otherwise.
+func hasDisruptiveActions(actions []ruleAction) bool {
+ for _, action := range actions {
+ if action.Atype == plugintypes.ActionTypeDisruptive {
+ return true
+ }
+ }
+ return false
+}
+
// Description: Updates the action list of the specified rule(s).
// Syntax: SecRuleUpdateActionById ID ACTIONLIST
// ---
@@ -1113,6 +1174,7 @@ func updateTargetBySingleID(id int, variables string, options *DirectiveOptions)
// ```apache
// SecRuleUpdateActionById 12345 "deny,status:403"
// ```
+// The rule ID can be single IDs or ranges of IDs. The targets are separated by a pipe character.
func directiveSecRuleUpdateActionByID(options *DirectiveOptions) error {
if len(options.Opts) == 0 {
return errEmptyOptions
@@ -1152,18 +1214,37 @@ func directiveSecRuleUpdateActionByID(options *DirectiveOptions) error {
return fmt.Errorf("invalid range: %s", idOrRange)
}
- for _, rule := range options.WAF.Rules.GetRules() {
- if rule.ID_ < start && rule.ID_ > end {
+ // Parse actions once to check if any are disruptive.
+ // Trim surrounding quotes because the SecLang syntax uses quoted action lists
+ // (e.g., SecRuleUpdateActionById 1004 "pass") and strings.Fields preserves them.
+ trimmedActions := strings.Trim(actions, "\"")
+ parsedActions, err := parseActions(options.WAF.Logger, trimmedActions)
+ if err != nil {
+ return err
+ }
+
+ // Check if any of the new actions are disruptive
+ hasDisruptiveAction := hasDisruptiveActions(parsedActions)
+
+ rules := options.WAF.Rules.GetRules()
+ for i := range rules {
+ if rules[i].ID_ < start || rules[i].ID_ > end {
continue
}
+
+ // Only clear disruptive actions if the update contains a disruptive action
+ if hasDisruptiveAction {
+ rules[i].ClearDisruptiveActions()
+ }
+
rp := RuleParser{
- rule: &rule,
+ rule: &rules[i],
options: RuleOptions{
WAF: options.WAF,
},
defaultActions: map[types.RulePhase][]ruleAction{},
}
- if err := rp.ParseActions(strings.Trim(actions, "\"")); err != nil {
+ if err := rp.applyParsedActions(parsedActions); err != nil {
return err
}
}
@@ -1173,11 +1254,32 @@ func directiveSecRuleUpdateActionByID(options *DirectiveOptions) error {
}
func updateActionBySingleID(id int, actions string, options *DirectiveOptions) error {
-
rule := options.WAF.Rules.FindByID(id)
if rule == nil {
return fmt.Errorf("SecRuleUpdateActionById: rule \"%d\" not found", id)
}
+
+ // Parse actions first to check if any are disruptive.
+ // Trim surrounding quotes from the SecLang action list syntax.
+ trimmedActions := strings.Trim(actions, "\"")
+ parsedActions, err := parseActions(options.WAF.Logger, trimmedActions)
+ if err != nil {
+ return err
+ }
+
+ // Check if any of the new actions are disruptive.
+ // hasDisruptiveActions returns false when parsedActions is empty or contains
+ // only non-disruptive actions, preserving existing disruptive actions on the rule.
+ hasDisruptiveAction := hasDisruptiveActions(parsedActions)
+
+ // Only clear disruptive actions if the update contains a disruptive action
+ // This matches ModSecurity behavior where SecRuleUpdateActionById replaces
+ // disruptive actions but preserves them if only non-disruptive actions are updated
+ if hasDisruptiveAction {
+ rule.ClearDisruptiveActions()
+ }
+
+ // Apply the parsed actions to the rule without re-parsing
rp := RuleParser{
rule: rule,
options: RuleOptions{
@@ -1185,7 +1287,7 @@ func updateActionBySingleID(id int, actions string, options *DirectiveOptions) e
},
defaultActions: map[types.RulePhase][]ruleAction{},
}
- return rp.ParseActions(strings.Trim(actions, "\""))
+ return rp.applyParsedActions(parsedActions)
}
// Description: Updates the target (variable) list of the specified rule(s) by tag.
diff --git a/internal/seclang/directives_test.go b/internal/seclang/directives_test.go
index 373f100be..5cc91ebb0 100644
--- a/internal/seclang/directives_test.go
+++ b/internal/seclang/directives_test.go
@@ -154,8 +154,7 @@ func TestDirectives(t *testing.T) {
"SecUploadKeepFiles": {
{"", expectErrorOnDirective},
{"Ox", expectErrorOnDirective},
- {"On", func(w *corazawaf.WAF) bool { return w.UploadKeepFiles }},
- {"Off", func(w *corazawaf.WAF) bool { return !w.UploadKeepFiles }},
+ {"Off", func(w *corazawaf.WAF) bool { return w.UploadKeepFiles == types.UploadKeepFilesOff }},
},
"SecUploadFileMode": {
{"", expectErrorOnDirective},
@@ -317,6 +316,10 @@ func TestDirectives(t *testing.T) {
{"/tmp-non-existing", expectErrorOnDirective},
{os.TempDir(), func(w *corazawaf.WAF) bool { return w.UploadDir == os.TempDir() }},
}
+ directiveCases["SecUploadKeepFiles"] = append(directiveCases["SecUploadKeepFiles"],
+ directiveCase{"On", func(w *corazawaf.WAF) bool { return w.UploadKeepFiles == types.UploadKeepFilesOn }},
+ directiveCase{"RelevantOnly", func(w *corazawaf.WAF) bool { return w.UploadKeepFiles == types.UploadKeepFilesRelevantOnly }},
+ )
}
for name, dCases := range directiveCases {
diff --git a/internal/seclang/directivesmap.gen.go b/internal/seclang/directivesmap.gen.go
index b6ec36ee4..c9c118674 100644
--- a/internal/seclang/directivesmap.gen.go
+++ b/internal/seclang/directivesmap.gen.go
@@ -13,6 +13,7 @@ var (
_ directive = directiveSecResponseBodyAccess
_ directive = directiveSecRequestBodyLimit
_ directive = directiveSecRequestBodyAccess
+ _ directive = directiveSecRequestBodyJsonDepthLimit
_ directive = directiveSecRuleEngine
_ directive = directiveSecWebAppID
_ directive = directiveSecServerSignature
@@ -75,6 +76,7 @@ var directivesMap = map[string]directive{
"secresponsebodyaccess": directiveSecResponseBodyAccess,
"secrequestbodylimit": directiveSecRequestBodyLimit,
"secrequestbodyaccess": directiveSecRequestBodyAccess,
+ "secrequestbodyjsondepthlimit": directiveSecRequestBodyJsonDepthLimit,
"secruleengine": directiveSecRuleEngine,
"secwebappid": directiveSecWebAppID,
"secserversignature": directiveSecServerSignature,
diff --git a/internal/seclang/parser.go b/internal/seclang/parser.go
index e1fbf3a70..e3dfc111d 100644
--- a/internal/seclang/parser.go
+++ b/internal/seclang/parser.go
@@ -148,7 +148,7 @@ func (p *Parser) parseString(data string) error {
func (p *Parser) evaluateLine(l string) error {
if l == "" || l[0] == '#' {
- panic("invalid line")
+ return errors.New("invalid line")
}
// first we get the directive
dir, opts, _ := strings.Cut(l, " ")
diff --git a/internal/seclang/rule_parser.go b/internal/seclang/rule_parser.go
index feeaa6663..9ed6151d9 100644
--- a/internal/seclang/rule_parser.go
+++ b/internal/seclang/rule_parser.go
@@ -198,6 +198,9 @@ func (rp *RuleParser) ParseOperator(operator string) error {
Root: rp.options.ParserConfig.Root,
Datasets: rp.options.Datasets,
}
+ if rp.options.WAF != nil {
+ opts.Memoizer = rp.options.WAF.Memoizer()
+ }
if wd := rp.options.ParserConfig.WorkingDir; wd != "" {
opts.Path = append(opts.Path, wd)
@@ -265,11 +268,25 @@ func (rp *RuleParser) ParseDefaultActions(actions string) error {
// ParseActions parses a comma separated list of actions:arguments
// Arguments can be wrapper inside quotes
func (rp *RuleParser) ParseActions(actions string) error {
- disabledActions := rp.options.ParserConfig.DisabledRuleActions
act, err := parseActions(rp.options.WAF.Logger, actions)
if err != nil {
return err
}
+ return rp.applyParsedActions(act)
+}
+
+// applyParsedActions applies a list of already-parsed actions to the rule.
+//
+// This is a helper method used internally by ParseActions and directive handlers
+// (such as SecRuleUpdateActionById) to avoid code duplication. It's useful when
+// actions have been parsed once for inspection and need to be applied without
+// re-parsing, avoiding redundant parsing operations.
+//
+// The method validates that none of the actions are disabled, executes metadata
+// actions, merges with default actions for the rule's phase, and initializes
+// all actions on the rule.
+func (rp *RuleParser) applyParsedActions(act []ruleAction) error {
+ disabledActions := rp.options.ParserConfig.DisabledRuleActions
// check if forbidden action:
for _, a := range act {
if utils.InSlice(a.Key, disabledActions) {
@@ -334,9 +351,13 @@ func ParseRule(options RuleOptions) (*corazawaf.Rule, error) {
}
var err error
+ rule := corazawaf.NewRule()
+ if options.WAF != nil {
+ rule.SetMemoizer(options.WAF.Memoizer())
+ }
rp := RuleParser{
options: options,
- rule: corazawaf.NewRule(),
+ rule: rule,
defaultActions: map[types.RulePhase][]ruleAction{},
}
var defaultActionsRaw []string
@@ -389,7 +410,7 @@ func ParseRule(options RuleOptions) (*corazawaf.Rule, error) {
return nil, err
}
}
- rule := rp.Rule()
+ rule = rp.Rule()
rule.File_ = options.ParserConfig.ConfigFile
rule.Line_ = options.ParserConfig.LastLine
diff --git a/internal/strings/strings.go b/internal/strings/strings.go
index dd78e3871..9b0b42e3c 100644
--- a/internal/strings/strings.go
+++ b/internal/strings/strings.go
@@ -129,3 +129,31 @@ func InSlice(a string, list []string) bool {
func WrapUnsafe(buf []byte) string {
return *(*string)(unsafe.Pointer(&buf))
}
+
+// HasRegex checks if s is enclosed in unescaped forward slashes (e.g. "/pattern/"),
+// consistent with the ModSecurity regex delimiter convention. It returns (true, pattern)
+// where pattern is the content between the slashes. Escaped closing slashes (e.g. "/foo\/")
+// are treated as plain strings and return (false, s).
+func HasRegex(s string) (bool, string) {
+ if len(s) < 2 || s[0] != '/' {
+ return false, s
+ }
+ lastIdx := len(s) - 1
+ if s[lastIdx] != '/' {
+ return false, s
+ }
+ // "//" edge-case: empty pattern
+ if lastIdx == 1 {
+ return true, ""
+ }
+ // Count consecutive backslashes immediately before the closing '/'.
+ // An even count (including zero) means the '/' is unescaped.
+ backslashes := 0
+ for i := lastIdx - 1; i >= 0 && s[i] == '\\'; i-- {
+ backslashes++
+ }
+ if backslashes%2 == 0 {
+ return true, s[1:lastIdx]
+ }
+ return false, s
+}
diff --git a/internal/strings/strings_test.go b/internal/strings/strings_test.go
index e398a9a57..5310d5d98 100644
--- a/internal/strings/strings_test.go
+++ b/internal/strings/strings_test.go
@@ -112,3 +112,97 @@ func TestRandomStringConcurrency(t *testing.T) {
go RandomString(10000)
}
}
+
+func TestHasRegex(t *testing.T) {
+ tCases := []struct {
+ name string
+ input string
+ expectIsRegex bool
+ expectPattern string
+ }{
+ {
+ name: "valid regex pattern",
+ input: "/user/",
+ expectIsRegex: true,
+ expectPattern: "user",
+ },
+ {
+ name: "escaped slash at end — not a regex",
+ input: `/user\/`,
+ expectIsRegex: false,
+ expectPattern: `/user\/`,
+ },
+ {
+ name: "double-escaped slash at end — is a regex",
+ input: `/user\\/`,
+ expectIsRegex: true,
+ expectPattern: `user\\`,
+ },
+ {
+ name: "triple-escaped slash at end — not a regex",
+ input: `/user\\\/`,
+ expectIsRegex: false,
+ expectPattern: `/user\\\/`,
+ },
+ {
+ name: "empty pattern //",
+ input: "//",
+ expectIsRegex: true,
+ expectPattern: "",
+ },
+ {
+ name: "too short — single char",
+ input: "/",
+ expectIsRegex: false,
+ expectPattern: "/",
+ },
+ {
+ name: "no leading slash",
+ input: "user/",
+ expectIsRegex: false,
+ expectPattern: "user/",
+ },
+ {
+ name: "no trailing slash",
+ input: "/user",
+ expectIsRegex: false,
+ expectPattern: "/user",
+ },
+ {
+ name: "complex pattern with anchors and quantifiers",
+ input: `/^json\.\d+\.field$/`,
+ expectIsRegex: true,
+ expectPattern: `^json\.\d+\.field$`,
+ },
+ {
+ name: "pattern with character class",
+ input: "/user[0-9]+/",
+ expectIsRegex: true,
+ expectPattern: "user[0-9]+",
+ },
+ {
+ name: "plain string without slashes",
+ input: "username",
+ expectIsRegex: false,
+ expectPattern: "username",
+ },
+ {
+ name: "empty string",
+ input: "",
+ expectIsRegex: false,
+ expectPattern: "",
+ },
+ }
+
+ for _, tCase := range tCases {
+ t.Run(tCase.name, func(t *testing.T) {
+ gotIsRegex, gotPattern := HasRegex(tCase.input)
+ if gotIsRegex != tCase.expectIsRegex {
+ t.Errorf("HasRegex(%q): isRegex = %v, want %v", tCase.input, gotIsRegex, tCase.expectIsRegex)
+ }
+ if gotPattern != tCase.expectPattern {
+ t.Errorf("HasRegex(%q): pattern = %q, want %q", tCase.input, gotPattern, tCase.expectPattern)
+ }
+ })
+ }
+}
diff --git a/internal/transformations/escape_seq_decode.go b/internal/transformations/escape_seq_decode.go
index f740b683b..f245109e3 100644
--- a/internal/transformations/escape_seq_decode.go
+++ b/internal/transformations/escape_seq_decode.go
@@ -97,6 +97,7 @@ func doEscapeSeqDecode(input string, pos int) (string, bool) {
data[d] = input[i+1]
d++
i += 2
+ changed = true
} else {
/* Input character not a backslash, copy it. */
data[d] = input[i]
diff --git a/internal/transformations/escape_seq_decode_test.go b/internal/transformations/escape_seq_decode_test.go
index d9a62f891..5e2a3ebd3 100644
--- a/internal/transformations/escape_seq_decode_test.go
+++ b/internal/transformations/escape_seq_decode_test.go
@@ -30,6 +30,10 @@ func TestEscapeSeqDecode(t *testing.T) {
input: "\\a\\b\\f\\n\\r\\t\\v\\u0000\\?\\'\\\"\\0\\12\\123\\x00\\xff",
want: "\a\b\f\n\r\t\vu0000?'\"\x00\nS\x00\xff",
},
+ {
+ input: "\\z",
+ want: "z",
+ },
}
for _, tc := range tests {
diff --git a/internal/transformations/remove_comments.go b/internal/transformations/remove_comments.go
index e2a136c40..a5eba6c7a 100644
--- a/internal/transformations/remove_comments.go
+++ b/internal/transformations/remove_comments.go
@@ -18,9 +18,11 @@ charLoop:
switch {
case (input[i] == '/') && (i+1 < inputLen) && (input[i+1] == '*'):
incomment = true
+ changed = true
i += 2
case (input[i] == '<') && (i+3 < inputLen) && (input[i+1] == '!') && (input[i+2] == '-') && (input[i+3] == '-'):
incomment = true
+ changed = true
i += 4
case (input[i] == '-') && (i+1 < inputLen) && (input[i+1] == '-'):
input[i] = ' '
diff --git a/internal/transformations/remove_comments_test.go b/internal/transformations/remove_comments_test.go
new file mode 100644
index 000000000..001afe166
--- /dev/null
+++ b/internal/transformations/remove_comments_test.go
@@ -0,0 +1,91 @@
+// Copyright 2024 Juan Pablo Tosso and the OWASP Coraza contributors
+// SPDX-License-Identifier: Apache-2.0
+
+package transformations
+
+import "testing"
+
+func TestRemoveComments(t *testing.T) {
+ tests := []struct {
+ name string
+ input string
+ want string
+ changed bool
+ }{
+ {
+ name: "no comments",
+ input: "hello world",
+ want: "hello world",
+ changed: false,
+ },
+ {
+ name: "c-style comment",
+ input: "hello /* comment */ world",
+ want: "hello world",
+ changed: true,
+ },
+ {
+ name: "html comment",
+ input: "hello world",
+ want: "hello world",
+ changed: true,
+ },
+ {
+ name: "c-style comment only",
+ input: "/* comment */",
+ want: "\x00",
+ changed: true,
+ },
+ {
+ name: "html comment only",
+ input: "",
+ want: "\x00",
+ changed: true,
+ },
+ {
+ name: "unclosed c-style comment",
+ input: "hello /* unclosed",
+ want: "hello ",
+ changed: true,
+ },
+ {
+ name: "unclosed html comment",
+ input: "hello