Skip to content

Commit 3931a65

Browse files
authored
[MLU] add floordiv kernel (#1344)
1 parent 3fb7eac commit 3931a65

File tree

2 files changed

+161
-0
lines changed

2 files changed

+161
-0
lines changed
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "kernels/funcs/mlu_baseop.h"
16+
#include "kernels/funcs/mlu_funcs.h"
17+
18+
namespace custom_kernel {
19+
template <typename T, typename Context>
20+
void FloorDivideKernel(const Context& dev_ctx,
21+
const phi::DenseTensor& x,
22+
const phi::DenseTensor& y,
23+
phi::DenseTensor* out) {
24+
dev_ctx.template Alloc<T>(out);
25+
26+
MLUCnnlTensorDesc input_x_desc(x);
27+
MLUCnnlTensorDesc input_y_desc(y);
28+
MLUCnnlTensorDesc output_desc(*out);
29+
30+
cnnlComputationPreference_t prefer = CNNL_COMPUTATION_HIGH_PRECISION;
31+
32+
// when input x and input y dtype are int64
33+
// cast datatype to int32 for cnnlFloorDiv usage
34+
Tensor x_temp, y_temp, out_temp;
35+
x_temp.Resize(x.dims());
36+
y_temp.Resize(y.dims());
37+
out_temp.Resize(out->dims());
38+
if (x.dtype() != DataType::INT64 && y.dtype() != DataType::INT64) {
39+
MLUCnnl::FloorDiv(dev_ctx,
40+
prefer,
41+
input_x_desc.get(),
42+
GetBasePtr(&x),
43+
input_y_desc.get(),
44+
GetBasePtr(&y),
45+
output_desc.get(),
46+
GetBasePtr(out));
47+
} else {
48+
dev_ctx.template Alloc<int32_t>(&x_temp);
49+
dev_ctx.template Alloc<int32_t>(&y_temp);
50+
dev_ctx.template Alloc<int32_t>(&out_temp);
51+
MLUCnnlTensorDesc x_temp_desc(x_temp);
52+
MLUCnnlTensorDesc y_temp_desc(y_temp);
53+
MLUCnnlTensorDesc out_temp_desc(out_temp);
54+
cnnlCastDataType_t cast_int32 = GetCastDataType(x.dtype(), DataType::INT32);
55+
56+
MLUCnnl::Cast(dev_ctx,
57+
cast_int32,
58+
input_x_desc.get(),
59+
GetBasePtr(&x),
60+
x_temp_desc.get(),
61+
GetBasePtr(&x_temp));
62+
63+
MLUCnnl::Cast(dev_ctx,
64+
cast_int32,
65+
input_y_desc.get(),
66+
GetBasePtr(&y),
67+
y_temp_desc.get(),
68+
GetBasePtr(&y_temp));
69+
70+
MLUCnnl::FloorDiv(dev_ctx,
71+
prefer,
72+
x_temp_desc.get(),
73+
GetBasePtr(&x_temp),
74+
y_temp_desc.get(),
75+
GetBasePtr(&y_temp),
76+
out_temp_desc.get(),
77+
GetBasePtr(&out_temp));
78+
79+
cnnlCastDataType_t cast_int64 =
80+
GetCastDataType(x_temp.dtype(), DataType::INT64);
81+
82+
MLUCnnl::Cast(dev_ctx,
83+
cast_int64,
84+
out_temp_desc.get(),
85+
GetBasePtr(&out_temp),
86+
output_desc.get(),
87+
GetBasePtr(out));
88+
}
89+
}
90+
} // namespace custom_kernel
91+
92+
PD_REGISTER_PLUGIN_KERNEL(floor_divide,
93+
mlu,
94+
ALL_LAYOUT,
95+
custom_kernel::FloorDivideKernel,
96+
int,
97+
int64_t,
98+
float,
99+
phi::dtype::float16) {}
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import print_function
16+
17+
import numpy as np
18+
import unittest
19+
20+
from tests.op_test import OpTest
21+
import paddle
22+
23+
paddle.enable_static()
24+
25+
26+
class TestElementwiseFloorDiv(OpTest):
27+
def setUp(self):
28+
self.op_type = "elementwise_floordiv"
29+
self.set_mlu()
30+
self.init_dtype()
31+
self.init_input_output()
32+
33+
self.inputs = {
34+
"X": OpTest.np_dtype_to_base_dtype(self.x),
35+
"Y": OpTest.np_dtype_to_base_dtype(self.y),
36+
}
37+
self.attrs = {}
38+
self.outputs = {"Out": self.out}
39+
40+
def set_mlu(self):
41+
self.__class__.use_custom_device = True
42+
self.place = paddle.CustomPlace("mlu", 0)
43+
44+
def init_input_output(self):
45+
self.x = np.random.uniform(1, 1000, [10, 10]).astype(self.dtype)
46+
self.y = np.random.uniform(1, 1000, [10, 10]).astype(self.dtype)
47+
self.out = np.floor_divide(self.x, self.y)
48+
49+
def init_dtype(self):
50+
self.dtype = "int64"
51+
52+
def test_check_output(self):
53+
self.check_output_with_place(self.place)
54+
55+
56+
class TestElementwiseFloorDiv2(TestElementwiseFloorDiv):
57+
def init_dtype(self):
58+
self.dtype = "int32"
59+
60+
61+
if __name__ == "__main__":
62+
unittest.main()

0 commit comments

Comments
 (0)