Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .testcoverage.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@ exclude:
- mocks # exclude generated mock files
- ^test/
- ^logger/testlogger
- middleware/test_support

19 changes: 19 additions & 0 deletions middleware/auth.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package middleware

import (
"net/http"

"github.com/go-http-utils/headers"
"github.com/platform-mesh/golang-commons/context"
)

// StoreAuthHeader stores the Authorization header within the request context
func StoreAuthHeader() func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(responseWriter http.ResponseWriter, request *http.Request) {
auth := request.Header.Get(headers.Authorization)
ctx := context.AddAuthHeaderToContext(request.Context(), auth)
next.ServeHTTP(responseWriter, request.WithContext(ctx))
})
}
}
101 changes: 101 additions & 0 deletions middleware/auth_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
package middleware

import (
"net/http"
"net/http/httptest"
"testing"

"github.com/go-http-utils/headers"
"github.com/platform-mesh/golang-commons/context"
"github.com/stretchr/testify/assert"
)

func TestStoreAuthHeader_WithAuthHeader(t *testing.T) {
expectedAuth := "Bearer token123"

nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Verify auth header is stored in context
authFromContext, err := context.GetAuthHeaderFromContext(r.Context())
assert.NoError(t, err)
assert.Equal(t, expectedAuth, authFromContext)

w.WriteHeader(http.StatusOK)
})

middleware := StoreAuthHeader()
handlerToTest := middleware(nextHandler)

req := httptest.NewRequest("GET", "http://testing", nil)
req.Header.Set(headers.Authorization, expectedAuth)
recorder := httptest.NewRecorder()

handlerToTest.ServeHTTP(recorder, req)

assert.Equal(t, http.StatusOK, recorder.Code)
}

func TestStoreAuthHeader_WithoutAuthHeader(t *testing.T) {
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Verify empty auth header returns error when not set
_, err := context.GetAuthHeaderFromContext(r.Context())
assert.Error(t, err) // Should return error when no auth header is set

w.WriteHeader(http.StatusOK)
})

middleware := StoreAuthHeader()
handlerToTest := middleware(nextHandler)

req := httptest.NewRequest("GET", "http://testing", nil)
// No authorization header set
recorder := httptest.NewRecorder()

handlerToTest.ServeHTTP(recorder, req)

assert.Equal(t, http.StatusOK, recorder.Code)
}

func TestStoreAuthHeader_WithEmptyAuthHeader(t *testing.T) {
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Verify empty auth header returns error when empty
_, err := context.GetAuthHeaderFromContext(r.Context())
assert.Error(t, err) // Should return error when auth header is empty

w.WriteHeader(http.StatusOK)
})

middleware := StoreAuthHeader()
handlerToTest := middleware(nextHandler)

req := httptest.NewRequest("GET", "http://testing", nil)
req.Header.Set(headers.Authorization, "")
recorder := httptest.NewRecorder()

handlerToTest.ServeHTTP(recorder, req)

assert.Equal(t, http.StatusOK, recorder.Code)
}

func TestStoreAuthHeader_MultipleAuthHeaders(t *testing.T) {
// Test behavior when multiple authorization headers are present
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Should get the first/combined value
authFromContext, err := context.GetAuthHeaderFromContext(r.Context())
assert.NoError(t, err)
assert.NotEmpty(t, authFromContext)

w.WriteHeader(http.StatusOK)
})

middleware := StoreAuthHeader()
handlerToTest := middleware(nextHandler)

req := httptest.NewRequest("GET", "http://testing", nil)
req.Header.Add(headers.Authorization, "Bearer token1")
req.Header.Add(headers.Authorization, "Bearer token2")
recorder := httptest.NewRecorder()

handlerToTest.ServeHTTP(recorder, req)

assert.Equal(t, http.StatusOK, recorder.Code)
}
22 changes: 22 additions & 0 deletions middleware/authzMiddlewares.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package middleware

import (
"net/http"

"github.com/platform-mesh/golang-commons/policy_services"
)

// Deprecated: CreateMiddlewares use CreateAuthMiddleware instead.
func CreateMiddlewares(retriever policy_services.TenantRetriever) []func(http.Handler) http.Handler {
return CreateAuthMiddleware(retriever)
}

func CreateAuthMiddleware(retriever policy_services.TenantRetriever) []func(http.Handler) http.Handler {
mws := make([]func(http.Handler) http.Handler, 0, 5)

mws = append(mws, StoreWebToken())
mws = append(mws, StoreAuthHeader())
mws = append(mws, StoreSpiffeHeader())

return mws
}
104 changes: 104 additions & 0 deletions middleware/authzMiddlewares_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
package middleware

import (
"context"
"testing"

"github.com/platform-mesh/golang-commons/policy_services"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)

// MockTenantRetriever is a mock implementation of TenantRetriever
type MockTenantRetriever struct {
mock.Mock
}

func (m *MockTenantRetriever) RetrieveTenant(ctx context.Context) (string, error) {
args := m.Called(ctx)
return args.String(0), args.Error(1)
}

func TestCreateMiddlewares_Deprecated(t *testing.T) {
mockRetriever := &MockTenantRetriever{}

middlewares := CreateMiddlewares(mockRetriever)

// Should return 3 middlewares (same as CreateAuthMiddleware)
assert.Len(t, middlewares, 3)

// Each middleware should be a valid function
for _, mw := range middlewares {
assert.NotNil(t, mw)
}
}

