Skip to content

Commit 65a6655

Browse files
Anaethelionaxw
andauthored
Product check fixes (#303) (#304)
* Product check improvements - don't issue original request if product check fails - check original request error before response header check (avoiding panic) * Don't cache product check errors The product check could have failed because of a transient network error, so repeat the check on subsequent requests as long as there has not been a prior success. Co-authored-by: Andrew Wilkins <[email protected]>
1 parent 3c08ee7 commit 65a6655

File tree

2 files changed

+97
-44
lines changed

2 files changed

+97
-44
lines changed

elasticsearch.go

Lines changed: 39 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -101,13 +101,12 @@ type Config struct {
101101
// Client represents the Elasticsearch client.
102102
//
103103
type Client struct {
104-
*esapi.API // Embeds the API methods
105-
Transport estransport.Interface
106-
107-
productCheckOnce sync.Once
108-
responseCheckOnce sync.Once
109-
productCheckError error
104+
*esapi.API // Embeds the API methods
105+
Transport estransport.Interface
110106
useResponseCheckOnly bool
107+
108+
productCheckMu sync.RWMutex
109+
productCheckSuccess bool
111110
}
112111

113112
type esVersion struct {
@@ -280,35 +279,50 @@ func ParseElasticsearchVersion(version string) (int64, int64, int64, error) {
280279
// Perform delegates to Transport to execute a request and return a response.
281280
//
282281
func (c *Client) Perform(req *http.Request) (*http.Response, error) {
283-
// ProductCheck validation
284-
c.productCheckOnce.Do(func() {
285-
// We skip this validation of we only want the header validation.
286-
// ResponseCheck path continues after original request.
287-
if c.useResponseCheckOnly {
288-
return
289-
}
290-
282+
// ProductCheck validation. We skip this validation of we only want the
283+
// header validation. ResponseCheck path continues after original request.
284+
if !c.useResponseCheckOnly {
291285
// Launch product check for 7.x, request info, check header then payload.
292-
c.productCheckError = c.productCheck()
293-
return
294-
})
286+
if err := c.doProductCheck(c.productCheck); err != nil {
287+
return nil, err
288+
}
289+
}
295290

296291
// Retrieve the original request.
297292
res, err := c.Transport.Perform(req)
298293

299-
c.responseCheckOnce.Do(func() {
300-
// ResponseCheck path continues, we run the header check on the first answer from ES.
301-
if c.useResponseCheckOnly {
302-
c.productCheckError = genuineCheckHeader(res.Header)
294+
// ResponseCheck path continues, we run the header check on the first answer from ES.
295+
if err == nil {
296+
checkHeader := func() error { return genuineCheckHeader(res.Header) }
297+
if err := c.doProductCheck(checkHeader); err != nil {
298+
res.Body.Close()
299+
return nil, err
303300
}
304-
})
305-
306-
if c.productCheckError != nil {
307-
return nil, c.productCheckError
308301
}
309302
return res, err
310303
}
311304

305+
// doProductCheck calls f if there as not been a prior successful call to doProductCheck,
306+
// returning nil otherwise.
307+
func (c *Client) doProductCheck(f func() error) error {
308+
c.productCheckMu.RLock()
309+
productCheckSuccess := c.productCheckSuccess
310+
c.productCheckMu.RUnlock()
311+
if productCheckSuccess {
312+
return nil
313+
}
314+
c.productCheckMu.Lock()
315+
defer c.productCheckMu.Unlock()
316+
if c.productCheckSuccess {
317+
return nil
318+
}
319+
if err := f(); err != nil {
320+
return err
321+
}
322+
c.productCheckSuccess = true
323+
return nil
324+
}
325+
312326
// productCheck runs an esapi.Info query to retrieve informations of the current cluster
313327
// decodes the response and decides if the cluster is a genuine Elasticsearch product.
314328
func (c *Client) productCheck() error {

elasticsearch_internal_test.go

Lines changed: 58 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -22,19 +22,22 @@ package elasticsearch
2222
import (
2323
"encoding/base64"
2424
"errors"
25-
"github.com/elastic/go-elasticsearch/v7/estransport"
2625
"io/ioutil"
2726
"net/http"
27+
"net/http/httptest"
2828
"net/url"
2929
"os"
30+
"reflect"
3031
"regexp"
3132
"strings"
3233
"testing"
34+
35+
"github.com/elastic/go-elasticsearch/v7/estransport"
3336
)
3437

3538
var called bool
3639

37-
type mockTransp struct{
40+
type mockTransp struct {
3841
RoundTripFunc func(*http.Request) (*http.Response, error)
3942
}
4043

@@ -64,7 +67,6 @@ func (t *mockTransp) RoundTrip(req *http.Request) (*http.Response, error) {
6467
return t.RoundTripFunc(req)
6568
}
6669

67-
6870
func TestClientConfiguration(t *testing.T) {
6971
t.Parallel()
7072

@@ -219,7 +221,7 @@ func TestClientConfiguration(t *testing.T) {
219221
}, nil
220222
},
221223
},
222-
})
224+
})
223225
if err != nil {
224226
t.Errorf("Unexpected error, got: %+v", err)
225227
}
@@ -444,7 +446,7 @@ func TestParseElasticsearchVersion(t *testing.T) {
444446
wantErr: true,
445447
},
446448
}
447-
for _, tt := range tests {
449+
for _, tt := range tests {
448450
t.Run(tt.name, func(t *testing.T) {
449451
got, got1, got2, err := ParseElasticsearchVersion(tt.version)
450452
if (err != nil) != tt.wantErr {
@@ -556,7 +558,7 @@ func TestGenuineCheckHeader(t *testing.T) {
556558
wantErr: true,
557559
},
558560
{
559-
name: "Unavailable product header",
561+
name: "Unavailable product header",
560562
headers: http.Header{},
561563
wantErr: true,
562564
},
@@ -575,49 +577,56 @@ func TestResponseCheckOnly(t *testing.T) {
575577
name string
576578
useResponseCheckOnly bool
577579
response *http.Response
580+
requestErr error
578581
wantErr bool
579-
} {
582+
}{
580583
{
581-
name: "Valid answer with header",
584+
name: "Valid answer with header",
582585
useResponseCheckOnly: false,
583586
response: &http.Response{
584587
Header: http.Header{"X-Elastic-Product": []string{"Elasticsearch"}},
585-
Body: ioutil.NopCloser(strings.NewReader("{}")),
588+
Body: ioutil.NopCloser(strings.NewReader("{}")),
586589
},
587-
wantErr: false,
590+
wantErr: false,
588591
},
589592
{
590-
name: "Valid answer without header",
593+
name: "Valid answer without header",
591594
useResponseCheckOnly: false,
592595
response: &http.Response{
593596
Body: ioutil.NopCloser(strings.NewReader("{}")),
594597
},
595-
wantErr: true,
598+
wantErr: true,
596599
},
597600
{
598-
name: "Valid answer with header and response check",
601+
name: "Valid answer with header and response check",
599602
useResponseCheckOnly: true,
600603
response: &http.Response{
601604
Header: http.Header{"X-Elastic-Product": []string{"Elasticsearch"}},
602-
Body: ioutil.NopCloser(strings.NewReader("{}")),
605+
Body: ioutil.NopCloser(strings.NewReader("{}")),
603606
},
604-
wantErr: false,
607+
wantErr: false,
605608
},
606609
{
607-
name: "Valid answer withouth header and response check",
610+
name: "Valid answer without header and response check",
608611
useResponseCheckOnly: true,
609612
response: &http.Response{
610613
Body: ioutil.NopCloser(strings.NewReader("{}")),
611614
},
612-
wantErr: true,
615+
wantErr: true,
616+
},
617+
{
618+
name: "Request failed",
619+
useResponseCheckOnly: true,
620+
response: nil,
621+
requestErr: errors.New("request failed"),
622+
wantErr: true,
613623
},
614-
615624
}
616625
for _, tt := range tests {
617626
t.Run(tt.name, func(t *testing.T) {
618627
c, _ := NewClient(Config{
619628
Transport: &mockTransp{RoundTripFunc: func(request *http.Request) (*http.Response, error) {
620-
return tt.response, nil
629+
return tt.response, tt.requestErr
621630
}},
622631
UseResponseCheckOnly: tt.useResponseCheckOnly,
623632
})
@@ -628,3 +637,33 @@ func TestResponseCheckOnly(t *testing.T) {
628637
})
629638
}
630639
}
640+
641+
func TestProductCheckError(t *testing.T) {
642+
var requestPaths []string
643+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
644+
requestPaths = append(requestPaths, r.URL.Path)
645+
if len(requestPaths) == 1 {
646+
// Simulate transient error from a proxy on the first request.
647+
// This must not be cached by the client.
648+
w.WriteHeader(http.StatusBadGateway)
649+
return
650+
}
651+
w.Header().Set("X-Elastic-Product", "Elasticsearch")
652+
w.Write([]byte("{}"))
653+
}))
654+
defer server.Close()
655+
656+
c, _ := NewClient(Config{Addresses: []string{server.URL}, DisableRetry: true})
657+
if _, err := c.Cat.Indices(); err == nil {
658+
t.Fatal("expected error")
659+
}
660+
if _, err := c.Cat.Indices(); err != nil {
661+
t.Fatalf("unexpected error: %s", err)
662+
}
663+
if n := len(requestPaths); n != 3 {
664+
t.Fatalf("expected 3 requests, got %d", n)
665+
}
666+
if !reflect.DeepEqual(requestPaths, []string{"/", "/", "/_cat/indices"}) {
667+
t.Fatalf("unexpected request paths: %s", requestPaths)
668+
}
669+
}

0 commit comments

Comments
 (0)