Skip to content

Commit 13509da

Browse files
author
Yibing Liu
committed
Merge upstream to branch wrap_squeezes
2 parents 03f6292 + 9be39bb commit 13509da

File tree

79 files changed

+3144
-737
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

79 files changed

+3144
-737
lines changed

paddle/fluid/API.spec

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ paddle.fluid.layers.beam_search_decode ArgSpec(args=['ids', 'scores', 'beam_size
113113
paddle.fluid.layers.conv2d_transpose ArgSpec(args=['input', 'num_filters', 'output_size', 'filter_size', 'padding', 'stride', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name'], varargs=None, keywords=None, defaults=(None, None, 0, 1, 1, None, None, None, True, None, None))
114114
paddle.fluid.layers.conv3d_transpose ArgSpec(args=['input', 'num_filters', 'output_size', 'filter_size', 'padding', 'stride', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name'], varargs=None, keywords=None, defaults=(None, None, 0, 1, 1, None, None, None, True, None, None))
115115
paddle.fluid.layers.sequence_expand ArgSpec(args=['x', 'y', 'ref_level', 'name'], varargs=None, keywords=None, defaults=(-1, None))
116+
paddle.fluid.layers.sequence_pad ArgSpec(args=['x', 'pad_value', 'maxlen'], varargs=None, keywords=None, defaults=(None,))
116117
paddle.fluid.layers.lstm_unit ArgSpec(args=['x_t', 'hidden_t_prev', 'cell_t_prev', 'forget_bias', 'param_attr', 'bias_attr', 'name'], varargs=None, keywords=None, defaults=(0.0, None, None, None))
117118
paddle.fluid.layers.reduce_sum ArgSpec(args=['input', 'dim', 'keep_dim', 'name'], varargs=None, keywords=None, defaults=(None, False, None))
118119
paddle.fluid.layers.reduce_mean ArgSpec(args=['input', 'dim', 'keep_dim', 'name'], varargs=None, keywords=None, defaults=(None, False, None))
@@ -148,6 +149,7 @@ paddle.fluid.layers.unsqueeze ArgSpec(args=['input', 'axes', 'name'], varargs=No
148149
paddle.fluid.layers.lod_reset ArgSpec(args=['x', 'y', 'target_lod'], varargs=None, keywords=None, defaults=(None, None))
149150
paddle.fluid.layers.lrn ArgSpec(args=['input', 'n', 'k', 'alpha', 'beta', 'name'], varargs=None, keywords=None, defaults=(5, 1.0, 0.0001, 0.75, None))
150151
paddle.fluid.layers.pad ArgSpec(args=['x', 'paddings', 'pad_value', 'name'], varargs=None, keywords=None, defaults=(0.0, None))
152+
paddle.fluid.layers.pad_constant_like ArgSpec(args=['x', 'y', 'pad_value', 'name'], varargs=None, keywords=None, defaults=(0.0, None))
151153
paddle.fluid.layers.label_smooth ArgSpec(args=['label', 'prior_dist', 'epsilon', 'dtype', 'name'], varargs=None, keywords=None, defaults=(None, 0.1, 'float32', None))
152154
paddle.fluid.layers.roi_pool ArgSpec(args=['input', 'rois', 'pooled_height', 'pooled_width', 'spatial_scale'], varargs=None, keywords=None, defaults=(1, 1, 1.0))
153155
paddle.fluid.layers.dice_loss ArgSpec(args=['input', 'label', 'epsilon'], varargs=None, keywords=None, defaults=(1e-05,))
@@ -166,6 +168,7 @@ paddle.fluid.layers.prelu ArgSpec(args=['x', 'mode', 'param_attr', 'name'], vara
166168
paddle.fluid.layers.flatten ArgSpec(args=['x', 'axis', 'name'], varargs=None, keywords=None, defaults=(1, None))
167169
paddle.fluid.layers.sequence_mask ArgSpec(args=['x', 'maxlen', 'dtype', 'name'], varargs=None, keywords=None, defaults=(None, 'int64', None))
168170
paddle.fluid.layers.stack ArgSpec(args=['x', 'axis'], varargs=None, keywords=None, defaults=(0,))
171+
paddle.fluid.layers.unstack ArgSpec(args=['x', 'axis', 'num'], varargs=None, keywords=None, defaults=(0, None))
169172
paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True))
170173
paddle.fluid.layers.open_recordio_file ArgSpec(args=['filename', 'shapes', 'lod_levels', 'dtypes', 'pass_num', 'for_parallel'], varargs=None, keywords=None, defaults=(1, True))
171174
paddle.fluid.layers.open_files ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None))
@@ -380,7 +383,7 @@ paddle.fluid.LoDTensor.__init__ 1. __init__(self: paddle.fluid.core.LoDTensor, a
380383
paddle.fluid.LoDTensor.has_valid_recursive_sequence_lengths has_valid_recursive_sequence_lengths(self: paddle.fluid.core.LoDTensor) -> bool
381384
paddle.fluid.LoDTensor.lod lod(self: paddle.fluid.core.LoDTensor) -> List[List[int]]
382385
paddle.fluid.LoDTensor.recursive_sequence_lengths recursive_sequence_lengths(self: paddle.fluid.core.LoDTensor) -> List[List[int]]
383-
paddle.fluid.LoDTensor.set 1. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[float32], arg1: paddle::platform::CPUPlace) -> None 2. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[int32], arg1: paddle::platform::CPUPlace) -> None 3. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[float64], arg1: paddle::platform::CPUPlace) -> None 4. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[int64], arg1: paddle::platform::CPUPlace) -> None 5. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[bool], arg1: paddle::platform::CPUPlace) -> None 6. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[uint16], arg1: paddle::platform::CPUPlace) -> None 7. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[uint8], arg1: paddle::platform::CPUPlace) -> None 8. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[float32], arg1: paddle::platform::CUDAPlace) -> None 9. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[int32], arg1: paddle::platform::CUDAPlace) -> None 10. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[float64], arg1: paddle::platform::CUDAPlace) -> None 11. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[int64], arg1: paddle::platform::CUDAPlace) -> None 12. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[bool], arg1: paddle::platform::CUDAPlace) -> None 13. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[uint16], arg1: paddle::platform::CUDAPlace) -> None 14. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[uint8], arg1: paddle::platform::CUDAPlace) -> None 15. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[float32], arg1: paddle::platform::CUDAPinnedPlace) -> None 16. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[int32], arg1: paddle::platform::CUDAPinnedPlace) -> None 17. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[float64], arg1: paddle::platform::CUDAPinnedPlace) -> None 18. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[int64], arg1: paddle::platform::CUDAPinnedPlace) -> None 19. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[bool], arg1: paddle::platform::CUDAPinnedPlace) -> None 20. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[uint16], arg1: paddle::platform::CUDAPinnedPlace) -> None 21. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[uint8], arg1: paddle::platform::CUDAPinnedPlace) -> None
386+
paddle.fluid.LoDTensor.set 1. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[float32], arg1: paddle::platform::CPUPlace) -> None 2. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[int32], arg1: paddle::platform::CPUPlace) -> None 3. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[float64], arg1: paddle::platform::CPUPlace) -> None 4. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[int64], arg1: paddle::platform::CPUPlace) -> None 5. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[bool], arg1: paddle::platform::CPUPlace) -> None 6. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[uint16], arg1: paddle::platform::CPUPlace) -> None 7. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[uint8], arg1: paddle::platform::CPUPlace) -> None 8. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[int8], arg1: paddle::platform::CPUPlace) -> None 9. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[float32], arg1: paddle::platform::CUDAPlace) -> None 10. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[int32], arg1: paddle::platform::CUDAPlace) -> None 11. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[float64], arg1: paddle::platform::CUDAPlace) -> None 12. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[int64], arg1: paddle::platform::CUDAPlace) -> None 13. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[bool], arg1: paddle::platform::CUDAPlace) -> None 14. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[uint16], arg1: paddle::platform::CUDAPlace) -> None 15. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[uint8], arg1: paddle::platform::CUDAPlace) -> None 16. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[int8], arg1: paddle::platform::CUDAPlace) -> None 17. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[float32], arg1: paddle::platform::CUDAPinnedPlace) -> None 18. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[int32], arg1: paddle::platform::CUDAPinnedPlace) -> None 19. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[float64], arg1: paddle::platform::CUDAPinnedPlace) -> None 20. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[int64], arg1: paddle::platform::CUDAPinnedPlace) -> None 21. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[bool], arg1: paddle::platform::CUDAPinnedPlace) -> None 22. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[uint16], arg1: paddle::platform::CUDAPinnedPlace) -> None 23. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[uint8], arg1: paddle::platform::CUDAPinnedPlace) -> None 24. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[int8], arg1: paddle::platform::CUDAPinnedPlace) -> None
384387
paddle.fluid.LoDTensor.set_lod set_lod(self: paddle.fluid.core.LoDTensor, arg0: List[List[int]]) -> None
385388
paddle.fluid.LoDTensor.set_recursive_sequence_lengths set_recursive_sequence_lengths(self: paddle.fluid.core.LoDTensor, arg0: List[List[int]]) -> None
386389
paddle.fluid.LoDTensor.shape shape(self: paddle.fluid.core.Tensor) -> List[int]

paddle/fluid/framework/data_type.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ static DataTypeMap* InitDataTypeMap() {
6464
RegType(size_t, proto::VarType::SIZE_T);
6565
RegType(int16_t, proto::VarType::INT16);
6666
RegType(uint8_t, proto::VarType::UINT8);
67+
RegType(int8_t, proto::VarType::INT8);
6768

6869
#undef RegType
6970
return retv;

paddle/fluid/framework/data_type.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,9 @@ inline void VisitDataType(proto::VarType::Type type, Visitor visitor) {
5454
case proto::VarType::INT16:
5555
visitor.template operator()<int16_t>();
5656
break;
57+
case proto::VarType::INT8:
58+
visitor.template operator()<int8_t>();
59+
break;
5760
default:
5861
PADDLE_THROW("Not supported %d", type);
5962
}

paddle/fluid/framework/details/multi_devices_graph_pass.cc

Lines changed: 41 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -754,17 +754,26 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
754754
node->Op()->Type());
755755

756756
CreateComputationalOp(result, node, op_dev_id);
757-
if (node->Op()->Type() == "concat") {
758-
ConnectOp(result, result->Get<GraphOps>(kGraphOps).back().get(),
759-
"fetch_barrier");
757+
}
758+
759+
void SetOpInputsAllPlaces(ir::Graph *result, ir::Node *node, int num_places) {
760+
auto *op_handle = result->Get<GraphOps>(kGraphOps).back().get();
761+
for (ir::Node *input : node->inputs) {
762+
VarHandle *var = nullptr;
763+
for (int place_offset = 0; place_offset < num_places; ++place_offset) {
764+
auto &var_holders = result->Get<GraphVars>(kGraphVars)[place_offset];
765+
auto &var_holder = var_holders[input->Name()];
766+
if (!var_holder.empty()) {
767+
var = var_holder.rbegin()->get();
768+
op_handle->AddInput(var);
769+
}
770+
}
760771
}
761772
}
762773

