diff --git a/go.mod b/go.mod index 3bb0a9bfef4..8ab791a2306 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/crowdsecurity/crowdsec -go 1.25.1 +go 1.25 require ( entgo.io/ent v0.14.2 @@ -9,7 +9,6 @@ require ( github.com/Masterminds/sprig/v3 v3.2.3 github.com/agext/levenshtein v1.2.3 github.com/alexliesenfeld/health v0.8.1 - github.com/appleboy/gin-jwt/v2 v2.10.3 github.com/aws/aws-lambda-go v1.47.0 github.com/aws/aws-sdk-go-v2 v1.38.3 github.com/aws/aws-sdk-go-v2/config v1.31.6 @@ -34,8 +33,6 @@ require ( github.com/expr-lang/expr v1.17.2 github.com/fatih/color v1.18.0 github.com/fsnotify/fsnotify v1.9.0 - github.com/gin-contrib/gzip v1.2.3 - github.com/gin-gonic/gin v1.10.0 github.com/go-co-op/gocron v1.37.0 github.com/go-openapi/errors v0.22.2 github.com/go-openapi/strfmt v0.23.0 @@ -123,9 +120,6 @@ require ( github.com/aws/smithy-go v1.23.0 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/bmatcuk/doublestar v1.3.4 // indirect - github.com/bytedance/sonic v1.13.2 // indirect - github.com/bytedance/sonic/loader v0.2.4 // indirect - github.com/cloudwego/base64x v0.1.5 // indirect github.com/containerd/errdefs/pkg v0.3.0 // indirect github.com/coreos/go-systemd/v22 v22.5.0 // indirect github.com/cpuguy83/go-md2man/v2 v2.0.6 // indirect @@ -135,8 +129,6 @@ require ( github.com/dustin/go-humanize v1.0.1 // indirect github.com/ebitengine/purego v0.8.4 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect - github.com/gabriel-vasile/mimetype v1.4.8 // indirect - github.com/gin-contrib/sse v1.0.0 // indirect github.com/go-logr/logr v1.4.3 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/go-ole/go-ole v1.2.6 // indirect @@ -146,9 +138,6 @@ require ( github.com/go-openapi/jsonreference v0.21.0 // indirect github.com/go-openapi/loads v0.22.0 // indirect github.com/go-openapi/spec v0.21.0 // indirect - github.com/go-playground/locales v0.14.1 // indirect - github.com/go-playground/universal-translator v0.18.1 // indirect - github.com/go-playground/validator/v10 v10.26.0 // indirect github.com/goccy/go-json v0.10.5 // indirect github.com/gogo/protobuf v1.3.2 // indirect github.com/golang/glog v1.2.5 // indirect @@ -175,8 +164,6 @@ require ( github.com/kaptinlin/jsonschema v0.4.6 // indirect github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 // indirect github.com/klauspost/compress v1.18.0 // indirect - github.com/klauspost/cpuid/v2 v2.2.10 // indirect - github.com/leodido/go-urn v1.4.0 // indirect github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect github.com/magefile/mage v1.15.1-0.20250615140142-78acbaf2e3ae // indirect github.com/mailru/easyjson v0.9.0 // indirect @@ -196,7 +183,6 @@ require ( github.com/oklog/ulid v1.3.1 // indirect github.com/opencontainers/go-digest v1.0.0 // indirect github.com/opencontainers/image-spec v1.1.1 // indirect - github.com/pelletier/go-toml/v2 v2.2.4 // indirect github.com/petar-dambovaliev/aho-corasick v0.0.0-20250424160509-463d218d4745 // indirect github.com/pierrec/lz4/v4 v4.1.18 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect @@ -217,12 +203,9 @@ require ( github.com/tklauser/go-sysconf v0.3.15 // indirect github.com/tklauser/numcpus v0.10.0 // indirect github.com/toorop/go-dkim v0.0.0-20201103131630-e1cd1a0a5208 // indirect - github.com/twitchyliquid64/golang-asm v0.15.1 // indirect - github.com/ugorji/go/codec v1.2.12 // indirect github.com/valllabh/ocsf-schema-golang v1.0.3 // indirect github.com/vmihailenco/msgpack v4.0.4+incompatible // indirect github.com/wasilibs/wazero-helpers v0.0.0-20250123031827-cd30c44769bb // indirect - github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 // indirect github.com/yusufpapurcu/wmi v1.2.4 // indirect github.com/zclconf/go-cty v1.14.4 // indirect github.com/zclconf/go-cty-yaml v1.1.0 // indirect @@ -234,7 +217,6 @@ require ( go.opentelemetry.io/otel/trace v1.36.0 // indirect go.uber.org/atomic v1.10.0 // indirect go.yaml.in/yaml/v2 v2.4.2 // indirect - golang.org/x/arch v0.15.0 // indirect golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b // indirect golang.org/x/term v0.35.0 // indirect golang.org/x/tools v0.37.0 // indirect diff --git a/go.sum b/go.sum index 49775536f42..295b7a5c6df 100644 --- a/go.sum +++ b/go.sum @@ -31,10 +31,6 @@ github.com/apparentlymart/go-textseg/v13 v13.0.0 h1:Y+KvPE1NYz0xl601PVImeQfFyEy6 github.com/apparentlymart/go-textseg/v13 v13.0.0/go.mod h1:ZK2fH7c4NqDTLtiYLvIkEghdlcqw7yxLeM89kiTRPUo= github.com/apparentlymart/go-textseg/v15 v15.0.0 h1:uYvfpb3DyLSCGWnctWKGj857c6ew1u1fNQOlOtuGxQY= github.com/apparentlymart/go-textseg/v15 v15.0.0/go.mod h1:K8XmNZdhEBkdlyDdvbmmsvpAG721bKi0joRfFdHIWJ4= -github.com/appleboy/gin-jwt/v2 v2.10.3 h1:KNcPC+XPRNpuoBh+j+rgs5bQxN+SwG/0tHbIqpRoBGc= -github.com/appleboy/gin-jwt/v2 v2.10.3/go.mod h1:LDUaQ8mF2W6LyXIbd5wqlV2SFebuyYs4RDwqMNgpsp8= -github.com/appleboy/gofight/v2 v2.1.2 h1:VOy3jow4vIK8BRQJoC/I9muxyYlJ2yb9ht2hZoS3rf4= -github.com/appleboy/gofight/v2 v2.1.2/go.mod h1:frW+U1QZEdDgixycTj4CygQ48yLTUhplt43+Wczp3rw= github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 h1:DklsrG3dyBCFEj5IhUbnKptjxatkF07cF2ak3yi77so= github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2/go.mod h1:WaHUgvxTVq04UNunO+XhnAqY/wQc+bxr74GqbsZ/Jqw= github.com/aws/aws-lambda-go v1.47.0 h1:0H8s0vumYx/YKs4sE7YM0ktwL2eWse+kfopsRI1sXVI= @@ -95,18 +91,10 @@ github.com/bufbuild/protocompile v0.4.0 h1:LbFKd2XowZvQ/kajzguUp2DC9UEIQhIq77fZZ github.com/bufbuild/protocompile v0.4.0/go.mod h1:3v93+mbWn/v3xzN+31nwkJfrEpAUwp+BagBSZWx+TP8= github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMUs= github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0= -github.com/bytedance/sonic v1.13.2 h1:8/H1FempDZqC4VqjptGo14QQlJx8VdZJegxs6wwfqpQ= -github.com/bytedance/sonic v1.13.2/go.mod h1:o68xyaF9u2gvVBuGHPlUVCy+ZfmNNO5ETf1+KgkJhz4= -github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU= -github.com/bytedance/sonic/loader v0.2.4 h1:ZWCw4stuXUsn1/+zQDqeE7JKP+QO47tz7QCNan80NzY= -github.com/bytedance/sonic/loader v0.2.4/go.mod h1:N8A3vUdtUebEY2/VQC0MyhYeKUFosQU6FxH2JmUe6VI= github.com/cenkalti/backoff/v5 v5.0.3 h1:ZN+IMa753KfX5hd8vVaMixjnqRZ3y8CuJKRKj1xcsSM= github.com/cenkalti/backoff/v5 v5.0.3/go.mod h1:rkhZdG3JZukswDf7f0cwqPNk4K0sa+F97BxZthm/crw= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= -github.com/cloudwego/base64x v0.1.5 h1:XPciSp1xaq2VCSt6lF0phncD4koWyULpl5bUxbfCyP4= -github.com/cloudwego/base64x v0.1.5/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w= -github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY= github.com/cockroachdb/apd v1.1.0 h1:3LFP3629v+1aKXU5Q37mxmRxX/pIu1nijXydLShEq5I= github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMeY4+DwBQ= github.com/containerd/errdefs v1.0.0 h1:tg5yIfIlQIrxYtu9ajqY42W3lpS19XqdxRQeEwYG8PI= @@ -170,14 +158,6 @@ github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7z github.com/fsnotify/fsnotify v1.6.0/go.mod h1:sl3t1tCWJFWoRz9R8WJCbQihKKwmorjAbSClcnxKAGw= github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k= github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0= -github.com/gabriel-vasile/mimetype v1.4.8 h1:FfZ3gj38NjllZIeJAmMhr+qKL8Wu+nOoI3GqacKw1NM= -github.com/gabriel-vasile/mimetype v1.4.8/go.mod h1:ByKUIKGjh1ODkGM1asKUbQZOLGrPjydw3hYPU2YU9t8= -github.com/gin-contrib/gzip v1.2.3 h1:dAhT722RuEG330ce2agAs75z7yB+NKvX/ZM1r8w0u2U= -github.com/gin-contrib/gzip v1.2.3/go.mod h1:ad72i4Bzmaypk8M762gNXa2wkxxjbz0icRNnuLJ9a/c= -github.com/gin-contrib/sse v1.0.0 h1:y3bT1mUWUxDpW4JLQg/HnTqV4rozuW4tC9eFKTxYI9E= -github.com/gin-contrib/sse v1.0.0/go.mod h1:zNuFdwarAygJBht0NTKiSi3jRf6RbqeILZ9Sp6Slhe0= -github.com/gin-gonic/gin v1.10.0 h1:nTuyha1TYqgedzytsKYqna+DfLos46nTv2ygFy86HFU= -github.com/gin-gonic/gin v1.10.0/go.mod h1:4PMNQiOhvDRa013RKVbsiNwoyezlm2rm0uX/T7kzp5Y= github.com/go-co-op/gocron v1.37.0 h1:ZYDJGtQ4OMhTLKOKMIch+/CY70Brbb1dGdooLEhh7b0= github.com/go-co-op/gocron v1.37.0/go.mod h1:3L/n6BkO7ABj+TrfSVXLRzsP26zmikL4ISkLQ0O8iNY= github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY= @@ -210,14 +190,6 @@ github.com/go-openapi/swag v0.23.1 h1:lpsStH0n2ittzTnbaSloVZLuB5+fvSY/+hnagBjSNZ github.com/go-openapi/swag v0.23.1/go.mod h1:STZs8TbRvEQQKUA+JZNAm3EWlgaOBGpyFDqQnDHMef0= github.com/go-openapi/validate v0.24.0 h1:LdfDKwNbpB6Vn40xhTdNZAnfLECL81w+VX3BumrGD58= github.com/go-openapi/validate v0.24.0/go.mod h1:iyeX1sEufmv3nPbBdX3ieNviWnOZaJ1+zquzJEf2BAQ= -github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= -github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= -github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= -github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= -github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= -github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= -github.com/go-playground/validator/v10 v10.26.0 h1:SP05Nqhjcvz81uJaRfEV0YBSSSGMc/iMaVtFbr3Sw2k= -github.com/go-playground/validator/v10 v10.26.0/go.mod h1:I5QpIEbmr8On7W0TktmJAumgzX4CA1XNl4ZmDuVHKKo= github.com/go-sql-driver/mysql v1.9.3 h1:U/N249h2WzJ3Ukj8SowVFjdtZKfu9vlLZxjPXV1aweo= github.com/go-sql-driver/mysql v1.9.3/go.mod h1:qn46aNg1333BRMNU69Lq93t8du/dwxI64Gl8i5p1WMU= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= @@ -366,10 +338,6 @@ github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+o github.com/klauspost/compress v1.15.9/go.mod h1:PhcZ0MbTNciWF3rruxRgKxI5NkcHHrHUDtV4Yw2GlzU= github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= -github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= -github.com/klauspost/cpuid/v2 v2.2.10 h1:tBs3QSyvjDyFTq3uoc/9xFpCuOsJQFNPiAhYdw2skhE= -github.com/klauspost/cpuid/v2 v2.2.10/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0= -github.com/knz/go-libedit v1.10.1/go.mod h1:MZTVkCWyz0oBc7JOWP3wNAzd002ZbM/5hgShxwh4x8M= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= @@ -384,8 +352,6 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= -github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= -github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/lib/pq v1.1.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= @@ -551,10 +517,7 @@ github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.2/go.mod h1:R6va5+xMeoiuVRoj+gSkQ7d3FALtqAAGI1FQKckRals= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= -github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= -github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= -github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/tetratelabs/wazero v1.9.0 h1:IcZ56OuxrtaEz8UYNRHBrUa9bYeX9oVY93KspZZBf/I= @@ -572,10 +535,6 @@ github.com/tklauser/numcpus v0.10.0 h1:18njr6LDBk1zuna922MgdjQuJFjrdppsZG60sHGfj github.com/tklauser/numcpus v0.10.0/go.mod h1:BiTKazU708GQTYF4mB+cmlpT2Is1gLk7XVuEeem8LsQ= github.com/toorop/go-dkim v0.0.0-20201103131630-e1cd1a0a5208 h1:PM5hJF7HVfNWmCjMdEfbuOBNXSVF2cMFGgQTPdKCbwM= github.com/toorop/go-dkim v0.0.0-20201103131630-e1cd1a0a5208/go.mod h1:BzWtXXrXzZUvMacR0oF/fbDDgUPO8L36tDMmRAf14ns= -github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= -github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= -github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE= -github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg= github.com/umahmood/haversine v0.0.0-20151105152445-808ab04add26 h1:UFHFmFfixpmfRBcxuu+LA9l8MdURWVdVNUHxO5n1d2w= github.com/umahmood/haversine v0.0.0-20151105152445-808ab04add26/go.mod h1:IGhd0qMDsUa9acVjsbsT7bu3ktadtGOHI79+idTew/M= github.com/valllabh/ocsf-schema-golang v1.0.3 h1:eR8k/3jP/OOqB8LRCtdJ4U+vlgd/gk5y3KMXoodrsrw= @@ -596,8 +555,6 @@ github.com/xdg-go/stringprep v1.0.4 h1:XLI/Ng3O1Atzq0oBs3TWm+5ZVgkq2aqdlvP9JtoZ6 github.com/xdg-go/stringprep v1.0.4/go.mod h1:mPGuuIYwz7CmR2bT9j4GbQqutWS1zV24gijq1dTyGkM= github.com/xhit/go-simple-mail/v2 v2.16.0 h1:ouGy/Ww4kuaqu2E2UrDw7SvLaziWTB60ICLkIkNVccA= github.com/xhit/go-simple-mail/v2 v2.16.0/go.mod h1:b7P5ygho6SYE+VIqpxA6QkYfv4teeyG4MKqB3utRu98= -github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 h1:ilQV1hzziu+LLM3zUTJ0trRztfwgjqKnBWNtSRkbmwM= -github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78/go.mod h1:aL8wCCfTfSfmXjznFBSZNN13rSJjlIOI1fUNAtF7rmI= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= @@ -642,8 +599,6 @@ go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= go.uber.org/zap v1.13.0/go.mod h1:zwrFLgMcdUuIBviXEYEH1YKNaOBnKXsx2IPda5bBwHM= go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI= go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU= -golang.org/x/arch v0.15.0 h1:QtOrQd0bTUnhNVNndMpLHNWrDmYzZ2KDqSrEymqInZw= -golang.org/x/arch v0.15.0/go.mod h1:JmwW7aLIoRUKgaTzhkiEFxvcEiQGyOg9BMonBJUS7EE= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a/go.mod h1:WFFai1msRO1wXaEeE5yQxYXgSfI8pQAWXbQop6sCtWE= golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= @@ -837,7 +792,6 @@ modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0= modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A= modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y= modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM= -nullprogram.com/x/optparse v1.0.0/go.mod h1:KdyPE+Igbe0jQUrVfMqDMeJQIJZEuyV7pjYmp6pbG50= pgregory.net/rapid v1.2.0 h1:keKAYRcjm+e1F0oAuU5F5+YPAWcyxNNRK2wud503Gnk= pgregory.net/rapid v1.2.0/go.mod h1:PY5XlDGj0+V1FCq0o192FdRhpKHGTRIWBgqjDBTrq04= rsc.io/binaryregexp v0.2.0 h1:HfqmD5MEmC0zvwBuF187nq9mdnXjXsSivRiXN7SmRkE= diff --git a/pkg/apiserver/alerts_test.go b/pkg/apiserver/alerts_test.go index c5534af1e7b..39595e82c71 100644 --- a/pkg/apiserver/alerts_test.go +++ b/pkg/apiserver/alerts_test.go @@ -11,12 +11,12 @@ import ( "testing" "time" - "github.com/gin-gonic/gin" "github.com/go-openapi/strfmt" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" logtest "github.com/sirupsen/logrus/hooks/test" + "github.com/crowdsecurity/crowdsec/pkg/apiserver/router" "github.com/crowdsecurity/crowdsec/pkg/csconfig" "github.com/crowdsecurity/crowdsec/pkg/database" "github.com/crowdsecurity/crowdsec/pkg/models" @@ -28,7 +28,7 @@ const ( ) type LAPI struct { - router *gin.Engine + router *router.Router loginResp models.WatcherAuthResponse bouncerKey string DBConfig *csconfig.DatabaseCfg @@ -77,22 +77,22 @@ func (l *LAPI) RecordResponse(t *testing.T, ctx context.Context, verb string, ur return w } -func InitMachineTest(t *testing.T, ctx context.Context) (*gin.Engine, models.WatcherAuthResponse, csconfig.Config) { - router, config := NewAPITest(t, ctx) - loginResp := LoginToTestAPI(t, ctx, router, config) +func InitMachineTest(t *testing.T, ctx context.Context) (*router.Router, models.WatcherAuthResponse, csconfig.Config) { + rtr, config := NewAPITest(t, ctx) + loginResp := LoginToTestAPI(t, ctx, rtr, config) - return router, loginResp, config + return rtr, loginResp, config } -func LoginToTestAPI(t *testing.T, ctx context.Context, router *gin.Engine, config csconfig.Config) models.WatcherAuthResponse { - body := CreateTestMachine(t, ctx, router, "") +func LoginToTestAPI(t *testing.T, ctx context.Context, rtr *router.Router, config csconfig.Config) models.WatcherAuthResponse { + body := CreateTestMachine(t, ctx, rtr, "") ValidateMachine(t, ctx, "test", config.API.Server.DbConfig) w := httptest.NewRecorder() req, err := http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers/login", strings.NewReader(body)) require.NoError(t, err) req.Header.Add("User-Agent", UserAgent) - router.ServeHTTP(w, req) + rtr.ServeHTTP(w, req) loginResp := models.WatcherAuthResponse{} err = json.NewDecoder(w.Body).Decode(&loginResp) diff --git a/pkg/apiserver/apiserver.go b/pkg/apiserver/apiserver.go index faec699a077..69faf515fd4 100644 --- a/pkg/apiserver/apiserver.go +++ b/pkg/apiserver/apiserver.go @@ -1,20 +1,18 @@ package apiserver import ( - "context" "cmp" + "context" "errors" "fmt" - "io" "io/fs" "net" "net/http" + "net/netip" "os" "runtime" - "strings" "time" - "github.com/gin-gonic/gin" "github.com/go-co-op/gocron" log "github.com/sirupsen/logrus" "gopkg.in/tomb.v2" @@ -22,7 +20,9 @@ import ( "github.com/crowdsecurity/go-cs-lib/trace" "github.com/crowdsecurity/crowdsec/pkg/apiserver/controllers" + "github.com/crowdsecurity/crowdsec/pkg/apiserver/middlewares" v1 "github.com/crowdsecurity/crowdsec/pkg/apiserver/middlewares/v1" + "github.com/crowdsecurity/crowdsec/pkg/apiserver/router" "github.com/crowdsecurity/crowdsec/pkg/csconfig" "github.com/crowdsecurity/crowdsec/pkg/csnet" "github.com/crowdsecurity/crowdsec/pkg/csplugin" @@ -30,6 +30,32 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/logging" ) +// convertTrustedProxies converts []string to []netip.Prefix +func convertTrustedProxies(proxyStrings []string) ([]netip.Prefix, error) { + if proxyStrings == nil { + return nil, nil + } + + proxies := make([]netip.Prefix, 0, len(proxyStrings)) + for _, proxy := range proxyStrings { + prefix, err := netip.ParsePrefix(proxy) + if err != nil { + // Try parsing as a single IP and convert to /32 or /128 + addr, err := netip.ParseAddr(proxy) + if err != nil { + return nil, fmt.Errorf("invalid proxy IP/CIDR: %s", proxy) + } + if addr.Is4() { + prefix = netip.PrefixFrom(addr, 32) + } else { + prefix = netip.PrefixFrom(addr, 128) + } + } + proxies = append(proxies, prefix) + } + return proxies, nil +} + const keyLength = 32 type APIServer struct { @@ -37,81 +63,13 @@ type APIServer struct { dbClient *database.Client controller *controllers.Controller flushScheduler *gocron.Scheduler - router *gin.Engine + router *router.Router httpServer *http.Server apic *apic papi *Papi httpServerTomb tomb.Tomb } -func isBrokenConnection(maybeError any) bool { - err, ok := maybeError.(error) - if !ok { - return false - } - - var netOpError *net.OpError - if errors.As(err, &netOpError) { - var syscallError *os.SyscallError - if errors.As(netOpError.Err, &syscallError) { - if strings.Contains(strings.ToLower(syscallError.Error()), "broken pipe") || strings.Contains(strings.ToLower(syscallError.Error()), "connection reset by peer") { - return true - } - } - } - - // because of https://github.com/golang/net/blob/39120d07d75e76f0079fe5d27480bcb965a21e4c/http2/server.go - // and because it seems gin doesn't handle those neither, we need to "hand define" some errors to properly catch them - // stolen from http2/server.go in x/net - var ( - errClientDisconnected = errors.New("client disconnected") - errClosedBody = errors.New("body closed by handler") - errHandlerComplete = errors.New("http2: request body closed due to handler exiting") - errStreamClosed = errors.New("http2: stream closed") - ) - - if errors.Is(err, errClientDisconnected) || - errors.Is(err, errClosedBody) || - errors.Is(err, errHandlerComplete) || - errors.Is(err, errStreamClosed) { - return true - } - - return false -} - -func recoverFromPanic(c *gin.Context) { - err := recover() //nolint:revive - if err == nil { - return - } - - // Check for a broken connection, as it is not really a - // condition that warrants a panic stack trace. - if isBrokenConnection(err) { - log.Warningf("client %s disconnected: %s", c.ClientIP(), err) - c.Abort() - } else { - log.Warningf("client %s error: %s", c.ClientIP(), err) - - filename, err := trace.WriteStackTrace(err) - if err != nil { - log.Errorf("also while writing stacktrace: %s", err) - } - - log.Warningf("stacktrace written to %s, please join to your issue", filename) - c.AbortWithStatus(http.StatusInternalServerError) - } -} - -// CustomRecoveryWithWriter returns a middleware for a writer that recovers from any panics and writes a 500 if there was one. -func CustomRecoveryWithWriter() gin.HandlerFunc { - return func(c *gin.Context) { - defer recoverFromPanic(c) - c.Next() - } -} - // NewServer creates a LAPI server. // It sets up a gin router, a database client, and a controller. func NewServer(ctx context.Context, config *csconfig.LocalApiServerCfg, accessLogger *log.Entry) (*APIServer, error) { @@ -133,59 +91,37 @@ func NewServer(ctx context.Context, config *csconfig.LocalApiServerCfg, accessLo } } - if !log.IsLevelEnabled(log.DebugLevel) { - gin.SetMode(gin.ReleaseMode) - } - - router := gin.New() - - router.ForwardedByClientIP = false - - // set the remore address of the request to 127.0.0.1 if it comes from a unix socket - router.Use(func(c *gin.Context) { - if c.Request.RemoteAddr == "@" { - c.Request.RemoteAddr = "127.0.0.1:65535" - } - }) + // Create new router + httpRouter := router.New() + // Set up middleware + var trustedProxies []netip.Prefix if config.TrustedProxies != nil && config.UseForwardedForHeaders { - if err = router.SetTrustedProxies(*config.TrustedProxies); err != nil { - return nil, fmt.Errorf("while setting trusted_proxies: %w", err) + var err error + trustedProxies, err = convertTrustedProxies(*config.TrustedProxies) + if err != nil { + return nil, fmt.Errorf("while parsing trusted proxies: %w", err) } - - router.ForwardedByClientIP = true } - gin.DefaultErrorWriter = accessLogger.WriterLevel(log.ErrorLevel) - gin.DefaultWriter = accessLogger.Writer() - - router.Use(gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string { - return fmt.Sprintf("%s - [%s] \"%s %s %s %d %s %q %s\"\n", - param.ClientIP, - param.TimeStamp.Format(time.RFC1123), - param.Method, - param.Path, - param.Request.Proto, - param.StatusCode, - param.Latency, - param.Request.UserAgent(), - param.ErrorMessage, - ) - })) + // Apply middleware in order + httpRouter.Use(middlewares.ClientIPMiddleware(trustedProxies, config.UseForwardedForHeaders)) + httpRouter.Use(middlewares.LoggingMiddleware(accessLogger)) + httpRouter.Use(middlewares.RecoveryMiddleware()) + httpRouter.Use(middlewares.GzipDecompressMiddleware()) - router.NoRoute(func(c *gin.Context) { - c.JSON(http.StatusNotFound, gin.H{"message": "Page or Method not found"}) - }) - router.Use(CustomRecoveryWithWriter()) + // Handle 404 - http.ServeMux handles this, but we can add custom handler + // Note: Method not allowed is handled by http.ServeMux automatically controller := &controllers.Controller{ DBClient: dbClient, - Router: router, + Router: httpRouter, Profiles: config.Profiles, Log: accessLogger, ConsoleConfig: config.ConsoleConfig, DisableRemoteLapiRegistration: config.DisableRemoteLapiRegistration, AutoRegisterCfg: config.AutoRegister, + // TrustedIPs will be set later from config.GetTrustedIPs() - don't confuse with trustedProxies } var ( @@ -232,18 +168,18 @@ func NewServer(ctx context.Context, config *csconfig.LocalApiServerCfg, accessLo controller.TrustedIPs = trustedIPs return &APIServer{ - cfg: config, + cfg: config, dbClient: dbClient, controller: controller, flushScheduler: flushScheduler, - router: router, + router: httpRouter, apic: apiClient, papi: papiClient, httpServerTomb: tomb.Tomb{}, }, nil } -func (s *APIServer) Router() (*gin.Engine, error) { +func (s *APIServer) Router() (*router.Router, error) { return s.router, nil } @@ -324,7 +260,7 @@ func (s *APIServer) Run(ctx context.Context, apiReady chan bool) error { s.httpServer = &http.Server{ Addr: s.cfg.ListenURI, - Handler: s.router, + Handler: s.router, // Use router directly so middleware is applied (404/405 will be logged) TLSConfig: tlsCfg, Protocols: &http.Protocols{}, } @@ -473,14 +409,8 @@ func (s *APIServer) Shutdown(ctx context.Context) error { } } - // close io.writer logger given to gin - if pipe, ok := gin.DefaultErrorWriter.(*io.PipeWriter); ok { - pipe.Close() - } - - if pipe, ok := gin.DefaultWriter.(*io.PipeWriter); ok { - pipe.Close() - } + // Note: We no longer use gin.DefaultErrorWriter/DefaultWriter + // If we add custom writers in the future, close them here s.httpServerTomb.Kill(nil) @@ -546,7 +476,7 @@ func (s *APIServer) InitController() error { cacheExpiration = *s.cfg.TLS.CacheExpiration } - s.controller.HandlerV1.Middlewares.JWT.TlsAuth, err = v1.NewTLSAuth(s.cfg.TLS.AllowedAgentsOU, s.cfg.TLS.CRLPath, + tlsAuthAgents, err := v1.NewTLSAuth(s.cfg.TLS.AllowedAgentsOU, s.cfg.TLS.CRLPath, cacheExpiration, log.WithFields(log.Fields{ "component": "tls-auth", @@ -555,8 +485,9 @@ func (s *APIServer) InitController() error { if err != nil { return fmt.Errorf("while creating TLS auth for agents: %w", err) } + s.controller.HandlerV1.Middlewares.JWT.SetTlsAuth(tlsAuthAgents) - s.controller.HandlerV1.Middlewares.APIKey.TlsAuth, err = v1.NewTLSAuth(s.cfg.TLS.AllowedBouncersOU, s.cfg.TLS.CRLPath, + tlsAuthBouncers, err := v1.NewTLSAuth(s.cfg.TLS.AllowedBouncersOU, s.cfg.TLS.CRLPath, cacheExpiration, log.WithFields(log.Fields{ "component": "tls-auth", @@ -565,6 +496,7 @@ func (s *APIServer) InitController() error { if err != nil { return fmt.Errorf("while creating TLS auth for bouncers: %w", err) } + s.controller.HandlerV1.Middlewares.APIKey.TlsAuth = tlsAuthBouncers return nil } diff --git a/pkg/apiserver/apiserver_test.go b/pkg/apiserver/apiserver_test.go index e376cdff8a6..598f5b7cb0e 100644 --- a/pkg/apiserver/apiserver_test.go +++ b/pkg/apiserver/apiserver_test.go @@ -11,7 +11,6 @@ import ( "testing" "time" - "github.com/gin-gonic/gin" "github.com/go-openapi/strfmt" log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" @@ -23,6 +22,7 @@ import ( "github.com/crowdsecurity/go-cs-lib/ptr" "github.com/crowdsecurity/go-cs-lib/version" + "github.com/crowdsecurity/crowdsec/pkg/apiserver/router" middlewares "github.com/crowdsecurity/crowdsec/pkg/apiserver/middlewares/v1" "github.com/crowdsecurity/crowdsec/pkg/csconfig" "github.com/crowdsecurity/crowdsec/pkg/database" @@ -148,24 +148,23 @@ func NewAPIServer(t *testing.T, ctx context.Context) (*APIServer, csconfig.Confi require.NoError(t, err) log.Info("Creating new API server") - gin.SetMode(gin.TestMode) return apiServer, config } -func NewAPITest(t *testing.T, ctx context.Context) (*gin.Engine, csconfig.Config) { +func NewAPITest(t *testing.T, ctx context.Context) (*router.Router, csconfig.Config) { apiServer, config := NewAPIServer(t, ctx) err := apiServer.InitController() require.NoError(t, err) - router, err := apiServer.Router() + rtr, err := apiServer.Router() require.NoError(t, err) - return router, config + return rtr, config } -func NewAPITestForwardedFor(t *testing.T) (*gin.Engine, csconfig.Config) { +func NewAPITestForwardedFor(t *testing.T) (*router.Router, csconfig.Config) { ctx := t.Context() config := LoadTestConfigForwardedFor(t) @@ -179,12 +178,11 @@ func NewAPITestForwardedFor(t *testing.T) (*gin.Engine, csconfig.Config) { require.NoError(t, err) log.Info("Creating new API server") - gin.SetMode(gin.TestMode) - router, err := apiServer.Router() + rtr, err := apiServer.Router() require.NoError(t, err) - return router, config + return rtr, config } func ValidateMachine(t *testing.T, ctx context.Context, machineID string, config *csconfig.DatabaseCfg) { @@ -286,7 +284,7 @@ func readDecisionsStreamResp(t *testing.T, resp *httptest.ResponseRecorder) (map return response, resp.Code } -func CreateTestMachine(t *testing.T, ctx context.Context, router *gin.Engine, token string) string { +func CreateTestMachine(t *testing.T, ctx context.Context, rtr *router.Router, token string) string { regReq := MachineTest regReq.RegistrationToken = token b, err := json.Marshal(regReq) @@ -298,7 +296,7 @@ func CreateTestMachine(t *testing.T, ctx context.Context, router *gin.Engine, to req, err := http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers", strings.NewReader(body)) require.NoError(t, err) req.Header.Set("User-Agent", UserAgent) - router.ServeHTTP(w, req) + rtr.ServeHTTP(w, req) return body } diff --git a/pkg/apiserver/controllers/controller.go b/pkg/apiserver/controllers/controller.go index ab8dc6501fb..12e33824701 100644 --- a/pkg/apiserver/controllers/controller.go +++ b/pkg/apiserver/controllers/controller.go @@ -6,10 +6,9 @@ import ( "strings" "github.com/alexliesenfeld/health" - "github.com/gin-contrib/gzip" - "github.com/gin-gonic/gin" v1 "github.com/crowdsecurity/crowdsec/pkg/apiserver/controllers/v1" + "github.com/crowdsecurity/crowdsec/pkg/apiserver/router" "github.com/crowdsecurity/crowdsec/pkg/csconfig" "github.com/crowdsecurity/crowdsec/pkg/database" "github.com/crowdsecurity/crowdsec/pkg/logging" @@ -18,7 +17,7 @@ import ( type Controller struct { DBClient *database.Client - Router *gin.Engine + Router *router.Router Profiles []*csconfig.ProfileCfg AlertsAddChan chan []*models.Alert DecisionDeleteChan chan []*models.Decision @@ -44,6 +43,9 @@ func (c *Controller) Init() error { */ + // Build router after all routes are registered + c.Router.Build() + return nil } @@ -59,20 +61,26 @@ func serveHealth() http.HandlerFunc { return health.NewHandler(checker) } -func eitherAuthMiddleware(jwtMiddleware gin.HandlerFunc, apiKeyMiddleware gin.HandlerFunc) gin.HandlerFunc { - return func(c *gin.Context) { - switch { - case c.GetHeader("X-Api-Key") != "": - apiKeyMiddleware(c) - case c.GetHeader("Authorization") != "": - jwtMiddleware(c) - // uh no auth header. is this TLS with mutual authentication? - case strings.HasPrefix(c.Request.UserAgent(), "crowdsec/"): - // guess log processors by sniffing user-agent - jwtMiddleware(c) - default: - apiKeyMiddleware(c) - } +// eitherAuthMiddleware creates a middleware that uses JWT or API key based on request headers +func eitherAuthMiddleware(jwtMiddleware router.Middleware, apiKeyMiddleware router.Middleware) router.Middleware { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Determine which auth method to use based on headers + switch { + case r.Header.Get("X-Api-Key") != "": + // Use API key middleware + apiKeyMiddleware(next).ServeHTTP(w, r) + case r.Header.Get("Authorization") != "": + // Use JWT middleware + jwtMiddleware(next).ServeHTTP(w, r) + case strings.HasPrefix(r.UserAgent(), "crowdsec/"): + // Guess log processors by sniffing user-agent - use JWT + jwtMiddleware(next).ServeHTTP(w, r) + default: + // Default to API key + apiKeyMiddleware(next).ServeHTTP(w, r) + } + }) } } @@ -95,61 +103,58 @@ func (c *Controller) NewV1() error { return err } - c.Router.GET("/health", gin.WrapF(serveHealth())) - c.Router.Use(v1.PrometheusMiddleware()) - // We don't want to compress the response body as it would likely break some existing bouncers - // But we do want to automatically uncompress incoming requests - c.Router.Use(gzip.Gzip(gzip.NoCompression, gzip.WithDecompressOnly(), gzip.WithDecompressFn(gzip.DefaultDecompressHandle))) - c.Router.HandleMethodNotAllowed = true - c.Router.UnescapePathValues = true - c.Router.UseRawPath = true - c.Router.NoRoute(func(ctx *gin.Context) { - ctx.AbortWithStatus(http.StatusNotFound) - }) - c.Router.NoMethod(func(ctx *gin.Context) { - ctx.AbortWithStatus(http.StatusMethodNotAllowed) + // Register health endpoint + c.Router.HandleFunc("/health", http.MethodGet, func(w http.ResponseWriter, r *http.Request) { + serveHealth()(w, r) }) + // Apply global middleware + c.Router.Use(v1.PrometheusMiddleware()) + // Note: Gzip decompression middleware is already applied in apiserver.NewServer() + // Note: Method not allowed and 404 handling will be done by http.ServeMux + groupV1 := c.Router.Group("/v1") - groupV1.POST("/watchers", c.HandlerV1.AbortRemoteIf(c.DisableRemoteLapiRegistration), c.HandlerV1.CreateMachine) - groupV1.POST("/watchers/login", c.HandlerV1.Middlewares.JWT.Middleware.LoginHandler) + + // Apply AbortRemoteIf middleware for /watchers endpoint + abortMiddleware := c.HandlerV1.AbortRemoteIf(c.DisableRemoteLapiRegistration) + watchersGroup := groupV1.Group("") + watchersGroup.Use(abortMiddleware) + watchersGroup.HandleFunc("/watchers", http.MethodPost, c.HandlerV1.CreateMachine) + + groupV1.HandleFunc("/watchers/login", http.MethodPost, c.HandlerV1.Middlewares.JWT.LoginHandler) jwtAuth := groupV1.Group("") - jwtAuth.GET("/refresh_token", c.HandlerV1.Middlewares.JWT.Middleware.RefreshHandler) - jwtAuth.Use(c.HandlerV1.Middlewares.JWT.Middleware.MiddlewareFunc(), v1.PrometheusMachinesMiddleware()) - { - jwtAuth.POST("/alerts", c.HandlerV1.CreateAlert) - jwtAuth.GET("/alerts", c.HandlerV1.FindAlerts) - jwtAuth.HEAD("/alerts", c.HandlerV1.FindAlerts) - jwtAuth.GET("/alerts/:alert_id", c.HandlerV1.FindAlertByID) - jwtAuth.HEAD("/alerts/:alert_id", c.HandlerV1.FindAlertByID) - jwtAuth.DELETE("/alerts/:alert_id", c.HandlerV1.DeleteAlertByID) - jwtAuth.DELETE("/alerts", c.HandlerV1.DeleteAlerts) - jwtAuth.DELETE("/decisions", c.HandlerV1.DeleteDecisions) - jwtAuth.DELETE("/decisions/:decision_id", c.HandlerV1.DeleteDecisionById) - jwtAuth.GET("/heartbeat", c.HandlerV1.HeartBeat) - jwtAuth.GET("/allowlists", c.HandlerV1.GetAllowlists) - jwtAuth.GET("/allowlists/:allowlist_name", c.HandlerV1.GetAllowlist) - jwtAuth.GET("/allowlists/check/:ip_or_range", c.HandlerV1.CheckInAllowlist) - jwtAuth.HEAD("/allowlists/check/:ip_or_range", c.HandlerV1.CheckInAllowlist) - jwtAuth.POST("/allowlists/check", c.HandlerV1.CheckInAllowlistBulk) - jwtAuth.DELETE("/watchers/self", c.HandlerV1.DeleteMachine) - } + jwtAuth.HandleFunc("/refresh_token", http.MethodGet, c.HandlerV1.Middlewares.JWT.RefreshHandler) + jwtAuth.Use(c.HandlerV1.Middlewares.JWT.MiddlewareFunc(), v1.PrometheusMachinesMiddleware()) + + // JWT authenticated routes - convert :param to {param} format for Go 1.22+ + jwtAuth.HandleFunc("/alerts", http.MethodPost, c.HandlerV1.CreateAlert) + jwtAuth.HandleFunc("/alerts", http.MethodGet, c.HandlerV1.FindAlerts) + jwtAuth.HandleFunc("/alerts", http.MethodHead, c.HandlerV1.FindAlerts) + jwtAuth.HandleFunc("/alerts/{alert_id}", http.MethodGet, c.HandlerV1.FindAlertByID) + jwtAuth.HandleFunc("/alerts/{alert_id}", http.MethodHead, c.HandlerV1.FindAlertByID) + jwtAuth.HandleFunc("/alerts/{alert_id}", http.MethodDelete, c.HandlerV1.DeleteAlertByID) + jwtAuth.HandleFunc("/alerts", http.MethodDelete, c.HandlerV1.DeleteAlerts) + jwtAuth.HandleFunc("/decisions", http.MethodDelete, c.HandlerV1.DeleteDecisions) + jwtAuth.HandleFunc("/decisions/{decision_id}", http.MethodDelete, c.HandlerV1.DeleteDecisionById) + jwtAuth.HandleFunc("/heartbeat", http.MethodGet, c.HandlerV1.HeartBeat) + jwtAuth.HandleFunc("/allowlists", http.MethodGet, c.HandlerV1.GetAllowlists) + jwtAuth.HandleFunc("/allowlists/{allowlist_name}", http.MethodGet, c.HandlerV1.GetAllowlist) + jwtAuth.HandleFunc("/allowlists/check/{ip_or_range}", http.MethodGet, c.HandlerV1.CheckInAllowlist) + jwtAuth.HandleFunc("/allowlists/check/{ip_or_range}", http.MethodHead, c.HandlerV1.CheckInAllowlist) + jwtAuth.HandleFunc("/allowlists/check", http.MethodPost, c.HandlerV1.CheckInAllowlistBulk) + jwtAuth.HandleFunc("/watchers/self", http.MethodDelete, c.HandlerV1.DeleteMachine) apiKeyAuth := groupV1.Group("") apiKeyAuth.Use(c.HandlerV1.Middlewares.APIKey.MiddlewareFunc(), v1.PrometheusBouncersMiddleware()) - { - apiKeyAuth.GET("/decisions", c.HandlerV1.GetDecision) - apiKeyAuth.HEAD("/decisions", c.HandlerV1.GetDecision) - apiKeyAuth.GET("/decisions/stream", c.HandlerV1.StreamDecision) - apiKeyAuth.HEAD("/decisions/stream", c.HandlerV1.StreamDecision) - } + apiKeyAuth.HandleFunc("/decisions", http.MethodGet, c.HandlerV1.GetDecision) + apiKeyAuth.HandleFunc("/decisions", http.MethodHead, c.HandlerV1.GetDecision) + apiKeyAuth.HandleFunc("/decisions/stream", http.MethodGet, c.HandlerV1.StreamDecision) + apiKeyAuth.HandleFunc("/decisions/stream", http.MethodHead, c.HandlerV1.StreamDecision) eitherAuth := groupV1.Group("") - eitherAuth.Use(eitherAuthMiddleware(c.HandlerV1.Middlewares.JWT.Middleware.MiddlewareFunc(), c.HandlerV1.Middlewares.APIKey.MiddlewareFunc())) - { - eitherAuth.POST("/usage-metrics", c.HandlerV1.UsageMetrics) - } + eitherAuth.Use(eitherAuthMiddleware(c.HandlerV1.Middlewares.JWT.MiddlewareFunc(), c.HandlerV1.Middlewares.APIKey.MiddlewareFunc())) + eitherAuth.HandleFunc("/usage-metrics", http.MethodPost, c.HandlerV1.UsageMetrics) return nil } diff --git a/pkg/apiserver/controllers/v1/alerts.go b/pkg/apiserver/controllers/v1/alerts.go index 7b59140d192..c77607b0083 100644 --- a/pkg/apiserver/controllers/v1/alerts.go +++ b/pkg/apiserver/controllers/v1/alerts.go @@ -6,14 +6,15 @@ import ( "fmt" "net" "net/http" + "net/netip" "strconv" "time" - "github.com/gin-gonic/gin" "github.com/go-openapi/strfmt" "github.com/google/uuid" log "github.com/sirupsen/logrus" + "github.com/crowdsecurity/crowdsec/pkg/apiserver/router" "github.com/crowdsecurity/crowdsec/pkg/database/ent" "github.com/crowdsecurity/crowdsec/pkg/models" "github.com/crowdsecurity/crowdsec/pkg/types" @@ -145,19 +146,19 @@ func (c *Controller) isAllowListed(ctx context.Context, alert *models.Alert) (bo } // CreateAlert writes the alerts received in the body to the database -func (c *Controller) CreateAlert(gctx *gin.Context) { +func (c *Controller) CreateAlert(w http.ResponseWriter, r *http.Request) { var input models.AddAlertsRequest - ctx := gctx.Request.Context() - machineID, _ := getMachineIDFromContext(gctx) + ctx := r.Context() + machineID, _ := getMachineIDFromContext(r) - if err := gctx.ShouldBindJSON(&input); err != nil { - gctx.JSON(http.StatusBadRequest, gin.H{"message": err.Error()}) + if err := router.BindJSON(r, &input); err != nil { + router.WriteJSON(w, http.StatusBadRequest, map[string]string{"message": err.Error()}) return } if err := input.Validate(strfmt.Default); err != nil { - c.HandleDBErrors(gctx, err) + c.HandleDBErrors(w, err) return } @@ -238,7 +239,7 @@ func (c *Controller) CreateAlert(gctx *gin.Context) { case "ignore": profile.Logger.Warningf("ignoring error: %s", err) default: - gctx.JSON(http.StatusInternalServerError, gin.H{"message": err.Error()}) + router.WriteJSON(w, http.StatusInternalServerError, map[string]string{"message": err.Error()}) return } } @@ -275,7 +276,7 @@ func (c *Controller) CreateAlert(gctx *gin.Context) { c.DBClient.CanFlush = true if err != nil { - c.HandleDBErrors(gctx, err) + c.HandleDBErrors(w, err) return } @@ -288,103 +289,105 @@ func (c *Controller) CreateAlert(gctx *gin.Context) { } } - gctx.JSON(http.StatusCreated, alerts) + router.WriteJSON(w, http.StatusCreated, alerts) } // FindAlerts returns alerts from the database based on the specified filter -func (c *Controller) FindAlerts(gctx *gin.Context) { - ctx := gctx.Request.Context() +func (c *Controller) FindAlerts(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() - result, err := c.DBClient.QueryAlertWithFilter(ctx, gctx.Request.URL.Query()) + result, err := c.DBClient.QueryAlertWithFilter(ctx, r.URL.Query()) if err != nil { - c.HandleDBErrors(gctx, err) + c.HandleDBErrors(w, err) return } data := FormatAlerts(result) - if gctx.Request.Method == http.MethodHead { - gctx.String(http.StatusOK, "") + if r.Method == http.MethodHead { + router.String(w, http.StatusOK, "") return } - gctx.JSON(http.StatusOK, data) + router.WriteJSON(w, http.StatusOK, data) } // FindAlertByID returns the alert associated with the ID -func (c *Controller) FindAlertByID(gctx *gin.Context) { - ctx := gctx.Request.Context() - alertIDStr := gctx.Param("alert_id") +func (c *Controller) FindAlertByID(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + alertIDStr := router.PathValue(r, "alert_id") alertID, err := strconv.Atoi(alertIDStr) if err != nil { - gctx.JSON(http.StatusBadRequest, gin.H{"message": "alert_id must be valid integer"}) + router.WriteJSON(w, http.StatusBadRequest, map[string]string{"message": "alert_id must be valid integer"}) return } result, err := c.DBClient.GetAlertByID(ctx, alertID) if err != nil { - c.HandleDBErrors(gctx, err) + c.HandleDBErrors(w, err) return } data := FormatOneAlert(result) - if gctx.Request.Method == http.MethodHead { - gctx.String(http.StatusOK, "") + if r.Method == http.MethodHead { + router.String(w, http.StatusOK, "") return } - gctx.JSON(http.StatusOK, data) + router.WriteJSON(w, http.StatusOK, data) } -func authIP(gctx *gin.Context, trustedIPs []net.IPNet) (ip string, trusted bool) { - ip = gctx.ClientIP() - return ip, ip == "127.0.0.1" || ip == "::1" || networksContainIP(trustedIPs, ip) || isUnixSocket(gctx) +func authIP(r *http.Request, trustedIPs []net.IPNet) (ip string, trusted bool) { + // Get client IP from context (resolved by ClientIPMiddleware) + // trustedIPs parameter is the ACL allowlist, not proxy networks + ip = router.GetClientIP(r) + return ip, ip == "127.0.0.1" || ip == "::1" || networksContainIP(trustedIPs, ip) || isUnixSocket(r) } // DeleteAlertByID delete the alert associated to the ID -func (c *Controller) DeleteAlertByID(gctx *gin.Context) { +func (c *Controller) DeleteAlertByID(w http.ResponseWriter, r *http.Request) { var err error - ctx := gctx.Request.Context() + ctx := r.Context() - if incomingIP, trusted := authIP(gctx, c.TrustedIPs); !trusted { - gctx.JSON(http.StatusForbidden, gin.H{"message": fmt.Sprintf("access forbidden from this IP (%s)", incomingIP)}) + if incomingIP, trusted := authIP(r, c.TrustedIPs); !trusted { + router.WriteJSON(w, http.StatusForbidden, map[string]string{"message": fmt.Sprintf("access forbidden from this IP (%s)", incomingIP)}) return } - decisionIDStr := gctx.Param("alert_id") + decisionIDStr := router.PathValue(r, "alert_id") decisionID, err := strconv.Atoi(decisionIDStr) if err != nil { - gctx.JSON(http.StatusBadRequest, gin.H{"message": "alert_id must be valid integer"}) + router.WriteJSON(w, http.StatusBadRequest, map[string]string{"message": "alert_id must be valid integer"}) return } err = c.DBClient.DeleteAlertByID(ctx, decisionID) if err != nil { - c.HandleDBErrors(gctx, err) + c.HandleDBErrors(w, err) return } deleteAlertResp := models.DeleteAlertsResponse{NbDeleted: "1"} - gctx.JSON(http.StatusOK, deleteAlertResp) + router.WriteJSON(w, http.StatusOK, deleteAlertResp) } // DeleteAlerts deletes alerts from the database based on the specified filter -func (c *Controller) DeleteAlerts(gctx *gin.Context) { - ctx := gctx.Request.Context() +func (c *Controller) DeleteAlerts(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() - if incomingIP, trusted := authIP(gctx, c.TrustedIPs); !trusted { - gctx.JSON(http.StatusForbidden, gin.H{"message": fmt.Sprintf("access forbidden from this IP (%s)", incomingIP)}) + if incomingIP, trusted := authIP(r, c.TrustedIPs); !trusted { + router.WriteJSON(w, http.StatusForbidden, map[string]string{"message": fmt.Sprintf("access forbidden from this IP (%s)", incomingIP)}) return } - nbDeleted, err := c.DBClient.DeleteAlertWithFilter(ctx, gctx.Request.URL.Query()) + nbDeleted, err := c.DBClient.DeleteAlertWithFilter(ctx, r.URL.Query()) if err != nil { - c.HandleDBErrors(gctx, err) + c.HandleDBErrors(w, err) return } @@ -392,13 +395,21 @@ func (c *Controller) DeleteAlerts(gctx *gin.Context) { NbDeleted: strconv.Itoa(nbDeleted), } - gctx.JSON(http.StatusOK, deleteAlertsResp) + router.WriteJSON(w, http.StatusOK, deleteAlertsResp) } func networksContainIP(networks []net.IPNet, ip string) bool { - parsedIP := net.ParseIP(ip) + addr, err := netip.ParseAddr(ip) + if err != nil { + return false + } + for _, network := range networks { - if network.Contains(parsedIP) { + prefix, err := netip.ParsePrefix(network.String()) + if err != nil { + continue + } + if prefix.Contains(addr) { return true } } diff --git a/pkg/apiserver/controllers/v1/allowlist.go b/pkg/apiserver/controllers/v1/allowlist.go index e35354ff330..253a26b5f06 100644 --- a/pkg/apiserver/controllers/v1/allowlist.go +++ b/pkg/apiserver/controllers/v1/allowlist.go @@ -4,22 +4,22 @@ import ( "net/http" "time" - "github.com/gin-gonic/gin" "github.com/go-openapi/strfmt" + "github.com/crowdsecurity/crowdsec/pkg/apiserver/router" "github.com/crowdsecurity/crowdsec/pkg/models" ) -func (c *Controller) CheckInAllowlistBulk(gctx *gin.Context) { +func (c *Controller) CheckInAllowlistBulk(w http.ResponseWriter, r *http.Request) { var req models.BulkCheckAllowlistRequest - if err := gctx.ShouldBindJSON(&req); err != nil { - gctx.JSON(http.StatusBadRequest, gin.H{"message": err.Error()}) + if err := router.BindJSON(r, &req); err != nil { + router.WriteJSON(w, http.StatusBadRequest, map[string]string{"message": err.Error()}) return } if len(req.Targets) == 0 { - gctx.JSON(http.StatusBadRequest, gin.H{"message": "targets list cannot be empty"}) + router.WriteJSON(w, http.StatusBadRequest, map[string]string{"message": "targets list cannot be empty"}) return } @@ -28,9 +28,9 @@ func (c *Controller) CheckInAllowlistBulk(gctx *gin.Context) { } for _, target := range req.Targets { - lists, err := c.DBClient.IsAllowlistedBy(gctx.Request.Context(), target) + lists, err := c.DBClient.IsAllowlistedBy(r.Context(), target) if err != nil { - c.HandleDBErrors(gctx, err) + c.HandleDBErrors(w, err) return } @@ -44,28 +44,28 @@ func (c *Controller) CheckInAllowlistBulk(gctx *gin.Context) { }) } - gctx.JSON(http.StatusOK, resp) + router.WriteJSON(w, http.StatusOK, resp) } -func (c *Controller) CheckInAllowlist(gctx *gin.Context) { - value := gctx.Param("ip_or_range") +func (c *Controller) CheckInAllowlist(w http.ResponseWriter, r *http.Request) { + value := router.PathValue(r, "ip_or_range") if value == "" { - gctx.JSON(http.StatusBadRequest, gin.H{"message": "value is required"}) + router.WriteJSON(w, http.StatusBadRequest, map[string]string{"message": "value is required"}) return } - allowlisted, reason, err := c.DBClient.IsAllowlisted(gctx.Request.Context(), value) + allowlisted, reason, err := c.DBClient.IsAllowlisted(r.Context(), value) if err != nil { - c.HandleDBErrors(gctx, err) + c.HandleDBErrors(w, err) return } - if gctx.Request.Method == http.MethodHead { + if r.Method == http.MethodHead { if allowlisted { - gctx.Status(http.StatusOK) + w.WriteHeader(http.StatusOK) } else { - gctx.Status(http.StatusNoContent) + w.WriteHeader(http.StatusNoContent) } return @@ -76,17 +76,17 @@ func (c *Controller) CheckInAllowlist(gctx *gin.Context) { Reason: reason, } - gctx.JSON(http.StatusOK, resp) + router.WriteJSON(w, http.StatusOK, resp) } -func (c *Controller) GetAllowlists(gctx *gin.Context) { - params := gctx.Request.URL.Query() +func (c *Controller) GetAllowlists(w http.ResponseWriter, r *http.Request) { + params := r.URL.Query() withContent := params.Get("with_content") == "true" - allowlists, err := c.DBClient.ListAllowLists(gctx.Request.Context(), withContent) + allowlists, err := c.DBClient.ListAllowLists(r.Context(), withContent) if err != nil { - c.HandleDBErrors(gctx, err) + c.HandleDBErrors(w, err) return } @@ -121,18 +121,18 @@ func (c *Controller) GetAllowlists(gctx *gin.Context) { }) } - gctx.JSON(http.StatusOK, resp) + router.WriteJSON(w, http.StatusOK, resp) } -func (c *Controller) GetAllowlist(gctx *gin.Context) { - allowlist := gctx.Param("allowlist_name") +func (c *Controller) GetAllowlist(w http.ResponseWriter, r *http.Request) { + allowlist := router.PathValue(r, "allowlist_name") - params := gctx.Request.URL.Query() + params := r.URL.Query() withContent := params.Get("with_content") == "true" - allowlistModel, err := c.DBClient.GetAllowList(gctx.Request.Context(), allowlist, withContent) + allowlistModel, err := c.DBClient.GetAllowList(r.Context(), allowlist, withContent) if err != nil { - c.HandleDBErrors(gctx, err) + c.HandleDBErrors(w, err) return } @@ -163,5 +163,5 @@ func (c *Controller) GetAllowlist(gctx *gin.Context) { Items: items, } - gctx.JSON(http.StatusOK, resp) + router.WriteJSON(w, http.StatusOK, resp) } diff --git a/pkg/apiserver/controllers/v1/decisions.go b/pkg/apiserver/controllers/v1/decisions.go index 86dd9845071..f4d08624a84 100644 --- a/pkg/apiserver/controllers/v1/decisions.go +++ b/pkg/apiserver/controllers/v1/decisions.go @@ -3,13 +3,14 @@ package v1 import ( "context" "encoding/json" + "maps" "net/http" "strconv" "time" - "github.com/gin-gonic/gin" log "github.com/sirupsen/logrus" + "github.com/crowdsecurity/crowdsec/pkg/apiserver/router" "github.com/crowdsecurity/crowdsec/pkg/database/ent" "github.com/crowdsecurity/crowdsec/pkg/fflag" "github.com/crowdsecurity/crowdsec/pkg/models" @@ -37,24 +38,24 @@ func FormatDecisions(decisions []*ent.Decision) []*models.Decision { return results } -func (c *Controller) GetDecision(gctx *gin.Context) { +func (c *Controller) GetDecision(w http.ResponseWriter, r *http.Request) { var ( results []*models.Decision data []*ent.Decision ) - ctx := gctx.Request.Context() + ctx := r.Context() - bouncerInfo, err := getBouncerFromContext(gctx) + bouncerInfo, err := getBouncerFromContext(r) if err != nil { - gctx.JSON(http.StatusUnauthorized, gin.H{"message": "not allowed"}) + router.WriteJSON(w, http.StatusUnauthorized, map[string]string{"message": "not allowed"}) return } - data, err = c.DBClient.QueryDecisionWithFilter(ctx, gctx.Request.URL.Query()) + data, err = c.DBClient.QueryDecisionWithFilter(ctx, r.URL.Query()) if err != nil { - c.HandleDBErrors(gctx, err) + c.HandleDBErrors(w, err) return } @@ -63,13 +64,13 @@ func (c *Controller) GetDecision(gctx *gin.Context) { /*let's follow a naive logic : when a bouncer queries /decisions, if the answer is empty, we assume there is no decision for this ip/user/..., but if it's non-empty, it means that there is one or more decisions for this target*/ if len(results) > 0 { - PrometheusBouncersHasNonEmptyDecision(gctx) + PrometheusBouncersHasNonEmptyDecision(r) } else { - PrometheusBouncersHasEmptyDecision(gctx) + PrometheusBouncersHasEmptyDecision(r) } - if gctx.Request.Method == http.MethodHead { - gctx.String(http.StatusOK, "") + if r.Method == http.MethodHead { + router.String(w, http.StatusOK, "") return } @@ -80,24 +81,24 @@ func (c *Controller) GetDecision(gctx *gin.Context) { } } - gctx.JSON(http.StatusOK, results) + router.WriteJSON(w, http.StatusOK, results) } -func (c *Controller) DeleteDecisionById(gctx *gin.Context) { - decisionIDStr := gctx.Param("decision_id") +func (c *Controller) DeleteDecisionById(w http.ResponseWriter, r *http.Request) { + decisionIDStr := router.PathValue(r, "decision_id") decisionID, err := strconv.Atoi(decisionIDStr) if err != nil { - gctx.JSON(http.StatusBadRequest, gin.H{"message": "decision_id must be valid integer"}) + router.WriteJSON(w, http.StatusBadRequest, map[string]string{"message": "decision_id must be valid integer"}) return } - ctx := gctx.Request.Context() + ctx := r.Context() nbDeleted, deletedFromDB, err := c.DBClient.ExpireDecisionByID(ctx, decisionID) if err != nil { - c.HandleDBErrors(gctx, err) + c.HandleDBErrors(w, err) return } @@ -113,15 +114,15 @@ func (c *Controller) DeleteDecisionById(gctx *gin.Context) { NbDeleted: strconv.Itoa(nbDeleted), } - gctx.JSON(http.StatusOK, deleteDecisionResp) + router.WriteJSON(w, http.StatusOK, deleteDecisionResp) } -func (c *Controller) DeleteDecisions(gctx *gin.Context) { - ctx := gctx.Request.Context() +func (c *Controller) DeleteDecisions(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() - nbDeleted, deletedFromDB, err := c.DBClient.ExpireDecisionsWithFilter(ctx, gctx.Request.URL.Query()) + nbDeleted, deletedFromDB, err := c.DBClient.ExpireDecisionsWithFilter(ctx, r.URL.Query()) if err != nil { - c.HandleDBErrors(gctx, err) + c.HandleDBErrors(w, err) return } @@ -137,27 +138,35 @@ func (c *Controller) DeleteDecisions(gctx *gin.Context) { NbDeleted: strconv.Itoa(nbDeleted), } - gctx.JSON(http.StatusOK, deleteDecisionResp) + router.WriteJSON(w, http.StatusOK, deleteDecisionResp) } -func writeStartupDecisions(gctx *gin.Context, filters map[string][]string, dbFunc func(context.Context, map[string][]string) ([]*ent.Decision, error)) error { +func writeStartupDecisions(w http.ResponseWriter, r *http.Request, filters map[string][]string, dbFunc func(context.Context, map[string][]string) ([]*ent.Decision, error)) error { // respBuffer := bytes.NewBuffer([]byte{}) limit := 30000 // FIXME : make it configurable needComma := false lastId := 0 - ctx := gctx.Request.Context() + ctx := r.Context() + flusher, hasFlusher := w.(http.Flusher) + + // Work on a copy of filters to avoid mutating the shared map + filtersCopy := make(map[string][]string, len(filters)+2) + maps.Copy(filtersCopy, filters) limitStr := strconv.Itoa(limit) - filters["limit"] = []string{limitStr} + filtersCopy["limit"] = []string{limitStr} for { if lastId > 0 { lastIdStr := strconv.Itoa(lastId) - filters["id_gt"] = []string{lastIdStr} + filtersCopy["id_gt"] = []string{lastIdStr} + } else { + // Clear id_gt if it exists from previous iteration + delete(filtersCopy, "id_gt") } - data, err := dbFunc(ctx, filters) + data, err := dbFunc(ctx, filtersCopy) if err != nil { return err } @@ -170,27 +179,29 @@ func writeStartupDecisions(gctx *gin.Context, filters map[string][]string, dbFun decisionJSON, _ := json.Marshal(decision) if needComma { - // respBuffer.Write([]byte(",")) - gctx.Writer.WriteString(",") + if _, err := w.Write([]byte(",")); err != nil { + return err + } } else { needComma = true } - //respBuffer.Write(decisionJSON) - //_, err := gctx.Writer.Write(respBuffer.Bytes()) - _, err := gctx.Writer.Write(decisionJSON) + _, err := w.Write(decisionJSON) if err != nil { - gctx.Writer.Flush() + if hasFlusher { + flusher.Flush() + } return err } - // respBuffer.Reset() } } log.Debugf("startup: %d decisions returned (limit: %d, lastid: %d)", len(data), limit, lastId) if len(data) < limit { - gctx.Writer.Flush() + if hasFlusher { + flusher.Flush() + } break } @@ -199,24 +210,32 @@ func writeStartupDecisions(gctx *gin.Context, filters map[string][]string, dbFun return nil } -func writeDeltaDecisions(gctx *gin.Context, filters map[string][]string, lastPull *time.Time, dbFunc func(context.Context, *time.Time, map[string][]string) ([]*ent.Decision, error)) error { +func writeDeltaDecisions(w http.ResponseWriter, r *http.Request, filters map[string][]string, lastPull *time.Time, dbFunc func(context.Context, *time.Time, map[string][]string) ([]*ent.Decision, error)) error { // respBuffer := bytes.NewBuffer([]byte{}) limit := 30000 // FIXME : make it configurable needComma := false lastId := 0 - ctx := gctx.Request.Context() + ctx := r.Context() + flusher, hasFlusher := w.(http.Flusher) + + // Work on a copy of filters to avoid mutating the shared map + filtersCopy := make(map[string][]string, len(filters)+2) + maps.Copy(filtersCopy, filters) limitStr := strconv.Itoa(limit) - filters["limit"] = []string{limitStr} + filtersCopy["limit"] = []string{limitStr} for { if lastId > 0 { lastIdStr := strconv.Itoa(lastId) - filters["id_gt"] = []string{lastIdStr} + filtersCopy["id_gt"] = []string{lastIdStr} + } else { + // Clear id_gt if it exists from previous iteration + delete(filtersCopy, "id_gt") } - data, err := dbFunc(ctx, lastPull, filters) + data, err := dbFunc(ctx, lastPull, filtersCopy) if err != nil { return err } @@ -229,27 +248,29 @@ func writeDeltaDecisions(gctx *gin.Context, filters map[string][]string, lastPul decisionJSON, _ := json.Marshal(decision) if needComma { - // respBuffer.Write([]byte(",")) - gctx.Writer.WriteString(",") + if _, err := w.Write([]byte(",")); err != nil { + return err + } } else { needComma = true } - //respBuffer.Write(decisionJSON) - //_, err := gctx.Writer.Write(respBuffer.Bytes()) - _, err := gctx.Writer.Write(decisionJSON) + _, err := w.Write(decisionJSON) if err != nil { - gctx.Writer.Flush() + if hasFlusher { + flusher.Flush() + } return err } - // respBuffer.Reset() } } log.Debugf("startup: %d decisions returned (limit: %d, lastid: %d)", len(data), limit, lastId) if len(data) < limit { - gctx.Writer.Flush() + if hasFlusher { + flusher.Flush() + } break } @@ -258,85 +279,115 @@ func writeDeltaDecisions(gctx *gin.Context, filters map[string][]string, lastPul return nil } -func (c *Controller) StreamDecisionChunked(gctx *gin.Context, bouncerInfo *ent.Bouncer, streamStartTime time.Time, filters map[string][]string) error { - var err error - - gctx.Writer.Header().Set("Content-Type", "application/json") - gctx.Writer.Header().Set("Transfer-Encoding", "chunked") - gctx.Writer.WriteHeader(http.StatusOK) - gctx.Writer.WriteString(`{"new": [`) // No need to check for errors, the doc says it always returns nil +// writeStartupResponse writes startup decisions (both active and expired) to the response +func (c *Controller) writeStartupResponse(w http.ResponseWriter, r *http.Request, filters map[string][]string, flusher http.Flusher, hasFlusher bool) error { + // Active decisions + err := writeStartupDecisions(w, r, filters, c.DBClient.QueryAllDecisionsWithFilters) + if err != nil { + log.Errorf("failed sending new decisions for startup: %v", err) + _, _ = w.Write([]byte(`], "deleted": []}`)) + if hasFlusher { + flusher.Flush() + } + return err + } - // if the blocker just started, return all decisions - if val, ok := gctx.Request.URL.Query()["startup"]; ok && val[0] == "true" { - // Active decisions - err := writeStartupDecisions(gctx, filters, c.DBClient.QueryAllDecisionsWithFilters) - if err != nil { - log.Errorf("failed sending new decisions for startup: %v", err) - gctx.Writer.WriteString(`], "deleted": []}`) - gctx.Writer.Flush() + if _, err := w.Write([]byte(`], "deleted": [`)); err != nil { + return err + } - return err + // Expired decisions + err = writeStartupDecisions(w, r, filters, c.DBClient.QueryExpiredDecisionsWithFilters) + if err != nil { + log.Errorf("failed sending expired decisions for startup: %v", err) + _, _ = w.Write([]byte(`]}`)) + if hasFlusher { + flusher.Flush() } + return err + } - gctx.Writer.WriteString(`], "deleted": [`) - // Expired decisions - err = writeStartupDecisions(gctx, filters, c.DBClient.QueryExpiredDecisionsWithFilters) - if err != nil { - log.Errorf("failed sending expired decisions for startup: %v", err) - gctx.Writer.WriteString(`]}`) - gctx.Writer.Flush() + if _, err := w.Write([]byte(`]}`)); err != nil { + return err + } + if hasFlusher { + flusher.Flush() + } + return nil +} - return err +// writeDeltaResponse writes delta decisions (both new and expired) to the response +func (c *Controller) writeDeltaResponse(w http.ResponseWriter, r *http.Request, bouncerInfo *ent.Bouncer, filters map[string][]string, flusher http.Flusher, hasFlusher bool) error { + err := writeDeltaDecisions(w, r, filters, bouncerInfo.LastPull, c.DBClient.QueryNewDecisionsSinceWithFilters) + if err != nil { + log.Errorf("failed sending new decisions for delta: %v", err) + _, _ = w.Write([]byte(`], "deleted": []}`)) + if hasFlusher { + flusher.Flush() } + return err + } - gctx.Writer.WriteString(`]}`) - gctx.Writer.Flush() - } else { - err = writeDeltaDecisions(gctx, filters, bouncerInfo.LastPull, c.DBClient.QueryNewDecisionsSinceWithFilters) - if err != nil { - log.Errorf("failed sending new decisions for delta: %v", err) - gctx.Writer.WriteString(`], "deleted": []}`) - gctx.Writer.Flush() + if _, err := w.Write([]byte(`], "deleted": [`)); err != nil { + return err + } - return err + err = writeDeltaDecisions(w, r, filters, bouncerInfo.LastPull, c.DBClient.QueryExpiredDecisionsSinceWithFilters) + if err != nil { + log.Errorf("failed sending expired decisions for delta: %v", err) + _, _ = w.Write([]byte("]}")) + if hasFlusher { + flusher.Flush() } + return err + } - gctx.Writer.WriteString(`], "deleted": [`) + if _, err := w.Write([]byte("]}")); err != nil { + return err + } + if hasFlusher { + flusher.Flush() + } + return nil +} - err = writeDeltaDecisions(gctx, filters, bouncerInfo.LastPull, c.DBClient.QueryExpiredDecisionsSinceWithFilters) - if err != nil { - log.Errorf("failed sending expired decisions for delta: %v", err) - gctx.Writer.WriteString("]}") - gctx.Writer.Flush() +func (c *Controller) StreamDecisionChunked(w http.ResponseWriter, r *http.Request, bouncerInfo *ent.Bouncer, streamStartTime time.Time, filters map[string][]string) error { + flusher, hasFlusher := w.(http.Flusher) - return err - } + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Transfer-Encoding", "chunked") + w.WriteHeader(http.StatusOK) + if _, err := w.Write([]byte(`{"new": [`)); err != nil { // Write initial JSON structure + return err + } - gctx.Writer.WriteString("]}") - gctx.Writer.Flush() + // if the blocker just started, return all decisions + val, ok := r.URL.Query()["startup"] + if ok && val[0] == "true" { + return c.writeStartupResponse(w, r, filters, flusher, hasFlusher) } - return nil + return c.writeDeltaResponse(w, r, bouncerInfo, filters, flusher, hasFlusher) } -func (c *Controller) StreamDecisionNonChunked(gctx *gin.Context, bouncerInfo *ent.Bouncer, streamStartTime time.Time, filters map[string][]string) error { +func (c *Controller) StreamDecisionNonChunked(w http.ResponseWriter, r *http.Request, bouncerInfo *ent.Bouncer, streamStartTime time.Time, filters map[string][]string) error { var ( data []*ent.Decision err error ) - ctx := gctx.Request.Context() + ctx := r.Context() ret := make(map[string][]*models.Decision, 0) ret["new"] = []*models.Decision{} ret["deleted"] = []*models.Decision{} - if val, ok := gctx.Request.URL.Query()["startup"]; ok { + if val, ok := r.URL.Query()["startup"]; ok { if val[0] == "true" { data, err = c.DBClient.QueryAllDecisionsWithFilters(ctx, filters) if err != nil { log.Errorf("failed querying decisions: %v", err) - gctx.JSON(http.StatusInternalServerError, gin.H{"message": err.Error()}) + router.WriteJSON(w, http.StatusInternalServerError, map[string]string{"message": err.Error()}) return err } @@ -347,14 +398,14 @@ func (c *Controller) StreamDecisionNonChunked(gctx *gin.Context, bouncerInfo *en data, err = c.DBClient.QueryExpiredDecisionsWithFilters(ctx, filters) if err != nil { log.Errorf("unable to query expired decision for '%s' : %v", bouncerInfo.Name, err) - gctx.JSON(http.StatusInternalServerError, gin.H{"message": err.Error()}) + router.WriteJSON(w, http.StatusInternalServerError, map[string]string{"message": err.Error()}) return err } ret["deleted"] = FormatDecisions(data) - gctx.JSON(http.StatusOK, ret) + router.WriteJSON(w, http.StatusOK, ret) return nil } @@ -364,7 +415,7 @@ func (c *Controller) StreamDecisionNonChunked(gctx *gin.Context, bouncerInfo *en data, err = c.DBClient.QueryNewDecisionsSinceWithFilters(ctx, bouncerInfo.LastPull, filters) if err != nil { log.Errorf("unable to query new decision for '%s' : %v", bouncerInfo.Name, err) - gctx.JSON(http.StatusInternalServerError, gin.H{"message": err.Error()}) + router.WriteJSON(w, http.StatusInternalServerError, map[string]string{"message": err.Error()}) return err } @@ -380,51 +431,52 @@ func (c *Controller) StreamDecisionNonChunked(gctx *gin.Context, bouncerInfo *en data, err = c.DBClient.QueryExpiredDecisionsSinceWithFilters(ctx, &since, filters) // do we want to give exactly lastPull time ? if err != nil { log.Errorf("unable to query expired decision for '%s' : %v", bouncerInfo.Name, err) - gctx.JSON(http.StatusInternalServerError, gin.H{"message": err.Error()}) + router.WriteJSON(w, http.StatusInternalServerError, map[string]string{"message": err.Error()}) return err } ret["deleted"] = FormatDecisions(data) - gctx.JSON(http.StatusOK, ret) + router.WriteJSON(w, http.StatusOK, ret) return nil } -func (c *Controller) StreamDecision(gctx *gin.Context) { +func (c *Controller) StreamDecision(w http.ResponseWriter, r *http.Request) { var err error streamStartTime := time.Now().UTC() - bouncerInfo, err := getBouncerFromContext(gctx) + bouncerInfo, err := getBouncerFromContext(r) if err != nil { - gctx.JSON(http.StatusUnauthorized, gin.H{"message": "not allowed"}) + router.WriteJSON(w, http.StatusUnauthorized, map[string]string{"message": "not allowed"}) return } - if gctx.Request.Method == http.MethodHead { + if r.Method == http.MethodHead { // For HEAD, just return as the bouncer won't get a body anyway, so no need to query the db // We also don't update the last pull time, as it would mess with the delta sent on the next request (if done without startup=true) - gctx.String(http.StatusOK, "") + router.String(w, http.StatusOK, "") return } - filters := gctx.Request.URL.Query() + filters := r.URL.Query() if _, ok := filters["scopes"]; !ok { filters["scopes"] = []string{"ip,range"} } if fflag.ChunkedDecisionsStream.IsEnabled() { - err = c.StreamDecisionChunked(gctx, bouncerInfo, streamStartTime, filters) + err = c.StreamDecisionChunked(w, r, bouncerInfo, streamStartTime, filters) } else { - err = c.StreamDecisionNonChunked(gctx, bouncerInfo, streamStartTime, filters) + err = c.StreamDecisionNonChunked(w, r, bouncerInfo, streamStartTime, filters) } if err == nil { // Only update the last pull time if no error occurred when sending the decisions to avoid missing decisions - // Do not reuse the context provided by gin because we already have sent the response to the client, so there's a chance for it to already be canceled + // Use a background context since we've already sent the response and the request context may be canceled + //nolint:contextcheck // We intentionally use context.Background() here since the response is already sent if err := c.DBClient.UpdateBouncerLastPull(context.Background(), streamStartTime, bouncerInfo.ID); err != nil { log.Errorf("unable to update bouncer '%s' pull: %v", bouncerInfo.Name, err) } diff --git a/pkg/apiserver/controllers/v1/errors.go b/pkg/apiserver/controllers/v1/errors.go index e58785ccbc0..c4928de8076 100644 --- a/pkg/apiserver/controllers/v1/errors.go +++ b/pkg/apiserver/controllers/v1/errors.go @@ -5,24 +5,23 @@ import ( "net/http" "strings" - "github.com/gin-gonic/gin" - + "github.com/crowdsecurity/crowdsec/pkg/apiserver/router" "github.com/crowdsecurity/crowdsec/pkg/database" ) -func (*Controller) HandleDBErrors(gctx *gin.Context, err error) { +func (*Controller) HandleDBErrors(w http.ResponseWriter, err error) { switch { case errors.Is(err, database.ItemNotFound): - gctx.JSON(http.StatusNotFound, gin.H{"message": err.Error()}) + router.WriteJSON(w, http.StatusNotFound, map[string]string{"message": err.Error()}) return case errors.Is(err, database.UserExists): - gctx.JSON(http.StatusForbidden, gin.H{"message": err.Error()}) + router.WriteJSON(w, http.StatusForbidden, map[string]string{"message": err.Error()}) return case errors.Is(err, database.HashError): - gctx.JSON(http.StatusBadRequest, gin.H{"message": err.Error()}) + router.WriteJSON(w, http.StatusBadRequest, map[string]string{"message": err.Error()}) return default: - gctx.JSON(http.StatusInternalServerError, gin.H{"message": err.Error()}) + router.WriteJSON(w, http.StatusInternalServerError, map[string]string{"message": err.Error()}) return } } diff --git a/pkg/apiserver/controllers/v1/heartbeat.go b/pkg/apiserver/controllers/v1/heartbeat.go index 799b736ccfe..f2e448c30b9 100644 --- a/pkg/apiserver/controllers/v1/heartbeat.go +++ b/pkg/apiserver/controllers/v1/heartbeat.go @@ -1,20 +1,16 @@ package v1 -import ( - "net/http" +import "net/http" - "github.com/gin-gonic/gin" -) +func (c *Controller) HeartBeat(w http.ResponseWriter, r *http.Request) { + machineID, _ := getMachineIDFromContext(r) -func (c *Controller) HeartBeat(gctx *gin.Context) { - machineID, _ := getMachineIDFromContext(gctx) - - ctx := gctx.Request.Context() + ctx := r.Context() if err := c.DBClient.UpdateMachineLastHeartBeat(ctx, machineID); err != nil { - c.HandleDBErrors(gctx, err) + c.HandleDBErrors(w, err) return } - gctx.Status(http.StatusOK) + w.WriteHeader(http.StatusOK) } diff --git a/pkg/apiserver/controllers/v1/machines.go b/pkg/apiserver/controllers/v1/machines.go index 8c799fa5114..72f3d92eabe 100644 --- a/pkg/apiserver/controllers/v1/machines.go +++ b/pkg/apiserver/controllers/v1/machines.go @@ -2,27 +2,31 @@ package v1 import ( "errors" - "net" "net/http" + "net/netip" - "github.com/gin-gonic/gin" "github.com/go-openapi/strfmt" log "github.com/sirupsen/logrus" + "github.com/crowdsecurity/crowdsec/pkg/apiserver/router" "github.com/crowdsecurity/crowdsec/pkg/models" "github.com/crowdsecurity/crowdsec/pkg/types" ) -func (c *Controller) shouldAutoRegister(token string, gctx *gin.Context) (bool, error) { +func (c *Controller) shouldAutoRegister(token string, r *http.Request) (bool, error) { if !*c.AutoRegisterCfg.Enable { return false, nil } - clientIP := net.ParseIP(gctx.ClientIP()) + // Get client IP from context (resolved by ClientIPMiddleware) + clientIPStr := router.GetClientIP(r) + clientIP, err := netip.ParseAddr(clientIPStr) - // Can probaby happen if using unix socket ? - if clientIP == nil { - log.Warnf("Failed to parse client IP for watcher self registration: %s", gctx.ClientIP()) + // Can probably happen if using unix socket ? + if err != nil { + log.Warnf("Failed to parse client IP for watcher self registration: %s", clientIPStr) + // Return false, nil to indicate IP is not in range (can't parse = not in range) + //nolint:nilerr // Returning false, nil is correct here - can't parse IP means not in range, not an error return false, nil } @@ -37,7 +41,12 @@ func (c *Controller) shouldAutoRegister(token string, gctx *gin.Context) (bool, // Check the source IP for _, ipRange := range c.AutoRegisterCfg.AllowedRangesParsed { - if ipRange.Contains(clientIP) { + // Convert net.IPNet to netip.Prefix for comparison + prefix, err := netip.ParsePrefix(ipRange.String()) + if err != nil { + continue + } + if prefix.Contains(clientIP) { return true, nil } } @@ -45,62 +54,68 @@ func (c *Controller) shouldAutoRegister(token string, gctx *gin.Context) (bool, return false, errors.New("IP not in allowed range for auto registration") } -func (c *Controller) CreateMachine(gctx *gin.Context) { - ctx := gctx.Request.Context() +func (c *Controller) CreateMachine(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() var input models.WatcherRegistrationRequest - if err := gctx.ShouldBindJSON(&input); err != nil { - gctx.JSON(http.StatusBadRequest, gin.H{"message": err.Error()}) + if err := router.BindJSON(r, &input); err != nil { + router.WriteJSON(w, http.StatusBadRequest, map[string]string{"message": err.Error()}) return } if err := input.Validate(strfmt.Default); err != nil { - gctx.JSON(http.StatusUnprocessableEntity, gin.H{"message": err.Error()}) + router.WriteJSON(w, http.StatusUnprocessableEntity, map[string]string{"message": err.Error()}) return } - autoRegister, err := c.shouldAutoRegister(input.RegistrationToken, gctx) + // Get client IP from context (resolved by ClientIPMiddleware) + // c.TrustedIPs is the ACL allowlist, not proxy networks + clientIP := router.GetClientIP(r) + autoRegister, err := c.shouldAutoRegister(input.RegistrationToken, r) if err != nil { - log.WithFields(log.Fields{"ip": gctx.ClientIP(), "machine_id": *input.MachineID}).Errorf("Auto-register failed: %s", err) - gctx.JSON(http.StatusUnauthorized, gin.H{"message": err.Error()}) + log.WithFields(log.Fields{"ip": clientIP, "machine_id": *input.MachineID}).Errorf("Auto-register failed: %s", err) + router.WriteJSON(w, http.StatusUnauthorized, map[string]string{"message": err.Error()}) return } - if _, err := c.DBClient.CreateMachine(ctx, input.MachineID, input.Password, gctx.ClientIP(), autoRegister, false, types.PasswordAuthType); err != nil { - c.HandleDBErrors(gctx, err) + if _, err := c.DBClient.CreateMachine(ctx, input.MachineID, input.Password, clientIP, autoRegister, false, types.PasswordAuthType); err != nil { + c.HandleDBErrors(w, err) return } if autoRegister { - log.WithFields(log.Fields{"ip": gctx.ClientIP(), "machine_id": *input.MachineID}).Info("Auto-registered machine") - gctx.Status(http.StatusAccepted) + log.WithFields(log.Fields{"ip": clientIP, "machine_id": *input.MachineID}).Info("Auto-registered machine") + w.WriteHeader(http.StatusAccepted) } else { - gctx.Status(http.StatusCreated) + w.WriteHeader(http.StatusCreated) } } -func (c *Controller) DeleteMachine(gctx *gin.Context) { - ctx := gctx.Request.Context() +func (c *Controller) DeleteMachine(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() - machineID, err := getMachineIDFromContext(gctx) + machineID, err := getMachineIDFromContext(r) if err != nil { - gctx.JSON(http.StatusBadRequest, gin.H{"message": err.Error()}) + router.WriteJSON(w, http.StatusBadRequest, map[string]string{"message": err.Error()}) return } if machineID == "" { - gctx.JSON(http.StatusBadRequest, gin.H{"message": "machineID not found in claims"}) + router.WriteJSON(w, http.StatusBadRequest, map[string]string{"message": "machineID not found in claims"}) return } if err := c.DBClient.DeleteWatcher(ctx, machineID); err != nil { - c.HandleDBErrors(gctx, err) + c.HandleDBErrors(w, err) return } - log.WithFields(log.Fields{"ip": gctx.ClientIP(), "machine_id": machineID}).Info("Deleted machine") + // Get client IP from context (resolved by ClientIPMiddleware) + // c.TrustedIPs is the ACL allowlist, not proxy networks + clientIP := router.GetClientIP(r) + log.WithFields(log.Fields{"ip": clientIP, "machine_id": machineID}).Info("Deleted machine") - gctx.Status(http.StatusNoContent) + w.WriteHeader(http.StatusNoContent) } diff --git a/pkg/apiserver/controllers/v1/metrics.go b/pkg/apiserver/controllers/v1/metrics.go index a6ae8613b5a..ba5b6986946 100644 --- a/pkg/apiserver/controllers/v1/metrics.go +++ b/pkg/apiserver/controllers/v1/metrics.go @@ -1,15 +1,16 @@ package v1 import ( + "net/http" "time" + "github.com/crowdsecurity/crowdsec/pkg/apiserver/router" "github.com/crowdsecurity/crowdsec/pkg/metrics" - "github.com/gin-gonic/gin" "github.com/prometheus/client_golang/prometheus" ) -func PrometheusBouncersHasEmptyDecision(c *gin.Context) { - bouncer, _ := getBouncerFromContext(c) +func PrometheusBouncersHasEmptyDecision(r *http.Request) { + bouncer, _ := getBouncerFromContext(r) if bouncer != nil { metrics.LapiNilDecisions.With(prometheus.Labels{ "bouncer": bouncer.Name, @@ -17,8 +18,8 @@ func PrometheusBouncersHasEmptyDecision(c *gin.Context) { } } -func PrometheusBouncersHasNonEmptyDecision(c *gin.Context) { - bouncer, _ := getBouncerFromContext(c) +func PrometheusBouncersHasNonEmptyDecision(r *http.Request) { + bouncer, _ := getBouncerFromContext(r) if bouncer != nil { metrics.LapiNonNilDecisions.With(prometheus.Labels{ "bouncer": bouncer.Name, @@ -26,60 +27,56 @@ func PrometheusBouncersHasNonEmptyDecision(c *gin.Context) { } } -func PrometheusMachinesMiddleware() gin.HandlerFunc { - return func(c *gin.Context) { - machineID, _ := getMachineIDFromContext(c) - if machineID != "" { - fullPath := c.FullPath() - if fullPath == "" { - fullPath = "invalid-endpoint" +func PrometheusMachinesMiddleware() router.Middleware { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + machineID, _ := getMachineIDFromContext(r) + if machineID != "" { + route := router.GetRoutePattern(r) + // routePatternMiddleware always sets a pattern (UnknownRoutePattern for unmatched routes) + metrics.LapiMachineHits.With(prometheus.Labels{ + "machine": machineID, + "route": route, + "method": r.Method, + }).Inc() } - metrics.LapiMachineHits.With(prometheus.Labels{ - "machine": machineID, - "route": fullPath, - "method": c.Request.Method, - }).Inc() - } - - c.Next() + next.ServeHTTP(w, r) + }) } } -func PrometheusBouncersMiddleware() gin.HandlerFunc { - return func(c *gin.Context) { - bouncer, _ := getBouncerFromContext(c) - if bouncer != nil { - fullPath := c.FullPath() - if fullPath == "" { - fullPath = "invalid-endpoint" +func PrometheusBouncersMiddleware() router.Middleware { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + bouncer, _ := getBouncerFromContext(r) + if bouncer != nil { + route := router.GetRoutePattern(r) + // routePatternMiddleware always sets a pattern (UnknownRoutePattern for unmatched routes) + metrics.LapiBouncerHits.With(prometheus.Labels{ + "bouncer": bouncer.Name, + "route": route, + "method": r.Method, + }).Inc() } - metrics.LapiBouncerHits.With(prometheus.Labels{ - "bouncer": bouncer.Name, - "route": fullPath, - "method": c.Request.Method, - }).Inc() - } - - c.Next() + next.ServeHTTP(w, r) + }) } } -func PrometheusMiddleware() gin.HandlerFunc { - return func(c *gin.Context) { - startTime := time.Now() - - fullPath := c.FullPath() - if fullPath == "" { - fullPath = "invalid-endpoint" - } - - metrics.LapiRouteHits.With(prometheus.Labels{ - "route": fullPath, - "method": c.Request.Method, - }).Inc() - c.Next() - - elapsed := time.Since(startTime) - metrics.LapiResponseTime.With(prometheus.Labels{"method": c.Request.Method, "endpoint": c.FullPath()}).Observe(elapsed.Seconds()) +func PrometheusMiddleware() router.Middleware { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + route := router.GetRoutePattern(r) + // routePatternMiddleware always sets a pattern (UnknownRoutePattern for unmatched routes) + // Start timing just before handler execution to avoid including pattern lookup overhead + startTime := time.Now() + next.ServeHTTP(w, r) + elapsed := time.Since(startTime) + metrics.LapiRouteHits.With(prometheus.Labels{ + "route": route, + "method": r.Method, + }).Inc() + metrics.LapiResponseTime.With(prometheus.Labels{"method": r.Method, "endpoint": route}).Observe(elapsed.Seconds()) + }) } } diff --git a/pkg/apiserver/controllers/v1/usagemetrics.go b/pkg/apiserver/controllers/v1/usagemetrics.go index 5b2c3e3b1a9..b3166b386e8 100644 --- a/pkg/apiserver/controllers/v1/usagemetrics.go +++ b/pkg/apiserver/controllers/v1/usagemetrics.go @@ -7,12 +7,12 @@ import ( "net/http" "time" - "github.com/gin-gonic/gin" "github.com/go-openapi/strfmt" log "github.com/sirupsen/logrus" "github.com/crowdsecurity/go-cs-lib/ptr" + "github.com/crowdsecurity/crowdsec/pkg/apiserver/router" "github.com/crowdsecurity/crowdsec/pkg/database/ent" "github.com/crowdsecurity/crowdsec/pkg/database/ent/metric" "github.com/crowdsecurity/crowdsec/pkg/models" @@ -31,16 +31,16 @@ func (c *Controller) updateBaseMetrics(ctx context.Context, machineID string, bo } // UsageMetrics receives metrics from log processors and remediation components -func (c *Controller) UsageMetrics(gctx *gin.Context) { +func (c *Controller) UsageMetrics(w http.ResponseWriter, r *http.Request) { var input models.AllMetrics logger := log.WithField("func", "UsageMetrics") // parse the payload - if err := gctx.ShouldBindJSON(&input); err != nil { + if err := router.BindJSON(r, &input); err != nil { logger.Errorf("Failed to bind json: %s", err) - gctx.JSON(http.StatusBadRequest, gin.H{"message": err.Error()}) + router.WriteJSON(w, http.StatusBadRequest, map[string]string{"message": err.Error()}) return } @@ -52,7 +52,7 @@ func (c *Controller) UsageMetrics(gctx *gin.Context) { Prefix: "validation failure list:\n", } logger.Errorf("Failed to validate usage metrics: %s", cleanErr) - gctx.JSON(http.StatusUnprocessableEntity, gin.H{"message": cleanErr.Error()}) + router.WriteJSON(w, http.StatusUnprocessableEntity, map[string]string{"message": cleanErr.Error()}) return } @@ -62,7 +62,7 @@ func (c *Controller) UsageMetrics(gctx *gin.Context) { generatedBy string ) - bouncer, _ := getBouncerFromContext(gctx) + bouncer, _ := getBouncerFromContext(r) if bouncer != nil { logger.Tracef("Received usage metris for bouncer: %s", bouncer.Name) @@ -70,7 +70,7 @@ func (c *Controller) UsageMetrics(gctx *gin.Context) { generatedBy = bouncer.Name } - machineID, _ := getMachineIDFromContext(gctx) + machineID, _ := getMachineIDFromContext(r) if machineID != "" { logger.Tracef("Received usage metrics for log processor: %s", machineID) @@ -81,14 +81,14 @@ func (c *Controller) UsageMetrics(gctx *gin.Context) { if generatedBy == "" { // how did we get here? logger.Error("No machineID or bouncer in request context after authentication") - gctx.JSON(http.StatusInternalServerError, gin.H{"message": "No machineID or bouncer in request context after authentication"}) + router.WriteJSON(w, http.StatusInternalServerError, map[string]string{"message": "No machineID or bouncer in request context after authentication"}) return } if machineID != "" && bouncer != nil { logger.Errorf("Payload has both machineID and bouncer") - gctx.JSON(http.StatusBadRequest, gin.H{"message": "Payload has both LP and RC data"}) + router.WriteJSON(w, http.StatusBadRequest, map[string]string{"message": "Payload has both LP and RC data"}) return } @@ -104,7 +104,7 @@ func (c *Controller) UsageMetrics(gctx *gin.Context) { case 0: if machineID != "" { logger.Errorf("Missing log processor data") - gctx.JSON(http.StatusBadRequest, gin.H{"message": "Missing log processor data"}) + router.WriteJSON(w, http.StatusBadRequest, map[string]string{"message": "Missing log processor data"}) return } @@ -116,7 +116,7 @@ func (c *Controller) UsageMetrics(gctx *gin.Context) { err := item0.Validate(strfmt.Default) if err != nil { logger.Errorf("Failed to validate log processor data: %s", err) - gctx.JSON(http.StatusUnprocessableEntity, gin.H{"message": err.Error()}) + router.WriteJSON(w, http.StatusUnprocessableEntity, map[string]string{"message": err.Error()}) return } @@ -130,7 +130,7 @@ func (c *Controller) UsageMetrics(gctx *gin.Context) { default: logger.Errorf("Payload has more than one log processor") // this is not checked in the swagger schema - gctx.JSON(http.StatusBadRequest, gin.H{"message": "Payload has more than one log processor"}) + router.WriteJSON(w, http.StatusBadRequest, map[string]string{"message": "Payload has more than one log processor"}) return } @@ -139,7 +139,7 @@ func (c *Controller) UsageMetrics(gctx *gin.Context) { case 0: if bouncer != nil { logger.Errorf("Missing remediation component data") - gctx.JSON(http.StatusBadRequest, gin.H{"message": "Missing remediation component data"}) + router.WriteJSON(w, http.StatusBadRequest, map[string]string{"message": "Missing remediation component data"}) return } @@ -149,7 +149,7 @@ func (c *Controller) UsageMetrics(gctx *gin.Context) { err := item0.Validate(strfmt.Default) if err != nil { logger.Errorf("Failed to validate remediation component data: %s", err) - gctx.JSON(http.StatusUnprocessableEntity, gin.H{"message": err.Error()}) + router.WriteJSON(w, http.StatusUnprocessableEntity, map[string]string{"message": err.Error()}) return } @@ -160,7 +160,7 @@ func (c *Controller) UsageMetrics(gctx *gin.Context) { } baseMetrics = item0.BaseMetrics default: - gctx.JSON(http.StatusBadRequest, gin.H{"message": "Payload has more than one remediation component"}) + router.WriteJSON(w, http.StatusBadRequest, map[string]string{"message": "Payload has more than one remediation component"}) return } @@ -171,12 +171,12 @@ func (c *Controller) UsageMetrics(gctx *gin.Context) { } } - ctx := gctx.Request.Context() + ctx := r.Context() err := c.updateBaseMetrics(ctx, machineID, bouncer, baseMetrics, hubItems, datasources) if err != nil { logger.Errorf("Failed to update base metrics: %s", err) - c.HandleDBErrors(gctx, err) + c.HandleDBErrors(w, err) return } @@ -184,7 +184,7 @@ func (c *Controller) UsageMetrics(gctx *gin.Context) { jsonPayload, err := json.Marshal(payload) if err != nil { logger.Errorf("Failed to serialize usage metrics: %s", err) - c.HandleDBErrors(gctx, err) + c.HandleDBErrors(w, err) return } @@ -193,7 +193,7 @@ func (c *Controller) UsageMetrics(gctx *gin.Context) { if _, err := c.DBClient.CreateMetric(ctx, generatedType, generatedBy, receivedAt, string(jsonPayload)); err != nil { logger.Error(err) - c.HandleDBErrors(gctx, err) + c.HandleDBErrors(w, err) return } @@ -201,5 +201,5 @@ func (c *Controller) UsageMetrics(gctx *gin.Context) { // if CreateMetrics() returned nil, the metric was already there, we're good // and don't split hair about 201 vs 200/204 - gctx.Status(http.StatusCreated) + w.WriteHeader(http.StatusCreated) } diff --git a/pkg/apiserver/controllers/v1/utils.go b/pkg/apiserver/controllers/v1/utils.go index b93371b2caa..71b47a75b38 100644 --- a/pkg/apiserver/controllers/v1/utils.go +++ b/pkg/apiserver/controllers/v1/utils.go @@ -6,16 +6,14 @@ import ( "net/http" "strings" - jwt "github.com/appleboy/gin-jwt/v2" - "github.com/gin-gonic/gin" - middlewares "github.com/crowdsecurity/crowdsec/pkg/apiserver/middlewares/v1" + "github.com/crowdsecurity/crowdsec/pkg/apiserver/router" "github.com/crowdsecurity/crowdsec/pkg/database/ent" ) -func getBouncerFromContext(ctx *gin.Context) (*ent.Bouncer, error) { - bouncerInterface, exist := ctx.Get(middlewares.BouncerContextKey) - if !exist { +func getBouncerFromContext(r *http.Request) (*ent.Bouncer, error) { + bouncerInterface := router.GetContextValue(r, middlewares.BouncerContextKey) + if bouncerInterface == nil { return nil, errors.New("bouncer not found") } @@ -27,48 +25,40 @@ func getBouncerFromContext(ctx *gin.Context) (*ent.Bouncer, error) { return bouncerInfo, nil } -func isUnixSocket(c *gin.Context) bool { - if localAddr, ok := c.Request.Context().Value(http.LocalAddrContextKey).(net.Addr); ok { +func isUnixSocket(r *http.Request) bool { + if localAddr, ok := r.Context().Value(http.LocalAddrContextKey).(net.Addr); ok { return strings.HasPrefix(localAddr.Network(), "unix") } return false } -func getMachineIDFromContext(ctx *gin.Context) (string, error) { - claims := jwt.ExtractClaims(ctx) - if claims == nil { - return "", errors.New("failed to extract claims") - } - - rawID, ok := claims[middlewares.MachineIDKey] - if !ok { - return "", errors.New("MachineID not found in claims") - } - - id, ok := rawID.(string) - if !ok { - // should never happen - return "", errors.New("failed to cast machineID to string") - } - - return id, nil +func getMachineIDFromContext(r *http.Request) (string, error) { + // Use the helper from jwt.go + return middlewares.GetMachineIDFromRequest(r) } -func (*Controller) AbortRemoteIf(option bool) gin.HandlerFunc { - return func(gctx *gin.Context) { - if !option { - return - } +// AbortRemoteIf creates a middleware that aborts remote requests if the option is enabled +func (*Controller) AbortRemoteIf(option bool) router.Middleware { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !option { + next.ServeHTTP(w, r) + return + } + + if isUnixSocket(r) { + next.ServeHTTP(w, r) + return + } - if isUnixSocket(gctx) { - return - } + incomingIP := router.GetClientIP(r) // Gets IP from context (resolved by ClientIPMiddleware) + if incomingIP != "127.0.0.1" && incomingIP != "::1" { + router.AbortWithJSON(w, http.StatusForbidden, map[string]string{"message": "access forbidden"}) + return + } - incomingIP := gctx.ClientIP() - if incomingIP != "127.0.0.1" && incomingIP != "::1" { - gctx.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"}) - gctx.Abort() - } + next.ServeHTTP(w, r) + }) } } diff --git a/pkg/apiserver/jwt_test.go b/pkg/apiserver/jwt_test.go index 0ffcfa2beb8..844761866b6 100644 --- a/pkg/apiserver/jwt_test.go +++ b/pkg/apiserver/jwt_test.go @@ -34,7 +34,7 @@ func TestLogin(t *testing.T) { router.ServeHTTP(w, req) assert.Equal(t, 401, w.Code) - assert.JSONEq(t, `{"code":401,"message":"ent: machine not found"}`, w.Body.String()) + assert.JSONEq(t, `{"code":401,"message":"incorrect Username or Password"}`, w.Body.String()) // Login with invalid body w = httptest.NewRecorder() diff --git a/pkg/apiserver/machines_test.go b/pkg/apiserver/machines_test.go index 291eef5873b..167ecbc20dd 100644 --- a/pkg/apiserver/machines_test.go +++ b/pkg/apiserver/machines_test.go @@ -56,7 +56,7 @@ func TestCreateMachine(t *testing.T) { func TestCreateMachineWithForwardedFor(t *testing.T) { ctx := t.Context() router, config := NewAPITestForwardedFor(t) - router.TrustedPlatform = "X-Real-IP" + // Trusted proxies are configured via middleware in NewAPITestForwardedFor // Create machine b, err := json.Marshal(MachineTest) diff --git a/pkg/apiserver/middlewares/clientip.go b/pkg/apiserver/middlewares/clientip.go new file mode 100644 index 00000000000..f5633b20f98 --- /dev/null +++ b/pkg/apiserver/middlewares/clientip.go @@ -0,0 +1,43 @@ +package middlewares + +import ( + "net" + "net/http" + "net/netip" + + "github.com/crowdsecurity/crowdsec/pkg/apiserver/router" +) + +// ClientIPMiddleware creates a middleware that extracts and sets the client IP from trusted proxy headers +// It resolves the real client IP once using trusted proxy configuration and stores it in the request context +// Downstream handlers can then use router.GetClientIP() to retrieve the resolved IP +// If useForwardedForHeaders is false, only RemoteAddr is used (forwarded headers are ignored) +func ClientIPMiddleware(trustedProxies []netip.Prefix, useForwardedForHeaders bool) router.Middleware { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Resolve the client IP using trusted proxy configuration + // Only check forwarded headers if the flag is enabled + var clientIP string + if useForwardedForHeaders { + clientIP = router.ResolveClientIP(r, trustedProxies) + } else { + // Only use RemoteAddr, ignore forwarded headers + if r.RemoteAddr == "@" { + clientIP = "127.0.0.1" + } else { + host, _, err := net.SplitHostPort(r.RemoteAddr) + if err != nil { + clientIP = r.RemoteAddr + } else { + clientIP = host + } + } + } + + // Store the resolved IP in context for downstream handlers + r = router.SetClientIP(r, clientIP) + + next.ServeHTTP(w, r) + }) + } +} diff --git a/pkg/apiserver/middlewares/gzip.go b/pkg/apiserver/middlewares/gzip.go new file mode 100644 index 00000000000..4f583404b38 --- /dev/null +++ b/pkg/apiserver/middlewares/gzip.go @@ -0,0 +1,70 @@ +package middlewares + +import ( + "compress/gzip" + "io" + "net/http" + "strings" + + "github.com/crowdsecurity/crowdsec/pkg/apiserver/router" +) + +// gzipReadCloser wraps a gzip.Reader and ensures both the gzip reader +// and the original body are closed when Close() is called +type gzipReadCloser struct { + *gzip.Reader + originalBody io.ReadCloser +} + +func (g *gzipReadCloser) Close() error { + // Close the gzip reader first + err1 := g.Reader.Close() + // Then close the original body to allow connection reuse + err2 := g.originalBody.Close() + // Return the first error if any + if err1 != nil { + return err1 + } + return err2 +} + +// GzipDecompressMiddleware creates a middleware that automatically decompresses gzip-encoded request bodies +// It does NOT compress responses (to avoid breaking existing bouncers) +func GzipDecompressMiddleware() router.Middleware { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Check if the request body is gzip-encoded (case-insensitive) + contentEncoding := strings.ToLower(r.Header.Get("Content-Encoding")) + if strings.Contains(contentEncoding, "gzip") { + // Keep reference to original body for proper cleanup + originalBody := r.Body + + // Create a gzip reader from the request body + gzipReader, err := gzip.NewReader(originalBody) + if err != nil { + // Close original body on error + originalBody.Close() + router.AbortWithJSON(w, http.StatusBadRequest, map[string]string{ + "message": "invalid gzip encoding", + }) + return + } + + // Wrap in custom ReadCloser that closes both gzip reader and original body + wrapped := &gzipReadCloser{ + Reader: gzipReader, + originalBody: originalBody, + } + + // Replace the request body with the decompressed reader + r.Body = wrapped + // Remove Content-Encoding header since we've decompressed + r.Header.Del("Content-Encoding") + } + + next.ServeHTTP(w, r) + // Note: The standard library will close r.Body after the handler completes. + // Our gzipReadCloser.Close() ensures both the gzip reader and original body are closed. + }) + } +} diff --git a/pkg/apiserver/middlewares/logging.go b/pkg/apiserver/middlewares/logging.go new file mode 100644 index 00000000000..73d169c2b1e --- /dev/null +++ b/pkg/apiserver/middlewares/logging.go @@ -0,0 +1,120 @@ +package middlewares + +import ( + "bufio" + "errors" + "fmt" + "io" + "net" + "net/http" + "time" + + log "github.com/sirupsen/logrus" + + "github.com/crowdsecurity/crowdsec/pkg/apiserver/router" +) + +// LoggingMiddleware creates a middleware that logs HTTP requests using the provided logger +// It logs: client IP, timestamp, method, path, protocol, status code, latency, user agent, and error message +// Matches the format used by Gin's LoggerWithFormatter +// If logger is nil, it falls back to the standard logger +func LoggingMiddleware(logger *log.Entry) router.Middleware { + if logger == nil { + logger = log.StandardLogger().WithFields(nil) + } + + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + start := time.Now() + + // Create a response writer wrapper to capture status code + wrapped := &responseWriter{ + ResponseWriter: w, + statusCode: http.StatusOK, + } + + // Process the request + next.ServeHTTP(wrapped, r) + + // Calculate latency + latency := time.Since(start) + + // Get client IP from context (resolved by ClientIPMiddleware) + // Falls back to RemoteAddr if not set (shouldn't happen if middleware is properly configured) + clientIP := router.GetClientIP(r) + + // Format latency as string (matches Gin's format) + latencyStr := latency.String() + + // Log the request in the same format as Gin's LoggerWithFormatter + // Use the provided logger which writes to the access log file + // Format the log message and write it directly to the logger's output writer + // This bypasses logrus formatting since access logs use a specific plain text format + // Only write if the logger level allows info-level messages or more verbose (debug/trace) + // IsLevelEnabled returns true if logger level >= InfoLevel (i.e., Info, Debug, or Trace) + if logger.Logger.IsLevelEnabled(log.InfoLevel) { + // Use concrete path (r.URL.Path) for access logs to show actual requests + // This is useful for debugging, e.g., seeing which IPs bouncers check in live mode + // Prometheus metrics use route templates to keep cardinality bounded + logMsg := fmt.Sprintf("%s - [%s] \"%s %s %s %d %s %q %s\"\n", + clientIP, + start.Format(time.RFC1123), + r.Method, + r.URL.Path, + r.Proto, + wrapped.statusCode, + latencyStr, + r.UserAgent(), + "", // Error message (empty for now, could be enhanced) + ) + // Ignore write errors - we don't want logging failures to affect request handling + _, _ = logger.Logger.Out.Write([]byte(logMsg)) + } + }) + } +} + +// responseWriter wraps http.ResponseWriter to capture status code +// It also forwards optional interfaces (Flusher, Hijacker, Pusher, ReaderFrom) +// to ensure streaming and connection upgrades work correctly +type responseWriter struct { + http.ResponseWriter + statusCode int +} + +func (rw *responseWriter) WriteHeader(code int) { + rw.statusCode = code + rw.ResponseWriter.WriteHeader(code) +} + +// Flush implements http.Flusher if the underlying ResponseWriter supports it +func (rw *responseWriter) Flush() { + if flusher, ok := rw.ResponseWriter.(http.Flusher); ok { + flusher.Flush() + } +} + +// Hijack implements http.Hijacker if the underlying ResponseWriter supports it +func (rw *responseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + if hijacker, ok := rw.ResponseWriter.(http.Hijacker); ok { + return hijacker.Hijack() + } + return nil, nil, errors.New("underlying ResponseWriter does not implement http.Hijacker") +} + +// Push implements http.Pusher if the underlying ResponseWriter supports it +func (rw *responseWriter) Push(target string, opts *http.PushOptions) error { + if pusher, ok := rw.ResponseWriter.(http.Pusher); ok { + return pusher.Push(target, opts) + } + return http.ErrNotSupported +} + +// ReadFrom implements io.ReaderFrom if the underlying ResponseWriter supports it +func (rw *responseWriter) ReadFrom(src io.Reader) (int64, error) { + if readerFrom, ok := rw.ResponseWriter.(io.ReaderFrom); ok { + return readerFrom.ReadFrom(src) + } + // Fallback to standard implementation if not supported + return io.Copy(rw.ResponseWriter, src) +} diff --git a/pkg/apiserver/middlewares/recovery.go b/pkg/apiserver/middlewares/recovery.go new file mode 100644 index 00000000000..edde5825686 --- /dev/null +++ b/pkg/apiserver/middlewares/recovery.go @@ -0,0 +1,36 @@ +package middlewares + +import ( + "net/http" + "runtime/debug" + + log "github.com/sirupsen/logrus" + + "github.com/crowdsecurity/crowdsec/pkg/apiserver/router" +) + +// RecoveryMiddleware creates a middleware that recovers from panics and logs the error +func RecoveryMiddleware() router.Middleware { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer func() { + if err := recover(); err != nil { + // Log the panic with stack trace + log.WithFields(log.Fields{ + "error": err, + "path": r.URL.Path, + "method": r.Method, + "stack": string(debug.Stack()), + }).Error("Panic recovered") + + // Write 500 error response + router.AbortWithJSON(w, http.StatusInternalServerError, map[string]string{ + "message": "Internal server error", + }) + } + }() + + next.ServeHTTP(w, r) + }) + } +} diff --git a/pkg/apiserver/middlewares/v1/api_key.go b/pkg/apiserver/middlewares/v1/api_key.go index 2f23da7bb62..e2ae77bb427 100644 --- a/pkg/apiserver/middlewares/v1/api_key.go +++ b/pkg/apiserver/middlewares/v1/api_key.go @@ -9,9 +9,9 @@ import ( "net/netip" "strings" - "github.com/gin-gonic/gin" log "github.com/sirupsen/logrus" + "github.com/crowdsecurity/crowdsec/pkg/apiserver/router" "github.com/crowdsecurity/crowdsec/pkg/database" "github.com/crowdsecurity/crowdsec/pkg/database/ent" "github.com/crowdsecurity/crowdsec/pkg/types" @@ -84,15 +84,15 @@ func HashSHA512(str string) string { return hashStr } -func (a *APIKey) authTLS(c *gin.Context, logger *log.Entry) *ent.Bouncer { +func (a *APIKey) authTLS(r *http.Request, clientIP string, logger *log.Entry) *ent.Bouncer { if a.TlsAuth == nil { logger.Warn("TLS Auth is not configured but client presented a certificate") return nil } - ctx := c.Request.Context() + ctx := r.Context() - extractedCN, err := a.TlsAuth.ValidateCert(c) + extractedCN, err := a.TlsAuth.ValidateCertFromRequest(r) if err != nil { logger.Warn(err) return nil @@ -100,7 +100,7 @@ func (a *APIKey) authTLS(c *gin.Context, logger *log.Entry) *ent.Bouncer { logger = logger.WithField("cn", extractedCN) - bouncerName := fmt.Sprintf("%s@%s", extractedCN, c.ClientIP()) + bouncerName := fmt.Sprintf("%s@%s", extractedCN, clientIP) bouncer, err := a.DbClient.SelectBouncerByName(ctx, bouncerName) // This is likely not the proper way, but isNotFound does not seem to work @@ -115,7 +115,7 @@ func (a *APIKey) authTLS(c *gin.Context, logger *log.Entry) *ent.Bouncer { logger.Infof("Creating bouncer %s", bouncerName) - bouncer, err = a.DbClient.CreateBouncer(ctx, bouncerName, c.ClientIP(), HashSHA512(apiKey), types.TlsAuthType, true) + bouncer, err = a.DbClient.CreateBouncer(ctx, bouncerName, clientIP, HashSHA512(apiKey), types.TlsAuthType, true) if err != nil { logger.Errorf("while creating bouncer db entry: %s", err) return nil @@ -133,22 +133,20 @@ func (a *APIKey) authTLS(c *gin.Context, logger *log.Entry) *ent.Bouncer { return bouncer } -func (a *APIKey) authPlain(c *gin.Context, logger *log.Entry) *ent.Bouncer { - val, ok := c.Request.Header[APIKeyHeader] +func (a *APIKey) authPlain(r *http.Request, clientIP string, logger *log.Entry) *ent.Bouncer { + val, ok := r.Header[APIKeyHeader] if !ok { logger.Errorf("API key not found") return nil } - clientIP := c.ClientIP() - - ctx := c.Request.Context() + ctx := r.Context() hashStr := HashSHA512(val[0]) // Appsec case, we only care if the key is valid // No content is returned, no last_pull update or anything - if c.Request.Method == http.MethodHead { + if r.Method == http.MethodHead { bouncer, err := a.DbClient.SelectBouncers(ctx, hashStr, types.ApiKeyAuthType) if err != nil { logger.Errorf("while fetching bouncer info: %s", err) @@ -216,65 +214,64 @@ func (a *APIKey) authPlain(c *gin.Context, logger *log.Entry) *ent.Bouncer { return bouncer } -func (a *APIKey) MiddlewareFunc() gin.HandlerFunc { - return func(c *gin.Context) { - var bouncer *ent.Bouncer - - ctx := c.Request.Context() - - clientIP := c.ClientIP() - - logger := log.WithField("ip", clientIP) - - if c.Request.TLS != nil && len(c.Request.TLS.PeerCertificates) > 0 { - bouncer = a.authTLS(c, logger) - } else { - bouncer = a.authPlain(c, logger) - } +// MiddlewareFunc returns a middleware that validates API keys +func (a *APIKey) MiddlewareFunc() router.Middleware { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var bouncer *ent.Bouncer - if bouncer == nil { - // XXX: StatusUnauthorized? - c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"}) - c.Abort() + ctx := r.Context() + clientIP := router.GetClientIP(r) // Gets IP from context (resolved by ClientIPMiddleware) - return - } + logger := log.WithField("ip", clientIP) - // Appsec request, return immediately if we found something - if c.Request.Method == http.MethodHead { - c.Set(BouncerContextKey, bouncer) - return - } - - logger = logger.WithField("name", bouncer.Name) + if r.TLS != nil && len(r.TLS.PeerCertificates) > 0 { + bouncer = a.authTLS(r, clientIP, logger) + } else { + bouncer = a.authPlain(r, clientIP, logger) + } - // 1st time we see this bouncer, we update its IP - if bouncer.IPAddress == "" { - if err := a.DbClient.UpdateBouncerIP(ctx, clientIP, bouncer.ID); err != nil { - logger.Errorf("Failed to update ip address for '%s': %s\n", bouncer.Name, err) - c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"}) - c.Abort() + if bouncer == nil { + router.AbortWithJSON(w, http.StatusForbidden, map[string]string{"message": "access forbidden"}) + return + } + // Appsec request, return immediately if we found something + if r.Method == http.MethodHead { + // Store bouncer in context and continue + r = router.SetContextValue(r, BouncerContextKey, bouncer) + next.ServeHTTP(w, r) return } - } - useragent := strings.Split(c.Request.UserAgent(), "/") - if len(useragent) != 2 { - logger.Warningf("bad user agent '%s'", c.Request.UserAgent()) - useragent = []string{c.Request.UserAgent(), "N/A"} - } + logger = logger.WithField("name", bouncer.Name) - if bouncer.Version != useragent[1] || bouncer.Type != useragent[0] { - if err := a.DbClient.UpdateBouncerTypeAndVersion(ctx, useragent[0], useragent[1], bouncer.ID); err != nil { - logger.Errorf("failed to update bouncer version and type: %s", err) - c.JSON(http.StatusForbidden, gin.H{"message": "bad user agent"}) - c.Abort() + // 1st time we see this bouncer, we update its IP + if bouncer.IPAddress == "" { + if err := a.DbClient.UpdateBouncerIP(ctx, clientIP, bouncer.ID); err != nil { + logger.Errorf("Failed to update ip address for '%s': %s\n", bouncer.Name, err) + router.AbortWithJSON(w, http.StatusForbidden, map[string]string{"message": "access forbidden"}) + return + } + } - return + useragent := strings.Split(r.UserAgent(), "/") + if len(useragent) != 2 { + logger.Warningf("bad user agent '%s'", r.UserAgent()) + useragent = []string{r.UserAgent(), "N/A"} + } + + if bouncer.Version != useragent[1] || bouncer.Type != useragent[0] { + if err := a.DbClient.UpdateBouncerTypeAndVersion(ctx, useragent[0], useragent[1], bouncer.ID); err != nil { + logger.Errorf("failed to update bouncer version and type: %s", err) + router.AbortWithJSON(w, http.StatusForbidden, map[string]string{"message": "bad user agent"}) + return + } } - } - c.Set(BouncerContextKey, bouncer) + // Store bouncer in context + r = router.SetContextValue(r, BouncerContextKey, bouncer) + next.ServeHTTP(w, r) + }) } } diff --git a/pkg/apiserver/middlewares/v1/jwt.go b/pkg/apiserver/middlewares/v1/jwt.go index d3dc6ec22ea..12667403c2c 100644 --- a/pkg/apiserver/middlewares/v1/jwt.go +++ b/pkg/apiserver/middlewares/v1/jwt.go @@ -1,19 +1,21 @@ package v1 import ( + "context" "crypto/rand" "errors" "fmt" + "net/http" "os" "strings" "time" - jwt "github.com/appleboy/gin-jwt/v2" - "github.com/gin-gonic/gin" "github.com/go-openapi/strfmt" + jwtv4 "github.com/golang-jwt/jwt/v4" log "github.com/sirupsen/logrus" "golang.org/x/crypto/bcrypt" + "github.com/crowdsecurity/crowdsec/pkg/apiserver/router" "github.com/crowdsecurity/crowdsec/pkg/database" "github.com/crowdsecurity/crowdsec/pkg/database/ent" "github.com/crowdsecurity/crowdsec/pkg/database/ent/machine" @@ -21,78 +23,226 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/types" ) -const MachineIDKey = "id" +type machineIDKey struct{} +var MachineIDKey = machineIDKey{} + +type authInput struct { + machineID string + clientMachine *ent.Machine + scenariosInput []string +} + +// randomSecret generates a cryptographically secure random secret +func randomSecret() ([]byte, error) { + size := 64 + secret := make([]byte, size) + n, err := rand.Read(secret) + if err != nil { + return nil, errors.New("unable to generate a new random seed for JWT generation") + } + if n != size { + return nil, errors.New("not enough entropy at random seed generation for JWT generation") + } + return secret, nil +} + +// JWT is the JWT middleware implementation using golang-jwt/jwt/v4 type JWT struct { - Middleware *jwt.GinJWTMiddleware - DbClient *database.Client - TlsAuth *TLSAuth + secret []byte + dbClient *database.Client + tlsAuth *TLSAuth + timeout time.Duration + maxRefresh time.Duration + tokenLookup []string // e.g., ["header: Authorization", "query: token", "cookie: jwt"] + tokenHeadName string // e.g., "Bearer" } -func PayloadFunc(data any) jwt.MapClaims { - if value, ok := data.(*models.WatcherAuthRequest); ok { - return jwt.MapClaims{ - MachineIDKey: &value.MachineID, +type jwtClaims struct { + jwtv4.RegisteredClaims + MachineID *string `json:"id"` +} + +// NewJWT creates a new JWT middleware using golang-jwt/jwt/v4 +func NewJWT(dbClient *database.Client) (*JWT, error) { + var ( + secret []byte + err error + ) + + // Get secret from environment variable + secretString := os.Getenv("CS_LAPI_SECRET") + secret = []byte(secretString) + + switch l := len(secret); { + case l == 0: + secret, err = randomSecret() + if err != nil { + return nil, err } + case l < 64: + return nil, errors.New("CS_LAPI_SECRET not strong enough") } - return jwt.MapClaims{} + return &JWT{ + secret: secret, + dbClient: dbClient, + tlsAuth: &TLSAuth{}, + timeout: time.Hour, + maxRefresh: time.Hour, + tokenLookup: []string{"header: Authorization", "query: token", "cookie: jwt"}, + tokenHeadName: "Bearer", + }, nil } -func IdentityHandler(c *gin.Context) any { - claims := jwt.ExtractClaims(c) - machineID := claims[MachineIDKey].(string) +// SetTlsAuth sets the TLS auth instance for the JWT middleware +func (j *JWT) SetTlsAuth(tlsAuth *TLSAuth) { + j.tlsAuth = tlsAuth +} - return &models.WatcherAuthRequest{ - MachineID: &machineID, +// extractToken extracts the JWT token from the request +// It checks header, query parameter, and cookie as configured +func (j *JWT) extractToken(r *http.Request) (string, error) { + for _, lookup := range j.tokenLookup { + parts := strings.Split(lookup, ":") + if len(parts) != 2 { + continue + } + + source := strings.TrimSpace(parts[0]) + name := strings.TrimSpace(parts[1]) + + switch source { + case "header": + token := r.Header.Get(name) + if token != "" { + // Remove token head name (e.g., "Bearer ") + if j.tokenHeadName != "" && strings.HasPrefix(token, j.tokenHeadName+" ") { + return strings.TrimPrefix(token, j.tokenHeadName+" "), nil + } + return token, nil + } + case "query": + token := r.URL.Query().Get(name) + if token != "" { + return token, nil + } + case "cookie": + cookie, err := r.Cookie(name) + if err == nil { + if cookie.Value == "" { + return "", errors.New("cookie token is empty") + } + return cookie.Value, nil + } + } } + + return "", errors.New("token not found") } -type authInput struct { - machineID string - clientMachine *ent.Machine - scenariosInput []string +// parseToken parses and validates a JWT token +func (j *JWT) parseToken(tokenString string) (*jwtClaims, error) { + token, err := jwtv4.ParseWithClaims(tokenString, &jwtClaims{}, func(token *jwtv4.Token) (any, error) { + // Validate signing method + if _, ok := token.Method.(*jwtv4.SigningMethodHMAC); !ok { + return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) + } + return j.secret, nil + }) + + if err != nil { + return nil, err + } + + if claims, ok := token.Claims.(*jwtClaims); ok && token.Valid { + return claims, nil + } + + return nil, errors.New("invalid token claims") } -func (j *JWT) authTLS(c *gin.Context) (*authInput, error) { - ctx := c.Request.Context() - ret := authInput{} +// generateToken generates a new JWT token for the given machine ID +func (j *JWT) generateToken(machineID string) (string, time.Time, error) { + now := time.Now() + expiresAt := now.Add(j.timeout) + + claims := &jwtClaims{ + RegisteredClaims: jwtv4.RegisteredClaims{ + ExpiresAt: jwtv4.NewNumericDate(expiresAt), + IssuedAt: jwtv4.NewNumericDate(now), + NotBefore: jwtv4.NewNumericDate(now), + }, + MachineID: &machineID, + } - if j.TlsAuth == nil { - err := errors.New("tls authentication required") - log.Warn(err) + token := jwtv4.NewWithClaims(jwtv4.SigningMethodHS256, claims) + tokenString, err := token.SignedString(j.secret) + if err != nil { + return "", time.Time{}, err + } - return nil, err + return tokenString, expiresAt, nil +} + +// refreshToken refreshes an existing token if it's within the refresh window +func (j *JWT) refreshToken(tokenString string) (string, time.Time, error) { + claims, err := j.parseToken(tokenString) + if err != nil { + return "", time.Time{}, err + } + + // Check if token is within refresh window + now := time.Now() + if claims.ExpiresAt != nil { + expiresAt := claims.ExpiresAt.Time + refreshDeadline := expiresAt.Add(j.maxRefresh) + if now.After(refreshDeadline) { + return "", time.Time{}, errors.New("token refresh deadline exceeded") + } } - extractedCN, err := j.TlsAuth.ValidateCert(c) + // Generate new token with same machine ID + if claims.MachineID == nil { + return "", time.Time{}, errors.New("token missing machine ID") + } + + return j.generateToken(*claims.MachineID) +} + +// authTLS handles TLS-based authentication +func (j *JWT) authTLS(r *http.Request, clientIP string) (*authInput, error) { + if j.tlsAuth == nil { + return nil, errors.New("tls authentication required") + } + + extractedCN, err := j.tlsAuth.ValidateCertFromRequest(r) if err != nil { log.Warn(err) return nil, err } - logger := log.WithField("ip", c.ClientIP()) + logger := log.WithField("ip", clientIP) + ret := authInput{} - ret.machineID = fmt.Sprintf("%s@%s", extractedCN, c.ClientIP()) + ret.machineID = fmt.Sprintf("%s@%s", extractedCN, clientIP) - ret.clientMachine, err = j.DbClient.Ent.Machine.Query(). + ctx := r.Context() + ret.clientMachine, err = j.dbClient.Ent.Machine.Query(). Where(machine.MachineId(ret.machineID)). First(ctx) if ent.IsNotFound(err) { // Machine was not found, let's create it logger.Infof("machine %s not found, create it", ret.machineID) - // let's use an apikey as the password, doesn't matter in this case (generatePassword is only available in cscli) pwd, err := GenerateAPIKey(dummyAPIKeySize) if err != nil { logger.WithField("cn", extractedCN). Errorf("error generating password: %s", err) - return nil, errors.New("error generating password") } password := strfmt.Password(pwd) - - ret.clientMachine, err = j.DbClient.CreateMachine(ctx, &ret.machineID, &password, "", true, true, types.TlsAuthType) + ret.clientMachine, err = j.dbClient.CreateMachine(ctx, &ret.machineID, &password, "", true, true, types.TlsAuthType) if err != nil { return nil, fmt.Errorf("while creating machine entry for %s: %w", ret.machineID, err) } @@ -102,7 +252,6 @@ func (j *JWT) authTLS(c *gin.Context) (*authInput, error) { if ret.clientMachine.AuthType != types.TlsAuthType { return nil, fmt.Errorf("machine %s attempted to auth with TLS cert but it is configured to use %s", ret.machineID, ret.clientMachine.AuthType) } - ret.machineID = ret.clientMachine.MachineId } @@ -112,31 +261,24 @@ func (j *JWT) authTLS(c *gin.Context) (*authInput, error) { Scenarios: []string{}, } - err = c.ShouldBindJSON(&loginInput) - if err != nil { + if err := router.BindJSON(r, &loginInput); err != nil { return nil, fmt.Errorf("missing scenarios list in login request for TLS auth: %w", err) } ret.scenariosInput = loginInput.Scenarios - return &ret, nil } -func (j *JWT) authPlain(c *gin.Context) (*authInput, error) { - var ( - loginInput models.WatcherAuthRequest - err error - ) - - ctx := c.Request.Context() - +// authPlain handles password-based authentication +func (j *JWT) authPlain(r *http.Request) (*authInput, error) { + var loginInput models.WatcherAuthRequest ret := authInput{} - if err = c.ShouldBindJSON(&loginInput); err != nil { + if err := router.BindJSON(r, &loginInput); err != nil { return nil, fmt.Errorf("missing: %w", err) } - if err = loginInput.Validate(strfmt.Default); err != nil { + if err := loginInput.Validate(strfmt.Default); err != nil { return nil, err } @@ -144,17 +286,23 @@ func (j *JWT) authPlain(c *gin.Context) (*authInput, error) { password := *loginInput.Password ret.scenariosInput = loginInput.Scenarios - ret.clientMachine, err = j.DbClient.Ent.Machine.Query(). + ctx := r.Context() + var err error + ret.clientMachine, err = j.dbClient.Ent.Machine.Query(). Where(machine.MachineId(ret.machineID)). First(ctx) if err != nil { log.Infof("Error machine login for %s : %+v ", ret.machineID, err) + if ent.IsNotFound(err) { + // Return generic error for security (don't reveal if machine exists) + return nil, errors.New("incorrect Username or Password") + } return nil, err } if ret.clientMachine == nil { log.Errorf("Nothing for '%s'", ret.machineID) - return nil, jwt.ErrFailedAuthentication + return nil, errors.New("incorrect Username or Password") } if ret.clientMachine.AuthType != types.PasswordAuthType { @@ -166,165 +314,185 @@ func (j *JWT) authPlain(c *gin.Context) (*authInput, error) { } if err := bcrypt.CompareHashAndPassword([]byte(ret.clientMachine.Password), []byte(password)); err != nil { - return nil, jwt.ErrFailedAuthentication + return nil, errors.New("incorrect Username or Password") } return &ret, nil } -func (j *JWT) Authenticator(c *gin.Context) (any, error) { +// authenticator performs authentication and returns the authenticated machine ID +func (j *JWT) authenticator(r *http.Request, clientIP string) (string, error) { var ( err error auth *authInput ) - ctx := c.Request.Context() + ctx := r.Context() - if c.Request.TLS != nil && len(c.Request.TLS.PeerCertificates) > 0 { - auth, err = j.authTLS(c) + if r.TLS != nil && len(r.TLS.PeerCertificates) > 0 { + auth, err = j.authTLS(r, clientIP) if err != nil { - return nil, err + return "", err } } else { - auth, err = j.authPlain(c) + auth, err = j.authPlain(r) if err != nil { - return nil, err + return "", err } } var scenarios string - if len(auth.scenariosInput) > 0 { - for _, scenario := range auth.scenariosInput { - if scenarios == "" { - scenarios = scenario - } else { - scenarios += "," + scenario - } - } - - err = j.DbClient.UpdateMachineScenarios(ctx, scenarios, auth.clientMachine.ID) + scenarios = strings.Join(auth.scenariosInput, ",") + err = j.dbClient.UpdateMachineScenarios(ctx, scenarios, auth.clientMachine.ID) if err != nil { log.Errorf("Failed to update scenarios list for '%s': %s\n", auth.machineID, err) - return nil, jwt.ErrFailedAuthentication + return "", errors.New("failed authentication") } } - clientIP := c.ClientIP() - if auth.clientMachine.IpAddress == "" { - err = j.DbClient.UpdateMachineIP(ctx, clientIP, auth.clientMachine.ID) + err = j.dbClient.UpdateMachineIP(ctx, clientIP, auth.clientMachine.ID) if err != nil { log.Errorf("Failed to update ip address for '%s': %s\n", auth.machineID, err) - return nil, jwt.ErrFailedAuthentication + return "", errors.New("failed authentication") } } if auth.clientMachine.IpAddress != clientIP && auth.clientMachine.IpAddress != "" { log.Warningf("new IP address detected for machine '%s': %s (old: %s)", auth.clientMachine.MachineId, clientIP, auth.clientMachine.IpAddress) - - err = j.DbClient.UpdateMachineIP(ctx, clientIP, auth.clientMachine.ID) + err = j.dbClient.UpdateMachineIP(ctx, clientIP, auth.clientMachine.ID) if err != nil { log.Errorf("Failed to update ip address for '%s': %s\n", auth.clientMachine.MachineId, err) - return nil, jwt.ErrFailedAuthentication + return "", errors.New("failed authentication") } } - useragent := strings.Split(c.Request.UserAgent(), "/") + useragent := strings.Split(r.UserAgent(), "/") if len(useragent) != 2 { - log.Warningf("bad user agent '%s' from '%s'", c.Request.UserAgent(), clientIP) - return nil, jwt.ErrFailedAuthentication + log.Warningf("bad user agent '%s' from '%s'", r.UserAgent(), clientIP) + return "", errors.New("failed authentication") } - if err := j.DbClient.UpdateMachineVersion(ctx, useragent[1], auth.clientMachine.ID); err != nil { + if err := j.dbClient.UpdateMachineVersion(ctx, useragent[1], auth.clientMachine.ID); err != nil { log.Errorf("unable to update machine '%s' version '%s': %s", auth.clientMachine.MachineId, useragent[1], err) - return nil, jwt.ErrFailedAuthentication + return "", errors.New("failed authentication") } - return &models.WatcherAuthRequest{ - MachineID: &auth.machineID, - }, nil -} - -func Authorizator(data any, c *gin.Context) bool { - return true + return auth.machineID, nil } -func Unauthorized(c *gin.Context, code int, message string) { - c.JSON(code, gin.H{ - "code": code, - "message": message, - }) -} +// LoginHandler handles login requests and returns a JWT token +func (j *JWT) LoginHandler(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + router.AbortWithStatus(w, http.StatusMethodNotAllowed) + return + } -func randomSecret() ([]byte, error) { - size := 64 - secret := make([]byte, size) + clientIP := router.GetClientIP(r) // Gets IP from context (resolved by ClientIPMiddleware) + machineID, err := j.authenticator(r, clientIP) + if err != nil { + router.AbortWithJSON(w, http.StatusUnauthorized, map[string]any{ + "code": http.StatusUnauthorized, + "message": err.Error(), + }) + return + } - n, err := rand.Read(secret) + tokenString, expiresAt, err := j.generateToken(machineID) if err != nil { - return nil, errors.New("unable to generate a new random seed for JWT generation") + router.AbortWithJSON(w, http.StatusInternalServerError, map[string]any{ + "code": http.StatusInternalServerError, + "message": "failed to generate token", + }) + return } - if n != size { - return nil, errors.New("not enough entropy at random seed generation for JWT generation") + response := models.WatcherAuthResponse{ + Code: 200, + Token: tokenString, + Expire: expiresAt.Format(time.RFC3339), } - return secret, nil + router.WriteJSON(w, http.StatusOK, response) } -func NewJWT(dbClient *database.Client) (*JWT, error) { - // Get secret from environment variable "SECRET" - var ( - secret []byte - err error - ) +// RefreshHandler handles token refresh requests +func (j *JWT) RefreshHandler(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + router.AbortWithStatus(w, http.StatusMethodNotAllowed) + return + } - // Please be aware that brute force HS256 is possible. - // PLEASE choose a STRONG secret - secretString := os.Getenv("CS_LAPI_SECRET") - secret = []byte(secretString) + tokenString, err := j.extractToken(r) + if err != nil { + router.AbortWithJSON(w, http.StatusUnauthorized, map[string]any{ + "code": http.StatusUnauthorized, + "message": err.Error(), + }) + return + } - switch l := len(secret); { - case l == 0: - secret, err = randomSecret() - if err != nil { - return &JWT{}, err - } - case l < 64: - return &JWT{}, errors.New("CS_LAPI_SECRET not strong enough") - } - - jwtMiddleware := &JWT{ - DbClient: dbClient, - TlsAuth: &TLSAuth{}, - } - - ret, err := jwt.New(&jwt.GinJWTMiddleware{ - Realm: "Crowdsec API local", - Key: secret, - Timeout: time.Hour, - MaxRefresh: time.Hour, - IdentityKey: MachineIDKey, - PayloadFunc: PayloadFunc, - IdentityHandler: IdentityHandler, - Authenticator: jwtMiddleware.Authenticator, - Authorizator: Authorizator, - Unauthorized: Unauthorized, - TokenLookup: "header: Authorization, query: token, cookie: jwt", - TokenHeadName: "Bearer", - TimeFunc: time.Now, - }) + newTokenString, expiresAt, err := j.refreshToken(tokenString) if err != nil { - return &JWT{}, err + router.AbortWithJSON(w, http.StatusUnauthorized, map[string]any{ + "code": http.StatusUnauthorized, + "message": err.Error(), + }) + return } - errInit := ret.MiddlewareInit() - if errInit != nil { - return &JWT{}, errors.New("authMiddleware.MiddlewareInit() Error:" + errInit.Error()) + response := models.WatcherAuthResponse{ + Code: 200, + Token: newTokenString, + Expire: expiresAt.Format(time.RFC3339), } - jwtMiddleware.Middleware = ret + router.WriteJSON(w, http.StatusOK, response) +} + +// MiddlewareFunc returns a middleware that validates JWT tokens +func (j *JWT) MiddlewareFunc() router.Middleware { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + tokenString, err := j.extractToken(r) + if err != nil { + router.AbortWithJSON(w, http.StatusUnauthorized, map[string]any{ + "code": http.StatusUnauthorized, + "message": err.Error(), + }) + return + } - return jwtMiddleware, nil + claims, err := j.parseToken(tokenString) + if err != nil { + router.AbortWithJSON(w, http.StatusUnauthorized, map[string]any{ + "code": http.StatusUnauthorized, + "message": "invalid token", + }) + return + } + + if claims.MachineID == nil { + router.AbortWithJSON(w, http.StatusUnauthorized, map[string]any{ + "code": http.StatusUnauthorized, + "message": "token missing machine ID", + }) + return + } + + // Store machine ID in request context + ctx := context.WithValue(r.Context(), MachineIDKey, *claims.MachineID) + next.ServeHTTP(w, r.WithContext(ctx)) + }) + } +} + +// GetMachineIDFromRequest extracts the machine ID from the request context +func GetMachineIDFromRequest(r *http.Request) (string, error) { + machineID, ok := r.Context().Value(MachineIDKey).(string) + if !ok || machineID == "" { + return "", errors.New("machine ID not found in request context") + } + return machineID, nil } diff --git a/pkg/apiserver/middlewares/v1/tls_auth.go b/pkg/apiserver/middlewares/v1/tls_auth.go index fadda5309fe..b9d786fce52 100644 --- a/pkg/apiserver/middlewares/v1/tls_auth.go +++ b/pkg/apiserver/middlewares/v1/tls_auth.go @@ -5,10 +5,10 @@ import ( "crypto/x509" "errors" "fmt" + "net/http" "slices" "time" - "github.com/gin-gonic/gin" log "github.com/sirupsen/logrus" ) @@ -37,7 +37,8 @@ func (ta *TLSAuth) isExpired(cert *x509.Certificate) bool { } // checkRevocationPath checks a single chain against OCSP and CRL -func (ta *TLSAuth) checkRevocationPath(ctx context.Context, chain []*x509.Certificate) (error, bool) { //nolint:revive +//nolint:revive // error-return: error should be last, but bool indicates if check was possible +func (ta *TLSAuth) checkRevocationPath(ctx context.Context, chain []*x509.Certificate) (error, bool) { // if we ever fail to check OCSP or CRL, we should not cache the result couldCheck := true @@ -97,21 +98,50 @@ func (ta *TLSAuth) checkAllowedOU(ous []string) error { return fmt.Errorf("client certificate OU %v doesn't match expected OU %v", ous, ta.AllowedOUs) } -func (ta *TLSAuth) ValidateCert(c *gin.Context) (string, error) { +func NewTLSAuth(allowedOus []string, crlPath string, cacheExpiration time.Duration, logger *log.Entry) (*TLSAuth, error) { + var err error + + cache := NewRevocationCache(cacheExpiration, logger) + + ta := &TLSAuth{ + revocationCache: cache, + ocspChecker: NewOCSPChecker(logger), + logger: logger, + } + + switch crlPath { + case "": + logger.Info("no crl_path, skipping CRL checks") + default: + ta.crlChecker, err = NewCRLChecker(crlPath, cache.Empty, logger) + if err != nil { + return nil, err + } + } + + if err := ta.setAllowedOu(allowedOus); err != nil { + return nil, err + } + + return ta, nil +} + +// ValidateCertFromRequest is like ValidateCert but takes http.Request instead of gin.Context +func (ta *TLSAuth) ValidateCertFromRequest(r *http.Request) (string, error) { // Checks cert validity, Returns true + CN if client cert matches requested OU var leaf *x509.Certificate - if c.Request.TLS == nil || len(c.Request.TLS.PeerCertificates) == 0 { + if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 { return "", errors.New("no certificate in request") } - if len(c.Request.TLS.VerifiedChains) == 0 { + if len(r.TLS.VerifiedChains) == 0 { return "", errors.New("no verified cert in request") } // although there can be multiple chains, the leaf certificate is the same // we take the first one - leaf = c.Request.TLS.VerifiedChains[0][0] + leaf = r.TLS.VerifiedChains[0][0] if err := ta.checkAllowedOU(leaf.Subject.OrganizationalUnit); err != nil { return "", err @@ -136,8 +166,8 @@ func (ta *TLSAuth) ValidateCert(c *gin.Context) (string, error) { couldCheck bool ) - for _, chain := range c.Request.TLS.VerifiedChains { - validErr, couldCheck = ta.checkRevocationPath(c.Request.Context(), chain) + for _, chain := range r.TLS.VerifiedChains { + validErr, couldCheck = ta.checkRevocationPath(r.Context(), chain) okToCache = okToCache && couldCheck if validErr != nil { @@ -155,31 +185,3 @@ func (ta *TLSAuth) ValidateCert(c *gin.Context) (string, error) { return leaf.Subject.CommonName, nil } - -func NewTLSAuth(allowedOus []string, crlPath string, cacheExpiration time.Duration, logger *log.Entry) (*TLSAuth, error) { - var err error - - cache := NewRevocationCache(cacheExpiration, logger) - - ta := &TLSAuth{ - revocationCache: cache, - ocspChecker: NewOCSPChecker(logger), - logger: logger, - } - - switch crlPath { - case "": - logger.Info("no crl_path, skipping CRL checks") - default: - ta.crlChecker, err = NewCRLChecker(crlPath, cache.Empty, logger) - if err != nil { - return nil, err - } - } - - if err := ta.setAllowedOu(allowedOus); err != nil { - return nil, err - } - - return ta, nil -} diff --git a/pkg/apiserver/router/helpers.go b/pkg/apiserver/router/helpers.go new file mode 100644 index 00000000000..6fc36d5f5e4 --- /dev/null +++ b/pkg/apiserver/router/helpers.go @@ -0,0 +1,316 @@ +package router + +import ( + "context" + "encoding/json" + "net" + "net/http" + "net/netip" + "strings" + + log "github.com/sirupsen/logrus" +) + +// Context key for client IP +type clientIPKey struct{} + +// Context key for route pattern +type routePatternKey struct{} + +// SetClientIP stores the resolved client IP in the request context +// This should be called by middleware that resolves the IP from trusted proxy headers +func SetClientIP(r *http.Request, ip string) *http.Request { + return r.WithContext(context.WithValue(r.Context(), clientIPKey{}, ip)) +} + +// SetRoutePattern stores the route pattern (template) in the request context +// This is used for metrics to avoid high cardinality from concrete paths +func SetRoutePattern(r *http.Request, pattern string) *http.Request { + return r.WithContext(context.WithValue(r.Context(), routePatternKey{}, pattern)) +} + +// GetRoutePattern retrieves the route pattern from the request context +// Falls back to r.URL.Path if not set (for backwards compatibility) +func GetRoutePattern(r *http.Request) string { + if pattern, ok := r.Context().Value(routePatternKey{}).(string); ok && pattern != "" { + return pattern + } + // If pattern not set, return empty string so invalid endpoint can be used + return "" +} + +// GetClientIP retrieves the client IP from the request context +// If not set in context, falls back to extracting from RemoteAddr +// This ensures backwards compatibility if middleware hasn't set it +func GetClientIP(r *http.Request) string { + if ip, ok := r.Context().Value(clientIPKey{}).(string); ok && ip != "" { + return ip + } + + // Fallback to RemoteAddr if not in context + if r.RemoteAddr == "@" { + return "127.0.0.1" + } + + host, _, err := net.SplitHostPort(r.RemoteAddr) + if err != nil { + return r.RemoteAddr + } + + return host +} + +// PathValue extracts a path parameter from the request +// This is a convenience wrapper around Request.PathValue() (Go 1.22+) +func PathValue(r *http.Request, key string) string { + return r.PathValue(key) +} + +// Query extracts a query parameter from the request URL +// Returns empty string if not found +func Query(r *http.Request, key string) string { + return r.URL.Query().Get(key) +} + +// QueryAll returns all values for a query parameter +func QueryAll(r *http.Request, key string) []string { + return r.URL.Query()[key] +} + +// BindJSON decodes the request body as JSON into the provided value +// The value must be a pointer to the target struct +// Unknown fields are allowed for backwards compatibility with clients that may send extra metadata +func BindJSON(r *http.Request, v any) error { + decoder := json.NewDecoder(r.Body) + // Removed DisallowUnknownFields() to maintain backwards compatibility + // Gin's ShouldBindJSON allowed unknown fields, and existing bouncers/log processors + // may ship extra metadata, especially when clients are upgraded at different cadences + return decoder.Decode(v) +} + +// JSON writes a JSON response with the given status code +// If encoding fails, it logs the error but does not return an error +// to match the behavior of common web frameworks where JSON encoding errors are rare +// Use WriteJSON if you need error handling +func JSON(w http.ResponseWriter, code int, v any) error { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(code) + data, err := json.Marshal(v) + if err != nil { + log.Errorf("Failed to encode JSON response: %v", err) + return err + } + _, err = w.Write(data) + return err +} + +// WriteJSON writes a JSON response without returning an error +// This is a convenience function that discards encoding errors (which are extremely rare) +func WriteJSON(w http.ResponseWriter, code int, v any) { + _ = JSON(w, code, v) +} + +// SetContextValue stores a value in the request context with the given key +// Returns a new request with the updated context +func SetContextValue(r *http.Request, key, value any) *http.Request { + return r.WithContext(context.WithValue(r.Context(), key, value)) +} + +// GetContextValue retrieves a value from the request context by key +// Returns nil if the key is not found +func GetContextValue(r *http.Request, key any) any { + return r.Context().Value(key) +} + +// ClientIP is a convenience wrapper around GetClientIP for backwards compatibility +// The trustedProxies parameter is ignored since IP resolution is handled by middleware +func ClientIP(r *http.Request, trustedProxies []net.IPNet) string { + return GetClientIP(r) +} + +// isTrustedProxy checks if an IP address is in the trusted proxy list +func isTrustedProxy(addr netip.Addr, trustedProxies []netip.Prefix) bool { + for _, prefix := range trustedProxies { + if prefix.Contains(addr) { + return true + } + } + return false +} + +// findFirstUntrustedIP iterates backwards through a comma-separated list of IPs +// and returns the first (rightmost) IP that is not in the trusted proxy list +func findFirstUntrustedIP(ipList string, trustedProxies []netip.Prefix) string { + if ipList == "" { + return "" + } + + ips := strings.Split(ipList, ",") + for i := len(ips) - 1; i >= 0; i-- { + ipStr := strings.TrimSpace(ips[i]) + addr, err := netip.ParseAddr(ipStr) + if err != nil { + continue + } + + if !isTrustedProxy(addr, trustedProxies) { + return ipStr + } + } + + return "" +} + +// isRemoteAddrTrusted checks if RemoteAddr is from a trusted proxy +func isRemoteAddrTrusted(r *http.Request, trustedProxies []netip.Prefix) bool { + if r.RemoteAddr == "" { + return true + } + + // Treat Unix sockets as 127.0.0.1 for trusted proxy checking + if r.RemoteAddr == "@" { + addr, err := netip.ParseAddr("127.0.0.1") + if err != nil { + return false + } + return isTrustedProxy(addr, trustedProxies) + } + + host, _, err := net.SplitHostPort(r.RemoteAddr) + if err != nil { + host = r.RemoteAddr + } + + remoteAddr, err := netip.ParseAddr(host) + if err != nil { + return false + } + + return isTrustedProxy(remoteAddr, trustedProxies) +} + +// resolveFromForwardedFor extracts client IP from X-Forwarded-For header +func resolveFromForwardedFor(r *http.Request, trustedProxies []netip.Prefix) string { + forwardedFor := r.Header.Get("X-Forwarded-For") + if forwardedFor == "" { + return "" + } + + return findFirstUntrustedIP(forwardedFor, trustedProxies) +} + +// resolveFromRealIP extracts client IP from X-Real-IP header +func resolveFromRealIP(r *http.Request, trustedProxies []netip.Prefix) string { + realIPHeader := r.Header.Get("X-Real-IP") + if realIPHeader == "" { + return "" + } + + // Check if RemoteAddr is from a trusted proxy + if !isRemoteAddrTrusted(r, trustedProxies) { + return "" + } + + if ip := findFirstUntrustedIP(realIPHeader, trustedProxies); ip != "" { + return ip + } + + realIPs := strings.Split(realIPHeader, ",") + if len(realIPs) == 0 { + return "" + } + + realIP := strings.TrimSpace(realIPs[0]) + if _, err := netip.ParseAddr(realIP); err == nil { + return realIP + } + + return "" +} + +// ResolveClientIP resolves the client IP from the request, respecting trusted proxy headers +func ResolveClientIP(r *http.Request, trustedProxies []netip.Prefix) string { + isUnixSocket := r.RemoteAddr == "@" + + if len(trustedProxies) == 0 { + if isUnixSocket { + return "127.0.0.1" + } + return extractIPFromRemoteAddr(r.RemoteAddr) + } + + if ip := resolveFromForwardedFor(r, trustedProxies); ip != "" { + return ip + } + + if ip := resolveFromRealIP(r, trustedProxies); ip != "" { + return ip + } + + if isUnixSocket { + return "127.0.0.1" + } + + return extractIPFromRemoteAddr(r.RemoteAddr) +} + +// extractIPFromRemoteAddr extracts the IP from RemoteAddr, handling port splitting +func extractIPFromRemoteAddr(remoteAddr string) string { + if remoteAddr == "" { + return "" + } + + host, _, err := net.SplitHostPort(remoteAddr) + if err != nil { + return remoteAddr + } + + return host +} + +// AbortWithStatus writes a status code and stops further processing +// This is a helper to mimic Gin's AbortWithStatus behavior +// Note: In standard HTTP handlers, you can't truly "abort" - you just return early +// This function sets the status code and can be used before returning +func AbortWithStatus(w http.ResponseWriter, code int) { + w.WriteHeader(code) +} + +// AbortWithJSON writes a JSON response with status code and stops further processing +func AbortWithJSON(w http.ResponseWriter, code int, v any) { + WriteJSON(w, code, v) +} + +// String writes a plain text response +func String(w http.ResponseWriter, code int, s string) { + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.WriteHeader(code) + _, _ = w.Write([]byte(s)) +} + +// GetHeader retrieves a header value from the request +func GetHeader(r *http.Request, key string) string { + return r.Header.Get(key) +} + +// SetHeader sets a header on the response +func SetHeader(w http.ResponseWriter, key, value string) { + w.Header().Set(key, value) +} + +// IsUnixSocket checks if the request came from a Unix socket +func IsUnixSocket(r *http.Request) bool { + return r.RemoteAddr == "@" +} + +// LogError logs an error with request context +func LogError(r *http.Request, err error, msg string) { + logger := log.WithError(err) + if r != nil { + logger = logger.WithFields(log.Fields{ + "method": r.Method, + "path": r.URL.Path, + }) + } + logger.Error(msg) +} diff --git a/pkg/apiserver/router/middleware.go b/pkg/apiserver/router/middleware.go new file mode 100644 index 00000000000..6d75a48bac6 --- /dev/null +++ b/pkg/apiserver/router/middleware.go @@ -0,0 +1,43 @@ +package router + +import ( + "net/http" + "strings" +) + +// Middleware is a function that wraps an http.Handler +// Standard Go middleware pattern +type Middleware func(http.Handler) http.Handler + +// ChainMiddleware chains multiple middlewares together +// The first middleware in the slice is the outermost (executes first) +// The last middleware in the slice is the innermost (executes last) +func ChainMiddleware(middlewares ...Middleware) Middleware { + return func(next http.Handler) http.Handler { + for i := len(middlewares) - 1; i >= 0; i-- { + next = middlewares[i](next) + } + return next + } +} + +// AdaptHandlerFunc converts an http.HandlerFunc to an http.Handler +// This is useful when mixing Handler and HandlerFunc types +func AdaptHandlerFunc(fn http.HandlerFunc) http.Handler { + return fn +} + +// MethodNotAllowedHandler returns a handler that responds with 405 Method Not Allowed +func MethodNotAllowedHandler(allowedMethods ...string) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Allow", strings.Join(allowedMethods, ", ")) + AbortWithStatus(w, http.StatusMethodNotAllowed) + }) +} + +// NotFoundHandler returns a handler that responds with 404 Not Found +func NotFoundHandler() http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _ = JSON(w, http.StatusNotFound, map[string]string{"message": "Page or Method not found"}) + }) +} diff --git a/pkg/apiserver/router/router.go b/pkg/apiserver/router/router.go new file mode 100644 index 00000000000..f7ddf0eaab8 --- /dev/null +++ b/pkg/apiserver/router/router.go @@ -0,0 +1,273 @@ +package router + +import ( + "net/http" + "strings" +) + +// Router wraps http.ServeMux with additional functionality for route groups and middleware +type Router struct { + mux *http.ServeMux + middleware []Middleware + wrappedHandler http.Handler // Cached handler with middleware chain (built once via Build()) + patternMap map[string]string // Maps fullPattern (method + path) to route template + built bool // Whether Build() has been called +} + +const ( + // UnknownRoutePattern is the sentinel pattern used for unmatched routes (404/405) + // This ensures metrics cardinality stays bounded for unknown routes + UnknownRoutePattern = "invalid-endpoint" +) + +// New creates a new Router instance +func New() *Router { + return &Router{ + mux: http.NewServeMux(), + middleware: []Middleware{}, + patternMap: make(map[string]string), + built: false, + } +} + +// ServeMux returns the underlying http.ServeMux for use with http.Server +func (r *Router) ServeMux() *http.ServeMux { + return r.mux +} + +// ServeHTTP implements http.Handler so Router can be used directly as a handler +func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) { + // Use cached wrapped handler (built via Build()) + r.wrappedHandler.ServeHTTP(w, req) + // Note: http.ServeMux will handle 404/405 with its default responses, but they still go through + // our middleware chain above, so logging/recovery/gzip all work correctly +} + +// matchPattern finds the route template that matches the given method and path +// Supports path variables (e.g., /v1/alerts/{alert_id} matches /v1/alerts/123) +// Returns UnknownRoutePattern if no match is found +// Safe to call without locks since patternMap is read-only after Build() +func (r *Router) matchPattern(method, path string) string { + // Try exact match first: "GET /v1/alerts/{alert_id}" + methodPath := method + " " + path + if pattern, ok := r.patternMap[methodPath]; ok { + return pattern + } + + // Try path-only match: "/v1/alerts/{alert_id}" + if pattern, ok := r.patternMap[path]; ok { + return pattern + } + + // Try to match against patterns with variables + // Compare segment by segment using matchesPattern helper + for storedPattern, template := range r.patternMap { + // Extract method and path from stored pattern + storedMethod := "" + patternPath := storedPattern + if idx := strings.Index(storedPattern, " "); idx > 0 { + storedMethod = storedPattern[:idx] + patternPath = storedPattern[idx+1:] + } + + // Skip if method doesn't match (unless stored pattern has no method) + if storedMethod != "" && storedMethod != method { + continue + } + + // Check if pattern has variables and matches path + if strings.Contains(patternPath, "{") && matchesPattern(path, patternPath) { + return template + } + } + + // Return sentinel pattern for unknown routes to bound metrics cardinality + return UnknownRoutePattern +} + +// matchesPattern checks if a concrete path matches a pattern with variables +// e.g., /v1/alerts/123 matches /v1/alerts/{alert_id} +func matchesPattern(path, pattern string) bool { + pathParts := strings.Split(strings.Trim(path, "/"), "/") + patternParts := strings.Split(strings.Trim(pattern, "/"), "/") + + if len(pathParts) != len(patternParts) { + return false + } + + for i := range pathParts { + // If pattern part is a variable {something}, it matches any path part + if strings.HasPrefix(patternParts[i], "{") && strings.HasSuffix(patternParts[i], "}") { + continue + } + // Otherwise, parts must match exactly + if pathParts[i] != patternParts[i] { + return false + } + } + + return true +} + +// routePatternMiddleware sets the route pattern in context before other middleware runs +// This must be the first middleware so metrics see the template pattern +// Unknown routes get UnknownRoutePattern to bound metrics cardinality +func (r *Router) routePatternMiddleware() Middleware { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + pattern := r.matchPattern(req.Method, req.URL.Path) + req = SetRoutePattern(req, pattern) + next.ServeHTTP(w, req) + }) + } +} + +// Build finalizes the router by building the wrapped handler with all middleware +// This should be called once after all routes and middleware are registered +// After Build(), the router is read-only and safe for concurrent use +func (r *Router) Build() { + if r.built { + return // Already built, no-op + } + + middlewares := make([]Middleware, 0, len(r.middleware)+1) + middlewares = append(middlewares, r.routePatternMiddleware()) + middlewares = append(middlewares, r.middleware...) + r.wrappedHandler = ChainMiddleware(middlewares...)(r.mux) + r.built = true +} + +// Use adds middleware to the router that will be applied to all routes +// Must be called before Build() +func (r *Router) Use(middlewares ...Middleware) { + if r.built { + panic("router: cannot add middleware after Build()") + } + r.middleware = append(r.middleware, middlewares...) +} + +// Group creates a route group with a path prefix +// All routes registered in the group will have the prefix prepended +func (r *Router) Group(prefix string) *Group { + // Ensure prefix starts with / and doesn't end with / + prefix = "/" + strings.Trim(prefix, "/") + if prefix == "/" { + prefix = "" + } + + return &Group{ + router: r, + prefix: prefix, + } +} + +// HandleFunc registers a handler for the given pattern with optional method restriction +// Pattern can use Go 1.22+ path variables like /users/{id} +// If method is empty, it matches all methods +// Note: Router-level middleware is applied in ServeHTTP, not here, to ensure 404/405 also go through middleware +// The pattern is stored for routePatternMiddleware to use for metrics +// Must be called before Build() +func (r *Router) HandleFunc(pattern, method string, handler http.HandlerFunc) { + if r.built { + panic("router: cannot register routes after Build()") + } + fullPattern := pattern + if method != "" { + fullPattern = method + " " + pattern + } + // Store pattern mapping for routePatternMiddleware + r.patternMap[fullPattern] = pattern + // Only store bare path entry for method-less handlers to avoid template conflicts + if method == "" { + r.patternMap[pattern] = pattern + } + r.mux.HandleFunc(fullPattern, handler) +} + +// Handle registers a handler for the given pattern with optional method restriction +// Note: Router-level middleware is applied in ServeHTTP, not here, to ensure 404/405 also go through middleware +// The pattern is stored for routePatternMiddleware to use for metrics +// Must be called before Build() +func (r *Router) Handle(pattern, method string, handler http.Handler) { + if r.built { + panic("router: cannot register routes after Build()") + } + fullPattern := pattern + if method != "" { + fullPattern = method + " " + pattern + } + // Store pattern mapping for routePatternMiddleware + r.patternMap[fullPattern] = pattern + // Only store bare path entry for method-less handlers to avoid template conflicts + if method == "" { + r.patternMap[pattern] = pattern + } + r.mux.Handle(fullPattern, handler) +} + +// Group represents a route group with a common prefix +type Group struct { + router *Router + prefix string + middleware []Middleware +} + +// Group creates a sub-group with an additional path prefix +func (g *Group) Group(prefix string) *Group { + // Combine prefixes + fullPrefix := g.prefix + if prefix != "" { + if fullPrefix != "" { + fullPrefix = fullPrefix + "/" + strings.Trim(prefix, "/") + } else { + fullPrefix = "/" + strings.Trim(prefix, "/") + } + } + // Copy parent middleware to child group + parentMiddleware := make([]Middleware, len(g.middleware)) + copy(parentMiddleware, g.middleware) + + return &Group{ + router: g.router, + prefix: fullPrefix, + middleware: parentMiddleware, + } +} + +// Use adds middleware to the group that will be applied to all routes in the group +func (g *Group) Use(middlewares ...Middleware) { + g.middleware = append(g.middleware, middlewares...) +} + +// HandleFunc registers a handler for the given pattern in the group +// The group prefix is automatically prepended to the pattern +func (g *Group) HandleFunc(pattern, method string, handler http.HandlerFunc) { + fullPattern := g.prefix + pattern + + // Apply group middleware + var wrapped http.Handler = handler + if len(g.middleware) > 0 { + wrapped = ChainMiddleware(g.middleware...)(handler) + } + + // Register with router (router-level middleware is applied in ServeHTTP) + // Router.HandleFunc will store the pattern for routePatternMiddleware + g.router.HandleFunc(fullPattern, method, func(w http.ResponseWriter, r *http.Request) { + wrapped.ServeHTTP(w, r) + }) +} + +// Handle registers a handler for the given pattern in the group +func (g *Group) Handle(pattern, method string, handler http.Handler) { + fullPattern := g.prefix + pattern + + // Apply group middleware + wrapped := handler + if len(g.middleware) > 0 { + wrapped = ChainMiddleware(g.middleware...)(handler) + } + + // Register with router (router-level middleware is applied in ServeHTTP) + // Router.Handle will store the pattern for routePatternMiddleware + g.router.Handle(fullPattern, method, wrapped) +} diff --git a/test/bats/11_bouncers_tls.bats b/test/bats/11_bouncers_tls.bats index 19a2a23ede9..151c19d5b8b 100644 --- a/test/bats/11_bouncers_tls.bats +++ b/test/bats/11_bouncers_tls.bats @@ -185,7 +185,7 @@ teardown() { --user-agent "crowdsec/someversion" \ https://localhost:8080/v1/usage-metrics -X POST --data "$payload" assert_stderr --partial 'error: 401' - assert_json '{code:401, message: "cookie token is empty"}' + assert_json '{code:401, message: "token not found"}' rune cscli bouncers delete localhost@127.0.0.1 }