diff --git a/proxy/protocol/request_v1.go b/proxy/protocol/request_v1.go index c5ffdc33..e953ff86 100644 --- a/proxy/protocol/request_v1.go +++ b/proxy/protocol/request_v1.go @@ -42,6 +42,9 @@ func (r *Request) decode(pd packetDecoder) (err error) { if version, err = pd.getInt16(); err != nil { return err } + if version == -1 { + version = r.Body.version() + } if r.Body.key() != key || r.Body.version() != version { return PacketDecodingError{fmt.Sprintf("expected request key,version %d,%d but got %d,%d", r.Body.key(), r.Body.version(), key, version)} } diff --git a/proxy/protocol/request_v2.go b/proxy/protocol/request_v2.go index de0d87d7..4fc54f04 100644 --- a/proxy/protocol/request_v2.go +++ b/proxy/protocol/request_v2.go @@ -42,6 +42,9 @@ func (r *RequestV2) decode(pd packetDecoder) (err error) { if version, err = pd.getInt16(); err != nil { return err } + if version == -1 { + version = r.Body.version() + } if r.Body.key() != key || r.Body.version() != version { return PacketDecodingError{fmt.Sprintf("expected request key,version %d,%d but got %d,%d", r.Body.key(), r.Body.version(), key, version)} } diff --git a/proxy/protocol/responses.go b/proxy/protocol/responses.go index 774ff90d..84530505 100644 --- a/proxy/protocol/responses.go +++ b/proxy/protocol/responses.go @@ -10,6 +10,8 @@ import ( const ( apiKeyMetadata = 3 apiKeyFindCoordinator = 10 + apiKeySaslHandshake = 17 + apiKeyApiVersions = 18 brokersKeyName = "brokers" hostKeyName = "host" @@ -23,8 +25,62 @@ const ( var ( metadataResponseSchemaVersions = createMetadataResponseSchemaVersions() findCoordinatorResponseSchemaVersions = createFindCoordinatorResponseSchemaVersions() + apiVersionsResponseSchemaVersions = createApiVersionsResponseSchemaVersions() ) +func createApiVersionsResponseSchemaVersions() []Schema { + apiVersionV0 := NewSchema("api_version", + &Mfield{Name: "api_key", Ty: TypeInt16}, + &Mfield{Name: "min_version", Ty: TypeInt16}, + &Mfield{Name: "max_version", Ty: TypeInt16}, + ) + + // Version 0: error_code + api_keys + apiVersionsResponseV0 := NewSchema("api_versions_response_v0", + &Mfield{Name: "error_code", Ty: TypeInt16}, + &Array{Name: "api_keys", Ty: apiVersionV0}, + ) + + // Version 1: error_code + api_keys + throttle_time_ms + apiVersionsResponseV1 := NewSchema("api_versions_response_v1", + &Mfield{Name: "error_code", Ty: TypeInt16}, + &Array{Name: "api_keys", Ty: apiVersionV0}, + &Mfield{Name: "throttle_time_ms", Ty: TypeInt32}, + ) + + // Version 2: Same as version 1 + apiVersionsResponseV2 := apiVersionsResponseV1 + + // ApiVersion struct for flexible versions (v3+) with compact arrays + apiVersionV3 := NewSchema("api_version_v3", + &Mfield{Name: "api_key", Ty: TypeInt16}, + &Mfield{Name: "min_version", Ty: TypeInt16}, + &Mfield{Name: "max_version", Ty: TypeInt16}, + &SchemaTaggedFields{Name: "api_version_tagged_fields"}, + ) + + // Version 3: Flexible version with tagged fields + // Tagged fields: supported_features (tag 0), finalized_features_epoch (tag 1), + // finalized_features (tag 2), zk_migration_ready (tag 3) + apiVersionsResponseV3 := NewSchema("api_versions_response_v3", + &Mfield{Name: "error_code", Ty: TypeInt16}, + &CompactArray{Name: "api_keys", Ty: apiVersionV3}, + &Mfield{Name: "throttle_time_ms", Ty: TypeInt32}, + &SchemaTaggedFields{Name: "response_tagged_fields"}, + ) + + // Version 4: Same as version 3 + apiVersionsResponseV4 := apiVersionsResponseV3 + + return []Schema{ + apiVersionsResponseV0, + apiVersionsResponseV1, + apiVersionsResponseV2, + apiVersionsResponseV3, + apiVersionsResponseV4, + } +} + func createMetadataResponseSchemaVersions() []Schema { metadataBrokerV0 := NewSchema("metadata_broker_v0", &Mfield{Name: nodeKeyName, Ty: TypeInt32}, @@ -325,6 +381,40 @@ func createFindCoordinatorResponseSchemaVersions() []Schema { return []Schema{findCoordinatorResponseV0, findCoordinatorResponseV1, findCoordinatorResponseV2, findCoordinatorResponseV3, findCoordinatorResponseV4, findCoordinatorResponseV5, findCoordinatorResponseV6} } +func modifyApiVersionsResponse(decodedStruct *Struct, fn config.NetAddressMappingFunc) error { + if decodedStruct == nil { + return errors.New("decoded struct must not be nil") + } + + versions, ok := decodedStruct.Get("api_keys").([]any) + if !ok || len(versions) == 0 { + return errors.New("versions not found") + } + for _, versionElement := range versions { + version := versionElement.(*Struct) + if version.Get("api_key").(int16) == apiKeySaslHandshake { + return nil + } + } + + schema := versions[0].(*Struct).GetSchema() + + // v1 Sasl auth does not seem to work with KafkaJS so pin to v0 + values := []any{int16(17), int16(0), int16(0)} + + // version 3+ of the api versions response + if len(schema.GetFields()) > 3 { + values = append(values, []rawTaggedField{}) + } + + versions = append(versions, &Struct{ + Schema: schema, + Values: values, + }) + + return decodedStruct.Replace("api_keys", versions) +} + func modifyMetadataResponse(decodedStruct *Struct, fn config.NetAddressMappingFunc) error { if decodedStruct == nil { return errors.New("decoded struct must not be nil") @@ -467,6 +557,8 @@ func (f *responseModifier) Apply(resp []byte) ([]byte, error) { func GetResponseModifier(apiKey int16, apiVersion int16, addressMappingFunc config.NetAddressMappingFunc) (ResponseModifier, error) { switch apiKey { + case apiKeyApiVersions: + return newResponseModifier(apiKey, apiVersion, addressMappingFunc, apiVersionsResponseSchemaVersions, modifyApiVersionsResponse) case apiKeyMetadata: return newResponseModifier(apiKey, apiVersion, addressMappingFunc, metadataResponseSchemaVersions, modifyMetadataResponse) case apiKeyFindCoordinator: