diff --git a/pkg/apis/acl.go b/pkg/apis/acl.go new file mode 100644 index 00000000..3b69d0af --- /dev/null +++ b/pkg/apis/acl.go @@ -0,0 +1,175 @@ +package apis + +import ( + "context" + "strings" + "time" +) + +// ResourceType represents the type of Kafka resource +type ResourceType string + +const ( + ResourceTypeTopic ResourceType = "TOPIC" + ResourceTypeGroup ResourceType = "GROUP" + ResourceTypeCluster ResourceType = "CLUSTER" + ResourceTypeTransactionalID ResourceType = "TRANSACTIONAL_ID" +) + +// Operation represents the type of operation allowed/denied +type Operation string + +const ( + OperationRead Operation = "READ" + OperationWrite Operation = "WRITE" + OperationCreate Operation = "CREATE" + OperationDelete Operation = "DELETE" + OperationAlter Operation = "ALTER" + OperationDescribe Operation = "DESCRIBE" + OperationClusterAction Operation = "CLUSTER_ACTION" + OperationAll Operation = "ALL" +) + +// PermissionType represents whether the ACL allows or denies access +type PermissionType string + +const ( + PermissionAllow PermissionType = "ALLOW" + PermissionDeny PermissionType = "DENY" +) + +// ACLEntry represents a single ACL rule +type ACLEntry struct { + Principal string `json:"principal"` // The user or group + ResourceType ResourceType `json:"resourceType"` // Type of resource + ResourceName string `json:"resourceName"` // Name of resource + PatternType string `json:"patternType"` // LITERAL or PREFIXED + Operation Operation `json:"operation"` // Type of operation + PermissionType PermissionType `json:"permissionType"` // Allow or Deny + Host string `json:"host"` // Host from which access is allowed + CreatedAt time.Time `json:"createdAt"` + UpdatedAt time.Time `json:"updatedAt"` +} + +// ACLCollection represents a collection of ACL entries +type ACLCollection struct { + ACLs []ACLEntry +} + +// ACLChecker interface for ACL plugins. +type ACLChecker interface { + // CheckACL checks if a given request is allowed based on configured ACL rules. + CheckACL(ctx context.Context, APIKey int, topics []string) (bool, []string, error) +} + +// ACLCheckerFactory interface for creating ACL checkers. +type ACLCheckerFactory interface { + New(ACLCollection) (ACLChecker, error) +} + +// ACLDecision represents the result of ACL evaluation +type ACLDecision struct { + Allowed bool + Reason string +} + +// ACLRequest represents a request to check permissions +type ACLRequest struct { + Principal string + ResourceType ResourceType + ResourceName string + Operation Operation + Host string +} + +// EvaluateAccess checks if the requested operation is allowed +func (ac *ACLCollection) EvaluateAccess(req ACLRequest) ACLDecision { + // Quick check if there are no ACLs + if len(ac.ACLs) == 0 { + return ACLDecision{ + Allowed: false, + Reason: "no ACLs defined", + } + } + + // First pass: Look for explicit DENY rules (these take precedence) + for _, acl := range ac.ACLs { + if isDenyMatch(acl, req) { + return ACLDecision{ + Allowed: false, + Reason: "explicitly denied by ACL", + } + } + } + + // Second pass: Look for ALLOW rules + for _, acl := range ac.ACLs { + if isAllowMatch(acl, req) { + return ACLDecision{ + Allowed: true, + Reason: "explicitly allowed by ACL", + } + } + } + + // Default deny if no matching rules found + return ACLDecision{ + Allowed: false, + Reason: "no matching allow rules found", + } +} + +// isDenyMatch checks if an ACL entry explicitly denies access +func isDenyMatch(acl ACLEntry, req ACLRequest) bool { + if acl.PermissionType != PermissionDeny { + return false + } + + return isMatch(acl, req) +} + +// isAllowMatch checks if an ACL entry allows access +func isAllowMatch(acl ACLEntry, req ACLRequest) bool { + if acl.PermissionType != PermissionAllow { + return false + } + + return isMatch(acl, req) +} + +// isMatch performs the actual matching logic +func isMatch(acl ACLEntry, req ACLRequest) bool { + // Check Principal (using exact match or wildcard) + if acl.Principal != "*" && acl.Principal != req.Principal { + return false + } + + // Check Resource Type + if acl.ResourceType != req.ResourceType { + return false + } + + // Check Resource Name (using pattern matching) + switch acl.PatternType { + case "LITERAL": + if acl.ResourceName != req.ResourceName { + return false + } + case "PREFIXED": + if !strings.HasPrefix(req.ResourceName, acl.ResourceName) { + return false + } + } + + // Check Operation (including ALL operation) + if acl.Operation != OperationAll && acl.Operation != req.Operation { + return false + } + + // Check Host (using exact match or wildcard) + if acl.Host != "*" && acl.Host != req.Host { + return false + } + + return true +} diff --git a/pkg/libs/acl-plugin/factory.go b/pkg/libs/acl-plugin/factory.go new file mode 100644 index 00000000..fe634df1 --- /dev/null +++ b/pkg/libs/acl-plugin/factory.go @@ -0,0 +1,64 @@ +package aclplugin + +import ( + "flag" + "fmt" + "strings" + + "github.com/grepplabs/kafka-proxy/pkg/apis" + "github.com/grepplabs/kafka-proxy/pkg/registry" +) + +func init() { + registry.NewComponentInterface(new(apis.ACLCheckerFactory)) + registry.Register(new(Factory), "acl-plugin") +} + +type pluginMeta struct { + rules []string +} + +type Factory struct{} + +func (f *Factory) New(params []string) (apis.ACLChecker, error) { + meta := &pluginMeta{} + fs := flag.NewFlagSet("acl-plugin settings", flag.ContinueOnError) + fs.Var(&stringArrayValue{&meta.rules}, "rule", "ACL rule (Operation,TopicPattern,Allow)") + + if err := fs.Parse(params); err != nil { + return nil, err + } + + rules, err := parseRules(meta.rules) + if err != nil { + return nil, err + } + + return NewACLChecker(rules) +} + +type stringArrayValue struct { + target *[]string +} + +func (s *stringArrayValue) String() string { + return strings.Join(*s.target, ",") +} + +func (s *stringArrayValue) Set(value string) error { + *s.target = append(*s.target, value) + return nil +} + +func parseRules(ruleStrings []string) ([]apis.ACLRule, error) { + rules := make([]apis.ACLRule, len(ruleStrings)) + for i, ruleStr := range ruleStrings { + parts := strings.Split(ruleStr, ",") + if len(parts) != 3 { + return nil, fmt.Errorf("invalid rule format: %s", ruleStr) + } + allow := parts[2] == "true" + rules[i] = apis.ACLRule{Operation: apis.Operation(parts[0]), Topic: parts[1], Allow: allow} + } + return rules, nil +} diff --git a/pkg/libs/acl-plugin/plugin.go b/pkg/libs/acl-plugin/plugin.go new file mode 100644 index 00000000..034a46fa --- /dev/null +++ b/pkg/libs/acl-plugin/plugin.go @@ -0,0 +1,62 @@ +package aclplugin + +import ( + "context" + "regexp" + + "github.com/grepplabs/kafka-proxy/pkg/apis" + "github.com/grepplabs/kafka-proxy/proxy/protocol" +) + +type ACLCheckerImpl struct { + rules []apis.ACLRule +} + +func NewACLChecker(rules []apis.ACLRule) (apis.ACLChecker, error) { + compiledRules := make([]apis.ACLRule, len(rules)) + for i, rule := range rules { + if rule.Topic != "" { + re, err := regexp.Compile(rule.Topic) + if err != nil { + return nil, err + } + compiledRules[i] = apis.ACLRule{Operation: rule.Operation, Topic: rule.Topic, Allow: rule.Allow, Re: re} + } else { + compiledRules[i] = rule + } + } + return &ACLCheckerImpl{rules: compiledRules}, nil +} + +func (a *ACLCheckerImpl) CheckACL(ctx context.Context, req *protocol.RequestKeyVersion, topics []string) (bool, []string, error) { + op := getOperationFromKey(req.ApiKey) + for _, topic := range topics { + anyMatched := false + for _, rule := range a.rules { + if (rule.Operation == apis.OperationAll || rule.Operation == op) && + (rule.Topic == "" || (topic != "" && rule.Re.MatchString(topic))) { + anyMatched = true + if !rule.Allow { + return false, nil, nil + } + } + } + if !anyMatched { + return false, nil, nil + } + } + return true, nil, nil +} + +func getOperationFromKey(apiKey int16) apis.Operation { + switch apiKey { + case 0: + return apis.OperationProduce + case 1: + return apis.OperationFetch + case 3: + return apis.OperationMetadata + default: + return apis.OperationAll + } +} diff --git a/proxy/processor.go b/proxy/processor.go index 1f61bf77..05493f0b 100644 --- a/proxy/processor.go +++ b/proxy/processor.go @@ -2,9 +2,11 @@ package proxy import ( "errors" + "time" + "github.com/grepplabs/kafka-proxy/config" + "github.com/grepplabs/kafka-proxy/pkg/apis" "github.com/grepplabs/kafka-proxy/proxy/protocol" - "time" ) const ( @@ -16,9 +18,16 @@ const ( defaultReadTimeout = 30 * time.Second minOpenRequests = 16 - apiKeyProduce = int16(0) - apiKeySaslHandshake = int16(17) - apiKeyApiApiVersions = int16(18) + apiKeyProduce = int16(0) + apiKeyFetch = int16(1) + apiKeyListOffsets = int16(2) + apiKeyCreateTopics = int16(19) + apiKeyDeleteTopics = int16(20) + apiKeyDeleteRecords = int16(21) + apiKeySaslHandshake = int16(17) + apiKeyApiApiVersions = int16(18) + apiKeyAddPartitionsToTxn = int16(24) + apiKeyCreatePartitions = int16(37) minRequestApiKey = int16(0) // 0 - Produce maxRequestApiKey = int16(120) // so far 67 is the last (reserve some for the feature) @@ -63,6 +72,8 @@ type processor struct { brokerAddress string // producer will never send request with acks=0 producerAcks0Disabled bool + + acl *apis.ACLCollection } func newProcessor(cfg ProcessorConfig, brokerAddress string) *processor { @@ -130,6 +141,7 @@ func (p *processor) RequestsLoop(dst DeadlineWriter, src DeadlineReaderWriter) ( localSasl: p.localSasl, localSaslDone: false, // sequential processing - mutex is required producerAcks0Disabled: p.producerAcks0Disabled, + acl: p.acl, } return ctx.requestsLoop(dst, src) @@ -148,6 +160,9 @@ type RequestsLoopContext struct { localSasl *LocalSasl localSaslDone bool + aclChecker apis.ACLChecker + acl *apis.ACLCollection + producerAcks0Disabled bool } @@ -218,6 +233,7 @@ func (p *processor) ResponsesLoop(dst DeadlineWriter, src DeadlineReader) (readE timeout: p.readTimeout, brokerAddress: p.brokerAddress, buf: make([]byte, p.responseBufferSize), + acl: p.acl, } return ctx.responsesLoop(dst, src) } @@ -229,6 +245,7 @@ type ResponsesLoopContext struct { timeout time.Duration brokerAddress string buf []byte // bufSize + acl *apis.ACLCollection } type ResponseHandler interface { diff --git a/proxy/processor_default.go b/proxy/processor_default.go index 291073a3..9070dcf8 100644 --- a/proxy/processor_default.go +++ b/proxy/processor_default.go @@ -2,15 +2,155 @@ package proxy import ( "bytes" + "context" + "encoding/binary" "errors" "fmt" - "github.com/grepplabs/kafka-proxy/proxy/protocol" - "github.com/sirupsen/logrus" "io" "strconv" "time" + + "github.com/grepplabs/kafka-proxy/pkg/apis" + "github.com/grepplabs/kafka-proxy/proxy/protocol" + "github.com/sirupsen/logrus" ) +// Map of API keys to their operation names +var apiKeyNames = map[int16]string{ + 0: "Produce", + 1: "Fetch", + 2: "ListOffsets", + 3: "Metadata", + 4: "LeaderAndIsr", + 5: "StopReplica", + 6: "UpdateMetadata", + 7: "ControlledShutdown", + 8: "OffsetCommit", + 9: "OffsetFetch", + 10: "FindCoordinator", + 11: "JoinGroup", + 12: "Heartbeat", + 13: "LeaveGroup", + 14: "SyncGroup", + 15: "DescribeGroups", + 16: "ListGroups", + 17: "SaslHandshake", + 18: "ApiVersions", + 19: "CreateTopics", + 20: "DeleteTopics", + 21: "DeleteRecords", + 22: "InitProducerId", + 23: "OffsetForLeaderEpoch", + 24: "AddPartitionsToTxn", + 25: "AddOffsetsToTxn", + 26: "EndTxn", + 27: "WriteTxnMarkers", + 28: "TxnOffsetCommit", + 29: "DescribeAcls", + 30: "CreateAcls", + 31: "DeleteAcls", + 32: "DescribeConfigs", + 33: "AlterConfigs", + 34: "AlterReplicaLogDirs", + 35: "DescribeLogDirs", + 36: "SaslAuthenticate", + 37: "CreatePartitions", + 38: "CreateDelegationToken", + 39: "RenewDelegationToken", + 40: "ExpireDelegationToken", + 41: "DescribeDelegationToken", + 42: "DeleteGroups", + 43: "ElectLeaders", + 44: "IncrementalAlterConfigs", + 45: "AlterPartitionReassignments", + 46: "ListPartitionReassignments", + 47: "OffsetDelete", + 48: "DescribeClientQuotas", + 49: "AlterClientQuotas", + 50: "DescribeUserScramCredentials", + 51: "AlterUserScramCredentials", + 52: "Vote", + 53: "BeginQuorumEpoch", + 54: "EndQuorumEpoch", + 55: "DescribeQuorum", + 56: "AlterIsr", + 57: "UpdateFeatures", + 58: "Envelope", + 59: "FetchSnapshot", + 60: "DescribeCluster", + 61: "DescribeProducers", + 62: "BrokerRegistration", + 63: "BrokerHeartbeat", +} + +// Map API keys to their resource types and operations +// Based on https://docs.confluent.io/platform/current/security/authorization/acls/overview.html +var apiKeyResources = map[int16]struct { + ResourceType string + Operation string +}{ + // Cluster operations + 3: {"Cluster", "Create"}, // Metadata + 4: {"Cluster", "ClusterAction"}, // LeaderAndIsr + 5: {"Cluster", "ClusterAction"}, // StopReplica + 6: {"Cluster", "ClusterAction"}, // UpdateMetadata + 7: {"Cluster", "ClusterAction"}, // ControlledShutdown + 19: {"Cluster", "Create"}, // CreateTopics + 23: {"Cluster", "ClusterAction"}, // OffsetForLeaderEpoch + 27: {"Cluster", "ClusterAction"}, // WriteTxnMarkers + 29: {"Cluster", "Describe"}, // DescribeAcls + 30: {"Cluster", "Alter"}, // CreateAcls + 31: {"Cluster", "Alter"}, // DeleteAcls + 32: {"Cluster", "DescribeConfigs"}, // DescribeConfigs + 33: {"Cluster", "AlterConfigs"}, // AlterConfigs + 34: {"Cluster", "Alter"}, // AlterReplicaLogDirs + 35: {"Cluster", "Describe"}, // DescribeLogDirs + 16: {"Cluster", "Describe"}, // ListGroups + + // Topic operations + 0: {"Topic", "Write"}, // Produce + 1: {"Topic", "Read"}, // Fetch + 2: {"Topic", "Describe"}, // ListOffsets + 20: {"Topic", "Delete"}, // DeleteTopics + 21: {"Topic", "Delete"}, // DeleteRecords + 24: {"Topic", "Write"}, // AddPartitionsToTxn + 37: {"Topic", "Alter"}, // CreatePartitions + 9: {"Topic", "Describe"}, // OffsetFetch + 8: {"Topic", "Read"}, // OffsetCommit + 28: {"Topic", "Read"}, // TxnOffsetCommit + + // Group operations + 10: {"Group", "Describe"}, // FindCoordinator + 11: {"Group", "Read"}, // JoinGroup + 12: {"Group", "Read"}, // Heartbeat + 13: {"Group", "Read"}, // LeaveGroup + 14: {"Group", "Read"}, // SyncGroup + 15: {"Group", "Describe"}, // DescribeGroups + 25: {"Group", "Read"}, // AddOffsetsToTxn + 42: {"Group", "Delete"}, // DeleteGroups + + // TransactionalId operations + 26: {"TransactionalId", "Write"}, // EndTxn + 22: {"TransactionalId", "Write"}, // InitProducerId + + // DelegationToken operations + 41: {"DelegationToken", "Describe"}, // DescribeTokens +} + +func getOperationName(apiKey int16) string { + if name, ok := apiKeyNames[apiKey]; ok { + return name + } + return "Unknown" +} + +func getResourceOperation(apiKey int16) string { + if resource, ok := apiKeyResources[apiKey]; ok { + return fmt.Sprintf("%s/%s", resource.ResourceType, resource.Operation) + } + return "Unknown/Unknown" +} + type DefaultRequestHandler struct { } @@ -38,7 +178,12 @@ func (handler *DefaultRequestHandler) handleRequest(dst DeadlineWriter, src Dead if err = protocol.Decode(keyVersionBuf, requestKeyVersion); err != nil { return true, err } - logrus.Debugf("Kafka request key %v, version %v, length %v", requestKeyVersion.ApiKey, requestKeyVersion.ApiVersion, requestKeyVersion.Length) + logrus.Debugf("Kafka request operation: %s (key: %v, version: %v, length: %v), resource operation: %s", + getOperationName(requestKeyVersion.ApiKey), + requestKeyVersion.ApiKey, + requestKeyVersion.ApiVersion, + requestKeyVersion.Length, + getResourceOperation(requestKeyVersion.ApiKey)) if requestKeyVersion.ApiKey < minRequestApiKey || requestKeyVersion.ApiKey > maxRequestApiKey { return true, fmt.Errorf("api key %d is invalid, possible cause: using plain connection instead of TLS", requestKeyVersion.ApiKey) @@ -48,7 +193,7 @@ func (handler *DefaultRequestHandler) handleRequest(dst DeadlineWriter, src Dead proxyRequestsBytes.WithLabelValues(ctx.brokerAddress).Add(float64(requestKeyVersion.Length + 4)) if _, ok := ctx.forbiddenApiKeys[requestKeyVersion.ApiKey]; ok { - return true, fmt.Errorf("api key %d is forbidden", requestKeyVersion.ApiKey) + return true, fmt.Errorf("api key %d (%s) is forbidden", requestKeyVersion.ApiKey, getOperationName(requestKeyVersion.ApiKey)) } if ctx.localSasl.enabled { @@ -85,11 +230,33 @@ func (handler *DefaultRequestHandler) handleRequest(dst DeadlineWriter, src Dead } } - mustReply, readBytes, err := handler.mustReply(requestKeyVersion, src, ctx) + mustReply, readBytes, topicNames, err := handler.mustReply(requestKeyVersion, src, ctx) if err != nil { return true, err } + ctx.acl = &apis.ACLCollection{} + + for _, topicName := range topicNames { + if topicName != "" { + logrus.Debugf("Topic name: %s", topicName) + } + } + + if ctx.aclChecker != nil { // Check if ACLChecker is configured + var allowed bool + allowed, _, err = ctx.aclChecker.CheckACL(context.Background(), 0, topicNames) // Pass both requestKeyVersion and topic + if err != nil { + return true, err + } + if !allowed { + return true, fmt.Errorf("access denied for operation %s (api key %d) on topic %q", + getOperationName(requestKeyVersion.ApiKey), + requestKeyVersion.ApiKey, + topicNames) + } + } + // send inFlightRequest to channel before myCopyN to prevent race condition in proxyResponses if mustReply { if err = sendRequestKeyVersion(ctx.openRequestsChannel, openRequestSendTimeout, requestKeyVersion); err != nil { @@ -117,6 +284,7 @@ func (handler *DefaultRequestHandler) handleRequest(dst DeadlineWriter, src Dead return false, err } } + // 4 bytes were written as keyVersionBuf (ApiKey, ApiVersion) if readErr, err = myCopyN(dst, src, int64(requestKeyVersion.Length-int32(4+len(readBytes))), ctx.buf); err != nil { return readErr, err @@ -133,102 +301,1493 @@ func (handler *DefaultRequestHandler) handleRequest(dst DeadlineWriter, src Dead } } -func (handler *DefaultRequestHandler) mustReply(requestKeyVersion *protocol.RequestKeyVersion, src io.Reader, ctx *RequestsLoopContext) (bool, []byte, error) { - if requestKeyVersion.ApiKey == apiKeyProduce { - if ctx.producerAcks0Disabled { - return true, nil, nil +func (handler *DefaultRequestHandler) mustReply( + requestKeyVersion *protocol.RequestKeyVersion, + src io.Reader, + ctx *RequestsLoopContext, +) (bool, []byte, []string, error) { + var ( + needReply bool = true + bufferRead bytes.Buffer + topicNames []string + err error + ) + + reader := io.TeeReader(src, &bufferRead) + + logrus.Debugf("ResponseHeaderVersion: %v", requestKeyVersion.ResponseHeaderVersion()) + + // Only parse headers for supported ApiKeys that need topic information + switch requestKeyVersion.ApiKey { + case apiKeyProduce, apiKeyFetch, apiKeyListOffsets, apiKeyCreateTopics, apiKeyDeleteTopics: + // Read CorrelationID (INT32) + if err = binary.Read(reader, binary.BigEndian, new(int32)); err != nil { + return false, nil, nil, err + } + + // Read ClientID (NULLABLE_STRING) + if _, err = readNullableString(reader); err != nil { + return false, nil, nil, err + } + + if requestKeyVersion.ResponseHeaderVersion() == 1 { + if err = readTaggedFields(reader); err != nil { + return false, nil, nil, err + } } - // header version for produce [0..8] is 1 (request_api_key,request_api_version,correlation_id (INT32),client_id, NULLABLE_STRING ) - acksReader := protocol.RequestAcksReader{} + default: + return true, nil, nil, nil + } + + switch requestKeyVersion.ApiKey { + case apiKeyProduce: + needReply, topicNames, err = handler.handleProduce(reader, requestKeyVersion, ctx) + case apiKeyFetch: + needReply, topicNames, err = handler.handleFetch(reader, requestKeyVersion) + case apiKeyListOffsets: + needReply, topicNames, err = handler.handleListOffsets(reader, requestKeyVersion) + case apiKeyCreateTopics: + needReply, topicNames, err = handler.handleCreateTopics(reader, requestKeyVersion) + case apiKeyDeleteTopics: + needReply, topicNames, err = handler.handleDeleteTopics(reader, requestKeyVersion) + case apiKeyDeleteRecords: + needReply, topicNames, err = handler.handleDeleteRecords(reader, requestKeyVersion) + case apiKeyCreatePartitions: + needReply, topicNames, err = handler.handleCreatePartitions(reader, requestKeyVersion) + default: + return true, nil, nil, nil + } + + if err != nil { + logrus.Errorf("Error processing request: %v", err) + return false, nil, nil, err + } + + return needReply, bufferRead.Bytes(), topicNames, nil +} - var ( - acks int16 - err error - ) - var bufferRead bytes.Buffer - reader := io.TeeReader(src, &bufferRead) - switch requestKeyVersion.ApiVersion { - case 0, 1, 2: - // CorrelationID + ClientID - if err = acksReader.ReadAndDiscardHeaderV1Part(reader); err != nil { +func (handler *DefaultRequestHandler) handleProduce( + reader io.Reader, + requestKeyVersion *protocol.RequestKeyVersion, + ctx *RequestsLoopContext, +) (bool, []string, error) { + var ( + acks int16 + topicNames []string + err error + ) + + // Read transactional_id + if requestKeyVersion.ApiVersion >= 3 { + if requestKeyVersion.ApiVersion >= 9 { + if _, err = readCompactNullableString(reader); err != nil { return false, nil, err } - // acks (INT16) - acks, err = acksReader.ReadAndDiscardProduceAcks(reader) - if err != nil { + } else { + if _, err = readNullableString(reader); err != nil { return false, nil, err } + } + } + + // Read acks and timeout_ms + if err = binary.Read(reader, binary.BigEndian, &acks); err != nil { + return false, nil, err + } + if err = binary.Read(reader, binary.BigEndian, new(int32)); err != nil { + return false, nil, err + } - case 3, 4, 5, 6, 7, 8, 9, 10: - // CorrelationID + ClientID - if err = acksReader.ReadAndDiscardHeaderV1Part(reader); err != nil { + // Read topics array + var topicsCount int32 + if requestKeyVersion.ApiVersion >= 9 { + if tc, err := readCompactArrayLength(reader); err != nil { + return false, nil, err + } else { + topicsCount = tc + } + } else { + if err = binary.Read(reader, binary.BigEndian, &topicsCount); err != nil { + return false, nil, err + } + } + logrus.Debugf("Topics count: %d", topicsCount) + + topicNames = make([]string, 0, topicsCount) + + for i := int32(0); i < topicsCount; i++ { + var currentTopicName string + if requestKeyVersion.ApiVersion >= 9 { + if currentTopicName, err = readCompactString(reader); err != nil { return false, nil, err } - // transactional_id (NULLABLE_STRING),acks (INT16) - acks, err = acksReader.ReadAndDiscardProduceTxnAcks(reader) + logrus.Debugf("Current topic name: %s", currentTopicName) + // Read partition data + partitionCount, err := readCompactArrayLength(reader) if err != nil { + logrus.Errorf("Failed to read partition count for topic %s: %v", currentTopicName, err) + return false, nil, err + } + + for j := int32(0); j < partitionCount; j++ { + // Read partition index + if err = binary.Read(reader, binary.BigEndian, new(int32)); err != nil { + logrus.Errorf("Failed to read partition index: %v", err) + return false, nil, err + } + + // Read records (COMPACT_RECORDS) + recordsLength, err := readUVarint(reader) + if err != nil { + logrus.Debugf("Records length: %d", recordsLength) + return false, nil, err + } + if _, err := io.CopyN(io.Discard, reader, int64(recordsLength)); err != nil { + logrus.Errorf("Failed to read records: %v", err) + return false, nil, err + } + + // Read tagged fields for partition + if err = readTaggedFields(reader); err != nil { + logrus.Errorf("Failed to read tagged fields for partition: %v", err) + return false, nil, err + } + } + + // Log for debugging + logrus.Debugf("Processed topic %s", currentTopicName) + + // Read tagged fields for topic + if err = readTaggedFields(reader); err != nil { + logrus.Errorf("Failed to read tagged fields for topic: %v", err) + return false, nil, err + } + + } else { + if currentTopicName, err = readString(reader); err != nil { return false, nil, err } - default: - return false, nil, fmt.Errorf("produce version %d is not supported", requestKeyVersion.ApiVersion) + // Similar handling for older versions... } - return acks != 0, bufferRead.Bytes(), nil + + topicNames = append(topicNames, currentTopicName) } - return true, nil, nil + + return !ctx.producerAcks0Disabled && acks != 0, topicNames, nil } -func (handler *DefaultResponseHandler) handleResponse(dst DeadlineWriter, src DeadlineReader, ctx *ResponsesLoopContext) (readErr bool, err error) { - //logrus.Println("Await Kafka response") +func (handler *DefaultRequestHandler) handleFetch( + reader io.Reader, + requestKeyVersion *protocol.RequestKeyVersion, +) (bool, []string, error) { + var ( + topicNames []string + err error + ) - // waiting for first bytes or EOF - reset deadlines - if err = src.SetReadDeadline(time.Time{}); err != nil { - return true, err + // Read initial fields based on version + // Read replica_id (INT32) for versions <=14 and version 17 + if requestKeyVersion.ApiVersion <= 14 || requestKeyVersion.ApiVersion == 17 { + if err = binary.Read(reader, binary.BigEndian, new(int32)); err != nil { + return false, nil, err + } } - if err = dst.SetWriteDeadline(time.Time{}); err != nil { - return true, err + + // Read max_wait_ms (INT32) + if err = binary.Read(reader, binary.BigEndian, new(int32)); err != nil { + return false, nil, err } - responseHeaderBuf := make([]byte, 8) // Size => int32, CorrelationId => int32 - if _, err = io.ReadFull(src, responseHeaderBuf); err != nil { - return true, err + // Read min_bytes (INT32) + if err = binary.Read(reader, binary.BigEndian, new(int32)); err != nil { + return false, nil, err } - var responseHeader protocol.ResponseHeader - if err = protocol.Decode(responseHeaderBuf, &responseHeader); err != nil { - return true, err + // Read max_bytes (INT32) if version >= 3 + if requestKeyVersion.ApiVersion >= 3 { + if err = binary.Read(reader, binary.BigEndian, new(int32)); err != nil { + return false, nil, err + } } - // Read the inFlightRequests channel after header is read. Otherwise the channel would block and socket EOF from remote would not be received. - requestKeyVersion, err := receiveRequestKeyVersion(ctx.openRequestsChannel, openRequestReceiveTimeout) - if err != nil { - return true, err + // Read isolation_level (INT8) if version >= 4 + if requestKeyVersion.ApiVersion >= 4 { + if err = binary.Read(reader, binary.BigEndian, new(int8)); err != nil { + return false, nil, err + } } - proxyResponsesBytes.WithLabelValues(ctx.brokerAddress).Add(float64(responseHeader.Length + 4)) - logrus.Debugf("Kafka response key %v, version %v, length %v", requestKeyVersion.ApiKey, requestKeyVersion.ApiVersion, responseHeader.Length) - responseDeadline := time.Now().Add(ctx.timeout) - err = dst.SetWriteDeadline(responseDeadline) - if err != nil { - return false, err + // Read session_id (INT32) and session_epoch (INT32) if version >= 7 + if requestKeyVersion.ApiVersion >= 7 { + if err = binary.Read(reader, binary.BigEndian, new(int32)); err != nil { + return false, nil, err + } + if err = binary.Read(reader, binary.BigEndian, new(int32)); err != nil { + return false, nil, err + } } - err = src.SetReadDeadline(responseDeadline) + + // Read topics and collect topic names + topicNames, err = readFetchTopics(reader, requestKeyVersion) if err != nil { - return true, err + return false, nil, err } - responseHeaderTaggedFields, err := protocol.NewResponseHeaderTaggedFields(requestKeyVersion) - if err != nil { - return true, err + + // Skip forgotten_topics_data if version >= 7 + if requestKeyVersion.ApiVersion >= 7 { + if err = skipForgottenTopicsData(reader, requestKeyVersion); err != nil { + return false, nil, err + } } - unknownTaggedFields, err := responseHeaderTaggedFields.MaybeRead(src) - if err != nil { - return true, err + + // Read rack_id if version >= 11 + if requestKeyVersion.ApiVersion >= 11 { + if requestKeyVersion.ApiVersion >= 12 { + // Read rack_id (COMPACT_STRING) + if _, err = readCompactString(reader); err != nil { + return false, nil, err + } + // Read tagged fields + if err = readTaggedFields(reader); err != nil { + return false, nil, err + } + } else { + // Read rack_id (STRING) + if _, err = readString(reader); err != nil { + return false, nil, err + } + } } - readResponsesHeaderLength := int32(4 + len(unknownTaggedFields)) // 4 = Length + CorrelationID - responseModifier, err := protocol.GetResponseModifier(requestKeyVersion.ApiKey, requestKeyVersion.ApiVersion, ctx.netAddressMappingFunc) - if err != nil { - return true, err + // Read tagged fields for request if version >= 12 + if requestKeyVersion.ApiVersion >= 12 { + if err = readTaggedFields(reader); err != nil { + return false, nil, err + } + } + + return true, topicNames, nil +} + +func readFetchTopics(reader io.Reader, requestKeyVersion *protocol.RequestKeyVersion) ([]string, error) { + var ( + topicNames []string + err error + ) + + if requestKeyVersion.ApiVersion >= 12 { + // Topics are COMPACT_ARRAY + topicsCount, err := readCompactArrayLength(reader) + if err != nil { + return nil, err + } + topicNames = make([]string, 0, topicsCount) + + for i := int32(0); i < topicsCount; i++ { + if requestKeyVersion.ApiVersion >= 13 { + // Skip topic_id (UUID) + if _, err := io.CopyN(io.Discard, reader, 16); err != nil { + return nil, err + } + } else { + // Read topic name (COMPACT_STRING) + var topicName string + if topicName, err = readCompactString(reader); err != nil { + return nil, err + } + topicNames = append(topicNames, topicName) + } + + // Read partitions + if err = skipPartitions(reader, requestKeyVersion); err != nil { + return nil, err + } + + // Read tagged fields for topic + if err = readTaggedFields(reader); err != nil { + return nil, err + } + } + } else { + // Topics are ARRAY of STRING + var topicsCount int32 + if err = binary.Read(reader, binary.BigEndian, &topicsCount); err != nil { + return nil, err + } + topicNames = make([]string, 0, topicsCount) + + for i := int32(0); i < topicsCount; i++ { + // Read topic name (STRING) + var topicName string + if topicName, err = readString(reader); err != nil { + return nil, err + } + topicNames = append(topicNames, topicName) + + // Read partitions + if err = skipPartitions(reader, requestKeyVersion); err != nil { + return nil, err + } + } + } + + return topicNames, nil +} + +func skipPartitions(reader io.Reader, requestKeyVersion *protocol.RequestKeyVersion) error { + var partitionsCount int32 + var err error + + if requestKeyVersion.ApiVersion >= 12 { + // Read partitions (COMPACT_ARRAY) + if partitionsCount, err = readCompactArrayLength(reader); err != nil { + return err + } + } else { + // Read partitions (ARRAY) + if err = binary.Read(reader, binary.BigEndian, &partitionsCount); err != nil { + return err + } + } + + for j := int32(0); j < partitionsCount; j++ { + // Skip partition index (INT32) + if err = binary.Read(reader, binary.BigEndian, new(int32)); err != nil { + return err + } + + // Read current_leader_epoch (INT32) if version >= 9 + if requestKeyVersion.ApiVersion >= 9 { + if err = binary.Read(reader, binary.BigEndian, new(int32)); err != nil { + return err + } + } + + // Skip fetch_offset (INT64) + if err = binary.Read(reader, binary.BigEndian, new(int64)); err != nil { + return err + } + + // Read last_fetched_epoch (INT32) if version >= 12 + if requestKeyVersion.ApiVersion >= 12 { + if err = binary.Read(reader, binary.BigEndian, new(int32)); err != nil { + return err + } + } + + // Read log_start_offset (INT64) if version >= 5 + if requestKeyVersion.ApiVersion >= 5 { + if err = binary.Read(reader, binary.BigEndian, new(int64)); err != nil { + return err + } + } + + // Skip partition_max_bytes (INT32) + if err = binary.Read(reader, binary.BigEndian, new(int32)); err != nil { + return err + } + + // Read tagged fields if version >= 12 + if requestKeyVersion.ApiVersion >= 12 { + if err = readTaggedFields(reader); err != nil { + return err + } + } + } + + return nil +} + +func skipForgottenTopicsData(reader io.Reader, requestKeyVersion *protocol.RequestKeyVersion) error { + var err error + var forgottenTopicsCount int32 + + if requestKeyVersion.ApiVersion >= 12 { + // Read forgotten_topics_data (COMPACT_ARRAY) + if forgottenTopicsCount, err = readCompactArrayLength(reader); err != nil { + return err + } + } else { + // Read forgotten_topics_data (ARRAY) + if err = binary.Read(reader, binary.BigEndian, &forgottenTopicsCount); err != nil { + return err + } } + + for i := int32(0); i < forgottenTopicsCount; i++ { + if requestKeyVersion.ApiVersion >= 13 { + // Skip topic_id (UUID) + if _, err := io.CopyN(io.Discard, reader, 16); err != nil { + return err + } + } else if requestKeyVersion.ApiVersion >= 12 { + // Skip topic name (COMPACT_STRING) + if _, err = readCompactString(reader); err != nil { + return err + } + } else { + // Skip topic name (STRING) + if _, err = readString(reader); err != nil { + return err + } + } + + // Skip partitions + if err = skipForgottenPartitions(reader, requestKeyVersion); err != nil { + return err + } + + // Read tagged fields if version >= 12 + if requestKeyVersion.ApiVersion >= 12 { + if err = readTaggedFields(reader); err != nil { + return err + } + } + } + + return nil +} + +func skipForgottenPartitions(reader io.Reader, requestKeyVersion *protocol.RequestKeyVersion) error { + var partitionsCount int32 + var err error + + if requestKeyVersion.ApiVersion >= 12 { + // Read partitions (COMPACT_ARRAY) + if partitionsCount, err = readCompactArrayLength(reader); err != nil { + return err + } + } else { + // Read partitions (ARRAY) + if err = binary.Read(reader, binary.BigEndian, &partitionsCount); err != nil { + return err + } + } + + // Skip partitions data + if _, err = io.CopyN(io.Discard, reader, int64(partitionsCount*4)); err != nil { + return err + } + + return nil +} + +func (handler *DefaultRequestHandler) handleListOffsets( + reader io.Reader, + requestKeyVersion *protocol.RequestKeyVersion, +) (bool, []string, error) { + var ( + topicNames []string + err error + ) + + // Read ReplicaID (INT32) + if err = binary.Read(reader, binary.BigEndian, new(int32)); err != nil { + return false, nil, err + } + + if requestKeyVersion.ApiVersion >= 2 { + // Read IsolationLevel (INT8) + if err = binary.Read(reader, binary.BigEndian, new(int8)); err != nil { + return false, nil, err + } + } + + if requestKeyVersion.ApiVersion >= 6 { + // Read Topics (COMPACT_ARRAY) + topicsCount, err := readCompactArrayLength(reader) + if err != nil { + return false, nil, err + } + + topicNames = make([]string, 0, topicsCount) + + for i := int32(0); i < topicsCount; i++ { + // Read TopicName (COMPACT_STRING) + topicName, err := readCompactString(reader) + if err != nil { + return false, nil, err + } + topicNames = append(topicNames, topicName) + + // Read Partitions (COMPACT_ARRAY) + partitionCount, err := readCompactArrayLength(reader) + if err != nil { + return false, nil, err + } + + // Skip partitions + for j := int32(0); j < partitionCount; j++ { + // Read PartitionIndex (INT32) + if err = binary.Read(reader, binary.BigEndian, new(int32)); err != nil { + return false, nil, err + } + + // Read CurrentLeaderEpoch (INT32) if version >= 4 + if requestKeyVersion.ApiVersion >= 4 { + if err = binary.Read(reader, binary.BigEndian, new(int32)); err != nil { + return false, nil, err + } + } + + // Read Timestamp (INT64) + if err = binary.Read(reader, binary.BigEndian, new(int64)); err != nil { + return false, nil, err + } + + // Read TaggedFields (TAG_BUFFER) + if err = readTaggedFields(reader); err != nil { + return false, nil, err + } + } + + // Read TaggedFields for Topic (TAG_BUFFER) + if err = readTaggedFields(reader); err != nil { + return false, nil, err + } + } + + // Read TaggedFields for Request (TAG_BUFFER) + if err = readTaggedFields(reader); err != nil { + return false, nil, err + } + + } else { + // Read TopicsCount (INT32) + var topicsCount int32 + if err = binary.Read(reader, binary.BigEndian, &topicsCount); err != nil { + return false, nil, err + } + + topicNames = make([]string, 0, topicsCount) + + for i := int32(0); i < topicsCount; i++ { + // Read TopicName (STRING) + topicName, err := readString(reader) + if err != nil { + return false, nil, err + } + topicNames = append(topicNames, topicName) + + // Read Partitions (ARRAY) + var partitionCount int32 + if err = binary.Read(reader, binary.BigEndian, &partitionCount); err != nil { + return false, nil, err + } + + // Skip partitions + for j := int32(0); j < partitionCount; j++ { + // Read PartitionIndex (INT32) + if err = binary.Read(reader, binary.BigEndian, new(int32)); err != nil { + return false, nil, err + } + + // Read Timestamp (INT64) + if err = binary.Read(reader, binary.BigEndian, new(int64)); err != nil { + return false, nil, err + } + + if requestKeyVersion.ApiVersion == 0 { + // Read MaxNumOffsets (INT32) + if err = binary.Read(reader, binary.BigEndian, new(int32)); err != nil { + return false, nil, err + } + } + } + } + } + + return true, topicNames, nil +} + +func (handler *DefaultRequestHandler) handleCreateTopics( + reader io.Reader, + requestKeyVersion *protocol.RequestKeyVersion, +) (bool, []string, error) { + var ( + topicNames []string + err error + ) + + if requestKeyVersion.ApiVersion >= 5 { + // Read topics as COMPACT_ARRAY + topicsCount, err := readCompactArrayLength(reader) + if err != nil { + return false, nil, err + } + + topicNames = make([]string, 0, topicsCount) + + for i := int32(0); i < topicsCount; i++ { + // Read TopicName (COMPACT_STRING) + var topicName string + if topicName, err = readCompactString(reader); err != nil { + return false, nil, err + } + + topicNames = append(topicNames, topicName) + + // Skip num_partitions (INT32) + if err = binary.Read(reader, binary.BigEndian, new(int32)); err != nil { + return false, nil, err + } + + // Skip replication_factor (INT16) + if err = binary.Read(reader, binary.BigEndian, new(int16)); err != nil { + return false, nil, err + } + + // Skip assignments + assignmentsCount, err := readCompactArrayLength(reader) + if err != nil { + return false, nil, err + } + for j := int32(0); j < assignmentsCount; j++ { + // Skip partition_index (INT32) + if err = binary.Read(reader, binary.BigEndian, new(int32)); err != nil { + return false, nil, err + } + // Skip broker_ids (COMPACT_ARRAY of INT32) + brokerIdsCount, err := readCompactArrayLength(reader) + if err != nil { + return false, nil, err + } + if _, err = io.CopyN(io.Discard, reader, int64(brokerIdsCount*4)); err != nil { + return false, nil, err + } + // Read tagged fields for assignment + if err = readTaggedFields(reader); err != nil { + return false, nil, err + } + } + + // Skip configs + configsCount, err := readCompactArrayLength(reader) + if err != nil { + return false, nil, err + } + for j := int32(0); j < configsCount; j++ { + // Skip name (COMPACT_STRING) + if _, err = readCompactString(reader); err != nil { + return false, nil, err + } + // Skip value (COMPACT_NULLABLE_STRING) + if _, err = readCompactNullableString(reader); err != nil { + return false, nil, err + } + // Read tagged fields for config + if err = readTaggedFields(reader); err != nil { + return false, nil, err + } + } + + // Read tagged fields for topic + if err = readTaggedFields(reader); err != nil { + return false, nil, err + } + } + + // Read timeout_ms (INT32) + if err = binary.Read(reader, binary.BigEndian, new(int32)); err != nil { + return false, nil, err + } + + // Read validate_only (BOOLEAN) + if err = binary.Read(reader, binary.BigEndian, new(bool)); err != nil { + return false, nil, err + } + + // Read tagged fields for request + if err = readTaggedFields(reader); err != nil { + return false, nil, err + } + } else { + // For versions <5, read topics as ARRAY + var topicsCount int32 + if err = binary.Read(reader, binary.BigEndian, &topicsCount); err != nil { + return false, nil, err + } + + topicNames = make([]string, 0, topicsCount) + + for i := int32(0); i < topicsCount; i++ { + // Read TopicName (STRING) + var topicName string + if topicName, err = readString(reader); err != nil { + return false, nil, err + } + + topicNames = append(topicNames, topicName) + + // Skip num_partitions (INT32) + if err = binary.Read(reader, binary.BigEndian, new(int32)); err != nil { + return false, nil, err + } + + // Skip replication_factor (INT16) + if err = binary.Read(reader, binary.BigEndian, new(int16)); err != nil { + return false, nil, err + } + + // Skip assignments + var assignmentsCount int32 + if err = binary.Read(reader, binary.BigEndian, &assignmentsCount); err != nil { + return false, nil, err + } + for j := int32(0); j < assignmentsCount; j++ { + // Skip partition_index (INT32) + if err = binary.Read(reader, binary.BigEndian, new(int32)); err != nil { + return false, nil, err + } + // Skip broker_ids (ARRAY of INT32) + var brokerIdsCount int32 + if err = binary.Read(reader, binary.BigEndian, &brokerIdsCount); err != nil { + return false, nil, err + } + if _, err = io.CopyN(io.Discard, reader, int64(brokerIdsCount*4)); err != nil { + return false, nil, err + } + } + + // Skip configs + var configsCount int32 + if err = binary.Read(reader, binary.BigEndian, &configsCount); err != nil { + return false, nil, err + } + for j := int32(0); j < configsCount; j++ { + // Skip name (STRING) + if _, err = readString(reader); err != nil { + return false, nil, err + } + // Skip value (NULLABLE_STRING) + if _, err = readNullableString(reader); err != nil { + return false, nil, err + } + } + } + + // Read timeout_ms (INT32) + if err = binary.Read(reader, binary.BigEndian, new(int32)); err != nil { + return false, nil, err + } + + // Read validate_only (BOOLEAN) for versions >= 1 + if requestKeyVersion.ApiVersion >= 1 { + if err = binary.Read(reader, binary.BigEndian, new(bool)); err != nil { + return false, nil, err + } + } + } + + return true, topicNames, nil +} + +func (handler *DefaultRequestHandler) handleDeleteTopics( + reader io.Reader, + requestKeyVersion *protocol.RequestKeyVersion, +) (bool, []string, error) { + var ( + topicNames []string + err error + ) + + if requestKeyVersion.ApiVersion >= 6 { + // Read topics as COMPACT_ARRAY + topicsCount, err := readCompactArrayLength(reader) + if err != nil { + return false, nil, err + } + + if topicsCount > 0 { + for i := int32(0); i < topicsCount; i++ { + // Read topic name (COMPACT_NULLABLE_STRING) + topicName, err := readCompactNullableString(reader) + if err != nil { + return false, nil, err + } + topicNames = append(topicNames, topicName) + + // Skip topic_id UUID (16 bytes) + uuidBytes := make([]byte, 16) + if _, err := io.ReadFull(reader, uuidBytes); err != nil { + return false, nil, err + } + + // Read tagged fields for topic + if err = readTaggedFields(reader); err != nil { + return false, nil, err + } + } + } + + // Read timeout_ms (INT32) + if err = binary.Read(reader, binary.BigEndian, new(int32)); err != nil { + return false, nil, err + } + + // Read tagged fields for request + if err = readTaggedFields(reader); err != nil { + return false, nil, err + } + + } else if requestKeyVersion.ApiVersion >= 4 { + // Read topics as COMPACT_ARRAY + topicsCount, err := readCompactArrayLength(reader) + if err != nil { + return false, nil, err + } + + logrus.Debugf("Topics count: %d", topicsCount) + + if topicsCount > 0 { + for i := int32(0); i < topicsCount; i++ { + // Read TopicName (COMPACT_STRING) + topicName, err := readCompactNullableString(reader) + if err != nil { + return false, nil, err + } + topicNames = append(topicNames, topicName) + } + } + + // Read timeout_ms (INT32) + if err = binary.Read(reader, binary.BigEndian, new(int32)); err != nil { + return false, nil, err + } + + // Read tagged fields + if err = readTaggedFields(reader); err != nil { + return false, nil, err + } + + } else { + // Versions 0-3 use array of STRINGs + var topicsCount int32 + if err = binary.Read(reader, binary.BigEndian, &topicsCount); err != nil { + return false, nil, err + } + + if topicsCount > 0 { + for i := int32(0); i < topicsCount; i++ { + // Read TopicName (STRING) + topicName, err := readString(reader) + if err != nil { + return false, nil, err + } + topicNames = append(topicNames, topicName) + } + } + + // Read timeout_ms (INT32) + if err = binary.Read(reader, binary.BigEndian, new(int32)); err != nil { + return false, nil, err + } + } + + return true, topicNames, nil +} + +func (handler *DefaultRequestHandler) handleDeleteRecords( + reader io.Reader, + requestKeyVersion *protocol.RequestKeyVersion, +) (bool, []string, error) { + var ( + topicNames []string + err error + ) + + if requestKeyVersion.ApiVersion >= 2 { + // Read topics as COMPACT_ARRAY + topicsCount, err := readCompactArrayLength(reader) + if err != nil { + return false, nil, err + } + + topicNames = make([]string, 0, topicsCount) + + for i := int32(0); i < topicsCount; i++ { + // Read topic name (COMPACT_STRING) + var topicName string + if topicName, err = readCompactString(reader); err != nil { + return false, nil, err + } + topicNames = append(topicNames, topicName) + + // Read partitions array + partitionsCount, err := readCompactArrayLength(reader) + if err != nil { + return false, nil, err + } + + // Skip partition data + for j := int32(0); j < partitionsCount; j++ { + // Skip partition_index (INT32) + if err = binary.Read(reader, binary.BigEndian, new(int32)); err != nil { + return false, nil, err + } + + // Skip offset (INT64) + if err = binary.Read(reader, binary.BigEndian, new(int64)); err != nil { + return false, nil, err + } + + // Read tagged fields for partition + if err = readTaggedFields(reader); err != nil { + return false, nil, err + } + } + + // Read tagged fields for topic + if err = readTaggedFields(reader); err != nil { + return false, nil, err + } + } + + // Read timeout_ms (INT32) + if err = binary.Read(reader, binary.BigEndian, new(int32)); err != nil { + return false, nil, err + } + + // Read tagged fields for request + if err = readTaggedFields(reader); err != nil { + return false, nil, err + } + + } else { + // Read topics array (normal array for versions 0-1) + var topicsCount int32 + if err = binary.Read(reader, binary.BigEndian, &topicsCount); err != nil { + return false, nil, err + } + + topicNames = make([]string, 0, topicsCount) + + for i := int32(0); i < topicsCount; i++ { + // Read topic name (STRING) + var topicName string + if topicName, err = readString(reader); err != nil { + return false, nil, err + } + topicNames = append(topicNames, topicName) + + // Read partitions array + var partitionsCount int32 + if err = binary.Read(reader, binary.BigEndian, &partitionsCount); err != nil { + return false, nil, err + } + + // Skip partition data + for j := int32(0); j < partitionsCount; j++ { + // Skip partition_index (INT32) + if err = binary.Read(reader, binary.BigEndian, new(int32)); err != nil { + return false, nil, err + } + + // Skip offset (INT64) + if err = binary.Read(reader, binary.BigEndian, new(int64)); err != nil { + return false, nil, err + } + } + } + + // Read timeout_ms (INT32) + if err = binary.Read(reader, binary.BigEndian, new(int32)); err != nil { + return false, nil, err + } + } + + return true, topicNames, nil +} + +func (handler *DefaultRequestHandler) handleAddPartitionsToTxn( + reader io.Reader, + requestKeyVersion *protocol.RequestKeyVersion, +) (bool, []string, error) { + var ( + topicNames []string + err error + ) + + if requestKeyVersion.ApiVersion >= 4 { + // Read transactions array as COMPACT_ARRAY + transactionsCount, err := readCompactArrayLength(reader) + if err != nil { + return false, nil, err + } + + topicNames = make([]string, 0) + + for i := int32(0); i < transactionsCount; i++ { + // Read transactional_id (COMPACT_STRING) + if _, err = readCompactString(reader); err != nil { + return false, nil, err + } + + // Skip producer_id (INT64) + if err = binary.Read(reader, binary.BigEndian, new(int64)); err != nil { + return false, nil, err + } + + // Skip producer_epoch (INT16) + if err = binary.Read(reader, binary.BigEndian, new(int16)); err != nil { + return false, nil, err + } + + // Skip verify_only (BOOLEAN) + if err = binary.Read(reader, binary.BigEndian, new(bool)); err != nil { + return false, nil, err + } + + // Read topics array + topicsCount, err := readCompactArrayLength(reader) + if err != nil { + return false, nil, err + } + + for j := int32(0); j < topicsCount; j++ { + // Read topic name + topicName, err := readCompactString(reader) + if err != nil { + return false, nil, err + } + topicNames = append(topicNames, topicName) + + // Skip partitions array + partitionsCount, err := readCompactArrayLength(reader) + if err != nil { + return false, nil, err + } + + // Skip partition indexes + if _, err = io.CopyN(io.Discard, reader, int64(partitionsCount*4)); err != nil { + return false, nil, err + } + + // Read tagged fields for partitions + if err = readTaggedFields(reader); err != nil { + return false, nil, err + } + } + + // Read tagged fields for topics + if err = readTaggedFields(reader); err != nil { + return false, nil, err + } + } + + // Read tagged fields for request + if err = readTaggedFields(reader); err != nil { + return false, nil, err + } + + } else { + // Read transactional_id + if requestKeyVersion.ApiVersion >= 3 { + if _, err = readCompactString(reader); err != nil { + return false, nil, err + } + } else { + if _, err = readString(reader); err != nil { + return false, nil, err + } + } + + // Skip producer_id (INT64) + if err = binary.Read(reader, binary.BigEndian, new(int64)); err != nil { + return false, nil, err + } + + // Skip producer_epoch (INT16) + if err = binary.Read(reader, binary.BigEndian, new(int16)); err != nil { + return false, nil, err + } + + // Read topics array + var topicsCount int32 + if requestKeyVersion.ApiVersion >= 3 { + if topicsCount, err = readCompactArrayLength(reader); err != nil { + return false, nil, err + } + } else { + if err = binary.Read(reader, binary.BigEndian, &topicsCount); err != nil { + return false, nil, err + } + } + + topicNames = make([]string, 0, topicsCount) + + for i := int32(0); i < topicsCount; i++ { + // Read topic name + var topicName string + if requestKeyVersion.ApiVersion >= 3 { + if topicName, err = readCompactString(reader); err != nil { + return false, nil, err + } + } else { + if topicName, err = readString(reader); err != nil { + return false, nil, err + } + } + topicNames = append(topicNames, topicName) + + // Skip partitions array + var partitionsCount int32 + if err = binary.Read(reader, binary.BigEndian, &partitionsCount); err != nil { + return false, nil, err + } + + // Skip partition indexes + if _, err = io.CopyN(io.Discard, reader, int64(partitionsCount*4)); err != nil { + return false, nil, err + } + + if requestKeyVersion.ApiVersion >= 3 { + // Read tagged fields for topic + if err = readTaggedFields(reader); err != nil { + return false, nil, err + } + } + } + + if requestKeyVersion.ApiVersion >= 3 { + // Read tagged fields for request + if err = readTaggedFields(reader); err != nil { + return false, nil, err + } + } + } + + return true, topicNames, nil +} + +func (handler *DefaultRequestHandler) handleCreatePartitions( + reader io.Reader, + requestKeyVersion *protocol.RequestKeyVersion, +) (bool, []string, error) { + var ( + topicNames []string + err error + ) + + if requestKeyVersion.ApiVersion >= 2 { + // Read topics array (COMPACT_ARRAY) + topicsCount, err := readCompactArrayLength(reader) + if err != nil { + return false, nil, err + } + + topicNames = make([]string, 0, topicsCount) + + for i := int32(0); i < topicsCount; i++ { + // Read topic name (COMPACT_STRING) + var topicName string + if topicName, err = readCompactString(reader); err != nil { + return false, nil, err + } + topicNames = append(topicNames, topicName) + + // Skip count (INT32) + if err = binary.Read(reader, binary.BigEndian, new(int32)); err != nil { + return false, nil, err + } + + // Read assignments array + assignmentsCount, err := readCompactArrayLength(reader) + if err != nil { + return false, nil, err + } + + // Skip assignments data + for j := int32(0); j < assignmentsCount; j++ { + // Read broker_ids array + brokerIdsCount, err := readCompactArrayLength(reader) + if err != nil { + return false, nil, err + } + + // Skip broker IDs + if _, err = io.CopyN(io.Discard, reader, int64(brokerIdsCount*4)); err != nil { + return false, nil, err + } + + // Read tagged fields for broker_ids array + if err = readTaggedFields(reader); err != nil { + return false, nil, err + } + } + + // Read tagged fields for topic + if err = readTaggedFields(reader); err != nil { + return false, nil, err + } + } + + // Skip timeout_ms (INT32) + if err = binary.Read(reader, binary.BigEndian, new(int32)); err != nil { + return false, nil, err + } + + // Skip validate_only (BOOLEAN) + if err = binary.Read(reader, binary.BigEndian, new(bool)); err != nil { + return false, nil, err + } + + // Read tagged fields for request + if err = readTaggedFields(reader); err != nil { + return false, nil, err + } + + } else { + // Read topics array + var topicsCount int32 + if err = binary.Read(reader, binary.BigEndian, &topicsCount); err != nil { + return false, nil, err + } + + topicNames = make([]string, 0, topicsCount) + + for i := int32(0); i < topicsCount; i++ { + // Read topic name (STRING) + var topicName string + if topicName, err = readString(reader); err != nil { + return false, nil, err + } + topicNames = append(topicNames, topicName) + + // Skip count (INT32) + if err = binary.Read(reader, binary.BigEndian, new(int32)); err != nil { + return false, nil, err + } + + // Read assignments array + var assignmentsCount int32 + if err = binary.Read(reader, binary.BigEndian, &assignmentsCount); err != nil { + return false, nil, err + } + + // Skip assignments data + for j := int32(0); j < assignmentsCount; j++ { + // Read broker_ids array + var brokerIdsCount int32 + if err = binary.Read(reader, binary.BigEndian, &brokerIdsCount); err != nil { + return false, nil, err + } + + // Skip broker IDs + if _, err = io.CopyN(io.Discard, reader, int64(brokerIdsCount*4)); err != nil { + return false, nil, err + } + } + } + + // Skip timeout_ms (INT32) + if err = binary.Read(reader, binary.BigEndian, new(int32)); err != nil { + return false, nil, err + } + + // Skip validate_only (BOOLEAN) + if err = binary.Read(reader, binary.BigEndian, new(bool)); err != nil { + return false, nil, err + } + } + + return true, topicNames, nil +} + +func readString(reader io.Reader) (string, error) { + var length int16 + if err := binary.Read(reader, binary.BigEndian, &length); err != nil { + return "", err + } + if length < 0 { + return "", fmt.Errorf("Invalid string length %d", length) + } else if length == 0 { + return "", nil + } else { + strBytes := make([]byte, length) + if _, err := io.ReadFull(reader, strBytes); err != nil { + return "", err + } + return string(strBytes), nil + } +} + +func readCompactNullableString(reader io.Reader) (string, error) { + // Read the length as an unsigned VarInt + length, err := readUVarint(reader) + if err != nil { + return "", err + } + + if length == 0 { + // Null string + return "", nil + } + + strLen := length - 1 + + // Ensure the string length is valid + if strLen < 0 { + return "", errors.New("invalid string length") + } + + // Read the string bytes + strBytes := make([]byte, strLen) + if _, err := io.ReadFull(reader, strBytes); err != nil { + return "", err + } + + return string(strBytes), nil +} + +func readUVarint(reader io.Reader) (uint64, error) { + var value uint64 + var shift uint + for { + var b [1]byte + if _, err := io.ReadFull(reader, b[:]); err != nil { + return 0, err + } + value |= uint64(b[0]&0x7F) << shift + if (b[0] & 0x80) == 0 { + break + } + shift += 7 + if shift > 63 { + return 0, fmt.Errorf("varint too long") + } + } + return value, nil +} + +func readCompactString(reader io.Reader) (string, error) { + length, err := readUVarint(reader) + if err != nil { + return "", err + } + if length == 0 { + return "", nil + } + length-- // Adjust for compact encoding (length includes one extra byte) + buf := make([]byte, length) + if _, err := io.ReadFull(reader, buf); err != nil { + return "", err + } + return string(buf), nil +} + +func readCompactArrayLength(reader io.Reader) (int32, error) { + length, err := readUVarint(reader) + if err != nil { + return 0, err + } + if length == 0 { + return 0, nil + } + return int32(length - 1), nil // Adjust for compact encoding +} + +func readTaggedFields(reader io.Reader) error { + numTags, err := readUVarint(reader) + if err != nil { + return err + } + for i := uint64(0); i < numTags; i++ { + // Read tag (UVarint) + _, err := readUVarint(reader) + if err != nil { + return err + } + // Read size (UVarint) + size, err := readUVarint(reader) + if err != nil { + return err + } + // Skip over the tag data + if _, err := io.CopyN(io.Discard, reader, int64(size)); err != nil { + return err + } + } + return nil +} + +func readNullableString(reader io.Reader) (string, error) { + var length int16 + if err := binary.Read(reader, binary.BigEndian, &length); err != nil { + return "", err + } + if length < 0 { + // Null string + return "", nil + } else if length == 0 { + return "", nil + } else { + strBytes := make([]byte, length) + if _, err := io.ReadFull(reader, strBytes); err != nil { + return "", err + } + return string(strBytes), nil + } +} + +func (handler *DefaultResponseHandler) handleResponse(dst DeadlineWriter, src DeadlineReader, ctx *ResponsesLoopContext) (readErr bool, err error) { + //logrus.Println("Await Kafka response") + + // waiting for first bytes or EOF - reset deadlines + if err = src.SetReadDeadline(time.Time{}); err != nil { + return true, err + } + if err = dst.SetWriteDeadline(time.Time{}); err != nil { + return true, err + } + + responseHeaderBuf := make([]byte, 8) // Size => int32, CorrelationId => int32 + if _, err = io.ReadFull(src, responseHeaderBuf); err != nil { + return true, err + } + + var responseHeader protocol.ResponseHeader + if err = protocol.Decode(responseHeaderBuf, &responseHeader); err != nil { + return true, err + } + + // Read the inFlightRequests channel after header is read. Otherwise the channel would block and socket EOF from remote would not be received. + requestKeyVersion, err := receiveRequestKeyVersion(ctx.openRequestsChannel, openRequestReceiveTimeout) + if err != nil { + return true, err + } + proxyResponsesBytes.WithLabelValues(ctx.brokerAddress).Add(float64(responseHeader.Length + 4)) + logrus.Debugf("Kafka response operation: %s (key: %v, version: %v, length: %v)", + getOperationName(requestKeyVersion.ApiKey), + requestKeyVersion.ApiKey, + requestKeyVersion.ApiVersion, + responseHeader.Length) + + responseDeadline := time.Now().Add(ctx.timeout) + err = dst.SetWriteDeadline(responseDeadline) + if err != nil { + return false, err + } + err = src.SetReadDeadline(responseDeadline) + if err != nil { + return true, err + } + responseHeaderTaggedFields, err := protocol.NewResponseHeaderTaggedFields(requestKeyVersion) + if err != nil { + return true, err + } + unknownTaggedFields, err := responseHeaderTaggedFields.MaybeRead(src) + if err != nil { + return true, err + } + readResponsesHeaderLength := int32(4 + len(unknownTaggedFields)) // 4 = Length + CorrelationID + + responseModifier, err := protocol.GetResponseModifier(requestKeyVersion.ApiKey, requestKeyVersion.ApiVersion, ctx.netAddressMappingFunc, ctx.acl) + if err != nil { + return true, err + } + + // TODO: implement filtering of topics in Metadata and ListTopics API Keys if responseModifier != nil { if responseHeader.Length > protocol.MaxResponseSize { return true, protocol.PacketDecodingError{Info: fmt.Sprintf("message of length %d too large", responseHeader.Length)} diff --git a/proxy/protocol/responses.go b/proxy/protocol/responses.go index 3309719b..a747ddb5 100644 --- a/proxy/protocol/responses.go +++ b/proxy/protocol/responses.go @@ -5,6 +5,7 @@ import ( "fmt" "github.com/grepplabs/kafka-proxy/config" + "github.com/grepplabs/kafka-proxy/pkg/apis" ) const ( @@ -296,7 +297,7 @@ func createFindCoordinatorResponseSchemaVersions() []Schema { return []Schema{findCoordinatorResponseV0, findCoordinatorResponseV1, findCoordinatorResponseV2, findCoordinatorResponseV3, findCoordinatorResponseV4} } -func modifyMetadataResponse(decodedStruct *Struct, fn config.NetAddressMappingFunc) error { +func modifyMetadataResponse(decodedStruct *Struct, fn config.NetAddressMappingFunc, acl *apis.ACLCollection) error { if decodedStruct == nil { return errors.New("decoded struct must not be nil") } @@ -339,10 +340,35 @@ func modifyMetadataResponse(decodedStruct *Struct, fn config.NetAddressMappingFu } } } + + topicMetadataArray, ok := decodedStruct.Get("topic_metadata").([]interface{}) + if !ok { + return errors.New("topic metadata list not found") + } + + for _, topicElement := range topicMetadataArray { + topic := topicElement.(*Struct) + + // Get topic name - try both "topic" and "name" fields based on version + topicName, ok := topic.Get("topic").(string) + if !ok { + topicName, ok = topic.Get("name").(string) + if !ok { + continue + } + } + + fmt.Printf("METADATA TOPICS!: %s\n", topicName) + } + + // TODO:! + topicMetadataArray = topicMetadataArray[:1] + decodedStruct.Replace("topic_metadata", topicMetadataArray) + return nil } -func modifyFindCoordinatorResponse(decodedStruct *Struct, fn config.NetAddressMappingFunc) error { +func modifyFindCoordinatorResponse(decodedStruct *Struct, fn config.NetAddressMappingFunc, acl *apis.ACLCollection) error { if decodedStruct == nil { return errors.New("decoded struct must not be nil") } @@ -408,12 +434,13 @@ type ResponseModifier interface { Apply(resp []byte) ([]byte, error) } -type modifyResponseFunc func(decodedStruct *Struct, fn config.NetAddressMappingFunc) error +type modifyResponseFunc func(decodedStruct *Struct, fn config.NetAddressMappingFunc, acl *apis.ACLCollection) error type responseModifier struct { schema Schema modifyResponseFunc modifyResponseFunc netAddressMappingFunc config.NetAddressMappingFunc + acl *apis.ACLCollection } func (f *responseModifier) Apply(resp []byte) ([]byte, error) { @@ -421,25 +448,25 @@ func (f *responseModifier) Apply(resp []byte) ([]byte, error) { if err != nil { return nil, err } - err = f.modifyResponseFunc(decodedStruct, f.netAddressMappingFunc) + err = f.modifyResponseFunc(decodedStruct, f.netAddressMappingFunc, f.acl) if err != nil { return nil, err } return EncodeSchema(decodedStruct, f.schema) } -func GetResponseModifier(apiKey int16, apiVersion int16, addressMappingFunc config.NetAddressMappingFunc) (ResponseModifier, error) { +func GetResponseModifier(apiKey int16, apiVersion int16, addressMappingFunc config.NetAddressMappingFunc, acl *apis.ACLCollection) (ResponseModifier, error) { switch apiKey { case apiKeyMetadata: - return newResponseModifier(apiKey, apiVersion, addressMappingFunc, metadataResponseSchemaVersions, modifyMetadataResponse) + return newResponseModifier(apiKey, apiVersion, addressMappingFunc, metadataResponseSchemaVersions, modifyMetadataResponse, acl) case apiKeyFindCoordinator: - return newResponseModifier(apiKey, apiVersion, addressMappingFunc, findCoordinatorResponseSchemaVersions, modifyFindCoordinatorResponse) + return newResponseModifier(apiKey, apiVersion, addressMappingFunc, findCoordinatorResponseSchemaVersions, modifyFindCoordinatorResponse, acl) default: return nil, nil } } -func newResponseModifier(apiKey int16, apiVersion int16, netAddressMappingFunc config.NetAddressMappingFunc, schemas []Schema, modifyResponseFunc modifyResponseFunc) (ResponseModifier, error) { +func newResponseModifier(apiKey int16, apiVersion int16, netAddressMappingFunc config.NetAddressMappingFunc, schemas []Schema, modifyResponseFunc modifyResponseFunc, acl *apis.ACLCollection) (ResponseModifier, error) { schema, err := getResponseSchema(apiKey, apiVersion, schemas) if err != nil { return nil, err @@ -448,6 +475,7 @@ func newResponseModifier(apiKey int16, apiVersion int16, netAddressMappingFunc c schema: schema, modifyResponseFunc: modifyResponseFunc, netAddressMappingFunc: netAddressMappingFunc, + acl: acl, }, nil }