Skip to content

Commit 40a2918

Browse files
authored
[ROCM] fixbug for arg_min_max (#36113)
ATT, cherry-pick #36098
1 parent fe5cddf commit 40a2918

File tree

1 file changed

+12
-9
lines changed

1 file changed

+12
-9
lines changed

paddle/fluid/operators/arg_min_max_op_base.cu.h

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -89,22 +89,25 @@ void ComputeFullArg(const platform::CUDADeviceContext& ctx, const Tensor& input,
8989
const int64_t n) {
9090
auto cu_stream = ctx.stream();
9191
auto ComputeBlockSize = [](int64_t col) {
92+
auto block_size = 8;
9293
if (col > 512)
93-
return 1024;
94+
block_size = 1024;
9495
else if (col > 256)
95-
return 512;
96+
block_size = 512;
9697
else if (col > 128)
97-
return 256;
98+
block_size = 256;
9899
else if (col > 64)
99-
return 128;
100+
block_size = 128;
100101
else if (col > 32)
101-
return 64;
102+
block_size = 64;
102103
else if (col > 16)
103-
return 32;
104+
block_size = 32;
104105
else if (col > 8)
105-
return 16;
106-
else
107-
return 8;
106+
block_size = 16;
107+
#ifdef __HIPCC__
108+
block_size = std::min(block_size, 256);
109+
#endif
110+
return block_size;
108111
};
109112

110113
int64_t max_grid_dimx = ctx.GetCUDAMaxGridDimSize().x;

0 commit comments

Comments
 (0)