763774
// Create RPC related op handles that connects its in ops and out ops.
764775
void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
765776
ir::Node *node) const {
766-
// FIXME(typhoonzero): Cleanup this deps for both sync mode and async mode
767-
// put them into transpiler.
768777
int op_dev_id = -1;
769778
if (node->Op()->Type() == "send") {
770779
// TODO(paddle-dev): getting the first var is not safe.
@@ -799,8 +808,6 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
799808
}
800809
auto recv_param_grad = boost::get<std::vector<std::string>>(
801810
node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName()));
802-
// FIXME(typhoonzero): assume each recv op output one param
803-
// Use the same place as send.
804811
if (recv_param_grad.size() == 2U) {
805812
op_dev_id = GetVarDeviceID(*result, recv_param_grad[1]);
806813
VLOG(10) << "recv param " << recv_param_grad[0]
@@ -814,34 +821,44 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
814821
.emplace(varname, op_dev_id);
815822
}
816823
} else {
817-
// send_barrier and fetch_barrier op can be scheduled on device 0
824+
// send_barrier, fetch_barrier will run on place 0;
818825
op_dev_id = 0;
819826
}
820827

821828
PADDLE_ENFORCE(op_dev_id != -1, "can not find the right place for rpc op: %s",
822829
node->Op()->Type());
823-
824830
result->Get<GraphOps>(kGraphOps).emplace_back(new RPCOpHandle(
825831
result->CreateOpNode(node->Op()), *node->Op(), local_scopes_[op_dev_id],
826832
node->Op()->Type(), places_[op_dev_id]));
827833

