Skip to content

Commit 95c0c12

Browse files
authored
Merge pull request #7384 from dzhwinter/feature/sync_wait
Feature/sync wait
2 parents 9867a37 + 92eb247 commit 95c0c12

File tree

2 files changed

+14
-3
lines changed

2 files changed

+14
-3
lines changed

paddle/framework/operator.cc

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ distributed under the License is distributed on an "AS IS" BASIS,
1111
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
14+
#include <gflags/gflags.h>
1415
#include <glog/logging.h>
1516

1617
#include <algorithm>
@@ -21,6 +22,10 @@ limitations under the License. */
2122
#include "paddle/framework/shape_inference.h"
2223
#include "paddle/framework/var_type.h"
2324

25+
DEFINE_bool(op_sync, false,
26+
"Default cuda is asynchronous device, set to True will"
27+
"force op run in synchronous mode.");
28+
2429
namespace paddle {
2530
namespace framework {
2631

@@ -542,8 +547,14 @@ void OperatorWithKernel::Run(const Scope& scope,
542547

543548
auto kernel_iter = kernels.find(expected_kernel_key);
544549

545-
kernel_iter->second->Compute(ExecutionContext(
546-
*this, new_scope, *pool.Get(expected_kernel_key.place_)));
550+
auto* new_dev_ctx = pool.Get(expected_kernel_key.place_);
551+
kernel_iter->second->Compute(
552+
ExecutionContext(*this, new_scope, *new_dev_ctx));
553+
554+
/*For profiling/benchmark only*/
555+
if (FLAGS_op_sync) {
556+
new_dev_ctx->Wait();
557+
}
547558
}
548559

549560
proto::DataType OperatorWithKernel::IndicateDataType(

python/paddle/v2/fluid/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def __bootstrap__():
5858

5959
read_env_flags = ['use_pinned_memory', 'check_nan_inf']
6060
if core.is_compile_gpu():
61-
read_env_flags.append('fraction_of_gpu_memory_to_use')
61+
read_env_flags += ['fraction_of_gpu_memory_to_use', 'op_sync']
6262
core.init_gflags([sys.argv[0]] +
6363
["--tryfromenv=" + ",".join(read_env_flags)])
6464
core.init_glog(sys.argv[0])

0 commit comments

Comments
 (0)