@@ -23,12 +23,18 @@ using executorch::extension::training::optimizer::SGDOptions;
2323using executorch::runtime::Error;
2424using executorch::runtime::Result;
2525DEFINE_string (model_path, " xor.pte" , " Model serialized in flatbuffer format." );
26+ DEFINE_string (ptd_path, " " , " Model weights serialized in flatbuffer format." );
2627
2728int main (int argc, char ** argv) {
2829 gflags::ParseCommandLineFlags (&argc, &argv, true );
29- if (argc != 1 ) {
30+ if (argc == 0 ) {
31+ ET_LOG (Error, " Please provide a model path." );
32+ return 1 ;
33+ } else if (argc > 2 ) {
3034 std::string msg = " Extra commandline args: " ;
31- for (int i = 1 /* skip argv[0] (program name) */ ; i < argc; i++) {
35+ for (int i = 2 /* skip argv[0] (pte path) and argv[1] (ptd path) */ ;
36+ i < argc;
37+ i++) {
3238 msg += argv[i];
3339 }
3440 ET_LOG (Error, " %s" , msg.c_str ());
@@ -46,7 +52,21 @@ int main(int argc, char** argv) {
4652 auto loader = std::make_unique<executorch::extension::FileDataLoader>(
4753 std::move (loader_res.get ()));
4854
49- auto mod = executorch::extension::training::TrainingModule (std::move (loader));
55+ std::unique_ptr<executorch::extension::FileDataLoader> ptd_loader = nullptr ;
56+ if (!FLAGS_ptd_path.empty ()) {
57+ executorch::runtime::Result<executorch::extension::FileDataLoader>
58+ ptd_loader_res =
59+ executorch::extension::FileDataLoader::from (FLAGS_ptd_path.c_str ());
60+ if (ptd_loader_res.error () != Error::Ok) {
61+ ET_LOG (Error, " Failed to open ptd file: %s" , FLAGS_ptd_path.c_str ());
62+ return 1 ;
63+ }
64+ ptd_loader = std::make_unique<executorch::extension::FileDataLoader>(
65+ std::move (ptd_loader_res.get ()));
66+ }
67+
68+ auto mod = executorch::extension::training::TrainingModule (
69+ std::move (loader), nullptr , nullptr , nullptr , std::move (ptd_loader));
5070
5171 // Create full data set of input and labels.
5272 std::vector<std::pair<
@@ -70,7 +90,10 @@ int main(int argc, char** argv) {
7090 // Get the params and names
7191 auto param_res = mod.named_parameters (" forward" );
7292 if (param_res.error () != Error::Ok) {
73- ET_LOG (Error, " Failed to get named parameters" );
93+ ET_LOG (
94+ Error,
95+ " Failed to get named parameters, error: %d" ,
96+ static_cast <int >(param_res.error ()));
7497 return 1 ;
7598 }
7699
@@ -112,5 +135,6 @@ int main(int argc, char** argv) {
112135 std::string (param.first .data ()), param.second });
113136 }
114137
115- executorch::extension::flat_tensor::save_ptd (" xor.ptd" , param_map, 16 );
138+ executorch::extension::flat_tensor::save_ptd (
139+ " trained_xor.ptd" , param_map, 16 );
116140}
0 commit comments