Skip to content

Commit 1d7c03e

Browse files
committed
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into lstm_bp
2 parents b50c33f + 03136f6 commit 1d7c03e

File tree

201 files changed

+7806
-1459
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

201 files changed

+7806
-1459
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,3 +28,4 @@ cmake_install.cmake
2828
paddle/.timestamp
2929
python/paddlepaddle.egg-info/
3030
paddle/pybind/pybind.h
31+
python/paddle/v2/framework/tests/tmp/*

go/cmd/pserver/pserver.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ func main() {
6767
cp, err = pserver.LoadCheckpoint(e, idx)
6868
if err != nil {
6969
if err == pserver.ErrCheckpointNotFound {
70-
log.Info("Could not find the pserver checkpoint.")
70+
log.Info("load checkpoint error", "error", err)
7171
} else {
7272
panic(err)
7373
}
@@ -99,7 +99,7 @@ func main() {
9999
candy.Must(err)
100100

101101
go func() {
102-
log.Info("starting pserver", log.Ctx{"port": *port})
102+
log.Info("serving pserver", log.Ctx{"port": *port})
103103
err = http.Serve(l, nil)
104104
candy.Must(err)
105105
}()

go/master/c/client.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,8 @@ func paddle_set_dataset(client C.paddle_master_client, path **C.char, size C.int
123123
}
124124
err := c.SetDataset(paths)
125125
if err != nil {
126-
log.Error("error set dataset", log.Ctx{"error": err})
126+
log.Error("error set dataset",
127+
log.Ctx{"error": err, "paths": paths})
127128
return C.PADDLE_MASTER_ERROR
128129
}
129130

go/master/client.go

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ func (c *Client) StartGetRecords(passID int) {
121121
}
122122

123123
func (c *Client) getRecords(passID int) {
124+
i := 0
124125
for {
125126
t, err := c.getTask(passID)
126127
if err != nil {
@@ -130,12 +131,20 @@ func (c *Client) getRecords(passID int) {
130131
c.ch <- record{nil, err}
131132
break
132133
}
133-
if err.Error() == ErrPassAfter.Error() {
134-
// wait util last pass finishes
135-
time.Sleep(time.Second * 3)
136-
continue
134+
135+
if i%60 == 0 {
136+
log.Debug("getTask of passID error.",
137+
log.Ctx{"error": err, "passID": passID})
138+
i = 0
137139
}
138-
log.Error("getTask error.", log.Ctx{"error": err})
140+
141+
// if err.Error() == ErrPassAfter.Error()
142+
// wait util last pass finishes
143+
// if other error such as network error
144+
// wait to reconnect or task time out
145+
time.Sleep(time.Second * 3)
146+
i += 3
147+
continue
139148
}
140149

141150
for _, chunk := range t.Chunks {

go/master/client_test.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ func TestNextRecord(t *testing.T) {
117117
if e != nil {
118118
panic(e)
119119
}
120+
120121
// test for n passes
121122
for pass := 0; pass < 10; pass++ {
122123
c.StartGetRecords(pass)

go/pserver/optimizer.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,15 @@ func newOptimizer(paramWithConfigs ParameterWithConfig, State []byte) *optimizer
7171
cstate = unsafe.Pointer(&s[0])
7272
}
7373

74+
var cptr (*C.uchar)
75+
if len(c) > 0 {
76+
cptr = (*C.uchar)(&c[0])
77+
} else {
78+
log.Error("empty config", "param name", paramWithConfigs.Param.Name)
79+
}
7480
o.config = c
7581
o.opt = C.paddle_create_optimizer(
76-
(*C.uchar)(&c[0]),
82+
cptr,
7783
C.int(len(c)),
7884
C.paddle_element_type(p.ElementType),
7985
cbuffer,

go/pserver/service.go

Lines changed: 31 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,11 @@ package pserver
1717
import (
1818
"bufio"
1919
"bytes"
20-
"crypto/md5"
2120
"encoding/gob"
22-
"encoding/hex"
2321
"encoding/json"
2422
"errors"
2523
"fmt"
24+
"hash/crc32"
2625
"io/ioutil"
2726
"os"
2827
"path"
@@ -40,7 +39,7 @@ type ElementType int
4039

4140
// ErrCheckpointNotFound indicates that the pserver checkpoint could
4241
// not be found.
43-
var ErrCheckpointNotFound = errors.New("checkpoint not found")
42+
var ErrCheckpointNotFound = errors.New("checkpoint not found in etcd")
4443

4544
// RPC error message.
4645
const (
@@ -76,7 +75,7 @@ type ParameterWithConfig struct {
7675
type checkpointMeta struct {
7776
UUID string `json:"uuid"`
7877
Path string `json:"path"`
79-
MD5 string `json:"md5"`
78+
CRC32 uint32 `json:"crc32"`
8079
Timestamp int64 `json:"timestamp"`
8180
}
8281

@@ -92,7 +91,7 @@ type Service struct {
9291
idx int
9392
checkpointInterval time.Duration
9493
checkpointPath string
95-
client *EtcdClient
94+
client KVStore
9695

9796
mu sync.Mutex
9897
optMap map[string]*optimizer
@@ -104,7 +103,12 @@ type parameterCheckpoint struct {
104103
State []byte
105104
}
106105

107-
func loadMeta(e *EtcdClient, idx int) (meta checkpointMeta, err error) {
106+
type KVStore interface {
107+
GetKey(key string, timeout time.Duration) ([]byte, error)
108+
PutKey(key string, value []byte, timeout time.Duration, withLease bool) error
109+
}
110+
111+
func loadMeta(e KVStore, idx int) (meta checkpointMeta, err error) {
108112
v, err := e.GetKey(PsCheckpoint+strconv.Itoa(idx), 3*time.Second)
109113
if err != nil {
110114
return
@@ -123,7 +127,7 @@ func loadMeta(e *EtcdClient, idx int) (meta checkpointMeta, err error) {
123127
}
124128

125129
// LoadCheckpoint loads checkpoint from file.
126-
func LoadCheckpoint(e *EtcdClient, idx int) (Checkpoint, error) {
130+
func LoadCheckpoint(e KVStore, idx int) (Checkpoint, error) {
127131
log.Info("Loading checkpoint", "pserver index", idx)
128132
defer traceTime(time.Now(), "load checkpoint")
129133

@@ -137,11 +141,8 @@ func LoadCheckpoint(e *EtcdClient, idx int) (Checkpoint, error) {
137141
return nil, err
138142
}
139143

140-
// TODO(helin): change MD5 to CRC since CRC is better for file
141-
// checksum in our use case (emphasize speed over security).
142-
h := md5.New()
143-
md5 := hex.EncodeToString(h.Sum(content))
144-
if md5 != cpMeta.MD5 {
144+
crc32 := crc32.ChecksumIEEE(content)
145+
if crc32 != cpMeta.CRC32 {
145146
return nil, errors.New(WrongChecksum)
146147
}
147148

@@ -150,12 +151,13 @@ func LoadCheckpoint(e *EtcdClient, idx int) (Checkpoint, error) {
150151
if err = dec.Decode(&cp); err != nil {
151152
return nil, err
152153
}
154+
153155
return cp, nil
154156
}
155157

156158
// NewService creates a new service, will bypass etcd registration if no
157159
// endpoints specified. It will recovery from checkpoint file if a exists a specified checkpoint.
158-
func NewService(idx int, interval time.Duration, path string, client *EtcdClient, cp Checkpoint) (*Service, error) {
160+
func NewService(idx int, interval time.Duration, path string, client KVStore, cp Checkpoint) (*Service, error) {
159161
s := &Service{
160162
idx: idx,
161163
checkpointInterval: interval,
@@ -173,6 +175,7 @@ func NewService(idx int, interval time.Duration, path string, client *EtcdClient
173175
}
174176
s.optMap[p.Param.Name] = newOptimizer(p, item.State)
175177
}
178+
close(s.initialized)
176179
}
177180
return s, nil
178181
}
@@ -221,7 +224,7 @@ func (s *Service) FinishInitParams(_ int, _ *int) error {
221224
for range t {
222225
err := s.checkpoint()
223226
if err != nil {
224-
log.Error("finish init params error", log.Ctx{"error": err})
227+
log.Error("checkpoint error", log.Ctx{"error": err})
225228
}
226229
}
227230
}()
@@ -274,6 +277,7 @@ func (s *Service) GetParam(name string, parameter *Parameter) error {
274277
parameter.Name = name
275278
parameter.ElementType = opt.elementType
276279
parameter.Content = opt.GetWeights()
280+
277281
log.Info("sending parameter to the trainer", "name", parameter.Name, "size", len(parameter.Content), "type", parameter.ElementType)
278282
return nil
279283
}
@@ -354,20 +358,29 @@ func (s *Service) checkpoint() (err error) {
354358

355359
oldMeta, err := loadMeta(s.client, s.idx)
356360
if err == ErrCheckpointNotFound {
357-
log.Info("Do not have existing checkpoint.")
361+
log.Info("old meta not found, skip removing old meta")
358362
err = nil
363+
} else if err == nil {
364+
log.Info("removing old meta")
365+
if oldMeta.Path != "" {
366+
rmErr := os.Remove(oldMeta.Path)
367+
if rmErr != nil {
368+
// log error, but still treat checkpoint as
369+
// successful.
370+
log.Error("remove old meta file error", log.Ctx{"error": rmErr})
371+
}
372+
}
359373
}
360374

361375
if err != nil {
362376
return
363377
}
364378

365-
h := md5.New()
366-
md5 := hex.EncodeToString(h.Sum(buf.Bytes()))
379+
crc32 := crc32.ChecksumIEEE(buf.Bytes())
367380
cpMeta := checkpointMeta{
368381
UUID: id,
369382
Timestamp: time.Now().UnixNano(),
370-
MD5: md5,
383+
CRC32: crc32,
371384
Path: p,
372385
}
373386

@@ -381,14 +394,5 @@ func (s *Service) checkpoint() (err error) {
381394
return
382395
}
383396

384-
if oldMeta.Path != "" {
385-
rmErr := os.Remove(oldMeta.Path)
386-
if rmErr != nil {
387-
// log error, but still treat checkpoint as
388-
// successful.
389-
log.Error("remove old meta file error", log.Ctx{"error": rmErr})
390-
}
391-
}
392-
393397
return
394398
}

go/pserver/service_internal_test.go

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
package pserver
2+
3+
import (
4+
"bytes"
5+
"encoding/binary"
6+
"fmt"
7+
"testing"
8+
"time"
9+
10+
"github.com/stretchr/testify/assert"
11+
)
12+
13+
const testDir = "./test_data"
14+
15+
type myKV struct {
16+
m map[string][]byte
17+
}
18+
19+
func (m *myKV) GetKey(key string, timeout time.Duration) ([]byte, error) {
20+
if m.m == nil {
21+
m.m = make(map[string][]byte)
22+
}
23+
return m.m[key], nil
24+
}
25+
26+
func (m *myKV) PutKey(key string, value []byte, timeout time.Duration, withLease bool) error {
27+
if m.m == nil {
28+
m.m = make(map[string][]byte)
29+
}
30+
m.m[key] = value
31+
return nil
32+
}
33+
34+
func TestCheckpoint(t *testing.T) {
35+
kv := &myKV{}
36+
s, err := NewService(0, time.Hour, testDir, kv, nil)
37+
assert.Nil(t, err)
38+
err = s.checkpoint()
39+
assert.Nil(t, err)
40+
_, err = LoadCheckpoint(kv, 0)
41+
assert.Nil(t, err)
42+
}
43+
44+
func float32ToByte(f float32) []byte {
45+
var buf bytes.Buffer
46+
err := binary.Write(&buf, binary.LittleEndian, f)
47+
if err != nil {
48+
fmt.Println("binary.Write failed:", err)
49+
}
50+
return buf.Bytes()
51+
}
52+
53+
func TestCheckpointWithData(t *testing.T) {
54+
kv := &myKV{}
55+
s, err := NewService(0, time.Hour, testDir, kv, nil)
56+
assert.Nil(t, err)
57+
58+
var content []byte
59+
for i := 0; i < 50000; i++ {
60+
content = append(content, float32ToByte(float32(i))...)
61+
}
62+
63+
p1 := Parameter{Name: "p1", ElementType: 1, Content: content}
64+
err = s.InitParam(ParameterWithConfig{Param: p1}, nil)
65+
assert.Nil(t, err)
66+
67+
err = s.FinishInitParams(0, nil)
68+
assert.Nil(t, err)
69+
70+
var p2 Parameter
71+
err = s.GetParam(p1.Name, &p2)
72+
assert.Nil(t, err)
73+
assert.Equal(t, p1, p2)
74+
75+
err = s.checkpoint()
76+
assert.Nil(t, err)
77+
cp, err := LoadCheckpoint(kv, 0)
78+
assert.Nil(t, err)
79+
s1, err := NewService(0, time.Hour, testDir, kv, cp)
80+
assert.Nil(t, err)
81+
82+
var p3 Parameter
83+
err = s1.GetParam(p1.Name, &p3)
84+
assert.Nil(t, err)
85+
assert.Equal(t, p1, p3)
86+
}

go/pserver/service_test.go

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,3 @@ func TestBlockUntilInitialized(t *testing.T) {
178178

179179
wg.Wait()
180180
}
181-
182-
func TestCheckpointSpeed(t *testing.T) {
183-
//TODO(zhihong): test speed
184-
}

paddle/framework/CMakeLists.txt

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ nv_test(lod_tensor_gpu_test SRCS lod_tensor_test.cu DEPS lod_tensor)
1515

1616
cc_test(variable_test SRCS variable_test.cc)
1717

18-
cc_library(scope SRCS scope.cc)
18+
cc_library(scope SRCS scope.cc DEPS glog)
1919
cc_test(scope_test SRCS scope_test.cc DEPS scope)
2020

2121

@@ -24,9 +24,10 @@ cc_test(program_desc_test SRCS program_desc_test.cc DEPS proto_desc)
2424
cc_library(op_proto_maker SRCS op_proto_maker.cc DEPS framework_proto attribute)
2525
cc_test(op_proto_maker_test SRCS op_proto_maker_test.cc DEPS op_proto_maker)
2626
cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto)
27-
cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope glog)
27+
cc_library(shape_inference SRCS shape_inference.cc DEPS ddim attribute)
28+
cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope glog shape_inference)
2829
cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry)
29-
cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS attribute ddim op_info operator)
30+
cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS shape_inference op_info operator glog)
3031

3132
cc_library(op_registry SRCS op_registry.cc DEPS op_proto_maker op_info operator glog proto_desc)
3233
cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry)
@@ -42,7 +43,7 @@ add_custom_command(TARGET framework_py_proto POST_BUILD
4243
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
4344

4445
cc_library(backward SRCS backward.cc DEPS net_op)
45-
cc_test(backward_test SRCS backward_test.cc DEPS backward recurrent_op device_context)
46+
cc_test(backward_test SRCS backward_test.cc DEPS backward recurrent_op device_context fill_constant_op)
4647

4748
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto backward glog)
4849

0 commit comments

Comments
 (0)