Skip to content

Commit db9e9de

Browse files
committed
added GetModelStats
1 parent c1a1452 commit db9e9de

File tree

2 files changed

+135
-43
lines changed

2 files changed

+135
-43
lines changed

suggestionbox/suggestionbox_model.go

Lines changed: 98 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,49 @@ type ModelOptions struct {
8282
Skipgrams int `json:"skipgrams,omitempty"`
8383
}
8484

85+
// CreateModel creates the Model in Suggestionbox.
86+
// If no ID is set, one will be assigned in the return Model.
87+
func (c *Client) CreateModel(ctx context.Context, model Model) (Model, error) {
88+
u, err := url.Parse(c.addr + "/suggestionbox/models")
89+
if err != nil {
90+
return model, err
91+
}
92+
if !u.IsAbs() {
93+
return model, errors.New("box address must be absolute")
94+
}
95+
var buf bytes.Buffer
96+
if err := json.NewEncoder(&buf).Encode(model); err != nil {
97+
return model, errors.Wrap(err, "encoding request body")
98+
}
99+
req, err := http.NewRequest(http.MethodPost, u.String(), &buf)
100+
if err != nil {
101+
return model, err
102+
}
103+
req = req.WithContext(ctx)
104+
req.Header.Set("Accept", "application/json; charset=utf-8")
105+
req.Header.Set("Content-Type", "application/json; charset=utf-8")
106+
resp, err := c.HTTPClient.Do(req)
107+
if err != nil {
108+
return model, err
109+
}
110+
defer resp.Body.Close()
111+
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
112+
return model, errors.New(resp.Status)
113+
}
114+
var response struct {
115+
Success bool
116+
Error string
117+
Model
118+
}
119+
if err := json.NewDecoder(resp.Body).Decode(&response); err != nil {
120+
return model, errors.Wrap(err, "decoding response")
121+
}
122+
if !response.Success {
123+
return model, ErrSuggestionbox(response.Error)
124+
}
125+
return response.Model, nil
126+
}
127+
85128
// ListModels gets a Model by its ID.
86129
func (c *Client) ListModels(ctx context.Context) ([]Model, error) {
87130
u, err := url.Parse(c.addr + "/suggestionbox/models")
@@ -194,49 +237,6 @@ func (c *Client) DeleteModel(ctx context.Context, modelID string) error {
194237
return nil
195238
}
196239

197-
// CreateModel creates the Model in Suggestionbox.
198-
// If no ID is set, one will be assigned in the return Model.
199-
func (c *Client) CreateModel(ctx context.Context, model Model) (Model, error) {
200-
u, err := url.Parse(c.addr + "/suggestionbox/models")
201-
if err != nil {
202-
return model, err
203-
}
204-
if !u.IsAbs() {
205-
return model, errors.New("box address must be absolute")
206-
}
207-
var buf bytes.Buffer
208-
if err := json.NewEncoder(&buf).Encode(model); err != nil {
209-
return model, errors.Wrap(err, "encoding request body")
210-
}
211-
req, err := http.NewRequest(http.MethodPost, u.String(), &buf)
212-
if err != nil {
213-
return model, err
214-
}
215-
req = req.WithContext(ctx)
216-
req.Header.Set("Accept", "application/json; charset=utf-8")
217-
req.Header.Set("Content-Type", "application/json; charset=utf-8")
218-
resp, err := c.HTTPClient.Do(req)
219-
if err != nil {
220-
return model, err
221-
}
222-
defer resp.Body.Close()
223-
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
224-
return model, errors.New(resp.Status)
225-
}
226-
var response struct {
227-
Success bool
228-
Error string
229-
Model
230-
}
231-
if err := json.NewDecoder(resp.Body).Decode(&response); err != nil {
232-
return model, errors.Wrap(err, "decoding response")
233-
}
234-
if !response.Success {
235-
return model, ErrSuggestionbox(response.Error)
236-
}
237-
return response.Model, nil
238-
}
239-
240240
// FeatureNumber makes a numerical Feature.
241241
func FeatureNumber(key string, value float64) Feature {
242242
return Feature{
@@ -293,3 +293,58 @@ func FeatureImageBase64(key string, data string) Feature {
293293
Value: data,
294294
}
295295
}
296+
297+
// ModelStats are the statistics for a Model.
298+
type ModelStats struct {
299+
// Predictions is the number of predictions this model has made.
300+
Predictions int `json:"predictions"`
301+
// Rewards is the number of rewards the model has received.
302+
Rewards int `json:"rewards"`
303+
// RewardRatio is the ratio between Predictions and Rewards.
304+
RewardRatio float64 `json:"reward_ratio"`
305+
// Explores is the number of times the model has explored,
306+
// to learn new things.
307+
Explores int `json:"explores"`
308+
// Exploits is the number of times the model has exploited learning.
309+
Exploits int `json:"exploits"`
310+
// ExploreRatio is the ratio between exploring and exploiting.
311+
ExploreRatio float64 `json:"explore_ratio"`
312+
}
313+
314+
// GetModelStats gets the statistics for the specified model.
315+
func (c *Client) GetModelStats(ctx context.Context, modelID string) (ModelStats, error) {
316+
var stats ModelStats
317+
u, err := url.Parse(c.addr + "/" + path.Join("suggestionbox", "models", modelID, "stats"))
318+
if err != nil {
319+
return stats, err
320+
}
321+
if !u.IsAbs() {
322+
return stats, errors.New("box address must be absolute")
323+
}
324+
req, err := http.NewRequest(http.MethodGet, u.String(), nil)
325+
if err != nil {
326+
return stats, err
327+
}
328+
req = req.WithContext(ctx)
329+
req.Header.Set("Accept", "application/json; charset=utf-8")
330+
resp, err := c.HTTPClient.Do(req)
331+
if err != nil {
332+
return stats, err
333+
}
334+
defer resp.Body.Close()
335+
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
336+
return stats, errors.New(resp.Status)
337+
}
338+
var response struct {
339+
Success bool
340+
Error string
341+
ModelStats
342+
}
343+
if err := json.NewDecoder(resp.Body).Decode(&response); err != nil {
344+
return stats, errors.Wrap(err, "decoding response")
345+
}
346+
if !response.Success {
347+
return stats, ErrSuggestionbox(response.Error)
348+
}
349+
return response.ModelStats, nil
350+
}

