Skip to content

Commit 0c6b490

Browse files
authored
[oneDNN] Reshape attr_axes when going to oneDNN kernel (#59641)
1 parent 5a3c593 commit 0c6b490

File tree

3 files changed

+143
-0
lines changed

3 files changed

+143
-0
lines changed

paddle/phi/kernels/onednn/squeeze_kernel.cc

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,29 @@ void SqueezeInferKernel(const Context& dev_ctx,
5959
const IntArray& axes,
6060
DenseTensor* out) {
6161
auto x_dims = x.dims();
62+
auto x_dims_tz = x_dims.size();
6263
std::vector<int32_t> tmp(axes.GetData().begin(), axes.GetData().end());
64+
65+
// Currently there is only tranformation for tensors, while attr axes still
66+
// follows default dtype instead of oneDNN dtype, so here manually change it
67+
if ((x_dims_tz >= 3) &&
68+
(phi::OneDNNContext::tls().get_cur_paddle_data_layout() ==
69+
phi::DataLayout::NDHWC ||
70+
phi::OneDNNContext::tls().get_cur_paddle_data_layout() ==
71+
phi::DataLayout::NHWC)) {
72+
int axes_size = tmp.size();
73+
for (int i = 0; i < axes_size; i++) {
74+
if (tmp[i] < 0) {
75+
tmp[i] += x_dims_tz;
76+
}
77+
if (tmp[i] >= 1 && tmp[i] < (x_dims_tz - 1)) {
78+
tmp[i] += 1;
79+
} else if (tmp[i] == (x_dims_tz - 1)) {
80+
tmp[i] = 1;
81+
}
82+
}
83+
}
84+
6385
auto out_dims = funcs::GetOutputSqueezeShape(tmp, x_dims, true);
6486
ExecuteSqueeze<T, Context>(dev_ctx, x, x_dims, out_dims, out);
6587
}

test/cpp/fluid/mkldnn/CMakeLists.txt

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,3 +104,18 @@ cc_test(
104104
scope
105105
device_context
106106
enforce)
107+
108+
cc_test(
109+
test_mkldnn_squeeze
110+
SRCS test_mkldnn_squeeze.cc
111+
DEPS fleet_executor
112+
conditional_block_op
113+
standalone_executor
114+
executor
115+
op_registry
116+
generated_static_op
117+
generated_op
118+
phi
119+
scope
120+
device_context
121+
enforce)
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
#include <gtest/gtest.h>
15+
16+
#include <fstream>
17+
18+
#include "paddle/fluid/framework/lod_tensor.h"
19+
#include "paddle/fluid/framework/naive_executor.h"
20+
#include "paddle/fluid/framework/op_registry.h"
21+
#include "paddle/fluid/framework/operator.h"
22+
#include "paddle/fluid/framework/scope.h"
23+
#include "paddle/phi/common/place.h"
24+
#include "paddle/phi/core/enforce.h"
25+
#include "paddle/phi/core/kernel_registry.h"
26+
27+
namespace paddle {
28+
namespace inference {
29+
namespace tensorrt {
30+
31+
template <typename DataType>
32+
void AddVarToScope(const std::string var_name,
33+
paddle::framework::Scope* scope,
34+
const paddle::framework::DDim& dims) {
35+
std::random_device seed;
36+
std::default_random_engine engine(seed());
37+
std::uniform_real_distribution<float> dist(0, 100);
38+
39+
phi::DenseTensor tmp_tensor;
40+
auto* tmp_data =
41+
tmp_tensor.mutable_data<DataType>(dims, paddle::platform::CPUPlace());
42+
auto* tensor = scope->Var(var_name)->GetMutable<phi::DenseTensor>();
43+
tensor->mutable_data<DataType>(dims, paddle::platform::CPUPlace());
44+
for (auto i = 0; i < tensor->numel(); ++i) {
45+
tmp_data[i] = static_cast<DataType>(dist(engine));
46+
}
47+
paddle::framework::TensorCopySync(
48+
tmp_tensor, paddle::platform::CPUPlace(), tensor);
49+
}
50+
void test_squeeze() {
51+
framework::Scope scope;
52+
paddle::platform::CPUPlace cpu_place;
53+
// Prepare Op description
54+
framework::OpDesc desc;
55+
// We assume it is kNHWC, so that can use this transformation
56+
phi::OneDNNContext::tls().set_cur_paddle_data_layout(DataLayout::kNHWC);
57+
desc.SetType("squeeze2");
58+
desc.SetInput("X", {"squeeze-X"});
59+
desc.SetOutput("Out", {"squeeze-Out"});
60+
// DataLayout::kNHWC will make it become {2, 3, 2, 1}
61+
AddVarToScope<float>("squeeze-X", &scope, {2, 2, 1, 3});
62+
AddVarToScope<float>("squeeze-Out", &scope, {2, 3, 2});
63+
// transform will make it become -1
64+
std::vector<int> axes({-2});
65+
66+
desc.SetAttr("axes", axes);
67+
desc.SetAttr("use_mkldnn", true);
68+
69+
auto op = paddle::framework::OpRegistry::CreateOp(desc);
70+
71+
op->Run(scope, cpu_place);
72+
}
73+
74+
void test_squeeze2() {
75+
framework::Scope scope;
76+
paddle::platform::CPUPlace cpu_place;
77+
// Prepare Op description
78+
framework::OpDesc desc;
79+
// We assume it is HNWC, so that can use this transformation
80+
phi::OneDNNContext::tls().set_cur_paddle_data_layout(DataLayout::kNHWC);
81+
desc.SetType("squeeze2");
82+
desc.SetInput("X", {"squeeze-X"});
83+
desc.SetOutput("Out", {"squeeze-Out"});
84+
// DataLayout::kNHWC will make it become {2, 1, 3, 2}
85+
AddVarToScope<float>("squeeze-X", &scope, {2, 3, 2, 1});
86+
AddVarToScope<float>("squeeze-Out", &scope, {2, 3, 2});
87+
// transform will make it become -3(1)
88+
std::vector<int> axes({-1});
89+
90+
desc.SetAttr("axes", axes);
91+
desc.SetAttr("use_mkldnn", true);
92+
93+
auto op = paddle::framework::OpRegistry::CreateOp(desc);
94+
95+
op->Run(scope, cpu_place);
96+
}
97+
98+
TEST(SqueezeOpConverter, normal) { test_squeeze(); }
99+
TEST(SqueezeOpConverter_2, normal) { test_squeeze2(); }
100+
101+
} // namespace tensorrt
102+
} // namespace inference
103+
} // namespace paddle
104+
105+
USE_OP_ITSELF(squeeze2);
106+
PD_DECLARE_KERNEL(squeeze_infer, OneDNN, ONEDNN);

0 commit comments

Comments
 (0)