@@ -11,6 +11,7 @@ distributed under the License is distributed on an "AS IS" BASIS,
11
11
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
See the License for the specific language governing permissions and
13
13
limitations under the License. */
14
+ #include < gflags/gflags.h>
14
15
#include < glog/logging.h>
15
16
16
17
#include < algorithm>
@@ -21,6 +22,10 @@ limitations under the License. */
21
22
#include " paddle/framework/shape_inference.h"
22
23
#include " paddle/framework/var_type.h"
23
24
25
+ DEFINE_bool (op_sync, false ,
26
+ " Default cuda is asynchronous device, set to True will"
27
+ " force op run in synchronous mode." );
28
+
24
29
namespace paddle {
25
30
namespace framework {
26
31
@@ -542,8 +547,14 @@ void OperatorWithKernel::Run(const Scope& scope,
542
547
543
548
auto kernel_iter = kernels.find (expected_kernel_key);
544
549
545
- kernel_iter->second ->Compute (ExecutionContext (
546
- *this , new_scope, *pool.Get (expected_kernel_key.place_ )));
550
+ auto * new_dev_ctx = pool.Get (expected_kernel_key.place_ );
551
+ kernel_iter->second ->Compute (
552
+ ExecutionContext (*this , new_scope, *new_dev_ctx));
553
+
554
+ /* For profiling/benchmark only*/
555
+ if (FLAGS_op_sync) {
556
+ new_dev_ctx->Wait ();
557
+ }
547
558
}
548
559
549
560
proto::DataType OperatorWithKernel::IndicateDataType (
0 commit comments