Skip to content

Commit fe7c181

Browse files
authored
Merge pull request #8538 from typhoonzero/add_raw_var_type
fix short connection again
2 parents 1924aa1 + 6a68679 commit fe7c181

File tree

7 files changed

+31
-11
lines changed

7 files changed

+31
-11
lines changed

paddle/fluid/framework/executor.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,13 +58,13 @@ static void CreateTensor(Variable* var, proto::VarType::Type var_type) {
5858
var->GetMutable<ReaderHolder>();
5959
} else if (var_type == proto::VarType::CHANNEL) {
6060
var->GetMutable<ChannelHolder>();
61-
} else if (var_type == proto::VarType::NCCL_COM) {
62-
// GetMutable will be called in ncclInit
61+
} else if (var_type == proto::VarType::RAW) {
62+
// GetMutable will be called in operator
6363
} else {
6464
PADDLE_THROW(
6565
"Variable type %d is not in "
6666
"[LOD_TENSOR, SELECTED_ROWS, FEED_MINIBATCH, FETCH_LIST, "
67-
"LOD_RANK_TABLE, PLACE_LIST, READER, CHANNEL, NCCL_COM]",
67+
"LOD_RANK_TABLE, PLACE_LIST, READER, CHANNEL, RAW]",
6868
var_type);
6969
}
7070
}

paddle/fluid/framework/framework.proto

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,10 @@ message VarType {
113113
PLACE_LIST = 14;
114114
READER = 15;
115115
CHANNEL = 16;
116-
NCCL_COM = 17;
116+
// Any runtime decided variable type is raw
117+
// raw variables should manage their own allocations
118+
// in operators like nccl_op
119+
RAW = 17;
117120
}
118121

119122
required Type type = 1;

paddle/fluid/operators/detail/grpc_client.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -177,8 +177,8 @@ std::shared_ptr<grpc::Channel> RPCClient::GetChannel(const std::string& ep) {
177177
args.SetMaxSendMessageSize(std::numeric_limits<int>::max());
178178
args.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
179179

180-
auto ch = std::shared_ptr<grpc::Channel>(
181-
grpc::CreateCustomChannel(ep, grpc::InsecureChannelCredentials(), args));
180+
auto ch =
181+
grpc::CreateCustomChannel(ep, grpc::InsecureChannelCredentials(), args);
182182

183183
channels_[ep] = ch;
184184
return ch;

paddle/fluid/operators/nccl_op.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ class NCCLInitOpVarTypeInference : public framework::VarTypeInference {
6565
framework::BlockDesc *block) const override {
6666
auto out_var_name = op_desc.Output("Communicator").front();
6767
auto &out_var = block->FindRecursiveOrCreateVar(out_var_name);
68-
auto var_type = framework::proto::VarType::NCCL_COM;
68+
auto var_type = framework::proto::VarType::RAW;
6969
out_var.SetType(var_type);
7070
}
7171
};

paddle/fluid/operators/send_op.cc

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,9 +121,27 @@ This operator will send tensor to recv_op at the parameter server.
121121
}
122122
};
123123

124+
class SendOpVarTypeInference : public framework::VarTypeInference {
125+
public:
126+
void operator()(const framework::OpDesc& op_desc,
127+
framework::BlockDesc* block) const override {
128+
auto out_var_name = op_desc.Output("RPCClient").front();
129+
auto& out_var = block->FindRecursiveOrCreateVar(out_var_name);
130+
auto var_type = framework::proto::VarType::RAW;
131+
out_var.SetType(var_type);
132+
}
133+
};
134+
135+
class SendOpShapeInference : public framework::InferShapeBase {
136+
public:
137+
void operator()(framework::InferShapeContext* ctx) const override {}
138+
};
139+
124140
} // namespace operators
125141
} // namespace paddle
126142

127143
namespace ops = paddle::operators;
128144

129-
REGISTER_OPERATOR(send, ops::SendOp, ops::SendOpMaker);
145+
REGISTER_OPERATOR(send, ops::SendOp, paddle::framework::EmptyGradOpMaker,
146+
ops::SendOpMaker, ops::SendOpVarTypeInference,
147+
ops::SendOpShapeInference);

paddle/fluid/pybind/protobuf.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,7 @@ void BindVarDsec(py::module &m) {
252252
.value("CHANNEL", proto::VarType::CHANNEL)
253253
.value("PLACE_LIST", proto::VarType::PLACE_LIST)
254254
.value("READER", proto::VarType::READER)
255-
.value("NCCL_COM", proto::VarType::NCCL_COM);
255+
.value("RAW", proto::VarType::RAW);
256256
}
257257

258258
void BindOpDesc(py::module &m) {

python/paddle/fluid/distribute_transpiler.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -226,8 +226,7 @@ def transpile(self,
226226
rpc_client_var = program.global_block().create_var(
227227
name="RPC_CLIENT_VAR",
228228
persistable=True,
229-
dtype='float32', # dtype and shape is not used in fact
230-
shape=[0])
229+
type=core.VarDesc.VarType.RAW)
231230

232231
# create send_op
233232
program.global_block().append_op(

0 commit comments

Comments
 (0)