828-
// TODO(panyx0718): This might not be needed anymore.
829-
if (node->Op()->Type() == "send_barrier") {
830-
ConnectOp(result, result->Get<GraphOps>(kGraphOps).back().get(), "send");
831-
} else if (node->Op()->Type() == "recv") {
832-
ConnectOp(result, result->Get<GraphOps>(kGraphOps).back().get(),
833-
"send_barrier");
834-
} else if (node->Op()->Type() == "fetch_barrier") {
835-
ConnectOp(result, result->Get<GraphOps>(kGraphOps).back().get(), "recv");
836-
} else if (node->Op()->Type() == "send") {
837-
// do nothing
834+
if (node->Op()->Type() == "send") {
835+
CreateOpHandleIOs(result, node, op_dev_id);
838836
} else {
839-
PADDLE_THROW(
840-
"rpc op should be in ["
841-
"send, send_barrier. recv, fetch_barrier]");
842-
}
837+
// send_barrier, recv, fetch_barrier's inputs are deps var, get them from
838+
// all places
839+
auto p = places_[op_dev_id];
840+
auto *op_handle = result->Get<GraphOps>(kGraphOps).back().get();
841+
op_handle->SetDeviceContext(p,
842+
platform::DeviceContextPool::Instance().Get(p));
843843

844-
CreateOpHandleIOs(result, node, op_dev_id);
844+
SetOpInputsAllPlaces(result, node, places_.size());
845+
for (ir::Node *output : node->outputs) {
846+
int outvar_dev_id = op_dev_id;
847+
if (node->Op()->Type() == "fetch_barrier") {
848+
outvar_dev_id = GetVarDeviceID(*result, output->Name());
849+
PADDLE_ENFORCE_NE(outvar_dev_id, -1);
850+
}
851+
p = places_[outvar_dev_id];
852+
ir::Node *new_node = nullptr;
853+
if (output->Var()) {
854+
new_node = result->CreateVarNode(output->Var());
855+
} else {
856+
new_node =
857+
result->CreateEmptyNode(output->Name(), ir::Node::Type::kVariable);
858+
}
859+
CreateOpOutput(result, op_handle, new_node, p, outvar_dev_id);
860+
}
861+
}
845862
}
846863

