Skip to content

Commit 85c8f18

Browse files
committed
Add Profiling Code
1 parent c2f5fdf commit 85c8f18

File tree

7 files changed

+106
-9
lines changed

7 files changed

+106
-9
lines changed

cmake/CPU.cmake

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,14 @@ ELSE()
3030
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O2 -DNDEBUG")
3131
ENDIF()
3232

33+
IF("${IPEX_DISP_OP}" STREQUAL "1")
34+
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DIPEX_DISP_OP")
35+
ENDIF()
36+
37+
IF("${IPEX_PROFILE_OP}" STREQUAL "1")
38+
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DIPEX_PROFILE_OP")
39+
ENDIF()
40+
3341
# ---[ Build flags
3442
set(CMAKE_C_STANDARD 11)
3543
set(CMAKE_CXX_STANDARD 14)

scripts/cpu/gen-dense-cpu-ops.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -517,16 +517,19 @@ def is_conv_overrideable_func(fname):
517517
# Gen definition code for cpp file
518518
code = '{} {{\n'.format(cpp_func_str_cpp)
519519

520-
# Gen profile info
521-
code += '#if defined(_DEBUG)\n'
520+
# Gen OP Name
521+
code += '#if defined(IPEX_DISP_OP)\n'
522522
code += ' printf("{}::{}\\n");\n'.format(_IPEX_OP_FUNC_NS, cpp_sig.def_name)
523523
code += '#endif\n'
524+
525+
# Gen profile info
524526
profiler_inputs = []
525527
for param in cpp_sig.input_params:
526528
if param.core_type in ['Tensor', 'Scalar']:
527529
profiler_inputs.append(param.name)
530+
code += '#if defined(IPEX_PROFILE_OP)\n'
528531
code += ' RECORD_FUNCTION("{ns}::{name}", std::vector<c10::IValue>({{{input_names}}}), torch::autograd::Node::peek_at_next_sequence_nr());\n'.format(ns=_IPEX_OP_FUNC_NS, name=cpp_sig.def_name, input_names=', '.join(profiler_inputs))
529-
532+
code += '#endif\n'
530533

531534
if is_conv_overrideable_func(cpp_sig.def_name):
532535
code += ' return AtenIpexCPUDev::dil_{}({});\n'.format(cpp_sig.def_name, ', '.join([param.name for param in cpp_sig.input_params]))

