Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 30 additions & 13 deletions api/v1/server/authn/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,14 @@ import (
"github.com/labstack/echo/v4"
"github.com/rs/zerolog"

"go.opentelemetry.io/otel/trace"

"github.com/hatchet-dev/hatchet/api/v1/server/middleware"
"github.com/hatchet-dev/hatchet/api/v1/server/middleware/redirect"
"github.com/hatchet-dev/hatchet/pkg/analytics"
"github.com/hatchet-dev/hatchet/pkg/config/server"
"github.com/hatchet-dev/hatchet/pkg/repository/sqlcv1"
"github.com/hatchet-dev/hatchet/pkg/telemetry"
)

type AuthN struct {
Expand Down Expand Up @@ -111,16 +114,18 @@ func (a *AuthN) authenticate(c echo.Context, r *middleware.RouteInfo) error {
func (a *AuthN) handleNoAuth(c echo.Context) error {
store := a.config.SessionStore

ctx := c.Request().Context()

session, err := store.Get(c.Request(), store.GetName())

if err != nil {
a.l.Debug().Err(err).Msg("error getting session")
a.l.Debug().Ctx(ctx).Err(err).Msg("error getting session")

return redirect.GetRedirectWithError(c, a.l, err, "Could not log in. Please try again and make sure cookies are enabled.")
}

if auth, ok := session.Values["authenticated"].(bool); ok && auth {
a.l.Debug().Msgf("user was authenticated when no security schemes permit auth")
a.l.Debug().Ctx(ctx).Msgf("user was authenticated when no security schemes permit auth")

return redirect.GetRedirectNoError(c, a.config.Runtime.ServerURL)
}
Expand All @@ -137,11 +142,12 @@ func (a *AuthN) handleCookieAuth(c echo.Context) error {
store := a.config.SessionStore

session, err := store.Get(c.Request(), store.GetName())
ctx := c.Request().Context()
if err != nil {
err = a.helpers.SaveUnauthenticated(c)

if err != nil {
a.l.Error().Err(err).Msg("error saving unauthenticated session")
a.l.Error().Ctx(ctx).Err(err).Msg("error saving unauthenticated session")
return fmt.Errorf("error saving unauthenticated session")
}

Expand All @@ -152,7 +158,7 @@ func (a *AuthN) handleCookieAuth(c echo.Context) error {
// if the session is new, make sure we write a Set-Cookie header to the response
if session.IsNew {
if saveErr := a.helpers.SaveNewSession(c, session); saveErr != nil {
a.l.Error().Err(saveErr).Msg("error saving unauthenticated session")
a.l.Error().Ctx(ctx).Err(saveErr).Msg("error saving unauthenticated session")
return fmt.Errorf("error saving unauthenticated session")
}

Expand All @@ -166,22 +172,22 @@ func (a *AuthN) handleCookieAuth(c echo.Context) error {
userID, ok := session.Values["user_id"].(string)

if !ok {
a.l.Debug().Msgf("could not cast user_id to string")
a.l.Debug().Ctx(ctx).Msgf("could not cast user_id to string")

return forbidden
}

userIdUUID, err := uuid.Parse(userID)

if err != nil {
a.l.Debug().Err(err).Msg("error parsing user id uuid from session")
a.l.Debug().Ctx(ctx).Err(err).Msg("error parsing user id uuid from session")

return forbidden
}

user, err := a.config.V1.User().GetUserByID(c.Request().Context(), userIdUUID)
if err != nil {
a.l.Debug().Err(err).Msg("error getting user by id")
a.l.Debug().Ctx(ctx).Err(err).Msg("error getting user by id")

if errors.Is(err, pgx.ErrNoRows) {
return forbidden
Expand All @@ -193,8 +199,12 @@ func (a *AuthN) handleCookieAuth(c echo.Context) error {
c.Set("user", user)
c.Set("session", session)

ctx := context.WithValue(c.Request().Context(), analytics.UserIDKey, userIdUUID)
ctx = context.WithValue(ctx, analytics.UserIDKey, userIdUUID)
ctx = context.WithValue(ctx, analytics.SourceKey, analytics.SourceUI)

span := trace.SpanFromContext(ctx)
telemetry.WithAttributes(span, telemetry.AttributeKV{Key: "user.id", Value: userIdUUID})

c.SetRequest(c.Request().WithContext(ctx))

return nil
Expand All @@ -205,18 +215,19 @@ func (a *AuthN) handleBearerAuth(c echo.Context) error {

// a tenant id must exist in the context in order for the bearer auth to succeed, since
// these tokens are tenant-scoped
ctx := c.Request().Context()
queriedTenant, ok := c.Get("tenant").(*sqlcv1.Tenant)

if !ok {
a.l.Debug().Msgf("tenant not found in context")
a.l.Debug().Ctx(ctx).Msgf("tenant not found in context")

return fmt.Errorf("tenant not found in context")
}

token, err := getBearerTokenFromRequest(c.Request())

if err != nil {
a.l.Debug().Err(err).Msg("error getting bearer token from request")
a.l.Debug().Ctx(ctx).Err(err).Msg("error getting bearer token from request")

return forbidden
}
Expand All @@ -225,24 +236,30 @@ func (a *AuthN) handleBearerAuth(c echo.Context) error {
tenantId, tokenUUID, err := a.config.Auth.JWTManager.ValidateTenantToken(c.Request().Context(), token)

if err != nil {
a.l.Debug().Err(err).Msg("error validating tenant token")
a.l.Debug().Ctx(ctx).Err(err).Msg("error validating tenant token")

return forbidden
}

// Verify that the tenant id which exists in the context is the same as the tenant id
// in the token.
if queriedTenant.ID != tenantId {
a.l.Debug().Msgf("tenant id in token does not match tenant id in context")
a.l.Debug().Ctx(ctx).Msgf("tenant id in token does not match tenant id in context")

return forbidden
}

c.Set(string(analytics.APITokenIDKey), tokenUUID)

ctx := context.WithValue(c.Request().Context(), analytics.APITokenIDKey, tokenUUID)
ctx = context.WithValue(ctx, analytics.APITokenIDKey, tokenUUID)
ctx = context.WithValue(ctx, analytics.TenantIDKey, tenantId)
ctx = context.WithValue(ctx, analytics.SourceKey, analytics.SourceAPI)

span := trace.SpanFromContext(ctx)
telemetry.WithAttributes(span,
telemetry.AttributeKV{Key: "tenant.id", Value: tenantId},
)

c.SetRequest(c.Request().WithContext(ctx))

return nil
Expand Down
12 changes: 7 additions & 5 deletions api/v1/server/authz/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,10 @@ func (a *AuthZ) authorize(c echo.Context, r *middleware.RouteInfo) error {
func (a *AuthZ) handleCookieAuth(c echo.Context, r *middleware.RouteInfo) error {
unauthorized := echo.NewHTTPError(http.StatusUnauthorized, "Not authorized to view this resource")

ctx := c.Request().Context()

if err := a.ensureVerifiedEmail(c, r); err != nil {
a.l.Debug().Err(err).Msgf("error ensuring verified email")
a.l.Debug().Ctx(ctx).Err(err).Msgf("error ensuring verified email")
return echo.NewHTTPError(http.StatusUnauthorized, "Please verify your email before continuing")
}

Expand All @@ -76,7 +78,7 @@ func (a *AuthZ) handleCookieAuth(c echo.Context, r *middleware.RouteInfo) error
user, ok := c.Get("user").(*sqlcv1.User)

if !ok {
a.l.Debug().Msgf("user not found in context")
a.l.Debug().Ctx(ctx).Msgf("user not found in context")

return unauthorized
}
Expand All @@ -85,13 +87,13 @@ func (a *AuthZ) handleCookieAuth(c echo.Context, r *middleware.RouteInfo) error
tenantMember, err := a.config.V1.Tenant().GetTenantMemberByUserID(c.Request().Context(), tenant.ID, user.ID)

if err != nil {
a.l.Debug().Err(err).Msgf("error getting tenant member")
a.l.Debug().Ctx(ctx).Err(err).Msgf("error getting tenant member")

return unauthorized
}

if tenantMember == nil {
a.l.Debug().Msgf("user is not a member of the tenant")
a.l.Debug().Ctx(ctx).Msgf("user is not a member of the tenant")

return unauthorized
}
Expand All @@ -101,7 +103,7 @@ func (a *AuthZ) handleCookieAuth(c echo.Context, r *middleware.RouteInfo) error

// authorize tenant operations
if err := a.authorizeTenantOperations(tenantMember.Role, r); err != nil {
a.l.Debug().Err(err).Msgf("error authorizing tenant operations")
a.l.Debug().Ctx(ctx).Err(err).Msgf("error authorizing tenant operations")

return unauthorized
}
Expand Down
6 changes: 4 additions & 2 deletions api/v1/server/handlers/monitoring/probe.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@ import (
"github.com/hatchet-dev/hatchet/api/v1/server/oas/gen"
)

func (m *MonitoringService) MonitoringPostRunProbe(ctx echo.Context, request gen.MonitoringPostRunProbeRequestObject) (gen.MonitoringPostRunProbeResponseObject, error) {
func (m *MonitoringService) MonitoringPostRunProbe(c echo.Context, request gen.MonitoringPostRunProbeRequestObject) (gen.MonitoringPostRunProbeResponseObject, error) {
ctx := c.Request().Context()

if !m.enabled {
m.l.Error().Msg("monitoring is not enabled")
m.l.Error().Ctx(ctx).Msg("monitoring is not enabled")
return gen.MonitoringPostRunProbe403JSONResponse{}, nil
}

Expand Down
41 changes: 41 additions & 0 deletions api/v1/server/middleware/telemetry/telemetry.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
package telemetry

import (
"errors"
"fmt"

"github.com/labstack/echo/v4"
"go.opentelemetry.io/contrib/instrumentation/github.com/labstack/echo/otelecho"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/codes"
"go.opentelemetry.io/otel/trace"

"github.com/hatchet-dev/hatchet/pkg/config/server"
)
Expand All @@ -30,3 +35,39 @@ func (m *OTelMiddleware) Middleware() echo.MiddlewareFunc {
otelecho.WithTracerProvider(tracerProvider),
)
}

// ErrorStatusMiddleware marks the current span as Error for any 4xx or 5xx response.
// otelecho only sets Error for 5xx (per OTel semantic conventions). This middleware
// must be registered after otelecho so it runs inside the span. The OTel SDK ignores
// attempts to downgrade from Error to Unset, so otelecho's subsequent status-setting
// for 4xx is a no-op.
func (m *OTelMiddleware) ErrorStatusMiddleware() echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
err := next(c)

span := trace.SpanFromContext(c.Request().Context())

statusCode := 0
if err != nil {
var he *echo.HTTPError
if errors.As(err, &he) {
statusCode = he.Code
}
}

if statusCode == 0 {
statusCode = c.Response().Status
}

if statusCode >= 400 {
span.SetStatus(codes.Error, fmt.Sprintf("HTTP %d", statusCode))
if err != nil {
span.RecordError(err)
}
}

return err
}
}
}
1 change: 1 addition & 0 deletions api/v1/server/run/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -722,6 +722,7 @@ func (t *APIServer) registerSpec(g *echo.Group, spec *openapi3.T) (*populator.Po
middleware.Recover(),
rateLimitMW.Middleware(),
otelMW.Middleware(),
otelMW.ErrorStatusMiddleware(),
allHatchetMiddleware,
)

Expand Down
4 changes: 2 additions & 2 deletions internal/msgqueue/postgres/exchange.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ func (p *PostgresMessageQueue) addTenantExchangeMessage(ctx context.Context, q m
err := p.RegisterTenant(ctx, tenantId)

if err != nil {
p.l.Error().Msgf("error registering tenant exchange: %v", err)
p.l.Error().Ctx(ctx).Msgf("error registering tenant exchange: %v", err)
return err
}

Expand Down Expand Up @@ -51,7 +51,7 @@ func (p *PostgresMessageQueue) pubNonDurableMessages(ctx context.Context, queueN
return p.repo.Notify(ctx, queueName, string(msgBytes))
})
} else {
p.l.Error().Err(err).Msg("error marshalling message")
p.l.Error().Ctx(ctx).Err(err).Msg("error marshalling message")
}
}

Expand Down
12 changes: 6 additions & 6 deletions internal/msgqueue/rabbitmq/rabbitmq.go
Original file line number Diff line number Diff line change
Expand Up @@ -477,19 +477,19 @@ func (t *MessageQueueImpl) pubMessage(ctx context.Context, q msgqueue.Queue, msg
err = t.RegisterTenant(ctx, msg.TenantID)

if err != nil {
t.l.Error().Msgf("error registering tenant exchange: %v", err)
t.l.Error().Ctx(ctx).Str("tenant_id", msg.TenantID.String()).Msgf("error registering tenant exchange: %v", err)
return err
}
}

t.l.Debug().Msgf("publishing tenant msg to exchange %s", msg.TenantID)
t.l.Debug().Ctx(ctx).Str("tenant_id", msg.TenantID.String()).Msgf("publishing tenant msg to exchange %s", msg.TenantID)

err = pub.PublishWithContext(ctx, msgqueue.GetTenantExchangeName(msg.TenantID), "", false, false, amqp.Publishing{
Body: body,
})

if err != nil {
t.l.Error().Msgf("error publishing tenant msg: %v", err)
t.l.Error().Ctx(ctx).Str("tenant_id", msg.TenantID.String()).Msgf("error publishing tenant msg: %v", err)
return err
}
}
Expand Down Expand Up @@ -558,7 +558,7 @@ func (t *MessageQueueImpl) RegisterTenant(ctx context.Context, tenantId uuid.UUI
poolCh, err := t.pubChannels.Acquire(ctx)

if err != nil {
t.l.Error().Msgf("[RegisterTenant] cannot acquire channel: %v", err)
t.l.Error().Ctx(ctx).Str("tenant_id", tenantId.String()).Msgf("[RegisterTenant] cannot acquire channel: %v", err)
return err
}

Expand All @@ -571,7 +571,7 @@ func (t *MessageQueueImpl) RegisterTenant(ctx context.Context, tenantId uuid.UUI

defer poolCh.Release()

t.l.Debug().Msgf("registering tenant exchange: %s", tenantId)
t.l.Debug().Ctx(ctx).Str("tenant_id", tenantId.String()).Msgf("registering tenant exchange: %s", tenantId)

// create a fanout exchange for the tenant. each consumer of the fanout exchange will get notified
// with the tenant events.
Expand All @@ -586,7 +586,7 @@ func (t *MessageQueueImpl) RegisterTenant(ctx context.Context, tenantId uuid.UUI
)

if err != nil {
t.l.Error().Msgf("cannot declare exchange: %q, %v", tenantId, err)
t.l.Error().Ctx(ctx).Str("tenant_id", tenantId.String()).Msgf("cannot declare exchange: %q, %v", tenantId, err)
return err
}

Expand Down
11 changes: 6 additions & 5 deletions internal/msgqueue/rabbitmq/rabbitmq_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -343,10 +343,11 @@ func TestDeadLetteringSuccess(t *testing.T) {

// deleteQueue is a helper function for removing durable queues which are used for tests.
func (t *MessageQueueImpl) deleteQueue(q msgqueue.Queue) error {
poolCh, err := t.subChannels.Acquire(context.Background())
ctx := context.Background()
poolCh, err := t.subChannels.Acquire(ctx)

if err != nil {
t.l.Error().Msgf("[deleteQueue] cannot acquire channel for deleting queue: %v", err)
t.l.Error().Ctx(ctx).Msgf("[deleteQueue] cannot acquire channel for deleting queue: %v", err)
return err
}

Expand All @@ -362,7 +363,7 @@ func (t *MessageQueueImpl) deleteQueue(q msgqueue.Queue) error {
_, err = ch.QueueDelete(q.Name(), true, true, false)

if err != nil {
t.l.Error().Msgf("cannot delete queue: %q, %v", q.Name(), err)
t.l.Error().Ctx(ctx).Msgf("cannot delete queue: %q, %v", q.Name(), err)
return err
}

Expand All @@ -373,14 +374,14 @@ func (t *MessageQueueImpl) deleteQueue(q msgqueue.Queue) error {
_, err = ch.QueueDelete(dlq1, true, true, false)

if err != nil {
t.l.Error().Msgf("cannot delete dead letter queue: %q, %v", dlq1, err)
t.l.Error().Ctx(ctx).Msgf("cannot delete dead letter queue: %q, %v", dlq1, err)
return err
}

_, err = ch.QueueDelete(dlq2, true, true, false)

if err != nil {
t.l.Error().Msgf("cannot delete dead letter queue: %q, %v", dlq2, err)
t.l.Error().Ctx(ctx).Msgf("cannot delete dead letter queue: %q, %v", dlq2, err)
return err
}
}
Expand Down
Loading