Skip to content

Commit 2b6a0b9

Browse files
committed
flownet2 patch: fix the error in hdf5 format model weights load.
1 parent db72704 commit 2b6a0b9

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

src/caffe/net.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -849,7 +849,7 @@ void Net<Dtype>::CopyTrainedLayersFromHDF5(const string trained_filename) {
849849
}
850850
}
851851
hdf5_load_nd_dataset(layer_hid, dataset_name.c_str(), 0, kMaxBlobAxes,
852-
target_blobs[j].get());
852+
target_blobs[j].get(), true); //Allow reshape here, as we are loading data not params
853853
}
854854
H5Gclose(layer_hid);
855855
}

src/caffe/solvers/sgd_solver.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -380,7 +380,7 @@ void SGDSolver<Dtype>::RestoreSolverStateFromHDF5(const string& state_file) {
380380
ostringstream oss;
381381
oss << i;
382382
hdf5_load_nd_dataset<Dtype>(history_hid, oss.str().c_str(), 0,
383-
kMaxBlobAxes, history_[i].get());
383+
kMaxBlobAxes, history_[i].get(), true); //Allow reshape here, as we are loading data not params
384384
}
385385
H5Gclose(history_hid);
386386
H5Fclose(file_hid);

0 commit comments

Comments
 (0)