Skip to content

Commit db11a30

Browse files
Merge pull request #51 from puppetlabs/upstream/match-vendor-media-types
Match any media type that seems like JSON or YAML
2 parents 79bc0ab + 7f1ead3 commit db11a30

File tree

2 files changed

+259
-141
lines changed

2 files changed

+259
-141
lines changed

apisprout.go

Lines changed: 167 additions & 141 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,27 @@ var (
4343
ErrMissingAuth = errors.New("Missing auth")
4444
)
4545

46+
var (
47+
marshalJSONMatcher = regexp.MustCompile(`^application/(vnd\..+\+)?json$`)
48+
marshalYAMLMatcher = regexp.MustCompile(`^(application|text)/(x-|vnd\..+\+)?yaml$`)
49+
)
50+
51+
type RefreshableRouter struct {
52+
router *openapi3filter.Router
53+
}
54+
55+
func (rr *RefreshableRouter) Set(router *openapi3filter.Router) {
56+
rr.router = router
57+
}
58+
59+
func (rr *RefreshableRouter) Get() *openapi3filter.Router {
60+
return rr.router
61+
}
62+
63+
func NewRefreshableRouter() *RefreshableRouter {
64+
return &RefreshableRouter{}
65+
}
66+
4667
// ContentNegotiator is used to match a media type during content negotiation
4768
// of HTTP requests.
4869
type ContentNegotiator struct {
@@ -393,142 +414,8 @@ func mapContainsKey(dict map[string]string, key string) bool {
393414
return false
394415
}
395416

396-
// server loads an OpenAPI file and runs a mock server using the paths and
397-
// examples defined in the file.
398-
func server(cmd *cobra.Command, args []string) {
399-
var swagger *openapi3.Swagger
400-
var router *openapi3filter.Router
401-
402-
uri := args[0]
403-
404-
var err error
405-
var data []byte
406-
407-
// Load either from an HTTP URL or from a local file depending on the passed
408-
// in value.
409-
if strings.HasPrefix(uri, "http") {
410-
req, err := http.NewRequest("GET", uri, nil)
411-
if err != nil {
412-
log.Fatal(err)
413-
}
414-
if customHeader := viper.GetString("header"); customHeader != "" {
415-
header := strings.Split(customHeader, ":")
416-
if len(header) != 2 {
417-
log.Fatal("Header format is invalid.")
418-
}
419-
req.Header.Add(strings.TrimSpace(header[0]), strings.TrimSpace(header[1]))
420-
}
421-
client := &http.Client{}
422-
resp, err := client.Do(req)
423-
if err != nil {
424-
log.Fatal(err)
425-
}
426-
427-
data, err = ioutil.ReadAll(resp.Body)
428-
resp.Body.Close()
429-
if err != nil {
430-
log.Fatal(err)
431-
}
432-
433-
if viper.GetBool("watch") {
434-
log.Fatal("Watching a URL is not supported.")
435-
}
436-
} else {
437-
data, err = ioutil.ReadFile(uri)
438-
if err != nil {
439-
log.Fatal(err)
440-
}
441-
442-
if viper.GetBool("watch") {
443-
// Set up a new filesystem watcher and reload the router every time
444-
// the file has changed on disk.
445-
watcher, err := fsnotify.NewWatcher()
446-
if err != nil {
447-
log.Fatal(err)
448-
}
449-
defer watcher.Close()
450-
451-
go func() {
452-
// Since waiting for events or errors is blocking, we do this in a
453-
// goroutine. It loops forever here but will exit when the process
454-
// is finished, e.g. when you `ctrl+c` to exit.
455-
for {
456-
select {
457-
case event, ok := <-watcher.Events:
458-
if !ok {
459-
return
460-
}
461-
if event.Op&fsnotify.Write == fsnotify.Write {
462-
fmt.Printf("🌙 Reloading %s\n", uri)
463-
data, err = ioutil.ReadFile(uri)
464-
if err != nil {
465-
log.Fatal(err)
466-
}
467-
468-
if s, r, err := load(uri, data); err == nil {
469-
swagger = s
470-
router = r
471-
} else {
472-
log.Printf("ERROR: Unable to load OpenAPI document: %s", err)
473-
}
474-
}
475-
case err, ok := <-watcher.Errors:
476-
if !ok {
477-
return
478-
}
479-
fmt.Println("error:", err)
480-
}
481-
}
482-
}()
483-
484-
watcher.Add(uri)
485-
}
486-
}
487-
488-
swagger, router, err = load(uri, data)
489-
if err != nil {
490-
log.Fatal(err)
491-
}
492-
493-
if strings.HasPrefix(uri, "http") {
494-
http.HandleFunc("/__reload", func(w http.ResponseWriter, r *http.Request) {
495-
resp, err := http.Get(uri)
496-
if err != nil {
497-
log.Printf("ERROR: %v", err)
498-
w.WriteHeader(http.StatusBadRequest)
499-
w.Write([]byte("error while reloading"))
500-
return
501-
}
502-
503-
data, err = ioutil.ReadAll(resp.Body)
504-
resp.Body.Close()
505-
if err != nil {
506-
log.Printf("ERROR: %v", err)
507-
w.WriteHeader(http.StatusBadRequest)
508-
w.Write([]byte("error while parsing"))
509-
return
510-
}
511-
512-
if s, r, err := load(uri, data); err == nil {
513-
swagger = s
514-
router = r
515-
}
516-
517-
w.WriteHeader(200)
518-
w.Write([]byte("reloaded"))
519-
log.Printf("Reloaded from %s", uri)
520-
})
521-
}
522-
523-
// Add a health check route which returns 200
524-
http.HandleFunc("/__health", func(w http.ResponseWriter, r *http.Request) {
525-
w.WriteHeader(200)
526-
log.Printf("Health check")
527-
})
528-
529-
// Register our custom HTTP handler that will use the router to find
530-
// the appropriate OpenAPI operation and try to return an example.
531-
http.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) {
417+
var handler = func(rr *RefreshableRouter) http.Handler {
418+
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
532419
if !viper.GetBool("disable-cors") {
533420
corsOrigin := req.Header.Get("Origin")
534421
if corsOrigin == "" {
@@ -582,7 +469,7 @@ func server(cmd *cobra.Command, args []string) {
582469
info = fmt.Sprintf("%s %v", req.Method, req.URL)
583470
}
584471

585-
route, pathParams, err := router.FindRoute(req.Method, req.URL)
472+
route, pathParams, err := rr.Get().FindRoute(req.Method, req.URL)
586473
if err != nil {
587474
log.Printf("ERROR: %s => %v", info, err)
588475
w.WriteHeader(http.StatusNotFound)
@@ -647,12 +534,11 @@ func server(cmd *cobra.Command, args []string) {
647534
} else if _, ok := example.([]byte); ok {
648535
encoded = example.([]byte)
649536
} else {
650-
switch mediatype {
651-
case "application/json", "application/vnd.api+json":
537+
if marshalJSONMatcher.MatchString(mediatype) {
652538
encoded, err = json.MarshalIndent(example, "", " ")
653-
case "application/x-yaml", "application/yaml", "text/x-yaml", "text/yaml", "text/vnd.yaml":
539+
} else if marshalYAMLMatcher.MatchString(mediatype) {
654540
encoded, err = yaml.Marshal(example)
655-
default:
541+
} else {
656542
log.Printf("Cannot marshal as '%s'!", mediatype)
657543
err = ErrCannotMarshal
658544
}
@@ -689,6 +575,146 @@ func server(cmd *cobra.Command, args []string) {
689575
w.WriteHeader(status)
690576
w.Write(encoded)
691577
})
578+
}
579+
580+
// server loads an OpenAPI file and runs a mock server using the paths and
581+
// examples defined in the file.
582+
func server(cmd *cobra.Command, args []string) {
583+
var swagger *openapi3.Swagger
584+
rr := NewRefreshableRouter()
585+
586+
uri := args[0]
587+
588+
var err error
589+
var data []byte
590+
591+
// Load either from an HTTP URL or from a local file depending on the passed
592+
// in value.
593+
if strings.HasPrefix(uri, "http") {
594+
req, err := http.NewRequest("GET", uri, nil)
595+
if err != nil {
596+
log.Fatal(err)
597+
}
598+
if customHeader := viper.GetString("header"); customHeader != "" {
599+
header := strings.Split(customHeader, ":")
600+
if len(header) != 2 {
601+
log.Fatal("Header format is invalid.")
602+
}
603+
req.Header.Add(strings.TrimSpace(header[0]), strings.TrimSpace(header[1]))
604+
}
605+
client := &http.Client{}
606+
resp, err := client.Do(req)
607+
if err != nil {
608+
log.Fatal(err)
609+
}
610+
611+
data, err = ioutil.ReadAll(resp.Body)
612+
resp.Body.Close()
613+
if err != nil {
614+
log.Fatal(err)
615+
}
616+
617+
if viper.GetBool("watch") {
618+
log.Fatal("Watching a URL is not supported.")
619+
}
620+
} else {
621+
data, err = ioutil.ReadFile(uri)
622+
if err != nil {
623+
log.Fatal(err)
624+
}
625+
626+
if viper.GetBool("watch") {
627+
// Set up a new filesystem watcher and reload the router every time
628+
// the file has changed on disk.
629+
watcher, err := fsnotify.NewWatcher()
630+
if err != nil {
631+
log.Fatal(err)
632+
}
633+
defer watcher.Close()
634+
635+
go func() {
636+
// Since waiting for events or errors is blocking, we do this in a
637+
// goroutine. It loops forever here but will exit when the process
638+
// is finished, e.g. when you `ctrl+c` to exit.
639+
for {
640+
select {
641+
case event, ok := <-watcher.Events:
642+
if !ok {
643+
return
644+
}
645+
if event.Op&fsnotify.Write == fsnotify.Write {
646+
fmt.Printf("🌙 Reloading %s\n", uri)
647+
data, err = ioutil.ReadFile(uri)
648+
if err != nil {
649+
log.Fatal(err)
650+
}
651+
652+
if s, r, err := load(uri, data); err == nil {
653+
swagger = s
654+
rr.Set(r)
655+
} else {
656+
log.Printf("ERROR: Unable to load OpenAPI document: %s", err)
657+
}
658+
}
659+
case err, ok := <-watcher.Errors:
660+
if !ok {
661+
return
662+
}
663+
fmt.Println("error:", err)
664+
}
665+
}
666+
}()
667+
668+
watcher.Add(uri)
669+
}
670+
}
671+
672+
swagger, router, err := load(uri, data)
673+
if err != nil {
674+
log.Fatal(err)
675+
}
676+
677+
rr.Set(router)
678+
679+
if strings.HasPrefix(uri, "http") {
680+
http.HandleFunc("/__reload", func(w http.ResponseWriter, r *http.Request) {
681+
resp, err := http.Get(uri)
682+
if err != nil {
683+
log.Printf("ERROR: %v", err)
684+
w.WriteHeader(http.StatusBadRequest)
685+
w.Write([]byte("error while reloading"))
686+
return
687+
}
688+
689+
data, err = ioutil.ReadAll(resp.Body)
690+
resp.Body.Close()
691+
if err != nil {
692+
log.Printf("ERROR: %v", err)
693+
w.WriteHeader(http.StatusBadRequest)
694+
w.Write([]byte("error while parsing"))
695+
return
696+
}
697+
698+
if s, r, err := load(uri, data); err == nil {
699+
swagger = s
700+
rr.Set(r)
701+
}
702+
703+
w.WriteHeader(200)
704+
w.Write([]byte("reloaded"))
705+
log.Printf("Reloaded from %s", uri)
706+
})
707+
}
708+
709+
// Add a health check route which returns 200
710+
http.HandleFunc("/__health", func(w http.ResponseWriter, r *http.Request) {
711+
w.WriteHeader(200)
712+
log.Printf("Health check")
713+
})
714+
715+
// Register our custom HTTP handler that will use the router to find
716+
// the appropriate OpenAPI operation and try to return an example.
717+
http.Handle("/", handler(rr))
692718

693719
format := "🌱 Sprouting %s on port %d"
694720
if viper.GetBool("https") {

0 commit comments

Comments
 (0)