Skip to content

Commit 683fb59

Browse files
committed
improve code formatting
1 parent 19bfb87 commit 683fb59

File tree

1 file changed

+76
-73
lines changed

1 file changed

+76
-73
lines changed

tools/binary_converter/main.cpp

Lines changed: 76 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ const string PATTERN_BIAS = "bias";
5454
const string PATTERN_FORWARD = "fwd";
5555
const string PATTERN_ARG = "arg";
5656
const string PATTERN_PAD = "pad";
57+
5758
// symbol json related
5859
const char* PREFIX_SYM_JSON_NODES = "nodes";
5960
const char* PREFIX_SYM_JSON_NODE_ROW_PTR = "node_row_ptr";
@@ -64,13 +65,14 @@ const char* PREFIX_SYM_JSON_ARG_NODES = "arg_nodes";
6465
// name of standard convolution and dense layer
6566
const string PREFIX_DENSE = "FullyConnected";
6667
const string PREFIX_CONVOLUTION = "Convolution";
68+
6769
// use this to distinguish arg_nodes : op = 'null'
6870
const string ARG_NODES_OP_PATTERN = "null";
6971

7072
const string PREFIX_BINARY_INFERENCE_CONV_LAYER = "BinaryInferenceConvolution";
7173
const string PREFIX_BINARY_INFERENCE_DENSE_LAYER = "BinaryInferenceFullyConnected";
7274

73-
bool _verbose = false;
75+
bool VERBOSE = false;
7476
//==============================================//
7577

