diff --git a/flashlight/nn/modules/Container.h b/flashlight/nn/modules/Container.h index 55af9fe55..013998609 100644 --- a/flashlight/nn/modules/Container.h +++ b/flashlight/nn/modules/Container.h @@ -45,7 +45,7 @@ class Container : public Module { /** * A collection of modules contained within a `Container`. */ - std::vector modules_; + std::vector> modules_; Container(); @@ -68,16 +68,23 @@ class Container : public Module { * @param module the module to add. */ template - void add(std::shared_ptr module) { + void add(std::shared_ptr module, std::string const& name) { + // TODO: Assert that there is no other module with this name if (!module) { throw std::invalid_argument("can't add null Module to Container"); } - modules_.emplace_back(module); + modules_.emplace_back({name, module}); for (int i = 0; i < module->params().size(); i++) { childParamIdx_[params_.size()] = std::make_tuple(modules_.size() - 1, i); - params_.push_back(module->param(i)); + auto pn = module->paramNamed(i); + pn.first = name + "." + pn.first; + params_.push_back(pn); } } + template + void add(std::shared_ptr module) { + add(module, std::to_string(modules_.size())); + } /** * Returns a pointer to the module at the specified index in the container's diff --git a/flashlight/nn/modules/Module.cpp b/flashlight/nn/modules/Module.cpp index 3868f54a6..a8d1fe4f6 100644 --- a/flashlight/nn/modules/Module.cpp +++ b/flashlight/nn/modules/Module.cpp @@ -16,60 +16,111 @@ #include +#include "flashlight/common/Logging.h" #include "flashlight/nn/modules/Module.h" - #include "flashlight/nn/Init.h" namespace fl { Module::Module() = default; -Module::Module(const std::vector& params) - : params_(params.begin(), params.end()) {} +Module::Module(const std::vector>& params) : params_(params.begin(), params.end()) {} +Module::Module(const std::vector& params) { + for (auto const& p : params) { + params_.push_back({std::to_string(params_.size()), p}); + } +} Variable Module::param(int position) const { + return paramNamed(position).second; +} + +std::pair Module::paramNamed(int position) const { if (!(position >= 0 && position < params_.size())) { throw std::out_of_range("Module param index out of range"); } return params_[position]; + } void Module::setParams(const Variable& var, int position) { if (!(position >= 0 && position < params_.size())) { throw std::out_of_range("Module param index out of range"); } - params_[position] = var; + params_[position].second = var; } void Module::train() { train_ = true; for (auto& param : params_) { - param.setCalcGrad(true); + param.second.setCalcGrad(true); } } void Module::zeroGrad() { for (auto& param : params_) { - param.zeroGrad(); + param.second.zeroGrad(); } } void Module::eval() { train_ = false; for (auto& param : params_) { - param.setCalcGrad(false); + param.second.setCalcGrad(false); } } std::vector Module::params() const { - return params_; + std::vector out(params_.size()); + std::transform(params_.begin(), params_.begin(), std::back_inserter(out), [](std::pair p) { return p.second; }); + return out; } std::vector Module::operator()(const std::vector& input) { return this->forward(input); } +void Module::loadStateDict(StateDict const& sd) { + // Mapping of parameters - allows to ensure parameters uniqueness + std::unordered_map myParams; + for (auto i = 0U; i < params_.size(); ++i) { + auto const& p = params_[i]; + if (myParams.find(p.first) != myParams.end()) { + throw std::runtime_error("Duplicate parameter with name " + p.first + " (parameters indices " + std::to_string(myParams[p.first]) + " and " + std::to_string(i) + ") in loadStateDict for module " + prettyString()); + } + myParams[p.first] = i; + if (sd.find(p.first) == sd.end()) { + VLOG(1) << "Parameter " << p.first << " not in state dict"; + } + } + + for (auto const& p: sd) { + auto it = myParams.find(p.first); + if (it == myParams.end()) { + VLOG(1) << "Parameter " << p.first << " in state dict but not in current Module - ignored"; + continue; + } + auto currentParam = params_[it->second].second; + if (currentParam.dims() != p.second.dims()) { + throw std::runtime_error("Loading parameter with name " + p.first + ": size mismatch in loadStateDict for module" + prettyString()); + } + setParams(p.second, it->second); + } +} + +StateDict Module::stateDict() const { + std::unordered_map myStateDict; + for (auto i = 0U; i < params_.size(); ++i) { + auto const& p = params_[i]; + if (myStateDict.find(p.first) != myStateDict.end()) { + throw std::runtime_error("Duplicate parameter with name " + p.first + " in stateDict for module " + prettyString()); + } + myStateDict[p.first] = p.second; + } + return myStateDict; +} + UnaryModule::UnaryModule() = default; UnaryModule::UnaryModule(const std::vector& params) diff --git a/flashlight/nn/modules/Module.h b/flashlight/nn/modules/Module.h index ad893a000..a75786ab7 100644 --- a/flashlight/nn/modules/Module.h +++ b/flashlight/nn/modules/Module.h @@ -22,9 +22,14 @@ #include #include #include +#include namespace fl { +typedef std::unordered_map StateDict; +// Or maybe recursive? map> +// Would be better for Container modules (but we can also flatten everything) + /** * An abstract computation unit capable of forward computation. Also * contains a collection of parameters that can be mutated, and will be @@ -42,7 +47,7 @@ class Module { * Parameters of module, represented as a collection of `Variable`, whose * ordering is based on the implementation of the respective module. */ - std::vector params_; + std::vector> params_; /** * A flag specifying whether or not the module is in `train` mode. If @@ -62,6 +67,7 @@ class Module { * @param params a vector of `Variable` which will replace `params_` */ explicit Module(const std::vector& params); + explicit Module(const std::vector>& params); public: /** @@ -90,6 +96,7 @@ class Module { * @return a `Variable` tensor for the parameter at the requested position */ Variable param(int position) const; + std::pair paramNamed(int position) const; /** * Sets a parameter at a specified position with a new, given one. @@ -104,6 +111,9 @@ class Module { */ virtual void setParams(const Variable& var, int position); + void loadStateDict(StateDict const& sd); + StateDict stateDict() const; + /** * Clears references to gradient Variables for all parameters in the module. */ diff --git a/flashlight/optim/Optimizers.h b/flashlight/optim/Optimizers.h index 4498eb4fc..7cdbf792a 100644 --- a/flashlight/optim/Optimizers.h +++ b/flashlight/optim/Optimizers.h @@ -24,6 +24,10 @@ namespace fl { +// Does it make sense to have a common abstraction between `Module` and optimizers? +// Optimizers and Modules both have parameters, can be serialized, have a `prettyString`, have a state_dict... +// The only thing is that Optimizers don't have `train`/`eval`/`forward`/`backward` + /** An abstract base class for first-order gradient-based optimizers. Any * derived class must implement the step() function. * Example usage: