@@ -23,12 +23,16 @@ using executorch::extension::training::optimizer::SGDOptions;
2323using executorch::runtime::Error;
2424using executorch::runtime::Result;
2525DEFINE_string (model_path, " xor.pte" , " Model serialized in flatbuffer format." );
26+ DEFINE_string (ptd_path, " " , " Model weights serialized in flatbuffer format." );
2627
2728int main (int argc, char ** argv) {
2829 gflags::ParseCommandLineFlags (&argc, &argv, true );
29- if (argc != 1 ) {
30+ if (argc == 0 ) {
31+ ET_LOG (Error, " Please provide a model path." );
32+ return 1 ;
33+ } else if (argc > 2 ) {
3034 std::string msg = " Extra commandline args: " ;
31- for (int i = 1 /* skip argv[0] (program name ) */ ; i < argc; i++) {
35+ for (int i = 2 /* skip argv[0] (pte path) and argv[1] (ptd path ) */ ; i < argc; i++) {
3236 msg += argv[i];
3337 }
3438 ET_LOG (Error, " %s" , msg.c_str ());
@@ -46,7 +50,20 @@ int main(int argc, char** argv) {
4650 auto loader = std::make_unique<executorch::extension::FileDataLoader>(
4751 std::move (loader_res.get ()));
4852
49- auto mod = executorch::extension::training::TrainingModule (std::move (loader));
53+ std::unique_ptr<executorch::extension::FileDataLoader> ptd_loader = nullptr ;
54+ if (!FLAGS_ptd_path.empty ()) {
55+ executorch::runtime::Result<executorch::extension::FileDataLoader>
56+ ptd_loader_res =
57+ executorch::extension::FileDataLoader::from (FLAGS_ptd_path.c_str ());
58+ if (ptd_loader_res.error () != Error::Ok) {
59+ ET_LOG (Error, " Failed to open ptd file: %s" , FLAGS_ptd_path.c_str ());
60+ return 1 ;
61+ }
62+ ptd_loader = std::make_unique<executorch::extension::FileDataLoader>(
63+ std::move (ptd_loader_res.get ()));
64+ }
65+
66+ auto mod = executorch::extension::training::TrainingModule (std::move (loader), nullptr , nullptr , nullptr , std::move (ptd_loader));
5067
5168 // Create full data set of input and labels.
5269 std::vector<std::pair<
@@ -70,7 +87,7 @@ int main(int argc, char** argv) {
7087 // Get the params and names
7188 auto param_res = mod.named_parameters (" forward" );
7289 if (param_res.error () != Error::Ok) {
73- ET_LOG (Error, " Failed to get named parameters" );
90+ ET_LOG (Error, " Failed to get named parameters, error: %d " , static_cast < int >(param_res. error ()) );
7491 return 1 ;
7592 }
7693
@@ -112,5 +129,5 @@ int main(int argc, char** argv) {
112129 std::string (param.first .data ()), param.second });
113130 }
114131
115- executorch::extension::flat_tensor::save_ptd (" xor .ptd" , param_map, 16 );
132+ executorch::extension::flat_tensor::save_ptd (" trained_xor .ptd" , param_map, 16 );
116133}
0 commit comments