7678
/**
@@ -85,8 +87,8 @@ int convert_to_binary_row(mxnet::NDArray& array) {
8587
CHECK(array.shape().ndim() >= 2); // second dimension is input depth from prev. layer, needed for next line
8688

8789
if (array.shape()[1] % BITS_PER_BINARY_WORD != 0){
88-
cout << "Error:" << "the operator has an invalid input dim: " << array.shape()[1];
89-
cout << ", which is not divisible by " << BITS_PER_BINARY_WORD << endl;
90+
cerr << "Error:" << "the operator has an invalid input dim: " << array.shape()[1];
91+
cerr << ", which is not divisible by " << BITS_PER_BINARY_WORD << endl;
9092
return -1;
9193
}
9294

@@ -98,8 +100,8 @@ int convert_to_binary_row(mxnet::NDArray& array) {
98100
binarized_shape[3] = array.shape()[3];
99101

100102
mxnet::NDArray temp(binarized_shape, mxnet::Context::CPU(), false, mxnet::op::xnor::corresponding_dtype());
101-
mxnet::op::xnor::get_binary_row((float*) array.data().dptr_,
102-
(BINARY_WORD*) temp.data().dptr_,
103+
mxnet::op::xnor::get_binary_row(static_cast<float *>(array.data().dptr_),
104+
static_cast<BINARY_WORD*> (temp.data().dptr_),
103105
size);
104106
array = temp;
105107

@@ -123,7 +125,8 @@ void transpose(mxnet::NDArray& array) {
123125
MSHADOW_REAL_TYPE_SWITCH(array.dtype(), DType, {
124126
for (int row = 0; row < rows; row++) {
125127
for (int col = 0; col < cols; col++) {
126-
((DType*)temp.data().dptr_)[col * rows + row] = ((DType*)array.data().dptr_)[row * cols + col];
128+
(static_cast<DType *> (temp.data().dptr_))[col * rows + row] =
129+
(static_cast<DType *> (array.data().dptr_))[row * cols + col];
127130
}
128131
}
129132
})
@@ -142,9 +145,9 @@ int transpose_and_convert_to_binary_col(mxnet::NDArray& array) {
142145

143146
CHECK(array.shape().ndim() == 2); // since we binarize column wise, we need to know no of rows and columns
144147

145-
if (array.shape()[0] % BITS_PER_BINARY_WORD != 0){
146-
cout << "Error:" << "the operator has an invalid input dim: " << array.shape()[0];
147-
cout << ", which is not divisible by " << BITS_PER_BINARY_WORD << endl;
148+
if (array.shape()[0] % BITS_PER_BINARY_WORD != 0) {
149+
cerr << "Error:" << "the operator has an invalid input dim: " << array.shape()[0];
150+
cerr << ", which is not divisible by " << BITS_PER_BINARY_WORD << endl;
148151
return -1;
149152
}
150153

@@ -153,10 +156,10 @@ int transpose_and_convert_to_binary_col(mxnet::NDArray& array) {
153156
binarized_shape[1] = array.shape()[0] / BITS_PER_BINARY_WORD;
154157

155158
mxnet::NDArray temp(binarized_shape, mxnet::Context::CPU(), false, mxnet::op::xnor::corresponding_dtype());
156-
mxnet::op::xnor::get_binary_col_unrolled((float*) array.data().dptr_,
157-
(BINARY_WORD*) temp.data().dptr_,
158-
array.shape()[0],
159-
array.shape()[1]);
159+
mxnet::op::xnor::get_binary_col_unrolled(static_cast<float *>(array.data().dptr_),
160+
static_cast<BINARY_WORD*>(temp.data().dptr_),
161+
array.shape()[0],
162+
array.shape()[1]);
160163
array = temp;
161164

162165
return 0;
@@ -175,41 +178,41 @@ int transpose_and_convert_to_binary_col(mxnet::NDArray& array) {
175178
* data: ndarray storing weight params
176179
* keys: the corresponding keys to the weights array
177180
*/
178-
void convert_params(vector<mxnet::NDArray>& data, const vector<string>& keys){
179-
string delimiter = ":";
181+
void convert_params(vector<mxnet::NDArray>& data, const vector<string>& keys) {
182+
const string delimiter = ":";
180183

181-
for (int i = 0; i < keys.size(); ++i)
182-
{
183-
string tp = keys[i].substr(0, keys[i].find(delimiter));
184-
string name = keys[i].substr(keys[i].find(delimiter)+1, keys[i].length()-1);
184+
for (int i = 0; i < keys.size(); ++i){
185+
string type = keys[i].substr(0, keys[i].find(delimiter));
186+
string name = keys[i].substr(keys[i].find(delimiter) + 1, keys[i].length() - 1);
185187

186-
if (_verbose){
187-
//logging
188-
cout << "Info: " << '\t' << "type:" << tp << "; ";
188+
if (VERBOSE) {
189+
// logging
190+
cout << "Info: " << '\t' << "type:" << type << "; ";
189191
cout << "name:" << name << "; ";
190192
cout << "shape:" << data[i].shape() << endl;
191193
}
192194

193195
// concatenate the weights of qconv layer
194-
if (tp == PATTERN_ARG
196+
if (type == PATTERN_ARG
195197
&& name.find(PATTERN_Q_CONV) != string::npos
196-
&& name.find("_"+PATTERN_WEIGHT) != string::npos){
197-
// concatenates binary row
198-
if(convert_to_binary_row(data[i]) < 0){
199-
cout << "Error: weights concatenation FAILED for operator '" << name <<"'" << endl;
200-
}else
198+
&& name.find("_"+PATTERN_WEIGHT) != string::npos) {
199+
// concatenates binary row
200+
if (convert_to_binary_row(data[i]) < 0) {
201+
cerr << "Error: weights concatenation FAILED for operator '" << name << "'" << endl;
202+
} else {
201203
cout << "Info: CONCATENATED layer: '" << name << "'" << endl;
204+
}
202205
}
203206

204207
// concatenate the weights of qfc layer
205-
if (tp == PATTERN_ARG
208+
if (type == PATTERN_ARG
206209
&& name.find(PATTERN_Q_DENSE) != string::npos
207-
&& name.find("_"+PATTERN_WEIGHT) != string::npos)
208-
{
209-
if (transpose_and_convert_to_binary_col(data[i]) < 0){
210-
cout << "Error: weights concatenation FAILED for operator '" << name <<"'" << endl;
211-
}else
210+
&& name.find("_" + PATTERN_WEIGHT) != string::npos) {
211+
if (transpose_and_convert_to_binary_col(data[i]) < 0) {
212+
cerr << "Error: weights concatenation FAILED for operator '" << name << "'" << endl;
213+
} else {
212214
cout << "Info: CONCATENATED layer: '" << name << "'" << endl;
215+
}
213216
}
214217
}
215218
}
@@ -225,7 +228,7 @@ void convert_params(vector<mxnet::NDArray>& data, const vector<string>& keys){
225228
int convert_params_file(const string& input_file, const string& output_file) {
226229
vector<mxnet::NDArray> data;
227230
vector<string> keys;
228-
231+
229232
{ // loading params file into data and keys
230233
// logging
231234
cout << "Info: " <<"LOADING input '.params' file: "<< input_file << endl;
@@ -241,7 +244,7 @@ int convert_params_file(const string& input_file, const string& output_file) {
241244
mxnet::NDArray::Save(fo.get(), data, keys);
242245
cout << "Info: " << "converted .params file saved!" << endl;
243246
}
244-
247+
245248
return 0;
246249
}
247250

@@ -253,18 +256,18 @@ int convert_params_file(const string& input_file, const string& output_file) {
253256
* @param a json file
254257
*/
255258
void print_rapidjson_doc(string json, string log_prefix="") {
256-
Document d;
259+
Document d;
257260
d.Parse(json.c_str());
258261

259262
// print heads
260263
CHECK(d.HasMember(PREFIX_SYM_JSON_HEADS));
261264
rapidjson::Value& heads = d[PREFIX_SYM_JSON_HEADS];
262265
CHECK(heads.IsArray() && heads.Capacity() > 0);
263266
// logging
264-
cout << "Info: " << log_prefix << "'heads' of input json: " << "[" << "["
267+
cout << "Info: " << log_prefix << "'heads' of input json: " << "[" << "["
265268
<< heads[0][0].GetInt() << ", "
266269
<< heads[0][1].GetInt() << ", "
267-
<< heads[0][2].GetInt()
270+
<< heads[0][2].GetInt()
268271
<< "]" << "]" << endl;
269272

270273
// print arg_nodes
@@ -273,12 +276,10 @@ void print_rapidjson_doc(string json, string log_prefix="") {
273276
CHECK(arg_nodes.IsArray());
274277
CHECK(!arg_nodes.Empty());
275278
// logging
276-
cout << "Info: " << log_prefix << "'arg_nodes' of input json: " << "[";
277-
for (int i = 0; i < arg_nodes.Capacity(); ++i)
278-
{
279-
cout << arg_nodes[i].GetInt();
280-
if (i < arg_nodes.Capacity()-1)
281-
{
279+
cout << "Info: " << log_prefix << "'arg_nodes' of input json: " << "[";
280+
for (int i = 0; i < arg_nodes.Capacity(); ++i) {
281+
cout << arg_nodes[i].GetInt();
282+
if (i < arg_nodes.Capacity()-1) {
282283
cout << ",";
283284
}
284285
}
@@ -290,10 +291,9 @@ void print_rapidjson_doc(string json, string log_prefix="") {
290291
CHECK(nodes.IsArray());
291292
CHECK(!nodes.Empty());
292293

293-
cout <<"Info: " << log_prefix << "number of nodes:" << nodes.Capacity() << endl;
294-
if (_verbose){
295-
for (int i = 0; i < nodes.Capacity(); ++i)
296-
{
294+
cout <<"Info: " << log_prefix << "number of nodes:" << nodes.Capacity() << endl;
295+
if (VERBOSE) {
296+
for (int i = 0; i < nodes.Capacity(); ++i) {
297297
cout <<"Info: \t" << log_prefix << "node index " << i << " : " << nodes[i]["name"].GetString() << endl;
298298
}
299299
}
@@ -328,7 +328,7 @@ void adjustIdsForRemovalOf(Value::ValueIterator& itr, uint currentId,
328328
- nodes: all operators
329329
- heads: head node
330330
- arg_nodes: arg nodes, usually 'null' operators.
331-
- node_row_ptr: not yet found detailed information about this item,
331+
- node_row_ptr: not yet found detailed information about this item,
332332
but it seems not affecting the inference
333333
* @param input_file path to mxnet symbol file with qconv and qdense layers
334334
* @param output_file path to converted symbol file
@@ -360,11 +360,11 @@ int convert_symbol_json(const string& input_fname, const string& output_fname) {
360360
CHECK(heads.IsArray() && heads.Capacity() > 0);
361361

362362
// update arg_nodes : contains indices of all "null" op
363-
CHECK(d.HasMember(PREFIX_SYM_JSON_ARG_NODES));
363+
CHECK(d.HasMember(PREFIX_SYM_JSON_ARG_NODES));
364364
Value& arg_nodes = d[PREFIX_SYM_JSON_ARG_NODES];
365365
CHECK(arg_nodes.IsArray());
366366
CHECK(!arg_nodes.Empty());
367-
367+
368368
// check, create nodes
369369
CHECK(d.HasMember(PREFIX_SYM_JSON_NODES));
370370
Value& nodes = d[PREFIX_SYM_JSON_NODES];
@@ -427,17 +427,17 @@ int convert_symbol_json(const string& input_fname, const string& output_fname) {
427427
}
428428

429429
// replace convolution and dense operators with binary inference layer
430-
if ((*itr)["op"].IsString() &&
430+
if ((*itr)["op"].IsString() &&
431431
string((*itr)["op"].GetString()) == PREFIX_CONVOLUTION) {
432432
(*itr)["op"].SetString(PREFIX_BINARY_INFERENCE_CONV_LAYER.c_str(), allocator);
433433
//logging
434434
cout << "Info: " <<"CONVERTING op: '" << (*itr)["name"].GetString() << "' from '"
435435
<< PREFIX_CONVOLUTION << "' to '" << PREFIX_BINARY_INFERENCE_CONV_LAYER << "'" << endl;
436-
}
437-
438-
if ((*itr)["op"].IsString() &&
436+
}
437+
438+
if ((*itr)["op"].IsString() &&
439439
string((*itr)["op"].GetString()) == PREFIX_DENSE){
440-
(*itr)["op"].SetString(PREFIX_BINARY_INFERENCE_DENSE_LAYER.c_str(), allocator);
440+
(*itr)["op"].SetString(PREFIX_BINARY_INFERENCE_DENSE_LAYER.c_str(), allocator);
441441
//logging
442442
cout << "Info: " <<"CONVERTING op: '" << (*itr)["name"].GetString() << "' from '"
443443
<< PREFIX_DENSE << "' to '" << PREFIX_BINARY_INFERENCE_DENSE_LAYER << "'" << endl;
@@ -477,20 +477,20 @@ int convert_symbol_json(const string& input_fname, const string& output_fname) {
477477
}
478478

479479
// add arg_node
480-
if ( string((*itr)["op"].GetString()) == ARG_NODES_OP_PATTERN) {
480+
if (string((*itr)["op"].GetString()) == ARG_NODES_OP_PATTERN) {
481481
arg_nodes.PushBack(Value().SetInt(currentNewId), allocator);
482482
}
483483
}
484484

485485

486-
// update heads
486+
// update heads
487487
for (Value::ValueIterator itr = heads.Begin(); itr != heads.End(); ++itr) {
488488
uint formerId = (*itr)[0].GetUint();
489489
CHECK(newIds.count(formerId) > 0);
490490
(*itr)[0].SetUint(newIds[formerId]);
491491
}
492492

493-
// update nodes
493+
// update nodes
494494
nodes = nodes_new;
495495

496496
// Save output json file
@@ -502,15 +502,15 @@ int convert_symbol_json(const string& input_fname, const string& output_fname) {
502502
{
503503
ofstream stream(output_fname);
504504
if (!stream.is_open()) {
505-
cout << "Error: " << "cant find json file at " + output_fname << endl;
505+
cerr << "Error: " << "cant find json file at " + output_fname << endl;
506506
return -1;
507507
}
508508
string output = buffer.GetString();
509509
stream << output;
510510
stream.close();
511511

512512
cout << "Info: " << "converted json file saved!" << endl;
513-
513+
514514
// print the current json docu
515515
print_rapidjson_doc(output, "updated ");
516516
}
@@ -522,50 +522,53 @@ int convert_symbol_json(const string& input_fname, const string& output_fname) {
522522
* @brief convert mxnet param and symbol file to use only binarized weights in conv and fc layers
523523
*
524524
*/
525-
int main(int argc, char ** argv){
525+
int main(int argc, char ** argv) {
526526
if (argc < 2 || argc > 4) {
527-
cout << "usage: " + string(argv[0]) + " <mxnet *.params file>" + " <output (optional)>" +
527+
cout << "usage: " + string(argv[0]) + " <mxnet *.params file>" + " <output (optional)>" +
528528
" --verbose" << endl;
529529
cout << " will binarize the weights of the qconv or qdense layers of your model," << endl;
530530
cout << " pack 32(x86 and ARMv7) or 64(x64) values into one and save the result with the prefix 'binarized_'" << endl;
531-
cout << "<output>: specify the location to store the binarized files. If not specified, the same location as the input model will be used." << endl;
531+
cout << "<output>: specify the location to store the binarized files. If not specified, the same location as the input model will be used." << endl;
532532
cout << "--verbose: for more information" << endl;
533533
return -1;
534534
}
535535

536536
// prepare file paths
537537
const string params_file(argv[1]);
538-
char *file_copy_basename = strdup(argv[1]);
539-
char *file_copy_dirname = strdup(argv[1]);
538+
char *file_copy_basename = strdup(argv[1]);
539+
char *file_copy_dirname = strdup(argv[1]);
540540
const string path(dirname(file_copy_dirname));
541541
const string params_file_name(basename(file_copy_basename));
542542
string out_path;
543-
if(argc >= 3)
543+
if (argc >= 3) {
544544
out_path = argv[2];
545+
}
545546

546-
if(out_path.empty() || out_path == "--verbose")
547+
if (out_path.empty() || out_path == "--verbose") {
547548
out_path = path;
549+
}
548550
free(file_copy_basename);
549551
free(file_copy_dirname);
550552

551553
if ( (argc == 3 && string(argv[2]) == "--verbose")
552-
|| (argc == 4 && string(argv[3]) == "--verbose"))
553-
_verbose = true;
554+
|| (argc == 4 && string(argv[3]) == "--verbose")) {
555+
VERBOSE = true;
556+
}
554557

555558
string base_name = params_file_name;
556559
base_name.erase(base_name.rfind('-')); // watchout if no '-'
557560

558561
const string json_file_name(path + "/" + base_name + "-symbol.json");
559562
const string param_out_fname(out_path + "/" + "binarized_" + params_file_name);
560563
const string json_out_fname(out_path + "/" + "binarized_" + base_name + "-symbol.json");
561-
564+
562565
if (int ret = convert_symbol_json(json_file_name, json_out_fname) != 0) {
563566
return ret;
564567
}
565568

566569
if (int ret = convert_params_file(params_file, param_out_fname) != 0) {
567570
return ret;
568571
}
569-
572+
570573
return 0;
571574
}

0 commit comments

Comments
 (0)