Skip to content

Commit ab06556

Browse files
authored
[CHERRY-PICK 1.8]fix bug that diag API can't use on Windows(#24825)
* cherry-pick #24762
1 parent 863f9e5 commit ab06556

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

paddle/fluid/operators/trace_op.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,10 @@ namespace paddle {
2424
namespace operators {
2525

2626
template <typename T>
27-
struct DiagFunctor {
28-
DiagFunctor(const T* input, const int64_t* diag_stride,
29-
const int64_t* ret_strides, int64_t pos, int64_t dim_size,
30-
T* diag)
27+
struct DiagonalFunctor {
28+
DiagonalFunctor(const T* input, const int64_t* diag_stride,
29+
const int64_t* ret_strides, int64_t pos, int64_t dim_size,
30+
T* diag)
3131
: input_(input),
3232
diag_stride_(diag_stride),
3333
ret_strides_(ret_strides),
@@ -157,8 +157,8 @@ framework::Tensor Diagonal(const framework::ExecutionContext& context,
157157

158158
auto& dev_ctx = context.template device_context<DeviceContext>();
159159
platform::ForRange<DeviceContext> for_range(dev_ctx, diag.numel());
160-
DiagFunctor<T> functor(input_data, diag_arr, ret_arr, pos, dim_size,
161-
diag_data);
160+
DiagonalFunctor<T> functor(input_data, diag_arr, ret_arr, pos, dim_size,
161+
diag_data);
162162
for_range(functor);
163163
return diag;
164164
} else {

0 commit comments

Comments
 (0)