847864
bool MultiDevSSAGraphBuilder::IsScaleLossOp(ir::Node *node) const {

paddle/fluid/framework/framework.proto

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ message VarType {
107107
// Tensor<size_t> is used in C++.
108108
SIZE_T = 19;
109109
UINT8 = 20;
110+
INT8 = 21;
110111

111112
// Other types that may need additional descriptions
112113
LOD_TENSOR = 7;

paddle/fluid/framework/ir/graph.cc

Lines changed: 0 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -132,63 +132,6 @@ Graph::Graph(const ProgramDesc &program) : program_(program) {
132132
}
133133
}
134134

135-
std::vector<ir::Node *> send_ops;
136-
ir::Node *send_bar = nullptr;
137-
std::vector<ir::Node *> recv_ops;
138-
ir::Node *fetch_bar = nullptr;
139-
for (ir::Node *node : Nodes()) {
140-
if (node->Name() == "send") {
141-
send_ops.push_back(node);
142-
} else if (node->Name() == "send_barrier") {
143-
PADDLE_ENFORCE(!send_bar, "only has one send barrier");
144-
send_bar = node;
145-
} else if (node->Name() == "recv") {
146-
recv_ops.push_back(node);
147-
} else if (node->Name() == "fetch_barrier") {
148-
PADDLE_ENFORCE(!fetch_bar, "only has one fetch barrier");
149-
fetch_bar = node;
150-
}
151-
}
152-
if (send_bar) {
153-
for (ir::Node *send : send_ops) {
154-
ir::Node *dep_var = CreateControlDepVar();
155-
send->outputs.push_back(dep_var);
156-
dep_var->inputs.push_back(send);
157-
send_bar->inputs.push_back(dep_var);
158-
dep_var->outputs.push_back(send_bar);
159-
}
160-
for (ir::Node *recv : recv_ops) {
161-
ir::Node *dep_var = CreateControlDepVar();
162-
recv->inputs.push_back(dep_var);
163-
dep_var->outputs.push_back(recv);
164-
send_bar->outputs.push_back(dep_var);
165-
dep_var->inputs.push_back(send_bar);
166-
}
167-
}
168-
if (fetch_bar) {
169-
for (ir::Node *recv : recv_ops) {
170-
ir::Node *dep_var = CreateControlDepVar();
171-
recv->outputs.push_back(dep_var);
172-
dep_var->inputs.push_back(recv);
173-
fetch_bar->inputs.push_back(dep_var);
174-
dep_var->outputs.push_back(fetch_bar);
175-
}
176-
}
177-
178-
std::vector<std::string> send_vars = FindDistTrainSendVars(send_ops);
179-
std::vector<std::string> recv_vars = FindDistTrainRecvVars(recv_ops);
180-
for (ir::Node *node : Nodes()) {
181-
if (IsDistTrainOp(node, send_vars, recv_vars)) {
182-
if (fetch_bar && node->Name() == "concat") {
183-
ir::Node *dep_var = CreateControlDepVar();
184-
fetch_bar->outputs.push_back(dep_var);
185-
dep_var->inputs.push_back(fetch_bar);
186-
node->inputs.push_back(dep_var);
187-
dep_var->outputs.push_back(node);
188-
}
189-
}
190-
}
191-
192135
/**
193136
* We should handle write after read(WAR) and write after write(WAW) here.
194137
* Because some of the operators of the program can be executed parallelly.

paddle/fluid/framework/tensor.cc

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,11 @@ void* Tensor::mutable_data(platform::Place place, std::type_index type,
4040
"When calling this method, the Tensor's numel must be "
4141
"equal or larger than zero. "
4242
"Please check Tensor::Resize has been called first.");
43-
size_t size = requested_size ? requested_size : numel() * SizeOfType(type);
43+
size_t size = numel() * SizeOfType(type);
44+
if (requested_size) {
45+
PADDLE_ENFORCE_GE(requested_size, size);
46+
size = requested_size;
47+
}
4448
/* some versions of boost::variant don't have operator!= */
4549
if (holder_ == nullptr || !(holder_->place() == place) ||
4650
holder_->size() < size + offset_) {

paddle/fluid/inference/analysis/analyzer.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ class DfgPassManagerImpl final : public DfgPassManager {
7272
auto trt_teller = [&](const Node* node) {
7373
std::unordered_set<std::string> teller_set(
7474
{"elementwise_add", "mul", "conv2d", "pool2d", "relu", "softmax",
75-
"depthwise_conv2d", "batch_norm"});
75+
"depthwise_conv2d", "batch_norm", "concat"});
7676
if (!node->IsFunction()) return false;
7777

7878
const auto* func = static_cast<const Function*>(node);

paddle/fluid/inference/api/api_tensorrt_subgraph_engine.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ class TensorRTSubgraphPredictor : public NativePaddlePredictor {
3232
: NativePaddlePredictor(config), config_(config) {}
3333

3434
bool Init(const std::shared_ptr<framework::Scope>& parent_scope) {
35+
FLAGS_IA_enable_tensorrt_subgraph_engine = true;
3536
VLOG(3) << "Predictor::init()";
3637
FLAGS_tensorrt_max_batch_size = config_.max_batch_size;
3738
FLAGS_tensorrt_workspace_size = config_.workspace_size;
@@ -161,3 +162,4 @@ USE_TRT_CONVERTER(fc);
161162
USE_TRT_CONVERTER(pool2d);
162163
USE_TRT_CONVERTER(softmax);
163164
USE_TRT_CONVERTER(batch_norm);
165+
USE_TRT_CONVERTER(concat);

paddle/fluid/inference/api/api_tensorrt_subgraph_engine_tester.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ void CompareTensorRTWithFluid(bool enable_tensorrt) {
3737
config1.use_gpu = true;
3838
config1.fraction_of_gpu_memory = 0.3;
3939
config1.device = 0;
40+
config1.max_batch_size = 10;
4041

4142
auto predictor0 =
4243
CreatePaddlePredictor<NativeConfig, PaddleEngineKind::kNative>(config0);

0 commit comments

Comments
 (0)