Skip to content

Commit cf98861

Browse files
Support rewriting Sasl capability when server does not support it
1 parent 317f1d7 commit cf98861

File tree

2 files changed

+89
-2
lines changed

2 files changed

+89
-2
lines changed

proxy/processor_default.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,12 @@ import (
44
"bytes"
55
"errors"
66
"fmt"
7-
"github.com/grepplabs/kafka-proxy/proxy/protocol"
8-
"github.com/sirupsen/logrus"
97
"io"
108
"strconv"
119
"time"
10+
11+
"github.com/grepplabs/kafka-proxy/proxy/protocol"
12+
"github.com/sirupsen/logrus"
1213
)
1314

1415
type DefaultRequestHandler struct {
@@ -235,6 +236,7 @@ func (handler *DefaultResponseHandler) handleResponse(dst DeadlineWriter, src De
235236
if _, err = io.ReadFull(src, resp); err != nil {
236237
return true, err
237238
}
239+
238240
newResponseBuf, err := responseModifier.Apply(resp)
239241
if err != nil {
240242
return true, err

proxy/protocol/responses.go

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ import (
1010
const (
1111
apiKeyMetadata = 3
1212
apiKeyFindCoordinator = 10
13+
apiKeySaslHandshake = 17
14+
apiKeyApiApiVersions = 18
1315

1416
brokersKeyName = "brokers"
1517
hostKeyName = "host"
@@ -23,8 +25,65 @@ const (
2325
var (
2426
metadataResponseSchemaVersions = createMetadataResponseSchemaVersions()
2527
findCoordinatorResponseSchemaVersions = createFindCoordinatorResponseSchemaVersions()
28+
apiVersionsResponseSchemaVersions = createApiVersionsResponseSchemaVersions()
29+
apiVersionSchema = createApiVersionSchema()
2630
)
2731

32+
func createApiVersionSchema() Schema {
33+
return NewSchema("api_version",
34+
&Mfield{Name: "api_key", Ty: TypeInt16},
35+
&Mfield{Name: "min_version", Ty: TypeInt16},
36+
&Mfield{Name: "max_version", Ty: TypeInt16},
37+
)
38+
}
39+
40+
func createApiVersionsResponseSchemaVersions() []Schema {
41+
// Version 0: error_code + api_keys
42+
apiVersionsResponseV0 := NewSchema("api_versions_response_v0",
43+
&Mfield{Name: "error_code", Ty: TypeInt16},
44+
&Array{Name: "api_keys", Ty: apiVersionSchema},
45+
)
46+
47+
// Version 1: error_code + api_keys + throttle_time_ms
48+
apiVersionsResponseV1 := NewSchema("api_versions_response_v1",
49+
&Mfield{Name: "error_code", Ty: TypeInt16},
50+
&Array{Name: "api_keys", Ty: apiVersionSchema},
51+
&Mfield{Name: "throttle_time_ms", Ty: TypeInt32},
52+
)
53+
54+
// Version 2: Same as version 1
55+
apiVersionsResponseV2 := apiVersionsResponseV1
56+
57+
// ApiVersion struct for flexible versions (v3+) with compact arrays
58+
apiVersionV3 := NewSchema("api_version_v3",
59+
&Mfield{Name: "api_key", Ty: TypeInt16},
60+
&Mfield{Name: "min_version", Ty: TypeInt16},
61+
&Mfield{Name: "max_version", Ty: TypeInt16},
62+
&SchemaTaggedFields{Name: "api_version_tagged_fields"},
63+
)
64+
65+
// Version 3: Flexible version with tagged fields
66+
// Tagged fields: supported_features (tag 0), finalized_features_epoch (tag 1),
67+
// finalized_features (tag 2), zk_migration_ready (tag 3)
68+
apiVersionsResponseV3 := NewSchema("api_versions_response_v3",
69+
&Mfield{Name: "error_code", Ty: TypeInt16},
70+
&CompactArray{Name: "api_keys", Ty: apiVersionV3},
71+
&Mfield{Name: "throttle_time_ms", Ty: TypeInt32},
72+
&SchemaTaggedFields{Name: "response_tagged_fields"},
73+
)
74+
75+
// Version 4: Same as version 3
76+
apiVersionsResponseV4 := apiVersionsResponseV3
77+
78+
return []Schema{
79+
apiVersionsResponseV0,
80+
apiVersionsResponseV1,
81+
apiVersionsResponseV2,
82+
apiVersionsResponseV3,
83+
apiVersionsResponseV4,
84+
}
85+
}
86+
2887
func createMetadataResponseSchemaVersions() []Schema {
2988
metadataBrokerV0 := NewSchema("metadata_broker_v0",
3089
&Mfield{Name: nodeKeyName, Ty: TypeInt32},
@@ -325,6 +384,30 @@ func createFindCoordinatorResponseSchemaVersions() []Schema {
325384
return []Schema{findCoordinatorResponseV0, findCoordinatorResponseV1, findCoordinatorResponseV2, findCoordinatorResponseV3, findCoordinatorResponseV4, findCoordinatorResponseV5, findCoordinatorResponseV6}
326385
}
327386

387+
func modifyApiVersionsResponse(decodedStruct *Struct, fn config.NetAddressMappingFunc) error {
388+
if decodedStruct == nil {
389+
return errors.New("decoded struct must not be nil")
390+
}
391+
392+
versions, ok := decodedStruct.Get("api_keys").([]interface{})
393+
if !ok {
394+
return errors.New("versions not found")
395+
}
396+
for _, versionElement := range versions {
397+
version := versionElement.(*Struct)
398+
if version.Get("api_key").(int16) == apiKeySaslHandshake {
399+
return nil
400+
}
401+
}
402+
403+
versions = append(versions, &Struct{
404+
Schema: apiVersionSchema,
405+
Values: []any{int16(17), int16(0), int16(0)},
406+
})
407+
408+
return decodedStruct.Replace("api_keys", versions)
409+
}
410+
328411
func modifyMetadataResponse(decodedStruct *Struct, fn config.NetAddressMappingFunc) error {
329412
if decodedStruct == nil {
330413
return errors.New("decoded struct must not be nil")
@@ -467,6 +550,8 @@ func (f *responseModifier) Apply(resp []byte) ([]byte, error) {
467550

468551
func GetResponseModifier(apiKey int16, apiVersion int16, addressMappingFunc config.NetAddressMappingFunc) (ResponseModifier, error) {
469552
switch apiKey {
553+
case apiKeyApiApiVersions:
554+
return newResponseModifier(apiKey, apiVersion, addressMappingFunc, apiVersionsResponseSchemaVersions, modifyApiVersionsResponse)
470555
case apiKeyMetadata:
471556
return newResponseModifier(apiKey, apiVersion, addressMappingFunc, metadataResponseSchemaVersions, modifyMetadataResponse)
472557
case apiKeyFindCoordinator:

0 commit comments

Comments
 (0)