Skip to content

Commit ff55993

Browse files
authored
fix dtype error of compare op, test=develop (#25059) (#25234)
1 parent 07a65aa commit ff55993

File tree

2 files changed

+20
-7
lines changed

2 files changed

+20
-7
lines changed

paddle/fluid/operators/elementwise/elementwise_op_function.h

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,12 @@ limitations under the License. */
1515
#pragma once
1616

1717
#include <glog/logging.h>
18+
1819
#include <algorithm>
1920
#include <functional> // for multiplies
2021
#include <iterator>
2122
#include <vector>
23+
2224
#include "paddle/fluid/framework/eigen.h"
2325
#include "paddle/fluid/framework/op_registry.h"
2426
#include "paddle/fluid/framework/operator.h"
@@ -30,6 +32,7 @@ limitations under the License. */
3032
#ifdef __NVCC__
3133
#include <cuda.h>
3234
#include <thrust/iterator/iterator_adaptor.h>
35+
3336
#include "paddle/fluid/platform/cuda_device_function.h"
3437
#include "paddle/fluid/platform/cuda_primitives.h"
3538
constexpr int ELEMWISE_MAX_BLOCK_DIM = 1024;
@@ -194,11 +197,11 @@ void CommonForwardBroadcastCPU(const framework::Tensor *x,
194197
}
195198

196199
#ifdef __NVCC__
197-
template <typename Functor, typename T>
200+
template <typename Functor, typename T, typename OutType = T>
198201
__global__ void CommonForwardBroadcastCUDAKernel(
199202
const int *x_strides_array, const int *y_strides_array,
200-
const int *out_dims_array, const T *x, const T *y, T *out, int out_size,
201-
int max_dim, Functor func, const bool is_xsize_larger) {
203+
const int *out_dims_array, const T *x, const T *y, OutType *out,
204+
int out_size, int max_dim, Functor func, const bool is_xsize_larger) {
202205
for (int out_index = blockIdx.x * blockDim.x + threadIdx.x;
203206
out_index < out_size; out_index += blockDim.x * gridDim.x) {
204207
int x_index = 0;
@@ -220,7 +223,7 @@ __global__ void CommonForwardBroadcastCUDAKernel(
220223
}
221224
}
222225

223-
template <typename Functor, typename T>
226+
template <typename Functor, typename T, typename OutType = T>
224227
void CommonForwardBroadcastCUDA(
225228
const framework::Tensor *x, const framework::Tensor *y,
226229
framework::Tensor *z, int *x_dims_array, int *y_dims_array,
@@ -230,7 +233,7 @@ void CommonForwardBroadcastCUDA(
230233
auto cplace = platform::CPUPlace();
231234
const T *x_data = x->data<T>();
232235
const T *y_data = y->data<T>();
233-
T *out_data = z->mutable_data<T>(ctx.GetPlace());
236+
OutType *out_data = z->mutable_data<OutType>(ctx.GetPlace());
234237

235238
std::vector<int> x_strides_array(max_dim);
236239
std::vector<int> y_strides_array(max_dim);
@@ -268,7 +271,7 @@ void CommonForwardBroadcastCUDA(
268271
dim3 block_size = dim3(PADDLE_CUDA_THREAD_SIZE, 1);
269272

270273
CommonForwardBroadcastCUDAKernel<
271-
Functor, T><<<gird_size, block_size, 0, ctx.stream()>>>(
274+
Functor, T, OutType><<<gird_size, block_size, 0, ctx.stream()>>>(
272275
x_strides_array_gpu, y_strides_array_gpu, out_dims_array_gpu, x_data,
273276
y_data, out_data, out_size, max_dim, func, is_xsize_larger);
274277
}
@@ -1796,7 +1799,7 @@ void CommonElementwiseBroadcastForward(
17961799

17971800
if (platform::is_gpu_place(ctx.GetPlace())) {
17981801
#ifdef __NVCC__
1799-
CommonForwardBroadcastCUDA<Functor, T>(
1802+
CommonForwardBroadcastCUDA<Functor, T, OutType>(
18001803
x, y, z, x_dims_array.data(), y_dims_array.data(),
18011804
out_dims_array.data(), max_dim,
18021805
ctx.template device_context<platform::CUDADeviceContext>(), func,

python/paddle/fluid/tests/unittests/test_math_op_patch_var_base.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,16 @@ def test_index(self):
237237
str1 = "just test"
238238
self.assertTrue(str1[var1] == 's')
239239

240+
def test_conpare_op_broadcast(self):
241+
a_np = np.random.uniform(-1, 1, [10, 1, 10]).astype(self.dtype)
242+
b_np = np.random.uniform(-1, 1, [1, 1, 10]).astype(self.dtype)
243+
with fluid.dygraph.guard():
244+
a = fluid.dygraph.to_variable(a_np)
245+
b = fluid.dygraph.to_variable(b_np)
246+
247+
self.assertEqual((a != b).dtype, fluid.core.VarDesc.VarType.BOOL)
248+
self.assertTrue(np.array_equal((a != b).numpy(), a_np != b_np))
249+
240250

241251
if __name__ == '__main__':
242252
unittest.main()

0 commit comments

Comments
 (0)