Skip to content

Commit b54990e

Browse files
authored
Merge pull request #5053 from helinwang/serialization
Fix parameter server checkpoint serialization crash
2 parents dd0008d + f28b4d6 commit b54990e

17 files changed

+129
-42
lines changed

go/pserver/optimizer.go

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,21 +72,34 @@ func newOptimizer(paramWithConfigs ParameterWithConfig, State []byte) *optimizer
7272
}
7373

7474
o.config = c
75-
o.opt = C.paddle_create_optimizer((*C.uchar)(&c[0]), C.int(len(c)),
76-
C.paddle_element_type(p.ElementType), cbuffer, C.int(paramBufferSize), (*C.char)(cstate), C.int(len(s)))
75+
o.opt = C.paddle_create_optimizer(
76+
(*C.uchar)(&c[0]),
77+
C.int(len(c)),
78+
C.paddle_element_type(p.ElementType),
79+
cbuffer,
80+
C.int(paramBufferSize),
81+
(*C.char)(cstate),
82+
C.int(len(s)),
83+
)
7784
return o
7885
}
7986

8087
func (o *optimizer) GetWeights() []byte {
8188
var buffer unsafe.Pointer
89+
// we do not own the buffer, no need to free later.
8290
bufferLen := C.paddle_optimizer_get_weights(o.opt, &buffer)
8391
return cArrayToSlice(buffer, int(bufferLen)*C.sizeof_float)
8492
}
8593

8694
func (o *optimizer) GetStates() []byte {
8795
var cbuffer *C.char
96+
// we owns the state buffer, need to free later.
8897
cbufferLen := C.paddle_optimizer_get_state(o.opt, &cbuffer)
89-
return cArrayToSlice(unsafe.Pointer(cbuffer), int(cbufferLen))
98+
buf := cArrayToSlice(unsafe.Pointer(cbuffer), int(cbufferLen))
99+
cpy := make([]byte, len(buf))
100+
copy(cpy, buf)
101+
C.free(unsafe.Pointer(cbuffer))
102+
return cpy
90103
}
91104