func TestCreateAuthMiddleware(t *testing.T) {
mockRetriever := &MockTenantRetriever{}

middlewares := CreateAuthMiddleware(mockRetriever)

// Should return 3 middlewares: StoreWebToken, StoreAuthHeader, StoreSpiffeHeader
assert.Len(t, middlewares, 3)

// Each middleware should be a valid function
for _, mw := range middlewares {
assert.NotNil(t, mw)
}
}

func TestCreateAuthMiddleware_WithNilRetriever(t *testing.T) {
middlewares := CreateAuthMiddleware(nil)

// Should still return 3 middlewares even with nil retriever
assert.Len(t, middlewares, 3)

// Each middleware should be a valid function
for _, mw := range middlewares {
assert.NotNil(t, mw)
}
}

func TestCreateAuthMiddleware_ReturnsCorrectMiddlewares(t *testing.T) {
mockRetriever := &MockTenantRetriever{}

middlewares := CreateAuthMiddleware(mockRetriever)

// Verify we get exactly 3 middlewares
assert.Len(t, middlewares, 3)

// We can't easily test the exact middleware functions returned without more complex setup,
// but we can verify they're all valid middleware functions by checking their signatures
for _, mw := range middlewares {
assert.NotNil(t, mw)
// Each middleware should be a function that takes an http.Handler and returns an http.Handler
// This is implicitly tested by the fact that the function compiles and returns without error
}
}

func TestCreateMiddlewares_Equivalence(t *testing.T) {
mockRetriever := &MockTenantRetriever{}

deprecatedMiddlewares := CreateMiddlewares(mockRetriever)
newMiddlewares := CreateAuthMiddleware(mockRetriever)

// Both functions should return the same number of middlewares
assert.Equal(t, len(deprecatedMiddlewares), len(newMiddlewares))

// Both should return 3 middlewares
assert.Len(t, deprecatedMiddlewares, 3)
assert.Len(t, newMiddlewares, 3)
}

// Test that implements policy_services.TenantRetriever interface
func TestTenantRetrieverInterface(t *testing.T) {
mockRetriever := &MockTenantRetriever{}

// Verify our mock implements the interface
var retriever policy_services.TenantRetriever = mockRetriever
assert.NotNil(t, retriever)

// Test that we can use it with the middleware functions
middlewares := CreateAuthMiddleware(retriever)
assert.Len(t, middlewares, 3)
}
17 changes: 17 additions & 0 deletions middleware/logger.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package middleware

import (
"net/http"

"github.com/platform-mesh/golang-commons/logger"
)

// StoreLoggerMiddleware is a middleware that stores a given Logger in the request context
func StoreLoggerMiddleware(log *logger.Logger) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := logger.SetLoggerInContext(r.Context(), log)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
56 changes: 56 additions & 0 deletions middleware/logger_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package middleware

import (
"net/http"
"net/http/httptest"
"testing"

"github.com/platform-mesh/golang-commons/logger"
"github.com/platform-mesh/golang-commons/logger/testlogger"
"github.com/stretchr/testify/assert"
)

func TestStoreLoggerMiddleware(t *testing.T) {
testLog := testlogger.New()

nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Verify logger is stored in context
logFromContext := logger.LoadLoggerFromContext(r.Context())
assert.NotNil(t, logFromContext)

// The logger should be the same instance we passed
assert.Equal(t, testLog.Logger, logFromContext)

w.WriteHeader(http.StatusOK)
})

middleware := StoreLoggerMiddleware(testLog.Logger)
handlerToTest := middleware(nextHandler)

req := httptest.NewRequest("GET", "http://testing", nil)
recorder := httptest.NewRecorder()

handlerToTest.ServeHTTP(recorder, req)

assert.Equal(t, http.StatusOK, recorder.Code)
}

func TestStoreLoggerMiddleware_NilLogger(t *testing.T) {
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Even with nil logger, the middleware should not panic
w.WriteHeader(http.StatusOK)
})

middleware := StoreLoggerMiddleware(nil)
handlerToTest := middleware(nextHandler)

req := httptest.NewRequest("GET", "http://testing", nil)
recorder := httptest.NewRecorder()

// Should not panic
assert.NotPanics(t, func() {
handlerToTest.ServeHTTP(recorder, req)
})

assert.Equal(t, http.StatusOK, recorder.Code)
}
17 changes: 17 additions & 0 deletions middleware/middleware.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package middleware

import (
"net/http"

"github.com/platform-mesh/golang-commons/logger"
)

func CreateMiddleware(log *logger.Logger) []func(http.Handler) http.Handler {
return []func(http.Handler) http.Handler{
SetOtelTracingContext(),
SentryRecoverer,
StoreLoggerMiddleware(log),
SetRequestId(),
SetRequestIdInLogger(),
}
}
41 changes: 41 additions & 0 deletions middleware/middleware_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package middleware

import (
"net/http"
"net/http/httptest"
"testing"

"github.com/platform-mesh/golang-commons/logger/testlogger"
"github.com/stretchr/testify/assert"
)

func TestCreateMiddleware(t *testing.T) {
log := testlogger.New()
middlewares := CreateMiddleware(log.Logger)

// Should return 5 middlewares
assert.Len(t, middlewares, 5)

// Each middleware should be a valid function
for _, mw := range middlewares {
assert.NotNil(t, mw)
}

// Test that middlewares can be chained
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})

// Apply all middlewares
var finalHandler http.Handler = handler
for i := len(middlewares) - 1; i >= 0; i-- {
finalHandler = middlewares[i](finalHandler)
}

req := httptest.NewRequest("GET", "http://testing", nil)
recorder := httptest.NewRecorder()

finalHandler.ServeHTTP(recorder, req)

assert.Equal(t, http.StatusOK, recorder.Code)
}
Loading
Loading