| 
 | 1 | +/*  | 
 | 2 | + * Copyright (c) Meta Platforms, Inc. and affiliates.  | 
 | 3 | + * All rights reserved.  | 
 | 4 | + *  | 
 | 5 | + * This source code is licensed under the BSD-style license found in the  | 
 | 6 | + * LICENSE file in the root directory of this source tree.  | 
 | 7 | + */  | 
 | 8 | + | 
 | 9 | +#include <executorch/extension/data_loader/file_data_loader.h>  | 
 | 10 | +#include <executorch/extension/tensor/tensor.h>  | 
 | 11 | +#include <executorch/extension/training/module/training_module.h>  | 
 | 12 | +#include <executorch/extension/training/optimizer/sgd.h>  | 
 | 13 | +#include <gflags/gflags.h>  | 
 | 14 | +#include <random>  | 
 | 15 | + | 
 | 16 | +#pragma clang diagnostic ignored \  | 
 | 17 | +    "-Wbraced-scalar-init" // {0} below upsets clang.  | 
 | 18 | + | 
 | 19 | +using executorch::extension::FileDataLoader;  | 
 | 20 | +using executorch::extension::training::optimizer::SGD;  | 
 | 21 | +using executorch::extension::training::optimizer::SGDOptions;  | 
 | 22 | +using executorch::runtime::Error;  | 
 | 23 | +using executorch::runtime::Result;  | 
 | 24 | +DEFINE_string(model_path, "xor.pte", "Model serialized in flatbuffer format.");  | 
 | 25 | + | 
 | 26 | +int main(int argc, char** argv) {  | 
 | 27 | +  gflags::ParseCommandLineFlags(&argc, &argv, true);  | 
 | 28 | +  if (argc != 1) {  | 
 | 29 | +    std::string msg = "Extra commandline args: ";  | 
 | 30 | +    for (int i = 1 /* skip argv[0] (program name) */; i < argc; i++) {  | 
 | 31 | +      msg += argv[i];  | 
 | 32 | +    }  | 
 | 33 | +    ET_LOG(Error, "%s", msg.c_str());  | 
 | 34 | +    return 1;  | 
 | 35 | +  }  | 
 | 36 | + | 
 | 37 | +  // Load the model file.  | 
 | 38 | +  executorch::runtime::Result<executorch::extension::FileDataLoader>  | 
 | 39 | +      loader_res =  | 
 | 40 | +          executorch::extension::FileDataLoader::from(FLAGS_model_path.c_str());  | 
 | 41 | +  if (loader_res.error() != Error::Ok) {  | 
 | 42 | +    ET_LOG(Error, "Failed to open model file: %s", FLAGS_model_path.c_str());  | 
 | 43 | +    return 1;  | 
 | 44 | +  }  | 
 | 45 | +  auto loader = std::make_unique<executorch::extension::FileDataLoader>(  | 
 | 46 | +      std::move(loader_res.get()));  | 
 | 47 | + | 
 | 48 | +  auto mod = executorch::extension::training::TrainingModule(std::move(loader));  | 
 | 49 | + | 
 | 50 | +  // Create full data set of input and labels.  | 
 | 51 | +  std::vector<std::pair<  | 
 | 52 | +      executorch::extension::TensorPtr,  | 
 | 53 | +      executorch::extension::TensorPtr>>  | 
 | 54 | +      data_set;  | 
 | 55 | +  data_set.push_back( // XOR(1, 1) = 0  | 
 | 56 | +      {executorch::extension::make_tensor_ptr<float>({1, 2}, {1, 1}),  | 
 | 57 | +       executorch::extension::make_tensor_ptr<long>({1}, {0})});  | 
 | 58 | +  data_set.push_back( // XOR(0, 0) = 0  | 
 | 59 | +      {executorch::extension::make_tensor_ptr<float>({1, 2}, {0, 0}),  | 
 | 60 | +       executorch::extension::make_tensor_ptr<long>({1}, {0})});  | 
 | 61 | +  data_set.push_back( // XOR(1, 0) = 1  | 
 | 62 | +      {executorch::extension::make_tensor_ptr<float>({1, 2}, {1, 0}),  | 
 | 63 | +       executorch::extension::make_tensor_ptr<long>({1}, {1})});  | 
 | 64 | +  data_set.push_back( // XOR(0, 1) = 1  | 
 | 65 | +      {executorch::extension::make_tensor_ptr<float>({1, 2}, {0, 1}),  | 
 | 66 | +       executorch::extension::make_tensor_ptr<long>({1}, {1})});  | 
 | 67 | + | 
 | 68 | +  // Create optimizer.  | 
 | 69 | +  // Get the params and names  | 
 | 70 | +  auto param_res = mod.named_parameters("forward");  | 
 | 71 | +  if (param_res.error() != Error::Ok) {  | 
 | 72 | +    ET_LOG(Error, "Failed to get named parameters");  | 
 | 73 | +    return 1;  | 
 | 74 | +  }  | 
 | 75 | + | 
 | 76 | +  SGDOptions options{0.1};  | 
 | 77 | +  SGD optimizer(param_res.get(), options);  | 
 | 78 | + | 
 | 79 | +  // Randomness to sample the data set.  | 
 | 80 | +  std::default_random_engine URBG{std::random_device{}()};  | 
 | 81 | +  std::uniform_int_distribution<int> dist{  | 
 | 82 | +      0, static_cast<int>(data_set.size()) - 1};  | 
 | 83 | + | 
 | 84 | +  // Train the model.  | 
 | 85 | +  size_t num_epochs = 5000;  | 
 | 86 | +  for (int i = 0; i < num_epochs; i++) {  | 
 | 87 | +    int index = dist(URBG);  | 
 | 88 | +    auto& data = data_set[index];  | 
 | 89 | +    const auto& results = mod.execute_forward_backward(  | 
 | 90 | +        "forward", {*data.first.get(), *data.second.get()});  | 
 | 91 | +    if (results.error() != Error::Ok) {  | 
 | 92 | +      ET_LOG(Error, "Failed to execute forward_backward");  | 
 | 93 | +      return 1;  | 
 | 94 | +    }  | 
 | 95 | +    if (i % 500 == 0 || i == num_epochs - 1) {  | 
 | 96 | +      ET_LOG(  | 
 | 97 | +          Info,  | 
 | 98 | +          "Step %d, Loss %f, Input [%.0f, %.0f], Prediction %ld, Label %ld",  | 
 | 99 | +          i,  | 
 | 100 | +          results.get()[0].toTensor().const_data_ptr<float>()[0],  | 
 | 101 | +          data.first->const_data_ptr<float>()[0],  | 
 | 102 | +          data.first->const_data_ptr<float>()[1],  | 
 | 103 | +          results.get()[1].toTensor().const_data_ptr<int64_t>()[0],  | 
 | 104 | +          data.second->const_data_ptr<int64_t>()[0]);  | 
 | 105 | +    }  | 
 | 106 | +    optimizer.step(mod.named_gradients("forward").get());  | 
 | 107 | +  }  | 
 | 108 | +}  | 
0 commit comments