92105
func (o *optimizer) UpdateParameter(g Gradient) error {

go/pserver/optimizer_test.go

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,12 @@
1515
package pserver
1616

1717
import (
18+
"encoding/binary"
1819
"io/ioutil"
20+
"math"
1921
"testing"
22+
23+
"github.com/stretchr/testify/assert"
2024
)
2125

2226
func TestOptimizerCreateRelease(t *testing.T) {
@@ -36,3 +40,39 @@ func TestOptimizerCreateRelease(t *testing.T) {
3640
o := newOptimizer(param, nil)
3741
o.Cleanup()
3842
}
43+
44+
func float32Bytes(float float32) []byte {
45+
bits := math.Float32bits(float)
46+
bytes := make([]byte, 4)
47+
binary.LittleEndian.PutUint32(bytes, bits)
48+
return bytes
49+
}
50+
51+
func TestOptimizerState(t *testing.T) {
52+
p := Parameter{
53+
Name: "a",
54+
ElementType: Int32,
55+
}
56+
weights := float32Bytes(100)
57+
p.Content = weights
58+
config, err := ioutil.ReadFile("./client/c/test/testdata/optimizer.pb")
59+
if err != nil {
60+
t.Fatalf("read optimizer proto failed")
61+
}
62+
param := ParameterWithConfig{
63+
Param: p,
64+
Config: config,
65+
}
66+
o := newOptimizer(param, nil)
67+
s := o.GetStates()
68+
69+
// clear param content and check if the state is restored.
70+
param.Param.Content = float32Bytes(300)
71+
o1 := newOptimizer(param, s)
72+
s1 := o1.GetStates()
73+
assert.Equal(t, s, s1)
74+
assert.Equal(t, weights, o.GetWeights())
75+
assert.Equal(t, weights, o1.GetWeights())
76+
o.Cleanup()
77+
o1.Cleanup()
78+
}

go/pserver/service.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,13 @@ func (s *Service) checkpoint() (err error) {
297297
return
298298
}
299299

300+
if _, err = os.Stat(s.checkpointPath); os.IsNotExist(err) {
301+
err = os.MkdirAll(s.checkpointPath, os.ModePerm)
302+
if err != nil {
303+
return
304+
}
305+
}
306+
300307
id := uuid.NewV4().String()
301308
p := path.Join(s.checkpointPath, id)
302309
f, err := os.Create(p)

paddle/optimizer/adadelta_optimizer.cc

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,19 +25,17 @@ void AdadeltaOptimizer::Update(const Tensor* gradient) {
2525
}
2626
}
2727

28-
const char* AdadeltaOptimizer::SerializeState(int* state_len) {
28+
std::string AdadeltaOptimizer::SerializeState() {
2929
AdadeltaOptimizerState state;
3030
state.set_num_sample_passed(num_sample_passed_);
31-
std::string lr_str = this->lr_policy_->SerializeState(state_len);
31+
std::string lr_str = this->lr_policy_->SerializeState();
3232
state.mutable_lr_state()->ParseFromString(lr_str);
3333

3434
TensorToProto(*parameter_, state.mutable_parameter());
3535
TensorToProto(*accum_gradient_, state.mutable_accum_gradient());
3636
TensorToProto(*accum_delta_, state.mutable_accum_delta());
3737
TensorToProto(*update_delta_, state.mutable_update_delta());
38-
auto str = state.SerializeAsString();
39-
*state_len += str.size();
40-
return str.c_str();
38+
return state.SerializeAsString();
4139
}
4240

4341
void AdadeltaOptimizer::DeserializeState(const std::string& str) {

paddle/optimizer/adadelta_optimizer.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ class AdadeltaOptimizer : public ParameterOptimizer {
2323
if (update_delta_) delete update_delta_;
2424
}
2525
void Update(const Tensor *gradient);
26-
const char *SerializeState(int *state_len);
26+
std::string SerializeState();
2727
void DeserializeState(const std::string &state);
2828

2929
private:

paddle/optimizer/adagrad_optimizer.cc

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,15 @@ void AdagradOptimizer::Update(const Tensor* gradient) {
1717
learning_rate * decay_ * param[i];
1818
}
1919
}
20-
const char* AdagradOptimizer::SerializeState(int* state_len) {
20+
std::string AdagradOptimizer::SerializeState() {
2121
AdagradOptimizerState state;
2222
state.set_num_sample_passed(num_sample_passed_);
23-
std::string lr_str = this->lr_policy_->SerializeState(state_len);
23+
std::string lr_str = this->lr_policy_->SerializeState();
2424
state.mutable_lr_state()->ParseFromString(lr_str);
2525

2626
TensorToProto(*parameter_, state.mutable_parameter());
2727
TensorToProto(*accum_gradient_, state.mutable_accum_gradient());
28-
auto str = state.SerializeAsString();
29-
*state_len += str.size();
30-
return str.c_str();
28+
return state.SerializeAsString();
3129
}
3230

3331
void AdagradOptimizer::DeserializeState(const std::string& str) {

paddle/optimizer/adagrad_optimizer.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ class AdagradOptimizer : public ParameterOptimizer {
1919
if (accum_gradient_) delete accum_gradient_;
2020
}
2121
void Update(const Tensor *gradient);
22-
const char *SerializeState(int *state_len);
22+
std::string SerializeState();
2323
void DeserializeState(const std::string &state);
2424

2525
private:

paddle/optimizer/adam_optimizer.cc

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,18 +22,16 @@ void AdamOptimizer::Update(const Tensor *gradient) {
2222
}
2323
}
2424

25-
const char *AdamOptimizer::SerializeState(int *state_len) {
25+
std::string AdamOptimizer::SerializeState() {
2626
AdamOptimizerState state;
27-
std::string lr_str = this->lr_policy_->SerializeState(state_len);
27+
std::string lr_str = this->lr_policy_->SerializeState();
2828
state.mutable_lr_state()->ParseFromString(lr_str);
2929
state.set_num_sample_passed(num_sample_passed_);
3030

3131
TensorToProto(*parameter_, state.mutable_parameter());
3232
TensorToProto(*momentums_, state.mutable_momentums());
3333
TensorToProto(*velocitys_, state.mutable_velocitys());
34-
auto str = state.SerializeAsString();
35-
*state_len += str.size();
36-
return str.c_str();
34+
return state.SerializeAsString();
3735
}
3836

3937
void AdamOptimizer::DeserializeState(const std::string &str) {

paddle/optimizer/adam_optimizer.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ class AdamOptimizer : public ParameterOptimizer {
2525
if (velocitys_) delete velocitys_;
2626
}
2727
void Update(const Tensor *gradient);
28-
const char *SerializeState(int *state_len);
28+
std::string SerializeState();
2929
void DeserializeState(const std::string &state);
3030

3131
private:

paddle/optimizer/lr_policy.h

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ class LrPolicy {
1010
public:
1111
virtual ~LrPolicy() {}
1212
virtual double LearningRate(const uint64_t num_sample_passed) = 0;
13-
virtual const char *SerializeState(int *state_len) = 0;
13+
virtual std::string SerializeState() = 0;
1414
virtual void DeserializeState(const std::string &state) = 0;
1515
};
1616

@@ -21,12 +21,10 @@ class ConstLr final : public LrPolicy {
2121
double LearningRate(const uint64_t num_sample_passed) {
2222
return learning_rate_;
2323
}
24-
const char *SerializeState(int *state_len) {
24+
std::string SerializeState() {
2525
LrPolicyState state;
2626
state.set_learning_rate(learning_rate_);
27-
auto str = state.SerializeAsString();
28-
*state_len = str.size();
29-
return str.c_str();
27+
return state.SerializeAsString();
3028
}
3129
void DeserializeState(const std::string &str) {
3230
LrPolicyState state;
@@ -46,14 +44,12 @@ class LinearLr final : public LrPolicy {
4644
return std::max(learning_rate_ - lr_decay_a_ * num_sample_passed,
4745
lr_decay_b_);
4846
}
49-
const char *SerializeState(int *state_len) {
47+
std::string SerializeState() {
5048
LrPolicyState state;
5149
state.set_learning_rate(learning_rate_);
5250
state.set_lr_decay_a(lr_decay_a_);
5351
state.set_lr_decay_b(lr_decay_b_);
54-
auto str = state.SerializeAsString();
55-
*state_len = str.size();
56-
return str.c_str();
52+
return state.SerializeAsString();
5753
}
5854
void DeserializeState(const std::string &str) {
5955
LrPolicyState state;

0 commit comments

Comments
 (0)