suggestionbox/suggestionbox_model_test.go

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,3 +180,40 @@ func TestFeatureHelpers(t *testing.T) {
180180
is.Equal(f.Value, "pretendthisisimagedata")
181181

182182
}
183+
184+
func TestGetModelStats(t *testing.T) {
185+
is := is.New(t)
186+
var apiCalls int
187+
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
188+
apiCalls++
189+
is.Equal(r.Method, http.MethodGet)
190+
is.Equal(r.URL.Path, "/suggestionbox/models/model1/stats")
191+
is.Equal(r.Header.Get("Accept"), "application/json; charset=utf-8")
192+
stats := suggestionbox.ModelStats{
193+
Predictions: 1,
194+
Rewards: 2,
195+
RewardRatio: 3.3,
196+
Explores: 4,
197+
Exploits: 5,
198+
ExploreRatio: 6.6,
199+
}
200+
is.NoErr(json.NewEncoder(w).Encode(struct {
201+
suggestionbox.ModelStats
202+
Success bool `json:"success"`
203+
}{
204+
Success: true,
205+
ModelStats: stats,
206+
}))
207+
}))
208+
defer srv.Close()
209+
sb := suggestionbox.New(srv.URL)
210+
stats, err := sb.GetModelStats(context.Background(), "model1")
211+
is.NoErr(err)
212+
is.Equal(apiCalls, 1) // apiCalls
213+
is.Equal(stats.Predictions, 1)
214+
is.Equal(stats.Rewards, 2)
215+
is.Equal(stats.RewardRatio, 3.3)
216+
is.Equal(stats.Explores, 4)
217+
is.Equal(stats.Exploits, 5)
218+
is.Equal(stats.ExploreRatio, 6.6)
219+
}

0 commit comments

Comments
 (0)