Skip to content

Commit a6a3b2f

Browse files
author
chengduo
authored
[Speed]Refine ParallelExecutor (#16190)
* refine parallelExecutor test=develop * Polish op_handle test=develop * Remove unnecessary op_handle test=develop * Fix Travis CI test=develop * Fix fetch bug test=develop * Remove WaitInputVarGenerated * Fix OpHandleBase::Run test=develop * debug test=develop * use origin fetch_op_handle test=develop * Revert op_handle_base.cc test=develop * Polish code test=develop * Fix OpHandleBase::Run test=develop * code refine * test CI and CE test=develop * fix OpHandle::Run test=develop * refine AllReduceOpHandle test=develop * Polish code test=develop
1 parent 3396552 commit a6a3b2f

15 files changed

+221
-404
lines changed

paddle/fluid/framework/details/CMakeLists.txt

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,7 @@ else()
5151
cc_library(fused_broadcast_op_handle SRCS fused_broadcast_op_handle.cc DEPS broadcast_op_handle)
5252
endif()
5353

54-
cc_library(data_balance_op_handle SRCS data_balance_op_handle.cc DEPS op_handle_base scope lod_tensor)
5554
cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor)
56-
cc_library(fuse_vars_op_handle SRCS fuse_vars_op_handle.cc DEPS op_handle_base scope)
5755

5856
if(WITH_GPU)
5957
cc_library(memory_optimize_helper SRCS memory_optimize_helper.cc DEPS graph graph_helper gpu_info)
@@ -74,7 +72,7 @@ cc_library(sequential_execution_pass SRCS sequential_execution_pass.cc DEPS grap
7472
cc_library(all_reduce_deps_pass SRCS all_reduce_deps_pass.cc DEPS graph graph_helper pass)
7573

7674
cc_library(multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle
77-
scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle fused_broadcast_op_handle)
75+
scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle fused_broadcast_op_handle)
7876

7977
cc_library(fuse_all_reduce_op_pass SRCS fuse_all_reduce_op_pass.cc DEPS graph graph_helper fused_all_reduce_op_handle)
8078

paddle/fluid/framework/details/all_reduce_op_handle.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,8 @@
1111
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
14-
#include <algorithm>
15-
1614
#include "paddle/fluid/framework/details/all_reduce_op_handle.h"
15+
#include <algorithm>
1716
#include "paddle/fluid/framework/details/container_cast.h"
1817
#include "paddle/fluid/framework/details/reduce_and_gather.h"
1918
#include "paddle/fluid/framework/details/variable_visitor.h"
@@ -56,6 +55,7 @@ void AllReduceOpHandle::RunImpl() {
5655
platform::RecordEvent record_event(Name());
5756

5857
WaitInputVarGenerated();
58+
5959
auto in_var_handles = DynamicCast<VarHandle>(this->Inputs());
6060
auto out_var_handles = DynamicCast<VarHandle>(this->Outputs());
6161
PADDLE_ENFORCE_EQ(

paddle/fluid/framework/details/broadcast_op_handle.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ struct BroadcastOpHandle : public OpHandleBase {
5757

5858
std::string Name() const override;
5959

60-
bool IsMultiDeviceTransfer() override { return false; };
60+
bool IsMultiDeviceTransfer() override { return true; };
6161

6262
protected:
6363
void RunImpl() override;

paddle/fluid/framework/details/build_strategy.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,10 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
147147
// Verify that the graph is correct for multi-device executor.
148148
AppendPass("multi_devices_check_pass");
149149

150+
if (VLOG_IS_ON(2)) {
151+
AppendPass("all_reduce_deps_pass");
152+
}
153+
150154
if (SeqOnlyAllReduceOps(strategy)) {
151155
VLOG(10) << "Add all_reduce_deps_pass";
152156
AppendPass("all_reduce_deps_pass");

paddle/fluid/framework/details/data_balance_op_handle.cc

Lines changed: 0 additions & 154 deletions
This file was deleted.

paddle/fluid/framework/details/data_balance_op_handle.h

Lines changed: 0 additions & 59 deletions
This file was deleted.

paddle/fluid/framework/details/fetch_op_handle.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,8 @@ void FetchOpHandle::WaitInputVarGenerated(const platform::Place &place) {
8282
}
8383
}
8484

85+
bool FetchOpHandle::IsMultiDeviceTransfer() { return true; }
86+
8587
std::string FetchOpHandle::Name() const { return "Fetch"; }
8688

8789
} // namespace details

paddle/fluid/framework/details/fetch_op_handle.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ struct FetchOpHandle : public OpHandleBase {
3939

4040
std::string Name() const override;
4141

42+
bool IsMultiDeviceTransfer() override;
43+
4244
protected:
4345
void RunImpl() override;
4446

paddle/fluid/framework/details/fuse_vars_op_handle.cc

Lines changed: 0 additions & 51 deletions
This file was deleted.

0 commit comments

Comments
 (0)