diff --git a/client/client.go b/client/client.go index 68d2252..763a2b6 100644 --- a/client/client.go +++ b/client/client.go @@ -197,6 +197,7 @@ func (c *Client) useInsecureHTTPClient(insecure bool) *http.Transport { } func (c *Client) MakeRestRequest(method string, path string, body *container.Container, authenticated bool) (*http.Request, error) { + origPath := path if c.platform == "nd" && path != "/login" { if strings.HasPrefix(path, "/") { path = path[1:] @@ -222,6 +223,10 @@ func (c *Client) MakeRestRequest(method string, path string, body *container.Con if err != nil { return nil, err } + if method == "PATCH" || method == "PUT" || method == "DELETE" || method == "POST" { + c.updateCacheForWrite(origPath) + log.Printf("[DEBUG] updating cache for write methods, endpoint %v", origPath) + } req.Header.Set("Content-Type", "application/json") log.Printf("[DEBUG] HTTP request %s %s", method, path) diff --git a/client/client_service.go b/client/client_service.go index 8659348..ffe3ef3 100644 --- a/client/client_service.go +++ b/client/client_service.go @@ -5,13 +5,112 @@ import ( "fmt" "log" "net/url" + "regexp" + "strings" + "sync" + "time" "github.com/ciscoecosystem/mso-go-client/container" "github.com/ciscoecosystem/mso-go-client/models" ) +const ( + CACHE_TIMEOUT = 60 // 60 seconds +) + +type msoApi struct { + readTs time.Time + resp *container.Container + writeTs time.Time +} + +var msoApiCache map[string]msoApi +var muApiCache sync.RWMutex // mutex lock for upating the map + +// init of the package +func init() { + msoApiCache = make(map[string]msoApi) +} + +// getFromCache: check the API cache and return the stored resp +// if it is with the timeout +func (c *Client) getFromCache(endpoint string) *container.Container { + defer muApiCache.RUnlock() + muApiCache.RLock() + updEndpoint := strings.Replace(endpoint, "mso/", "", 1) + if api, ok := msoApiCache[updEndpoint]; ok { + curTs := time.Now() + rDiff := curTs.Sub(api.readTs) + wDiff := curTs.Sub(api.writeTs) + log.Printf("[DEBUG] getFromCache readTs %v writeTs: %v rDiff %v wDiff %v\n", api.readTs, api.writeTs, rDiff.Seconds(), wDiff.Seconds()) + if rDiff.Seconds() >= CACHE_TIMEOUT || wDiff.Seconds() <= CACHE_TIMEOUT { + return nil + } + log.Printf("[DEBUG] Found GET response in cache for schema endpoint: %v\n", updEndpoint) + return api.resp + } + return nil +} + +// storeInCache: store the given response in the API cache +func (c *Client) storeInCache(endpoint string, resp *container.Container) { + updEndpoint := strings.Replace(endpoint, "mso/", "", 1) + var re = regexp.MustCompile(`^api/v1/schemas/(.*)$`) + matches := re.FindStringSubmatch(updEndpoint) + + if len(matches) != 2 { + return + } + + defer muApiCache.Unlock() + + muApiCache.Lock() + if api, ok := msoApiCache[updEndpoint]; ok { + curTs := time.Now() + wDiff := curTs.Sub(api.writeTs) + if wDiff.Seconds() <= CACHE_TIMEOUT { + log.Printf("[DEBUG] Skip storing endpoint %v due to recent writeTs: %v\n", updEndpoint, api.writeTs) + return + } + } + + api := msoApi{ + readTs: time.Now(), + resp: resp, + writeTs: time.Now().Add(-180 * time.Second), + } + + log.Printf("[DEBUG] Caching GET endpoint:: %s readTs %v writeTs %v", updEndpoint, api.readTs, api.writeTs) + msoApiCache[updEndpoint] = api +} + +// invalidateCache: invalidate the cache +func (c *Client) updateCacheForWrite(endpoint string) { + updEndpoint := strings.Replace(endpoint, "mso/", "", 1) + var re = regexp.MustCompile(`^api/v1/schemas/(.*)(\?)?`) + matches := re.FindStringSubmatch(updEndpoint) + if len(matches) != 2 && len(matches) != 3 { + return + } + + defer muApiCache.Unlock() + schEndPoint := "api/v1/schemas/" + matches[1] + muApiCache.Lock() + if api, ok := msoApiCache[schEndPoint]; ok { + api.writeTs = time.Now() + api.resp = nil + msoApiCache[schEndPoint] = api + log.Printf("[DEBUG] Update writeTs %v in cache for schema endpoint: %v\n", api.writeTs, schEndPoint) + } +} + func (c *Client) GetViaURL(endpoint string) (*container.Container, error) { + cobj := c.getFromCache(endpoint) + if cobj != nil { + c.storeInCache(endpoint, cobj) + return cobj, nil + } req, err := c.MakeRestRequest("GET", endpoint, nil, true) if err != nil { @@ -26,8 +125,15 @@ func (c *Client) GetViaURL(endpoint string) (*container.Container, error) { if obj == nil { return nil, errors.New("Empty response body") } - return obj, CheckForErrors(obj, "GET") + err = CheckForErrors(obj, "GET") + + if err != nil { + return obj, err + } + + c.storeInCache(endpoint, obj) + return obj, nil } func (c *Client) GetPlatform() string { diff --git a/go.mod b/go.mod index d16e33a..13f2f78 100644 --- a/go.mod +++ b/go.mod @@ -2,4 +2,7 @@ module github.com/ciscoecosystem/mso-go-client go 1.12 -require github.com/hashicorp/go-version v1.6.0 // indirect +require ( + github.com/hashicorp/go-version v1.6.0 // indirect + github.com/stretchr/testify v1.10.0 // indirect +) diff --git a/tests/client_test.go b/tests/client_test.go index b9cced6..100f371 100644 --- a/tests/client_test.go +++ b/tests/client_test.go @@ -1,9 +1,15 @@ package tests import ( + "encoding/json" + "fmt" + "sync" "testing" + "time" "github.com/ciscoecosystem/mso-go-client/client" + "github.com/ciscoecosystem/mso-go-client/container" + "github.com/stretchr/testify/assert" ) func TestClientAuthenticate(t *testing.T) { @@ -14,13 +20,202 @@ func TestClientAuthenticate(t *testing.T) { t.Error(err) } + fmt.Printf("err is %v", err) + if client.AuthToken.Token == "{}" { t.Error("Token is empty") } - t.Error("all wrong") + + fmt.Printf("Got Token %v", client.AuthToken.Token) } func GetTestClient() *client.Client { return client.GetClient("https://173.36.219.193", "admin", client.Password("ins3965!ins3965!"), client.Insecure(true)) +} + +func TestParallelGetSchemas(t *testing.T) { + cl := GetTestClient() + err := cl.Authenticate() + if err != nil { + t.Error(err) + } + schId := "6878807a072d2d88bec9b3b3" // Test_Schema + schUrl := "api/v1/schemas/" + schId + _, err = cl.GetViaURL(schUrl) + + assert := assert.New(t) + assert.Equal(err, nil) + + numRequests := 6 + resps := make(map[int]*container.Container) + errs := []error{} + + numObjs := 100 + numBatches := numObjs / numRequests + + fmt.Printf("Requesting %v objects in %v batches in %v requests per batch", numObjs, numBatches, numRequests) + + for b := 1; b <= numBatches; b++ { + wgReqs := sync.WaitGroup{} + // Create the workers + for w := 1; w <= numRequests; w++ { + wgReqs.Add(numRequests) + go func(reqN int) { + defer wgReqs.Done() + var err error + resps[reqN], err = cl.GetViaURL(schUrl) + fmt.Printf("Batch: %v Request: %v GetViaURL err = [%v]\n", b, reqN, err) + errs = append(errs, err) + }(w) + } + wgReqs.Wait() + // time.Sleep(2 * time.Second) + time.Sleep(200000000) // 2*10^8 nano seconds = 200 ms + } + assert.Equal(err, nil) + fmt.Printf("len(resps) = %v\n", len(resps)) +} + +func TestParallelGetSchemasMso(t *testing.T) { + cl := GetTestClient() + err := cl.Authenticate() + if err != nil { + t.Error(err) + } + schId := "6878807a072d2d88bec9b3b3" // for Test_Schema + schUrl := "mso/api/v1/schemas/" + schId + _, err = cl.GetViaURL(schUrl) + + assert := assert.New(t) + assert.Equal(err, nil) + + numRequests := 6 + resps := make(map[int]*container.Container) + errs := []error{} + + numObjs := 120 + numBatches := numObjs / numRequests + + fmt.Printf("Requesting %v objects in %v batches in %v requests per batch", numObjs, numBatches, numRequests) + + for b := 1; b <= numBatches; b++ { + wgReqs := sync.WaitGroup{} + // Create the workers + for w := 1; w <= numRequests; w++ { + wgReqs.Add(1) + go func(reqN int) { + defer wgReqs.Done() + var err error + resps[reqN], err = cl.GetViaURL(schUrl) + fmt.Printf("Batch: %v Request: %v GetViaURL err = [%v]\n", b, reqN, err) + errs = append(errs, err) + }(w) + } + wgReqs.Wait() + // time.Sleep(2 * time.Second) + time.Sleep(200000000) // 2*10^8 nano seconds = 200 ms + } + assert.Equal(err, nil) + fmt.Printf("len(resps) = %v\n", len(resps)) +} + +func TestParallelPatchSchemas(t *testing.T) { + cl := GetTestClient() + err := cl.Authenticate() + if err != nil { + t.Error(err) + } + + return + + schemaID := "6878807a072d2d88bec9b3b3" + schUrl := "api/v1/schemas/" + schemaID + + assert := assert.New(t) + + _, err = cl.GetViaURL(schUrl) + + numBatches := 3 + numRequests := 3 + for b := 0; b < numBatches; b++ { + wgReqs := sync.WaitGroup{} + // Create the workers + for w := 1; w <= numRequests; w++ { + wgReqs.Add(1) + bdNum := b*numRequests + w + go func(bdN int) { + defer wgReqs.Done() + var err error + bdName := fmt.Sprintf("BD%v", bdN) + desc := fmt.Sprintf("new descr %v", 300+bdN) + err = patchBDDescr(cl, schemaID, "Tmpl1", bdName, desc) + assert.Equal(err, nil) + _, err = cl.GetViaURL(schUrl) + fmt.Printf("Batch: %v Request: %v GetViaURL err = [%v]\n", b, w, err) + }(bdNum) + } + wgReqs.Wait() + // time.Sleep(2 * time.Second) + time.Sleep(200000000) // 2*10^8 nano seconds = 200 ms + } + assert.Equal(err, nil) +} + +func doPatchRequest(msoClient *client.Client, path string, payloadCon *container.Container) error { + req, err := msoClient.MakeRestRequest("PATCH", path, payloadCon, true) + if err != nil { + return err + } + + cont, _, err := msoClient.Do(req) + if err != nil { + return err + } + + err = client.CheckForErrors(cont, "PATCH") + if err != nil { + return err + } + + return nil +} + +func addPatchPayloadToContainer(payloadContainer *container.Container, op, path string, value interface{}) error { + + payloadMap := map[string]interface{}{"op": op, "path": path, "value": value} + + payload, err := json.Marshal(payloadMap) + if err != nil { + return err + } + + jsonContainer, err := container.ParseJSON([]byte(payload)) + if err != nil { + return err + } + + err = payloadContainer.ArrayAppend(jsonContainer.Data()) + if err != nil { + return err + } + + return nil +} + +func patchBDDescr(cl *client.Client, schemaID string, templateName string, bdName string, desc string) error { + basePath := fmt.Sprintf("/templates/%s/bds/%s", templateName, bdName) + payloadCon := container.New() + payloadCon.Array() + + err := addPatchPayloadToContainer(payloadCon, "replace", fmt.Sprintf("%s/description", basePath), desc) + if err != nil { + return err + } + + err = doPatchRequest(cl, fmt.Sprintf("api/v1/schemas/%s", schemaID), payloadCon) + if err != nil { + return err + } + return nil }