diff --git a/generator/go_client.stoneg.py b/generator/go_client.stoneg.py index 7e35749..4c37325 100644 --- a/generator/go_client.stoneg.py +++ b/generator/go_client.stoneg.py @@ -34,6 +34,7 @@ def _generate_client(self, namespace): for route in namespace.routes: generate_doc(self, route) self.emit(self._generate_route_signature(namespace, route)) + self.emit(self._generate_route_signature_context(namespace, route)) self.emit() self.emit('type apiImpl dropbox.Context') @@ -44,7 +45,7 @@ def _generate_client(self, namespace): self.emit('ctx := apiImpl(dropbox.NewContext(c))') self.emit('return &ctx') - def _generate_route_signature(self, namespace, route): + def _generate_route_signature(self, namespace, route, name_suffix="", initial_args=None): req = fmt_type(route.arg_data_type, namespace) res = fmt_type(route.result_data_type, namespace, use_interface=True) fn = fmt_var(route.name) @@ -52,20 +53,29 @@ def _generate_route_signature(self, namespace, route): fn += 'V%d' % route.version style = route.attrs.get('style', 'rpc') - arg = '' if is_void_type(route.arg_data_type) else 'arg {req}' - ret = '(err error)' if is_void_type(route.result_data_type) else \ - '(res {res}, err error)' - signature = '{fn}(' + arg + ') ' + ret + args = [] + if initial_args: + args.extend(initial_args) + if not is_void_type(route.arg_data_type): + args.append('arg {req}') + if style == 'upload': + args.append('content io.Reader') + + rets = [] + if not is_void_type(route.result_data_type): + rets.append('res {res}') if style == 'download': - signature = '{fn}(' + arg + \ - ') (res {res}, content io.ReadCloser, err error)' - elif style == 'upload': - signature = '{fn}(' + arg + ', content io.Reader) ' + ret - if is_void_type(route.arg_data_type): - signature = '{fn}(content io.Reader) ' + ret + rets.append('content io.ReadCloser') + rets.append('err error') + + signature = '{fn}' + name_suffix + '(' + ", ".join(args) + ') (' + ", ".join(rets) + ')' return signature.format(fn=fn, req=req, res=res) + def _generate_route_signature_context(self, namespace, route): + return self._generate_route_signature(namespace, route, name_suffix="Context", initial_args=['ctx context.Context']) + + def _generate_route(self, namespace, route): out = self.emit @@ -73,6 +83,8 @@ def _generate_route(self, namespace, route): if route.version != 1: route_name += '_v%d' % route.version + route_style = route.attrs.get('style', '') + fn = fmt_var(route.name) if route.version != 1: fn += 'V%d' % route.version @@ -85,9 +97,9 @@ def _generate_route(self, namespace, route): out('EndpointError {err} `json:"error"`'.format(err=err)) out() - signature = 'func (dbx *apiImpl) ' + self._generate_route_signature( + signature_context = 'func (dbx *apiImpl) ' + self._generate_route_signature_context( namespace, route) - with self.block(signature): + with self.block(signature_context): if route.deprecated is not None: out('log.Printf("WARNING: API `%s` is deprecated")' % fn) if route.deprecated.by is not None: @@ -116,8 +128,8 @@ def _generate_route(self, namespace, route): out("var resp []byte") out("var respBody io.ReadCloser") - out("resp, respBody, err = (*dropbox.Context)(dbx).Execute(req, {body})".format( - body="content" if route.attrs.get('style', '') == 'upload' else "nil")) + out("resp, respBody, err = (*dropbox.Context)(dbx).Execute(ctx, req, {body})".format( + body="content" if route_style == 'upload' else "nil")) with self.block("if err != nil"): out("var appErr {fn}APIError".format(fn=fn)) out("err = {auth}ParseError(err, &appErr)".format( @@ -144,9 +156,20 @@ def _generate_route(self, namespace, route): else: out("_ = resp") - if route.attrs.get('style', 'rpc') == "download": + if route_style == "download": out("content = respBody") else: out("_ = respBody") out('return') out() + + signature = 'func (dbx *apiImpl) ' + self._generate_route_signature( + namespace, route) + with self.block(signature): + args = ["context.Background()"] + if not is_void_type(route.arg_data_type): + args.append('arg') + if route_style == "upload": + args.append('content') + out('return dbx.' + fn + 'Context(' + ", ".join(args) + ');') + out('') diff --git a/generator/go_rsrc/sdk.go b/generator/go_rsrc/sdk.go index 3de0415..e40d74c 100644 --- a/generator/go_rsrc/sdk.go +++ b/generator/go_rsrc/sdk.go @@ -167,9 +167,9 @@ type Request struct { ExtraHeaders map[string]string } -func (c *Context) Execute(req Request, body io.Reader) ([]byte, io.ReadCloser, error) { +func (c *Context) Execute(ctx context.Context, req Request, body io.Reader) ([]byte, io.ReadCloser, error) { url := c.URLGenerator(req.Host, req.Namespace, req.Route) - httpReq, err := http.NewRequest("POST", url, body) + httpReq, err := http.NewRequestWithContext(ctx, "POST", url, body) if err != nil { return nil, nil, err }