diff --git a/apisprout.go b/apisprout.go index 44fc633..dfcda34 100644 --- a/apisprout.go +++ b/apisprout.go @@ -597,6 +597,48 @@ var handler = func(rr *RefreshableRouter) http.Handler { }) } +// +func loadSwaggerFromUri(uri string) (data []byte, err error) { + if strings.HasPrefix(uri, "http") { + req, httpErr := http.NewRequest("GET", uri, nil) + if httpErr != nil { + err = httpErr + return + } + if customHeader := viper.GetString("header"); customHeader != "" { + header := strings.Split(customHeader, ":") + if len(header) != 2 { + err = errors.New("Header format is invalid") + } else { + req.Header.Add(strings.TrimSpace(header[0]), strings.TrimSpace(header[1])) + } + } + if err != nil { + return + } + + client := &http.Client{} + resp, httpErr := client.Do(req) + if httpErr != nil { + err = httpErr + return + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + err = fmt.Errorf("Server at %s reported %d status code", uri, resp.StatusCode) + return + } + data, err = ioutil.ReadAll(resp.Body) + if err != nil { + return + } + } else { + data, err = ioutil.ReadFile(uri) + } + + return data, err +} + // server loads an OpenAPI file and runs a mock server using the paths and // examples defined in the file. func server(cmd *cobra.Command, args []string) { @@ -611,65 +653,40 @@ func server(cmd *cobra.Command, args []string) { // Load either from an HTTP URL or from a local file depending on the passed // in value. - if strings.HasPrefix(uri, "http") { - req, err := http.NewRequest("GET", uri, nil) - if err != nil { - log.Fatal(err) - } - if customHeader := viper.GetString("header"); customHeader != "" { - header := strings.Split(customHeader, ":") - if len(header) != 2 { - log.Fatal("Header format is invalid.") - } - req.Header.Add(strings.TrimSpace(header[0]), strings.TrimSpace(header[1])) - } - client := &http.Client{} - resp, err := client.Do(req) - if err != nil { - log.Fatal(err) - } + data, err = loadSwaggerFromUri(uri) + if err != nil { + log.Fatal(err) + } - data, err = ioutil.ReadAll(resp.Body) - resp.Body.Close() - if err != nil { - log.Fatal(err) + if viper.GetBool("watch") { + if strings.HasPrefix(uri, "http") { + log.Fatal(errors.New("Watching a URL is not supported.")) } - if viper.GetBool("watch") { - log.Fatal("Watching a URL is not supported.") - } - } else { - data, err = ioutil.ReadFile(uri) + // Set up a new filesystem watcher and reload the router every time + // the file has changed on disk. + watcher, err := fsnotify.NewWatcher() if err != nil { log.Fatal(err) } - - if viper.GetBool("watch") { - // Set up a new filesystem watcher and reload the router every time - // the file has changed on disk. - watcher, err := fsnotify.NewWatcher() - if err != nil { - log.Fatal(err) - } - defer watcher.Close() - - go func() { - // Since waiting for events or errors is blocking, we do this in a - // goroutine. It loops forever here but will exit when the process - // is finished, e.g. when you `ctrl+c` to exit. - for { - select { - case event, ok := <-watcher.Events: - if !ok { - return - } - if event.Op&fsnotify.Write == fsnotify.Write { - fmt.Printf("🌙 Reloading %s\n", uri) - data, err = ioutil.ReadFile(uri) - if err != nil { - log.Fatal(err) - } - + defer watcher.Close() + + go func() { + // Since waiting for events or errors is blocking, we do this in a + // goroutine. It loops forever here but will exit when the process + // is finished, e.g. when you `ctrl+c` to exit. + for { + select { + case event, ok := <-watcher.Events: + if !ok { + return + } + if event.Op&fsnotify.Write == fsnotify.Write { + fmt.Printf("🌙 Reloading %s\n", uri) + data, err = loadSwaggerFromUri(uri) + if err != nil { + log.Printf("ERROR: %s", err) + } else { if s, r, err := load(uri, data); err == nil { swagger = s rr.Set(r) @@ -677,17 +694,17 @@ func server(cmd *cobra.Command, args []string) { log.Printf("ERROR: Unable to load OpenAPI document: %s", err) } } - case err, ok := <-watcher.Errors: - if !ok { - return - } - fmt.Println("error:", err) } + case err, ok := <-watcher.Errors: + if !ok { + return + } + log.Printf("ERROR: %s", err) } - }() + } + }() - watcher.Add(uri) - } + watcher.Add(uri) } swagger, router, err := load(uri, data) @@ -697,35 +714,25 @@ func server(cmd *cobra.Command, args []string) { rr.Set(router) - if strings.HasPrefix(uri, "http") { - http.HandleFunc("/__reload", func(w http.ResponseWriter, r *http.Request) { - resp, err := http.Get(uri) - if err != nil { - log.Printf("ERROR: %v", err) - w.WriteHeader(http.StatusBadRequest) - w.Write([]byte("error while reloading")) - return - } - - data, err = ioutil.ReadAll(resp.Body) - resp.Body.Close() - if err != nil { - log.Printf("ERROR: %v", err) - w.WriteHeader(http.StatusBadRequest) - w.Write([]byte("error while parsing")) - return - } - + http.HandleFunc("/__reload", func(w http.ResponseWriter, r *http.Request) { + log.Printf("🌙 Reloading %s\n", uri) + data, err = loadSwaggerFromUri(uri) + if err == nil { if s, r, err := load(uri, data); err == nil { swagger = s rr.Set(r) } - + } + if err == nil { + log.Printf("Reloaded from %s", uri) w.WriteHeader(200) w.Write([]byte("reloaded")) - log.Printf("Reloaded from %s", uri) - }) - } + } else { + log.Printf("ERROR: %s", err) + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte("error while reloading")) + } + }) // Add a health check route which returns 200 http.HandleFunc("/__health", func(w http.ResponseWriter, r *http.Request) {