Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion cmd/hz/app/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package app
import (
"errors"
"fmt"
"github.com/urfave/cli/v2"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

format,挪到下面

"os"
"path/filepath"
"strings"
Expand All @@ -30,7 +31,6 @@ import (
"github.com/cloudwego/hertz/cmd/hz/thrift"
"github.com/cloudwego/hertz/cmd/hz/util"
"github.com/cloudwego/hertz/cmd/hz/util/logs"
"github.com/urfave/cli/v2"
)

// global args. MUST fork it when use
Expand Down Expand Up @@ -368,6 +368,22 @@ func GenerateLayout(args *config.Argument) error {
NeedGoMod: args.NeedGoMod,
}

for _, p := range args.ProtobufPlugins {
pluginParams := strings.Split(p, ":")
if pluginParams[0] == "http-swagger" {
layout.EnableSwagger = true
layout.SwaggerOutDir = pluginParams[2]
}
}
for _, p := range args.ThriftPlugins {
pluginName, pluginPath, params := util.ParseThriftPluginString(p)
if pluginName == "http-swagger" {
layout.EnableSwagger = true
layout.SwaggerOutDir = params["OutputDir"]
}
pluginPath = filepath.Join(args.OutDir, pluginPath)
}

if args.CustomizeLayout == "" {
// generate by default
err := lg.GenerateByService(layout)
Expand Down
22 changes: 22 additions & 0 deletions cmd/hz/generator/layout.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ type Layout struct {
ModelDir string
HandlerDir string
RouterDir string
SwaggerOutDir string
EnableSwagger bool
}

// LayoutGenerator contains the information generated by generating the layout template
Expand Down Expand Up @@ -122,6 +124,20 @@ func (lg *LayoutGenerator) GenerateByService(service Layout) error {
}
}

if service.EnableSwagger {
defaultSwaggerDir := defaultSwaggerDir + sp + swaggerTplName
if tpl, exist := lg.tpls[defaultSwaggerDir]; exist {
delete(lg.tpls, defaultSwaggerDir)
newSwaggerFile := filepath.Clean(service.SwaggerOutDir + sp + "swagger.go")
lg.tpls[newSwaggerFile] = tpl
}
} else {
defaultSwaggerDir := defaultSwaggerDir + sp + swaggerTplName
if _, exist := lg.tpls[defaultSwaggerDir]; exist {
delete(lg.tpls, defaultSwaggerDir)
}
}

if !service.NeedGoMod {
gomodFile := "go.mod"
if _, exist := lg.tpls[gomodFile]; exist {
Expand Down Expand Up @@ -192,23 +208,29 @@ func serviceToLayoutData(service Layout) (map[string]interface{}, error) {
"UseApacheThrift": service.UseApacheThrift,
"HandlerPkg": handlerPkg,
"RouterPkg": routerPkg,
"EnableSwagger": service.EnableSwagger,
}, nil
}

// serviceToRouterData stores the registers function, router import path, handler import path
func serviceToRouterData(service Layout) (map[string]interface{}, error) {
routerDir := sp + defaultRouterDir
handlerDir := sp + defaultHandlerDir
swaggerDir := sp + defaultSwaggerDir
if len(service.RouterDir) != 0 {
routerDir = sp + service.RouterDir
}
if len(service.HandlerDir) != 0 {
handlerDir = sp + service.HandlerDir
}
if len(service.SwaggerOutDir) != 0 {
swaggerDir = filepath.Clean(sp + service.SwaggerOutDir)
}
return map[string]interface{}{
"Registers": []string{},
"RouterPkgPath": service.GoModule + util.PathToImport(routerDir, ""),
"HandlerPkgPath": service.GoModule + util.PathToImport(handlerDir, ""),
"SwaggerPkgPath": service.GoModule + util.PathToImport(swaggerDir, ""),
}, nil
}

Expand Down
47 changes: 47 additions & 0 deletions cmd/hz/generator/layout_tpl.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ const (
defaultConfDir = "conf"
defaultRouterDir = "biz" + sp + "router"
defaultClientDir = "biz" + sp + "client"
defaultSwaggerDir = "swagger"
)

const (
Expand Down Expand Up @@ -153,11 +154,18 @@ package main
import (
"github.com/cloudwego/hertz/pkg/app/server"
router "{{.RouterPkgPath}}"
{{- if .EnableSwagger}}
swagger "{{.SwaggerPkgPath}}"
{{- end}}
)

// register registers all routers.
func register(r *server.Hertz) {

{{- if .EnableSwagger}}
swagger.BindSwagger(r)
{{- end}}

router.GeneratedRegister(r)

customizedRegister(r)
Expand Down Expand Up @@ -216,5 +224,44 @@ BinaryName={{.ServiceName}}
echo "$CURDIR/bin/${BinaryName}"
exec $CURDIR/bin/${BinaryName}`,
},
{
Path: defaultSwaggerDir + sp + "swagger.go",
Body: `// Code generated by hertz generator.
package swagger

import (
"context"
_ "embed"
"github.com/cloudwego/hertz/pkg/app"
"github.com/cloudwego/hertz/pkg/app/server"
"github.com/hertz-contrib/cors"
"github.com/hertz-contrib/swagger"
swaggerFiles "github.com/swaggo/files"
)

//go:embed openapi.yaml
var openapiYAML []byte

func BindSwagger(h *server.Hertz) {
h.Use(cors.New(cors.Config{
AllowOrigins: []string{"*"},
AllowMethods: []string{"*"},
AllowHeaders: []string{"*"},
ExposeHeaders: []string{"Content-Length", "Authorization"},
AllowCredentials: true,
}))

h.GET("/swagger/*any", swagger.WrapHandler(
swaggerFiles.Handler,
swagger.URL("/openapi.yaml"),
))

h.GET("/openapi.yaml", func(c context.Context, ctx *app.RequestContext) {
ctx.Header("Content-Type", "application/x-yaml")
ctx.Write(openapiYAML)
})
}
`,
},
},
}
6 changes: 3 additions & 3 deletions cmd/hz/generator/package_tpl.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ var (
clientTplName = "client.go" // generate a default client for server
hertzClientTplName = "hertz_client.go" // underlying client for client command
idlClientName = "idl_client.go" // client of service for quick call

insertPointNew = "//INSERT_POINT: DO NOT DELETE THIS LINE!"
insertPointPatternNew = `//INSERT_POINT\: DO NOT DELETE THIS LINE\!`
swaggerTplName = "swagger.go"
insertPointNew = "//INSERT_POINT: DO NOT DELETE THIS LINE!"
insertPointPatternNew = `//INSERT_POINT\: DO NOT DELETE THIS LINE\!`
)

var templateNameSet = map[string]string{
Expand Down
37 changes: 37 additions & 0 deletions cmd/hz/util/data.go
Original file line number Diff line number Diff line change
Expand Up @@ -434,3 +434,40 @@ func ToGoFuncName(s string) string {
}
return string(ss)
}

// ParseThriftPluginString parses a string in the format "pluginName[:pluginPath][key=value]*".
func ParseThriftPluginString(input string) (pluginName, pluginPath string, params map[string]string) {
// Initialize the output variables
pluginName = ""
pluginPath = ""
params = make(map[string]string)

// Split the input on the first ':'
parts := strings.SplitN(input, ":", 2)
if len(parts) > 0 {
// Handle the plugin name and path
pluginPart := parts[0]
pluginNameAndPath := strings.SplitN(pluginPart, "=", 2)
pluginName = pluginNameAndPath[0]

if len(pluginNameAndPath) > 1 {
pluginPath = pluginNameAndPath[1]
}
}

// If there's a parameters part, parse it
if len(parts) > 1 {
paramsPart := parts[1]
paramPairs := strings.Split(paramsPart, ",")
for _, pair := range paramPairs {
keyValue := strings.SplitN(pair, "=", 2)
if len(keyValue) == 2 {
params[keyValue[0]] = keyValue[1]
} else if len(keyValue) == 1 {
params[keyValue[0]] = ""
}
}
}

return pluginName, pluginPath, params
}