Skip to content

Commit a95f18f

Browse files
feat: support for client middleware in Flight configuration and testing
1 parent 1ccd6a7 commit a95f18f

File tree

4 files changed

+103
-1
lines changed

4 files changed

+103
-1
lines changed

influxdb3/config.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ import (
3131
"strconv"
3232
"time"
3333

34+
"github.com/apache/arrow-go/v18/arrow/flight"
3435
"github.com/influxdata/line-protocol/v2/lineprotocol"
3536
)
3637

@@ -149,6 +150,9 @@ type ClientConfig struct {
149150

150151
// Proxy URL
151152
Proxy string
153+
154+
// Flight client middleware
155+
Middleware []flight.ClientMiddleware
152156
}
153157

154158
// validate validates the config.

influxdb3/query.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ func (c *Client) initializeQueryClient(hostPortURL string, secure bool, proxyURL
7979
}
8080
}
8181

82-
client, err := flight.NewClientWithMiddleware(hostPortURL, nil, nil, opts...)
82+
client, err := flight.NewClientWithMiddleware(hostPortURL, nil, c.config.Middleware, opts...)
8383
if err != nil {
8484
return fmt.Errorf("flight: %w", err)
8585
}

influxdb3/query_test.go

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,60 @@ func TestQueryWithLargeResponsePass(t *testing.T) {
213213
assert.Nil(t, qIter.Err())
214214
}
215215

216+
func TestQueryWithMiddlewareSuccess(t *testing.T) {
217+
s := *testutil.StartCheckMessageFromMiddlewareFlightServer(t)
218+
defer func() {
219+
s.Shutdown()
220+
}()
221+
222+
middlewares := []flight.ClientMiddleware{
223+
flight.CreateClientMiddleware(&testutil.ClientTestMiddleware{}),
224+
}
225+
226+
client, err := New(ClientConfig{
227+
Host: "http://" + s.Addr().String(),
228+
Token: "my_secret_token",
229+
Database: "explore",
230+
Middleware: middlewares,
231+
})
232+
require.NoError(t, err)
233+
defer func(client *Client) {
234+
err := client.Close()
235+
if err != nil {
236+
t.Fatal(err)
237+
}
238+
}(client)
239+
_, qErr := client.Query(context.Background(),
240+
"SELECT name FROM examples",
241+
)
242+
require.NotContains(t, qErr.Error(), "invalid value from middleware")
243+
}
244+
245+
func TestQueryWithMiddlewareFail(t *testing.T) {
246+
s := *testutil.StartCheckMessageFromMiddlewareFlightServer(t)
247+
defer func() {
248+
s.Shutdown()
249+
}()
250+
251+
client, err := New(ClientConfig{
252+
Host: "http://" + s.Addr().String(),
253+
Token: "my_secret_token",
254+
Database: "explore",
255+
Middleware: nil,
256+
})
257+
require.NoError(t, err)
258+
defer func(client *Client) {
259+
err := client.Close()
260+
if err != nil {
261+
t.Fatal(err)
262+
}
263+
}(client)
264+
_, qErr := client.Query(context.Background(),
265+
"SELECT name FROM examples",
266+
)
267+
require.Contains(t, qErr.Error(), "invalid value from middleware")
268+
}
269+
216270
func TestQueryWithQueryTimeoutDeadlineExpired(t *testing.T) {
217271
s := flight.NewServerWithMiddleware(nil)
218272
err := s.Init("localhost:0")

influxdb3/testutil/mocks.go

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package testutil
33

44
import (
55
"bytes"
6+
"context"
67
"encoding/json"
78
"errors"
89
"fmt"
@@ -20,6 +21,7 @@ import (
2021
"github.com/apache/arrow-go/v18/arrow/memory"
2122
"github.com/stretchr/testify/assert"
2223
"google.golang.org/grpc/codes"
24+
"google.golang.org/grpc/metadata"
2325
"google.golang.org/grpc/status"
2426
)
2527

@@ -461,3 +463,45 @@ func arrayOf(mem memory.Allocator, a any, valids []bool) arrow.Array {
461463
panic(fmt.Errorf("arrdata: invalid data slice type %T", a))
462464
}
463465
}
466+
467+
type ClientTestMiddleware struct {
468+
}
469+
470+
func (c *ClientTestMiddleware) StartCall(ctx context.Context) context.Context {
471+
return metadata.AppendToOutgoingContext(ctx, "sent-from-middleware", "some-value")
472+
}
473+
474+
type CheckMessageFromMiddlewareFlightServer struct {
475+
flight.BaseFlightServer
476+
}
477+
478+
func (f *CheckMessageFromMiddlewareFlightServer) DoGet(tkt *flight.Ticket, fs flight.FlightService_DoGetServer) error {
479+
ctx := fs.Context()
480+
md, _ := metadata.FromIncomingContext(ctx)
481+
values := md.Get("sent-from-middleware")
482+
if values == nil || values[0] != "some-value" {
483+
return status.Errorf(codes.Internal, "invalid value from middleware")
484+
}
485+
486+
return nil
487+
}
488+
489+
//nolint:all
490+
func StartCheckMessageFromMiddlewareFlightServer(t *testing.T) *flight.Server {
491+
mockServer := CheckMessageFromMiddlewareFlightServer{}
492+
s := flight.NewServerWithMiddleware([]flight.ServerMiddleware{})
493+
err := s.Init("localhost:0")
494+
if err != nil {
495+
assert.Fail(t, err.Error())
496+
}
497+
s.RegisterFlightService(&mockServer)
498+
499+
go func() {
500+
err := s.Serve()
501+
if err != nil {
502+
assert.Fail(t, err.Error())
503+
}
504+
}()
505+
506+
return &s
507+
}

0 commit comments

Comments
 (0)