scripts/cpu/gen-sparse-cpu-ops.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,8 @@ class AtenIpexCPUSparse {{
8181
#include <ATen/CPUGenerator.h>
8282
#include <c10/util/Exception.h>
8383
#include <c10/util/Logging.h>
84+
#include <torch/csrc/autograd/function.h>
85+
#include <torch/csrc/autograd/record_function.h>
8486
8587
#include "aten_ipex_bridge.h"
8688
#include "ipex_sparse_tensor_impl.h"
@@ -405,6 +407,21 @@ def gen_code(self):
405407

406408
# Gen definition code for cpp file
407409
code += '{} {{\n'.format(cpp_func_str_cpp)
410+
411+
# Gen OP Name
412+
code += '#if defined(IPEX_DISP_OP)\n'
413+
code += ' printf("{}::{}\\n");\n'.format(_IPEX_OP_FUNC_NS, cpp_sparse_sig.def_name)
414+
code += '#endif\n'
415+
416+
# Gen profile info
417+
profiler_inputs = []
418+
for param in cpp_sparse_sig.input_params:
419+
if param.core_type in ['Tensor', 'Scalar']:
420+
profiler_inputs.append(param.name)
421+
code += '#if defined(IPEX_PROFILE_OP)\n'
422+
code += ' RECORD_FUNCTION("{ns}::{name}", std::vector<c10::IValue>({{{input_names}}}), torch::autograd::Node::peek_at_next_sequence_nr());\n'.format(ns=_IPEX_OP_FUNC_NS, name=cpp_sparse_sig.def_name, input_names=', '.join(profiler_inputs))
423+
code += '#endif\n'
424+
408425
code += self.gen_fallback_prepare_code(cpp_sparse_sig)
409426
code += self.gen_fallback_code(cpp_sparse_sig)
410427
code += self.gen_fallback_post_code(cpp_sparse_sig)

setup.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,12 @@ def build_extension(self, ext):
193193
'-DPYTHON_INCLUDE_DIR=' + python_include_dir,
194194
]
195195

196+
if _check_env_flag("IPEX_DISP_OP"):
197+
cmake_args += ['-DIPEX_DISP_OP=1']
198+
199+
if _check_env_flag("IPEX_PROFILE_OP"):
200+
cmake_args += ['-DIPEX_PROFILE_OP=1']
201+
196202
if _check_env_flag("USE_SYCL"):
197203
cmake_args += ['-DUSE_SYCL=1']
198204

torch_ipex/csrc/cpu/CustomOPs.h

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,17 @@
99
#include <c10/util/Optional.h>
1010
#include <torch/csrc/autograd/custom_function.h>
1111
#include <torch/csrc/autograd/function.h>
12+
#include <torch/csrc/autograd/record_function.h>
1213
#include <torch/csrc/autograd/variable.h>
1314
#include <torch/script.h>
1415

1516
class NewLinearOp : public torch::autograd::Function<NewLinearOp> {
1617
public:
1718
static at::Tensor _forward(at::Tensor input, at::Tensor weight,
1819
at::Tensor bias = at::Tensor()) {
20+
#if defined(IPEX_PROFILE_OP)
21+
RECORD_FUNCTION("IPEXLinearOp::_forward", std::vector<c10::IValue>({input, weight, bias}), torch::autograd::Node::peek_at_next_sequence_nr());
22+
#endif
1923
try {
2024
if (torch_ipex::check_auto_dnnl() &&
2125
input.device().type() == c10::DeviceType::DPCPP) {
@@ -48,13 +52,19 @@ class NewLinearOp : public torch::autograd::Function<NewLinearOp> {
4852
static at::Tensor forward(torch::autograd::AutogradContext *ctx,
4953
at::Tensor input, at::Tensor weight,
5054
at::Tensor bias = at::Tensor()) {
55+
#if defined(IPEX_PROFILE_OP)
56+
RECORD_FUNCTION("IPEXLinearOp::forward", std::vector<c10::IValue>({input, weight, bias}), torch::autograd::Node::peek_at_next_sequence_nr());
57+
#endif
5158
ctx->save_for_backward({input, weight, bias});
5259
return _forward(input, weight, bias);
5360
}
5461

5562
static torch::autograd::tensor_list
5663
backward(torch::autograd::AutogradContext *ctx,
5764
torch::autograd::tensor_list grad_outputs) {
65+
#if defined(IPEX_PROFILE_OP)
66+
RECORD_FUNCTION("IPEXLinearOp::backward", std::vector<c10::IValue>({}), torch::autograd::Node::peek_at_next_sequence_nr());
67+
#endif
5868
auto saved = ctx->get_saved_variables();
5969
at::Tensor input = saved[0];
6070
at::Tensor weight = saved[1];
@@ -149,6 +159,9 @@ class NewMaxPool2dOp : public torch::autograd::Function<NewMaxPool2dOp> {
149159
_forward(at::Tensor input, at::IntArrayRef kernel_size,
150160
at::IntArrayRef stride, at::IntArrayRef padding,
151161
at::IntArrayRef dilation, bool ceil_mode) {
162+
#if defined(IPEX_PROFILE_OP)
163+
RECORD_FUNCTION("IPEXMaxPool2dOp::_forward", std::vector<c10::IValue>({input}), torch::autograd::Node::peek_at_next_sequence_nr());
164+
#endif
152165
try {
153166
if (torch_ipex::check_auto_dnnl() &&
154167
input.device().type() == c10::DeviceType::DPCPP) {
@@ -187,6 +200,9 @@ class NewMaxPool2dOp : public torch::autograd::Function<NewMaxPool2dOp> {
187200
at::Tensor input, at::IntArrayRef kernel_size,
188201
at::IntArrayRef stride, at::IntArrayRef padding,
189202
at::IntArrayRef dilation, bool ceil_mode) {
203+
#if defined(IPEX_PROFILE_OP)
204+
RECORD_FUNCTION("IPEXMaxPool2dOp::forward", std::vector<c10::IValue>({input}), torch::autograd::Node::peek_at_next_sequence_nr());
205+
#endif
190206
ctx->saved_data["kernel_size"] = kernel_size;
191207
ctx->saved_data["stride"] = stride;
192208
ctx->saved_data["padding"] = padding;
@@ -203,6 +219,9 @@ class NewMaxPool2dOp : public torch::autograd::Function<NewMaxPool2dOp> {
203219
static torch::autograd::tensor_list
204220
backward(torch::autograd::AutogradContext *ctx,
205221
torch::autograd::tensor_list grad_outputs) {
222+
#if defined(IPEX_PROFILE_OP)
223+
RECORD_FUNCTION("IPEXMaxPool2dOp::backward", std::vector<c10::IValue>({}), torch::autograd::Node::peek_at_next_sequence_nr());
224+
#endif
206225
auto saved = ctx->get_saved_variables();
207226
at::Tensor input = saved[0];
208227
at::Tensor indices = saved[1];
@@ -263,6 +282,9 @@ class NewMaxPool3dOp : public torch::autograd::Function<NewMaxPool3dOp> {
263282
_forward(at::Tensor input, at::IntArrayRef kernel_size,
264283
at::IntArrayRef stride, at::IntArrayRef padding,
265284
at::IntArrayRef dilation, bool ceil_mode) {
285+
#if defined(IPEX_PROFILE_OP)
286+
RECORD_FUNCTION("IPEXMaxPool3dOp::_forward", std::vector<c10::IValue>({input}), torch::autograd::Node::peek_at_next_sequence_nr());
287+
#endif
266288
try {
267289
if (torch_ipex::check_auto_dnnl() &&
268290
input.device().type() == c10::DeviceType::DPCPP) {
@@ -298,6 +320,9 @@ class NewMaxPool3dOp : public torch::autograd::Function<NewMaxPool3dOp> {
298320
at::Tensor input, at::IntArrayRef kernel_size,
299321
at::IntArrayRef stride, at::IntArrayRef padding,
300322
at::IntArrayRef dilation, bool ceil_mode) {
323+
#if defined(IPEX_PROFILE_OP)
324+
RECORD_FUNCTION("IPEXMaxPool3dOp::forward", std::vector<c10::IValue>({input}), torch::autograd::Node::peek_at_next_sequence_nr());
325+
#endif
301326
ctx->saved_data["kernel_size"] = kernel_size;
302327
ctx->saved_data["stride"] = stride;
303328
ctx->saved_data["padding"] = padding;
@@ -314,6 +339,9 @@ class NewMaxPool3dOp : public torch::autograd::Function<NewMaxPool3dOp> {
314339
static torch::autograd::tensor_list
315340
backward(torch::autograd::AutogradContext *ctx,
316341
torch::autograd::tensor_list grad_outputs) {
342+
#if defined(IPEX_PROFILE_OP)
343+
RECORD_FUNCTION("IPEXMaxPool3dOp::backward", std::vector<c10::IValue>({}), torch::autograd::Node::peek_at_next_sequence_nr());
344+
#endif
317345
auto saved = ctx->get_saved_variables();
318346
at::Tensor input = saved[0];
319347
at::Tensor indices = saved[1];
@@ -372,6 +400,9 @@ class NewApaptiveAvgPoolingOp
372400
: public torch::autograd::Function<NewApaptiveAvgPoolingOp> {
373401
public:
374402
static at::Tensor _forward(at::Tensor input, at::IntArrayRef output_size) {
403+
#if defined(IPEX_PROFILE_OP)
404+
RECORD_FUNCTION("IPEXApaptiveAvgPoolingOp::_forward", std::vector<c10::IValue>({input}), torch::autograd::Node::peek_at_next_sequence_nr());
405+
#endif
375406
try {
376407
if (torch_ipex::check_auto_dnnl() && input.device().type() == c10::DeviceType::DPCPP) {
377408
auto src_dil_type = torch_ipex::cpu::dbl::comm::try_gen_dil_tensor(input).get_data_type();
@@ -397,13 +428,19 @@ class NewApaptiveAvgPoolingOp
397428

398429
static at::Tensor forward(torch::autograd::AutogradContext *ctx,
399430
at::Tensor input, at::IntArrayRef output_size) {
431+
#if defined(IPEX_PROFILE_OP)
432+
RECORD_FUNCTION("IPEXApaptiveAvgPoolingOp::forward", std::vector<c10::IValue>({input}), torch::autograd::Node::peek_at_next_sequence_nr());
433+
#endif
400434
ctx->save_for_backward({input});
401435
return _forward(input, output_size);
402436
}
403437

404438
static torch::autograd::tensor_list
405439
backward(torch::autograd::AutogradContext *ctx,
406440
torch::autograd::tensor_list grad_outputs) {
441+
#if defined(IPEX_PROFILE_OP)
442+
RECORD_FUNCTION("IPEXApaptiveAvgPoolingOp::backward", std::vector<c10::IValue>({}), torch::autograd::Node::peek_at_next_sequence_nr());
443+
#endif
407444
auto saved = ctx->get_saved_variables();
408445
at::Tensor input = saved[0];
409446

torch_ipex/csrc/cpu/DevOPs.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
namespace torch_ipex {
2727
namespace cpu {
2828

29-
#if defined(_DEBUG)
29+
#if defined(IPEX_DISP_OP)
3030
#define DEBUG(fmt) printf(fmt);
3131
#else
3232
#define DEBUG(fmt)
@@ -78,7 +78,7 @@ at::Tensor AtenIpexCPUDev::dil_convolution(
7878
dbl::conv::prepack_conv_weights(input, dil_input,
7979
weight, stride, padding, dilation, groups);
8080
}
81-
81+
8282
dil_weight = dbl::comm::try_gen_dil_tensor(weight);
8383

8484
if (bias.defined()) {
@@ -360,7 +360,7 @@ at::Tensor& dil_add_common(
360360
IPEX_CHECK(self.sizes().equals(other.sizes()),
361361
"dil add not support broadcast yet");
362362
if (check_auto_mix_int8_fp32()) {
363-
// for accuracy, reorder int8 to fp32
363+
// for accuracy, reorder int8 to fp32
364364
dbl::comm::reorder_to_dtype(self, at::kFloat);
365365
dbl::comm::reorder_to_dtype(other, at::kFloat);
366366
} else {
@@ -824,7 +824,7 @@ at::Tensor AtenIpexCPUDev::dil_linear(
824824
if (check_auto_mix_int8_fp32() && check_int8_calibration()) {
825825
insert_or_updata_observer({self}, {aten_output}, "Linear");
826826
}
827-
827+
828828
if (self.dim() > 2) {
829829
auto input_size = self.sizes();
830830
std::vector<int64_t> output_size(input_size.begin(), input_size.end() - 1);
@@ -1027,7 +1027,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> AtenIpexCPUDev::dil_native_batch_
10271027
dil::batch_normalization_forward_inference::compute(
10281028
x, w, b, y, eps, input_scales, output_scales);
10291029
}
1030-
1030+
10311031
auto aten_output = dbl::comm::gen_aten_tensor_by(std::move(y));
10321032

10331033
if (check_auto_mix_int8_fp32() && check_int8_calibration()) {
@@ -1421,7 +1421,7 @@ at::Tensor& AtenIpexCPUDev::dil_relu_(at::Tensor& input) {
14211421
dil::algorithm::eltwise_relu,
14221422
dil::prop_kind::forward_training,
14231423
/*alpha*/ 0.0);
1424-
1424+
14251425
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(dil_self.is_public_format() || check_tensor_own_whole_storage(input));
14261426
dbl::comm::sync_shape_from_dil_to_aten(input, dil_self);
14271427
return input;

torch_ipex/csrc/cpu/FusionOPs.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
#include <ATen/InferSize.h>
66
#include <c10/util/Exception.h>
77
#include <c10/util/Logging.h>
8+
#include <torch/csrc/autograd/function.h>
9+
#include <torch/csrc/autograd/record_function.h>
810

911
#include <limits>
1012

@@ -220,6 +222,9 @@ at::Tensor AtenIpexJITDev::dil_convolution_swish(
220222
at::IntArrayRef padding,
221223
at::IntArrayRef dilation,
222224
int64_t groups) {
225+
#if defined(IPEX_PROFILE_OP)
226+
RECORD_FUNCTION("AtenIpexJITDev::dil_convolution_swish", std::vector<c10::IValue>({input, weight, bias}), torch::autograd::Node::peek_at_next_sequence_nr());
227+
#endif
223228
return dil_convolution_outplace_fusion(
224229
input,
225230
weight,
@@ -239,6 +244,9 @@ at::Tensor AtenIpexJITDev::dil_convolution_sigmoid(
239244
at::IntArrayRef padding,
240245
at::IntArrayRef dilation,
241246
int64_t groups) {
247+
#if defined(IPEX_PROFILE_OP)
248+
RECORD_FUNCTION("AtenIpexJITDev::dil_convolution_sigmoid", std::vector<c10::IValue>({input, weight, bias}), torch::autograd::Node::peek_at_next_sequence_nr());
249+
#endif
242250
return dil_convolution_outplace_fusion(
243251
input,
244252
weight,
@@ -260,6 +268,9 @@ at::Tensor AtenIpexJITDev::dil_convolution_clamp(
260268
int64_t groups,
261269
float lower_bound,
262270
float upper_bound) {
271+
#if defined(IPEX_PROFILE_OP)
272+
RECORD_FUNCTION("AtenIpexJITDev::dil_convolution_clamp", std::vector<c10::IValue>({input, weight, bias}), torch::autograd::Node::peek_at_next_sequence_nr());
273+
#endif
263274
return dil_convolution_outplace_fusion(
264275
input,
265276
weight,
@@ -279,6 +290,9 @@ at::Tensor AtenIpexJITDev::dil_convolution_relu(
279290
at::IntArrayRef padding,
280291
at::IntArrayRef dilation,
281292
int64_t groups) {
293+
#if defined(IPEX_PROFILE_OP)
294+
RECORD_FUNCTION("AtenIpexJITDev::dil_convolution_relu", std::vector<c10::IValue>({input, weight, bias}), torch::autograd::Node::peek_at_next_sequence_nr());
295+
#endif
282296
return dil_convolution_outplace_fusion(
283297
input,
284298
weight,
@@ -302,6 +316,9 @@ at::Tensor AtenIpexJITDev::dil_convolution_elu(
302316
float alpha,
303317
at::Scalar scale,
304318
at::Scalar input_scale) {
319+
#if defined(IPEX_PROFILE_OP)
320+
RECORD_FUNCTION("AtenIpexJITDev::dil_convolution_elu", std::vector<c10::IValue>({input, weight, bias}), torch::autograd::Node::peek_at_next_sequence_nr());
321+
#endif
305322
auto scale_value = scale.to<float>();
306323
auto input_scale_value = input_scale.to<float>();
307324
return dil_convolution_outplace_fusion(
@@ -325,6 +342,9 @@ at::Tensor& AtenIpexJITDev::dil_convolution_sum(
325342
int64_t groups,
326343
at::Tensor& accumu,
327344
at::Scalar alpha) {
345+
#if defined(IPEX_PROFILE_OP)
346+
RECORD_FUNCTION("AtenIpexJITDev::dil_convolution_sum", std::vector<c10::IValue>({input, weight, bias}), torch::autograd::Node::peek_at_next_sequence_nr());
347+
#endif
328348
auto scale = alpha.to<float>();
329349
return dil_convolution_inplace_fusion(
330350
input,
@@ -349,6 +369,9 @@ at::Tensor& AtenIpexJITDev::dil_convolution_sum_relu(
349369
int64_t groups,
350370
at::Tensor& accumu,
351371
at::Scalar alpha) {
372+
#if defined(IPEX_PROFILE_OP)
373+
RECORD_FUNCTION("AtenIpexJITDev::dil_convolution_sum_relu", std::vector<c10::IValue>({input, weight, bias}), torch::autograd::Node::peek_at_next_sequence_nr());
374+
#endif
352375
auto scale = alpha.to<float>();
353376
return dil_convolution_inplace_fusion(
354377
input,
@@ -367,6 +390,9 @@ at::Tensor AtenIpexJITDev::dil_linear_fuse_relu(
367390
const at::Tensor& self,
368391
const at::Tensor& weight,
369392
const at::Tensor& bias) {
393+
#if defined(IPEX_PROFILE_OP)
394+
RECORD_FUNCTION("AtenIpexJITDev::dil_linear_fuse_relu", std::vector<c10::IValue>({self, weight, bias}), torch::autograd::Node::peek_at_next_sequence_nr());
395+
#endif
370396
IPEX_CHECK(self.dim() >= 2,
371397
"dil_linear: input needs to has dim at least 2, input dim ", self.dim());
372398
auto input_contiguous = self.is_contiguous() ? self : self.contiguous();

0 commit comments

Comments
 (0)