Skip to content
This repository was archived by the owner on Sep 27, 2019. It is now read-only.

Commit fd2cf84

Browse files
authored
Merge pull request #1207 from saatviks/saatviks_tfsessionwrapper
Tensorflow Integration: Step 2
2 parents 4091b50 + 48cda43 commit fd2cf84

File tree

11 files changed

+634
-3
lines changed

11 files changed

+634
-3
lines changed
Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// Peloton
4+
//
5+
// tf_session_entity.cpp
6+
//
7+
// Identification: src/brain/tf_session_entity/tf_session_entity.cpp
8+
//
9+
// Copyright (c) 2015-2018, Carnegie Mellon University Database Group
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "brain/tf_session_entity/tf_session_entity.h"
14+
#include "brain/tf_session_entity/tf_session_entity_input.h"
15+
#include "brain/tf_session_entity/tf_session_entity_output.h"
16+
17+
namespace peloton {
18+
namespace brain {
19+
20+
/**
21+
* Constructor/Desctructor
22+
**/
23+
24+
TFSE_TEMPLATE_ARGUMENTS
25+
TFSE_TYPE::TfSessionEntity() {
26+
graph_ = TF_NewGraph();
27+
status_ = TF_NewStatus();
28+
session_options_ = TF_NewSessionOptions();
29+
session_ = TF_NewSession(graph_, session_options_, status_);
30+
}
31+
32+
TFSE_TEMPLATE_ARGUMENTS
33+
TFSE_TYPE::~TfSessionEntity() {
34+
TF_DeleteStatus(status_);
35+
TF_DeleteGraph(graph_);
36+
}
37+
38+
/*
39+
********
40+
* Graph Import Utilities
41+
********
42+
*/
43+
44+
TFSE_TEMPLATE_ARGUMENTS
45+
void TFSE_TYPE::FreeBuffer(void *data, UNUSED_ATTRIBUTE size_t length) {
46+
::operator delete(data);
47+
}
48+
49+
TFSE_TEMPLATE_ARGUMENTS
50+
void TFSE_TYPE::ImportGraph(const std::string &filename) {
51+
TF_Buffer *graph_def = ReadFile(filename);
52+
TF_ImportGraphDefOptions *opts = TF_NewImportGraphDefOptions();
53+
TF_GraphImportGraphDef(graph_, graph_def, opts, status_);
54+
TF_DeleteImportGraphDefOptions(opts);
55+
TF_DeleteBuffer(graph_def);
56+
PL_ASSERT(IsStatusOk());
57+
}
58+
59+
TFSE_TEMPLATE_ARGUMENTS
60+
TF_Buffer *TFSE_TYPE::ReadFile(const std::string &filename) {
61+
FILE *f = fopen(filename.c_str(), "rb");
62+
fseek(f, 0, SEEK_END);
63+
size_t fsize = (size_t)ftell(f);
64+
fseek(f, 0, SEEK_SET); // same as rewind(f);
65+
// Reference:
66+
// https://stackoverflow.com/questions/14111900/using-new-on-void-pointer
67+
void *data = ::operator new(fsize);
68+
UNUSED_ATTRIBUTE size_t size_read = fread(data, fsize, 1, f);
69+
fclose(f);
70+
71+
TF_Buffer *buf = TF_NewBuffer();
72+
buf->data = data;
73+
buf->length = fsize;
74+
buf->data_deallocator = TfSessionEntity::FreeBuffer;
75+
return buf;
76+
}
77+
78+
/*
79+
********
80+
* Evaluation/Session.Run()
81+
********
82+
*/
83+
84+
// Evaluate op with no inputs/outputs
85+
TFSE_TEMPLATE_ARGUMENTS
86+
void TFSE_TYPE::Eval(const std::string &opName) {
87+
TF_Operation *op = TF_GraphOperationByName(graph_, opName.c_str());
88+
TF_SessionRun(session_, nullptr, nullptr, nullptr, 0, // inputs
89+
nullptr, nullptr, 0, // outputs
90+
&op, 1, // targets
91+
nullptr, status_);
92+
}
93+
94+
// Evaluate op with inputs and outputs
95+
TFSE_TEMPLATE_ARGUMENTS
96+
OutputType *TFSE_TYPE::Eval(
97+
const std::vector<TfSessionEntityInput<InputType>>& helper_inputs,
98+
const std::vector<TfSessionEntityOutput<OutputType>>& helper_outputs) {
99+
std::vector<TF_Tensor *> in_vals, out_vals;
100+
std::vector<TF_Output> ins, outs;
101+
for (auto helperIn : helper_inputs) {
102+
ins.push_back(
103+
{TF_GraphOperationByName(graph_, helperIn.GetPlaceholderName().c_str()),
104+
0});
105+
in_vals.push_back(helperIn.GetTensor());
106+
}
107+
for (auto helperOut : helper_outputs) {
108+
outs.push_back({TF_GraphOperationByName(
109+
graph_, helperOut.GetPlaceholderName().c_str()),
110+
0});
111+
out_vals.push_back(helperOut.GetTensor());
112+
}
113+
TF_SessionRun(session_, nullptr, &(ins.at(0)), &(in_vals.at(0)),
114+
ins.size(), // Inputs
115+
&(outs.at(0)), &(out_vals.at(0)), outs.size(), // Outputs
116+
nullptr, 0, // Operations
117+
nullptr, status_);
118+
PL_ASSERT(TF_GetCode(status_) == TF_OK);
119+
return static_cast<OutputType *>(TF_TensorData(out_vals.at(0)));
120+
}
121+
122+
// Evaluate op with only inputs(where nothing is output eg. Backprop)
123+
TFSE_TEMPLATE_ARGUMENTS
124+
void TFSE_TYPE::Eval(const std::vector<TfSessionEntityInput<InputType>>& helper_inputs,
125+
const std::string &op_name) {
126+
std::vector<TF_Tensor *> in_vals;
127+
std::vector<TF_Output> ins;
128+
for (auto helperIn : helper_inputs) {
129+
ins.push_back(
130+
{TF_GraphOperationByName(graph_, helperIn.GetPlaceholderName().c_str()),
131+
0});
132+
in_vals.push_back(helperIn.GetTensor());
133+
}
134+
TF_Operation *op = TF_GraphOperationByName(graph_, op_name.c_str());
135+
TF_SessionRun(session_, nullptr, &(ins.at(0)), &(in_vals.at(0)),
136+
ins.size(), // Inputs
137+
nullptr, nullptr, 0, // Outputs
138+
&op, 1, // Operations
139+
nullptr, status_);
140+
PL_ASSERT(TF_GetCode(status_) == TF_OK);
141+
}
142+
143+
/*
144+
********
145+
* Helper Operations
146+
********
147+
*/
148+
149+
TFSE_TEMPLATE_ARGUMENTS
150+
void TFSE_TYPE::PrintOperations() {
151+
TF_Operation *oper;
152+
size_t pos = 0;
153+
std::string graph_ops = "Graph Operations List:";
154+
while ((oper = TF_GraphNextOperation(graph_, &pos)) != nullptr) {
155+
graph_ops += TF_OperationName(oper);
156+
graph_ops += "\n";
157+
}
158+
LOG_DEBUG("%s", graph_ops.c_str());
159+
}
160+
161+
TFSE_TEMPLATE_ARGUMENTS
162+
bool TFSE_TYPE::IsStatusOk() { return TF_GetCode(status_) == TF_OK; }
163+
164+
// Explicit template Initialization
165+
template class TfSessionEntity<float, float>;
166+
} // namespace brain
167+
} // namespace peloton
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// Peloton
4+
//
5+
// tf_session_entity_input.cpp
6+
//
7+
// Identification: src/brain/tf_session_entity/tf_session_entity_input.cpp
8+
//
9+
// Copyright (c) 2015-2018, Carnegie Mellon University Database Group
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "brain/tf_session_entity/tf_session_entity_input.h"
14+
15+
namespace peloton {
16+
namespace brain {
17+
TFSEIN_TEMPLATE_ARGUMENTS
18+
TFSEIN_TYPE::TfSessionEntityInput(const InputType& input, const std::string &op) {
19+
this->placeholder_name_ = op;
20+
this->DetermineDataType();
21+
InputType input_for_tf = input;
22+
this->tensor_ =
23+
TF_AllocateTensor(this->data_type_, nullptr, 0, sizeof(InputType));
24+
auto buff = (InputType *)TF_TensorData(this->tensor_);
25+
PL_MEMCPY(buff, &input_for_tf, sizeof(InputType));
26+
}
27+
28+
// 1d vector
29+
TFSEIN_TEMPLATE_ARGUMENTS
30+
TFSEIN_TYPE::TfSessionEntityInput(const std::vector<InputType> &input,
31+
const std::string &op) {
32+
this->placeholder_name_ = op;
33+
this->DetermineDataType();
34+
int64_t dims[] = {static_cast<int64_t>(input.size())};
35+
const InputType *input_for_tf = input.data();
36+
this->tensor_ =
37+
TF_AllocateTensor(this->data_type_, dims, 1, dims[0] * sizeof(InputType));
38+
auto buff = (InputType *)TF_TensorData(this->tensor_);
39+
PL_MEMCPY(buff, input_for_tf, dims[0] * sizeof(InputType));
40+
}
41+
42+
// 2d vector
43+
TFSEIN_TEMPLATE_ARGUMENTS
44+
TFSEIN_TYPE::TfSessionEntityInput(const std::vector<std::vector<InputType>>& input,
45+
const std::string &op) {
46+
this->placeholder_name_ = op;
47+
this->DetermineDataType();
48+
int64_t dims[] = {static_cast<int64_t>(input.size()),
49+
static_cast<int64_t>(input[0].size())};
50+
InputType *input_for_tf = Flatten(input);
51+
this->tensor_ = TF_AllocateTensor(this->data_type_, dims, 2,
52+
dims[0] * dims[1] * sizeof(InputType));
53+
auto buff = (InputType *)TF_TensorData(this->tensor_);
54+
PL_MEMCPY(buff, input_for_tf, dims[0] * dims[1] * sizeof(InputType));
55+
}
56+
57+
// raw flattened input
58+
TFSEIN_TEMPLATE_ARGUMENTS
59+
TFSEIN_TYPE::TfSessionEntityInput(InputType *input, const std::vector<int64_t>& dims,
60+
const std::string &op) {
61+
this->placeholder_name_ = op;
62+
this->DetermineDataType();
63+
InputType *input_for_tf = input;
64+
int64_t num_elems = 1;
65+
for (auto elem : dims) {
66+
num_elems *= elem;
67+
}
68+
this->tensor_ = TF_AllocateTensor(this->data_type_, dims.data(), dims.size(),
69+
num_elems * sizeof(InputType));
70+
auto buff = (InputType *)TF_TensorData(this->tensor_);
71+
PL_MEMCPY(buff, input_for_tf, num_elems * sizeof(InputType));
72+
}
73+
74+
// Flattens 2d inputs
75+
TFSEIN_TEMPLATE_ARGUMENTS
76+
InputType *TFSEIN_TYPE::Flatten(const std::vector<std::vector<InputType>>& elems) {
77+
std::vector<InputType> flattened;
78+
for (auto row : elems) {
79+
for (float elem : row) {
80+
flattened.push_back(elem);
81+
}
82+
}
83+
return flattened.data();
84+
}
85+
86+
// Explicit template Initialization
87+
template class TfSessionEntityInput<float>;
88+
89+
} // namespace brain
90+
} // namespace peloton
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// Peloton
4+
//
5+
// tf_session_entity_io.cpp
6+
//
7+
// Identification: src/brain/tf_session_entity/tf_session_entity_io.cpp
8+
//
9+
// Copyright (c) 2015-2018, Carnegie Mellon University Database Group
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "brain/tf_session_entity/tf_session_entity_io.h"
14+
15+
namespace peloton {
16+
namespace brain {
17+
18+
TFSEIO_BASE_TEMPLATE_ARGUMENTS
19+
void TFSEIO_BASE_TYPE::DetermineDataType() {
20+
if (std::is_same<N, int64_t>::value) {
21+
data_type_ = TF_INT64;
22+
} else if (std::is_same<N, int32_t>::value) {
23+
data_type_ = TF_INT32;
24+
} else if (std::is_same<N, int16_t>::value) {
25+
data_type_ = TF_INT16;
26+
} else if (std::is_same<N, int8_t>::value) {
27+
data_type_ = TF_INT8;
28+
} else if (std::is_same<N, int>::value) {
29+
data_type_ = TF_INT32;
30+
} else if (std::is_same<N, float>::value) {
31+
data_type_ = TF_FLOAT;
32+
} else if (std::is_same<N, double>::value) {
33+
data_type_ = TF_DOUBLE;
34+
}
35+
}
36+
37+
TFSEIO_BASE_TEMPLATE_ARGUMENTS
38+
std::string TFSEIO_BASE_TYPE::GetPlaceholderName() { return placeholder_name_; }
39+
40+
TFSEIO_BASE_TEMPLATE_ARGUMENTS
41+
TF_Tensor *TFSEIO_BASE_TYPE::GetTensor() { return tensor_; }
42+
43+
// Explicit template Initialization
44+
template class TfSessionEntityIOBase<float>;
45+
46+
} // namespace brain
47+
} // namespace peloton
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// Peloton
4+
//
5+
// tf_session_entity_output.cpp
6+
//
7+
// Identification: src/brain/tf_session_entity/tf_session_entity_output.cpp
8+
//
9+
// Copyright (c) 2015-2018, Carnegie Mellon University Database Group
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "brain/tf_session_entity/tf_session_entity_output.h"
14+
15+
namespace peloton {
16+
namespace brain {
17+
18+
TFSEOUT_TEMPLATE_ARGUMENTS
19+
TFSEOUT_TYPE::TfSessionEntityOutput(const std::string &op) {
20+
this->placeholder_name_ = op;
21+
this->DetermineDataType();
22+
this->tensor_ =
23+
TF_AllocateTensor(this->data_type_, nullptr, 0, sizeof(OutputType));
24+
}
25+
26+
TFSEOUT_TEMPLATE_ARGUMENTS
27+
TFSEOUT_TYPE::TfSessionEntityOutput(const std::vector<int64_t>& dims,
28+
const std::string &op) {
29+
this->placeholder_name_ = op;
30+
this->DetermineDataType();
31+
int64_t num_elems = 1;
32+
for (auto elem : dims) {
33+
num_elems *= elem;
34+
}
35+
this->tensor_ = TF_AllocateTensor(this->data_type_, dims.data(), dims.size(),
36+
sizeof(OutputType) * num_elems);
37+
}
38+
39+
// Explicit template Initialization
40+
template class TfSessionEntityOutput<float>;
41+
} // namespace brain
42+
} // namespace peloton

0 commit comments

Comments
 (0)