Skip to content

Commit 642055a

Browse files
committed
添加强制注册http接口功能
1 parent 7dd11b0 commit 642055a

File tree

8 files changed

+240
-126
lines changed

8 files changed

+240
-126
lines changed

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,6 @@ fmt:
2222
@for path in $(SOURCE_PATH); do echo "gofmt -s -l -w $$path"; gofmt -s -l -w $$path; done;
2323

2424
crab: lint
25-
@for path in $(SOURCE_PATH); do echo "go test ./$$path"; go test "./"$$path; done;
25+
@for path in $(SOURCE_PATH); do echo "go test ./$$path"; go test "./"$$path/...; done;
2626

2727

http/server/handler.go

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,6 @@ import (
1717
"github.com/dearcode/crab/log"
1818
)
1919

20-
var (
21-
server = newHTTPServer()
22-
)
23-
2420
type handlerRegexp struct {
2521
keys []string
2622
exp *regexp.Regexp
@@ -46,6 +42,19 @@ type httpServer struct {
4642
mu sync.RWMutex
4743
}
4844

45+
var (
46+
server = newHTTPServer()
47+
keysExp *regexp.Regexp
48+
)
49+
50+
func init() {
51+
exp, err := regexp.Compile(`{(\w+)?}`)
52+
if err != nil {
53+
panic(err.Error())
54+
}
55+
keysExp = exp
56+
}
57+
4958
func newHTTPServer() *httpServer {
5059
return &httpServer{
5160
path: make(map[string]handler),
@@ -83,11 +92,26 @@ func Register(obj interface{}) error {
8392
return register(obj, "", false)
8493
}
8594

95+
//RegisterMust 只要struct实现了Get(),Post(),Delete(),Put()接口就可以自动注册, 如果添加失败panic.
96+
func RegisterMust(obj interface{}) {
97+
if err := register(obj, "", false); err != nil {
98+
panic(err.Error())
99+
}
100+
101+
}
102+
86103
//RegisterPath 注册url完全匹配.
87104
func RegisterPath(obj interface{}, path string) error {
88105
return register(obj, path, false)
89106
}
90107

108+
//RegisterPathMust 注册url完全匹配,如果遇到错误panic.
109+
func RegisterPathMust(obj interface{}, path string) {
110+
if err := register(obj, path, false); err != nil {
111+
panic(err.Error())
112+
}
113+
}
114+
91115
//RegisterHandler 注册自定义url完全匹配.
92116
func RegisterHandler(call func(http.ResponseWriter, *http.Request), method, path string) error {
93117
h := handler{
@@ -114,16 +138,11 @@ func RegisterPrefix(obj interface{}, path string) error {
114138
return register(obj, path, true)
115139
}
116140

117-
var (
118-
keysExp *regexp.Regexp
119-
)
120-
121-
func init() {
122-
exp, err := regexp.Compile(`{(\w+)?}`)
123-
if err != nil {
141+
//RegisterPrefixMust 注册url前缀并保证成功.
142+
func RegisterPrefixMust(obj interface{}, path string) {
143+
if err := RegisterPrefix(obj, path); err != nil {
124144
panic(err.Error())
125145
}
126-
keysExp = exp
127146
}
128147

129148
func newHandlerRegexp(h handler) handlerRegexp {

orm/orm.go

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ import (
1111

1212
"github.com/dearcode/crab/log"
1313
"github.com/dearcode/crab/meta"
14-
"github.com/dearcode/crab/util"
14+
"github.com/dearcode/crab/util/str"
1515
)
1616

1717
//Stmt db stmt.
@@ -169,9 +169,9 @@ func (s *Stmt) SQLColumn(rt reflect.Type, table string) string {
169169
case reflect.Struct:
170170
if f.Tag.Get("db_table") == "one" {
171171
s.table += ","
172-
s.table += util.FieldEscape(f.Name)
173-
bs.WriteString(s.SQLColumn(f.Type, util.FieldEscape(f.Name)))
174-
field := util.FieldEscape(f.Name)
172+
s.table += str.FieldEscape(f.Name)
173+
bs.WriteString(s.SQLColumn(f.Type, str.FieldEscape(f.Name)))
174+
field := str.FieldEscape(f.Name)
175175
s.addWhere(fmt.Sprintf("%s.%s_id = %s.id", table, field, field))
176176
continue
177177
}
@@ -180,7 +180,7 @@ func (s *Stmt) SQLColumn(rt reflect.Type, table string) string {
180180
}
181181
name := f.Tag.Get("db")
182182
if name == "" {
183-
name = util.FieldEscape(f.Name)
183+
name = str.FieldEscape(f.Name)
184184
}
185185
if !strings.Contains(name, ".") {
186186
fmt.Fprintf(bs, "%s.", table)
@@ -216,16 +216,16 @@ func (s *Stmt) firstTable() string {
216216

217217
// addRelation 添加多表关联条件
218218
func (s *Stmt) addRelation(t1, t2 string, id interface{}) *Stmt {
219-
t1 = util.FieldEscape(t1)
220-
t2 = util.FieldEscape(t2)
219+
t1 = str.FieldEscape(t1)
220+
t2 = str.FieldEscape(t2)
221221
s.addWhere(fmt.Sprintf("id in (select %s_id from %s_%s_relation where %s_id=%d)", t1, t2, t1, t2, id))
222222
return s
223223
}
224224

225225
// addOne2More 添加一对多关联条件
226226
func (s *Stmt) addOne2More(t1, t2 string, id interface{}) *Stmt {
227-
t1 = util.FieldEscape(t1)
228-
t2 = util.FieldEscape(t2)
227+
t1 = str.FieldEscape(t1)
228+
t2 = str.FieldEscape(t2)
229229
s.addWhere(fmt.Sprintf("%s.%s_id=%d", t1, t2, id))
230230
return s
231231
}
@@ -325,14 +325,14 @@ func (s *Stmt) Query(result interface{}) error {
325325
switch f.Tag.Get("db_table") {
326326
case "more":
327327
//填充一对多结果,每次去查询
328-
if err = NewStmt(s.db, util.FieldEscape(f.Name)).addRelation(f.Name, s.firstTable(), id).Query(lr); err != nil {
328+
if err = NewStmt(s.db, str.FieldEscape(f.Name)).addRelation(f.Name, s.firstTable(), id).Query(lr); err != nil {
329329
if errors.Cause(err) != meta.ErrNotFound {
330330
return errors.Trace(err)
331331
}
332332
}
333333
case "one2more":
334334
//填充一对多结果,每次去查询
335-
if err = NewStmt(s.db, util.FieldEscape(f.Name)).addOne2More(f.Name, s.firstTable(), id).Query(lr); err != nil {
335+
if err = NewStmt(s.db, str.FieldEscape(f.Name)).addOne2More(f.Name, s.firstTable(), id).Query(lr); err != nil {
336336
if errors.Cause(err) != meta.ErrNotFound {
337337
return errors.Trace(err)
338338
}
@@ -398,7 +398,7 @@ func (s *Stmt) SQLInsert(rt reflect.Type, rv reflect.Value) (sql string, refs []
398398
}
399399
name := rt.Field(i).Tag.Get("db")
400400
if name == "" {
401-
name = util.FieldEscape(rt.Field(i).Name)
401+
name = str.FieldEscape(rt.Field(i).Name)
402402
}
403403

404404
bs.WriteString(name)
@@ -444,7 +444,7 @@ func (s *Stmt) SQLUpdate(rt reflect.Type, rv reflect.Value) (sql string, refs []
444444

445445
name := rt.Field(i).Tag.Get("db")
446446
if name == "" {
447-
name = util.FieldEscape(rt.Field(i).Name)
447+
name = str.FieldEscape(rt.Field(i).Name)
448448
}
449449

450450
fmt.Fprintf(bs, "`%s`=?, ", name)

project.ini

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
#按F9对应Make
2+
Make=make
3+
#按F8对应Test
4+
Test=make test
5+
#按,f对应格式化
6+
Format=go fmt
7+
#生成tags时跳过以下目录
8+
ExcludePath=docs,bin,Godeps,_vendor,vendor,logs
9+
#生成tags时跳过以下类型文件
10+
CtagsExcludeFile=*.cc,*.c,*.h,*.go
11+
#查找文件时跳过以下类型文件
12+
FindExcludeFile=*.cc,*.c,*.h,*_test.go
13+
#扩展tags目录
14+
#CtagsExtendFile=/usr/local/go/src

util/aes/ecb.go

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
package aes
2+
3+
import (
4+
"bytes"
5+
"crypto/aes"
6+
"crypto/cipher"
7+
"encoding/base64"
8+
9+
"github.com/juju/errors"
10+
)
11+
12+
//Decrypt 解密.
13+
func Decrypt(crypted string, key []byte) ([]byte, error) {
14+
raw, err := base64.StdEncoding.DecodeString(crypted)
15+
if err != nil {
16+
return nil, errors.Trace(err)
17+
}
18+
19+
block, err := aes.NewCipher(key)
20+
if err != nil {
21+
return nil, errors.Trace(err)
22+
}
23+
24+
blockMode := NewECBDecrypter(block)
25+
data := make([]byte, len(raw))
26+
blockMode.CryptBlocks(data, raw)
27+
data = pkcs5UnPadding(data)
28+
29+
return data, nil
30+
}
31+
32+
//Encrypt 加密.
33+
func Encrypt(src string, key []byte) (string, error) {
34+
block, err := aes.NewCipher(key)
35+
if err != nil {
36+
return "", err
37+
}
38+
39+
ecb := NewECBEncrypter(block)
40+
content := pkcs5Padding([]byte(src), block.BlockSize())
41+
buf := make([]byte, len(content))
42+
ecb.CryptBlocks(buf, content)
43+
44+
return base64.StdEncoding.EncodeToString(buf), nil
45+
}
46+
47+
func pkcs5Padding(ciphertext []byte, blockSize int) []byte {
48+
padding := blockSize - len(ciphertext)%blockSize
49+
padtext := bytes.Repeat([]byte{byte(padding)}, padding)
50+
return append(ciphertext, padtext...)
51+
}
52+
53+
func pkcs5UnPadding(origData []byte) []byte {
54+
length := len(origData)
55+
// 去掉最后一个字节 unpadding 次
56+
unpadding := int(origData[length-1])
57+
return origData[:(length - unpadding)]
58+
}
59+
60+
type ecb struct {
61+
b cipher.Block
62+
blockSize int
63+
}
64+
65+
func newECB(b cipher.Block) *ecb {
66+
return &ecb{
67+
b: b,
68+
blockSize: b.BlockSize(),
69+
}
70+
}
71+
72+
type ecbEncrypter ecb
73+
74+
// NewECBEncrypter returns a BlockMode which encrypts in electronic code book mode, using the given Block.
75+
func NewECBEncrypter(b cipher.Block) cipher.BlockMode {
76+
return (*ecbEncrypter)(newECB(b))
77+
}
78+
func (x *ecbEncrypter) BlockSize() int { return x.blockSize }
79+
func (x *ecbEncrypter) CryptBlocks(dst, src []byte) {
80+
if len(src)%x.blockSize != 0 {
81+
panic("crypto/cipher: input not full blocks")
82+
}
83+
if len(dst) < len(src) {
84+
panic("crypto/cipher: output smaller than input")
85+
}
86+
for len(src) > 0 {
87+
x.b.Encrypt(dst, src[:x.blockSize])
88+
src = src[x.blockSize:]
89+
dst = dst[x.blockSize:]
90+
}
91+
}
92+
93+
type ecbDecrypter ecb
94+
95+
// NewECBDecrypter returns a BlockMode which decrypts in electronic code book mode, using the given Block.
96+
func NewECBDecrypter(b cipher.Block) cipher.BlockMode {
97+
return (*ecbDecrypter)(newECB(b))
98+
}
99+
func (x *ecbDecrypter) BlockSize() int { return x.blockSize }
100+
func (x *ecbDecrypter) CryptBlocks(dst, src []byte) {
101+
if len(src)%x.blockSize != 0 {
102+
panic("crypto/cipher: input not full blocks")
103+
}
104+
if len(dst) < len(src) {
105+
panic("crypto/cipher: output smaller than input")
106+
}
107+
for len(src) > 0 {
108+
x.b.Decrypt(dst, src[:x.blockSize])
109+
src = src[x.blockSize:]
110+
dst = dst[x.blockSize:]
111+
}
112+
}

util/aes/ecb_test.go

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
package aes
2+
3+
import (
4+
"testing"
5+
)
6+
7+
func TestAes(t *testing.T) {
8+
src := "test"
9+
key := []byte("dbsjdcom\x00\x00\x00\x00\x00\x00\x00\x00")
10+
11+
crypted, err := Encrypt(src, key)
12+
if err != nil {
13+
t.Fatalf("AesEncrypt error:%s", err.Error())
14+
}
15+
deSrc, err := Decrypt(crypted, []byte(key))
16+
if err != nil {
17+
t.Fatalf("AesDecrypt error:%s", err.Error())
18+
}
19+
t.Logf("desc:%v", deSrc)
20+
}
21+
22+
func TestDesPanic(t *testing.T) {
23+
key := []byte("dbsjdcom\x00\x00\x00\x00\x00\x00\x00\x00")
24+
crypted := "fDmxIdK9p3oEyQoL1Bwz4Fakia3Y4Qn1SF8podapMFU="
25+
deSrc, err := Decrypt(crypted, []byte(key))
26+
if err != nil {
27+
t.Logf("err:%v", err)
28+
}
29+
30+
t.Logf("%s", deSrc)
31+
32+
}

0 commit comments

Comments
 (0)