Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions flashlight/nn/modules/Container.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class Container : public Module {
/**
* A collection of modules contained within a `Container`.
*/
std::vector<ModulePtr> modules_;
std::vector<std::pair<std::string /* module name */, ModulePtr>> modules_;

Container();

Expand All @@ -68,16 +68,23 @@ class Container : public Module {
* @param module the module to add.
*/
template <typename T>
void add(std::shared_ptr<T> module) {
void add(std::shared_ptr<T> module, std::string const& name) {
// TODO: Assert that there is no other module with this name
if (!module) {
throw std::invalid_argument("can't add null Module to Container");
}
modules_.emplace_back(module);
modules_.emplace_back({name, module});
for (int i = 0; i < module->params().size(); i++) {
childParamIdx_[params_.size()] = std::make_tuple(modules_.size() - 1, i);
params_.push_back(module->param(i));
auto pn = module->paramNamed(i);
pn.first = name + "." + pn.first;
params_.push_back(pn);
}
}
template <typename T>
void add(std::shared_ptr<T> module) {
add(module, std::to_string(modules_.size()));
}

/**
* Returns a pointer to the module at the specified index in the container's
Expand Down
67 changes: 59 additions & 8 deletions flashlight/nn/modules/Module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,60 +16,111 @@

#include <stdexcept>

#include "flashlight/common/Logging.h"
#include "flashlight/nn/modules/Module.h"


#include "flashlight/nn/Init.h"

namespace fl {

Module::Module() = default;

Module::Module(const std::vector<Variable>& params)
: params_(params.begin(), params.end()) {}
Module::Module(const std::vector<std::pair<std::string, Variable>>& params) : params_(params.begin(), params.end()) {}
Module::Module(const std::vector<Variable>& params) {
for (auto const& p : params) {
params_.push_back({std::to_string(params_.size()), p});
}
}

Variable Module::param(int position) const {
return paramNamed(position).second;
}

std::pair<std::string, Variable> Module::paramNamed(int position) const {
if (!(position >= 0 && position < params_.size())) {
throw std::out_of_range("Module param index out of range");
}
return params_[position];

}

void Module::setParams(const Variable& var, int position) {
if (!(position >= 0 && position < params_.size())) {
throw std::out_of_range("Module param index out of range");
}
params_[position] = var;
params_[position].second = var;
}

void Module::train() {
train_ = true;
for (auto& param : params_) {
param.setCalcGrad(true);
param.second.setCalcGrad(true);
}
}

void Module::zeroGrad() {
for (auto& param : params_) {
param.zeroGrad();
param.second.zeroGrad();
}
}

void Module::eval() {
train_ = false;
for (auto& param : params_) {
param.setCalcGrad(false);
param.second.setCalcGrad(false);
}
}

std::vector<Variable> Module::params() const {
return params_;
std::vector<Variable> out(params_.size());
std::transform(params_.begin(), params_.begin(), std::back_inserter(out), [](std::pair<std::string, Variable> p) { return p.second; });
return out;
}

std::vector<Variable> Module::operator()(const std::vector<Variable>& input) {
return this->forward(input);
}

void Module::loadStateDict(StateDict const& sd) {
// Mapping of parameters - allows to ensure parameters uniqueness
std::unordered_map<std::string, size_t> myParams;
for (auto i = 0U; i < params_.size(); ++i) {
auto const& p = params_[i];
if (myParams.find(p.first) != myParams.end()) {
throw std::runtime_error("Duplicate parameter with name " + p.first + " (parameters indices " + std::to_string(myParams[p.first]) + " and " + std::to_string(i) + ") in loadStateDict for module " + prettyString());
}
myParams[p.first] = i;
if (sd.find(p.first) == sd.end()) {
VLOG(1) << "Parameter " << p.first << " not in state dict";
}
}

for (auto const& p: sd) {
auto it = myParams.find(p.first);
if (it == myParams.end()) {
VLOG(1) << "Parameter " << p.first << " in state dict but not in current Module - ignored";
continue;
}
auto currentParam = params_[it->second].second;
if (currentParam.dims() != p.second.dims()) {
throw std::runtime_error("Loading parameter with name " + p.first + ": size mismatch in loadStateDict for module" + prettyString());
}
setParams(p.second, it->second);
}
}

StateDict Module::stateDict() const {
std::unordered_map<std::string, Variable> myStateDict;
for (auto i = 0U; i < params_.size(); ++i) {
auto const& p = params_[i];
if (myStateDict.find(p.first) != myStateDict.end()) {
throw std::runtime_error("Duplicate parameter with name " + p.first + " in stateDict for module " + prettyString());
}
myStateDict[p.first] = p.second;
}
return myStateDict;
}

UnaryModule::UnaryModule() = default;

UnaryModule::UnaryModule(const std::vector<Variable>& params)
Expand Down
12 changes: 11 additions & 1 deletion flashlight/nn/modules/Module.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,14 @@
#include <stdexcept>
#include <string>
#include <vector>
#include <unordered_map>

namespace fl {

typedef std::unordered_map<std::string, Variable> StateDict;
// Or maybe recursive? map<string, variant<map, Variable>>
// Would be better for Container modules (but we can also flatten everything)

/**
* An abstract computation unit capable of forward computation. Also
* contains a collection of parameters that can be mutated, and will be
Expand All @@ -42,7 +47,7 @@ class Module {
* Parameters of module, represented as a collection of `Variable`, whose
* ordering is based on the implementation of the respective module.
*/
std::vector<Variable> params_;
std::vector<std::pair<std::string /* param name */, Variable>> params_;

/**
* A flag specifying whether or not the module is in `train` mode. If
Expand All @@ -62,6 +67,7 @@ class Module {
* @param params a vector of `Variable` which will replace `params_`
*/
explicit Module(const std::vector<Variable>& params);
explicit Module(const std::vector<std::pair<std::string, Variable>>& params);

public:
/**
Expand Down Expand Up @@ -90,6 +96,7 @@ class Module {
* @return a `Variable` tensor for the parameter at the requested position
*/
Variable param(int position) const;
std::pair<std::string, Variable> paramNamed(int position) const;

/**
* Sets a parameter at a specified position with a new, given one.
Expand All @@ -104,6 +111,9 @@ class Module {
*/
virtual void setParams(const Variable& var, int position);

void loadStateDict(StateDict const& sd);
StateDict stateDict() const;

/**
* Clears references to gradient Variables for all parameters in the module.
*/
Expand Down
4 changes: 4 additions & 0 deletions flashlight/optim/Optimizers.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@

namespace fl {

// Does it make sense to have a common abstraction between `Module` and optimizers?
// Optimizers and Modules both have parameters, can be serialized, have a `prettyString`, have a state_dict...
// The only thing is that Optimizers don't have `train`/`eval`/`forward`/`backward`

/** An abstract base class for first-order gradient-based optimizers. Any
* derived class must implement the step() function.
* Example usage:
Expand Down