Skip to content
This repository was archived by the owner on Jun 17, 2025. It is now read-only.
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion cmd/static/model_flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ func modelFlags() []cli.Flag {
}},
&cli.StringFlag{Name: consts.DBType, Usage: "Specify database type. (mysql or sqlserver or sqlite or postgres)", Value: string(consts.MySQL), DefaultText: string(consts.MySQL), Action: func(context *cli.Context, s string) error {
if _, ok := config.OpenTypeFuncMap[consts.DataBaseType(strings.ToLower(s))]; !ok {
return fmt.Errorf("unknow db type %s (support mysql || postgres || sqlite || sqlserver for now)", s)
return fmt.Errorf("unknown db type %s (support mysql || postgres || sqlite || sqlserver for now)", s)
}
return nil
}},
Expand All @@ -51,5 +51,15 @@ func modelFlags() []cli.Flag {
&cli.BoolFlag{Name: consts.TypeTag, Usage: "Specify generate field with gorm column type tag", Value: false, DefaultText: "false"},
&cli.BoolFlag{Name: consts.IndexTag, Usage: "Specify generate field with gorm index tag", Value: false, DefaultText: "false"},
&cli.StringFlag{Name: consts.SQLDir, Usage: "Specify a sql file or directory", Value: "", DefaultText: ""},
&cli.StringFlag{Name: consts.Mode, Usage: modeUsage, Value: "noctx,defaultquery,queryinterface", DefaultText: "noctx,defaultquery,queryinterface"},
}
}

const modeUsage = `Specify gorm/gen generator mode
(https://gorm.io/gen/dao.html#Generator-Modes).
There is no need to follow strict case, we offer some abbreviations

- WithoutContext (noctx)
- WithDefaultQuery (defaultquery)
- WithQueryInterface (queryinterface)
`
2 changes: 2 additions & 0 deletions config/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ type ModelArgument struct {
FieldWithIndexTag bool
FieldWithTypeTag bool
SQLDir string
Mode string
}

func NewModelArgument() *ModelArgument {
Expand All @@ -62,5 +63,6 @@ func (c *ModelArgument) ParseCli(ctx *cli.Context) error {
c.FieldWithIndexTag = ctx.Bool(consts.IndexTag)
c.FieldWithTypeTag = ctx.Bool(consts.TypeTag)
c.SQLDir = ctx.String(consts.SQLDir)
c.Mode = ctx.String(consts.Mode)
return nil
}
1 change: 1 addition & 0 deletions pkg/consts/const.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ const (
TypeTag = "type_tag"
HexTag = "hex"
SQLDir = "sql_dir"
Mode = "mode"
)

const (
Expand Down
107 changes: 107 additions & 0 deletions pkg/model/generator_tpl.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
/*
* Copyright 2022 CloudWeGo Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package model

import (
"bytes"
_ "embed"
"fmt"
"go/format"
"path/filepath"
"text/template"

"github.com/cloudwego/cwgo/config"
"github.com/cloudwego/cwgo/pkg/consts"

"gorm.io/gen"
)

//go:embed model_gen.tpl
var mergedTemplate string

type GenMethodTmpl struct {
GormOpen string
gen.Config
OnlyModel bool
// UseRawSQL indicates whether to use raw SQL for the database connection
UseRawSQL bool
Tables []string
StrategyParams struct {
ExcludeTables []string
Type string
}
}

func execTmpl(c *config.ModelArgument) ([]byte, error) {
tpl, err := template.New("merged").Parse(mergedTemplate)
if err != nil {
return nil, fmt.Errorf("parse template fail: %w", err)
}

absOutPath, _ := filepath.Abs(c.OutPath)
data := GenMethodTmpl{
GormOpen: buildGormOpen(c),
UseRawSQL: c.DSN == "" && c.SQLDir != "",
Tables: c.Tables,
OnlyModel: c.OnlyModel,
Config: gen.Config{
OutPath: absOutPath,
OutFile: c.OutFile,
ModelPkgPath: c.ModelPkgName,
WithUnitTest: c.WithUnitTest,
FieldNullable: c.FieldNullable,
FieldSignable: c.FieldSignable,
FieldWithIndexTag: c.FieldWithIndexTag,
FieldWithTypeTag: c.FieldWithTypeTag,
Mode: buildGenMode(c.Mode),
},
StrategyParams: struct {
ExcludeTables []string
Type string
}{
ExcludeTables: c.ExcludeTables,
Type: c.Type,
},
}

var buf bytes.Buffer
err = tpl.Execute(&buf, data)
if err != nil {
return nil, fmt.Errorf("execute template fail: %w", err)
}
fmtCode, err := format.Source(buf.Bytes())
return fmtCode, err
}

func buildGormOpen(c *config.ModelArgument) string {
abs, _ := filepath.Abs(c.SQLDir)
switch {
case c.SQLDir != "":
return fmt.Sprintf(
"db, err := gorm.Open(rawsql.New(rawsql.Config{FilePath: []string{%q}}))",
abs,
)
case c.DSN != "" && c.Type != "":
return fmt.Sprintf(
"db, err := gorm.Open(%s.Open(%q))",
consts.DataBaseType(c.Type),
c.DSN,
)
default:
return ""
}
}
39 changes: 37 additions & 2 deletions pkg/model/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ package model

import (
"fmt"
"os"
"path/filepath"
"regexp"
"strings"

"gorm.io/rawsql"
Expand Down Expand Up @@ -57,6 +60,7 @@ func Model(c *config.ModelArgument) error {
FieldSignable: c.FieldSignable,
FieldWithIndexTag: c.FieldWithIndexTag,
FieldWithTypeTag: c.FieldWithTypeTag,
Mode: buildGenMode(c.Mode),
}

if len(c.ExcludeTables) > 0 || c.Type == string(consts.Sqlite) {
Expand Down Expand Up @@ -87,9 +91,22 @@ func Model(c *config.ModelArgument) error {
if !c.OnlyModel {
g.ApplyBasic(models...)
}

g.Execute()
return nil

// generate gen.go to update dal code
getwd, _ := os.Getwd()
outPath := filepath.Join(getwd, c.OutPath)
genMainFileRootDir := filepath.Dir(outPath)
buf, err := execTmpl(c)
if err != nil {
return fmt.Errorf("exec template fail: %w", err)
}

if err := os.MkdirAll(filepath.Join(genMainFileRootDir, "gen_exec"), 0755); err != nil {
return fmt.Errorf("failed to create directory: %w", err)
}

return os.WriteFile(filepath.Join(genMainFileRootDir, "gen_exec", "main.go"), buf, 0o644)
}

func genModels(g *gen.Generator, db *gorm.DB, tables []string) (models []interface{}, err error) {
Expand All @@ -109,3 +126,21 @@ func genModels(g *gen.Generator, db *gorm.DB, tables []string) (models []interfa
}
return models, nil
}

func buildGenMode(mode string) gen.GenerateMode {
var generateMode gen.GenerateMode
ss := regexp.MustCompile(`[,;\s]+`).Split(strings.ToLower(mode), -1)
for _, s := range ss {
s = strings.TrimSpace(s)
switch s {
case "withoutcontext", "noctx":
generateMode |= gen.WithoutContext
case "withdefaultquery", "defaultquery":
generateMode |= gen.WithDefaultQuery
case "withqueryinterface", "queryinterface":
generateMode |= gen.WithQueryInterface
default:
}
}
return generateMode
}
88 changes: 88 additions & 0 deletions pkg/model/model_gen.tpl
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
{{ define "GORM_OPEN" }}
{{ .GormOpen }}
if err != nil {
log.Fatal("open db fail: ", err)
}
{{ end }}

{{ define "GEN_CONFIG" }}
genConfig := gen.Config{
OutPath: {{ printf "%q" .OutPath }},
OutFile: {{ printf "%q" .OutFile }},
ModelPkgPath: {{ printf "%q" .ModelPkgPath }},
WithUnitTest: {{ .WithUnitTest }},
FieldNullable: {{ .FieldNullable }},
FieldSignable: {{ .FieldSignable }},
FieldWithIndexTag: {{ .FieldWithIndexTag }},
FieldWithTypeTag: {{ .FieldWithTypeTag }},
Mode: {{ .Mode }},
}
{{ end }}

{{ define "TABLE_STRATEGY" }}
{{if or (gt (len .ExcludeTables) 0) (eq .Type "sqlite") }}
genConfig.WithTableNameStrategy(func(tableName string) string {
{{if eq .Type "sqlite" -}}
if strings.HasPrefix(tableName, "sqlite") {
return ""
}
{{end -}}
{{if gt (len .ExcludeTables) 0 -}}
switch tableName {
{{range $table := .ExcludeTables -}}
case "{{ $table }}":
return ""
{{end -}}
}
{{end -}}
return tableName
})
{{end}}
{{ end }}

package main

import (
"fmt"
"log"

"gorm.io/gorm"
"gorm.io/gen"
{{if .UseRawSQL }} "gorm.io/rawsql" {{end}}
)

func main() {
{{ template "GORM_OPEN" . }}
{{ template "GEN_CONFIG" .Config }}
{{ template "TABLE_STRATEGY" .StrategyParams }}

g := gen.NewGenerator(genConfig)
g.UseDB(db)
models, err := genModels(g, db, []string{ {{- range $index, $element := .Tables }}
{{- if $index }}, {{ end }}"{{$element}}"{{- end }} })
if err != nil {
log.Fatal("gen models fail: ", err)
}
{{if not .OnlyModel}}
g.ApplyBasic(models...)
{{end}}
g.Execute()
}

func genModels(g *gen.Generator, db *gorm.DB, tables []string) (models []interface{}, err error) {
var tablesNameList []string
if len(tables) == 0 {
tablesNameList, err = db.Migrator().GetTables()
if err != nil {
return nil, fmt.Errorf("migrator get all tables fail: %w", err)
}
} else {
tablesNameList = tables
}

models = make([]interface{}, len(tablesNameList))
for i, tableName := range tablesNameList {
models[i] = g.GenerateModel(tableName)
}
return models, nil
}
Loading