We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 272d56b commit 681e60eCopy full SHA for 681e60e
aten/src/ATen/cuda/detail/OffsetCalculator.cuh
@@ -45,6 +45,24 @@ struct OffsetCalculator {
45
46
C10_HOST_DEVICE offset_type get(index_t linear_idx) const {
47
offset_type offsets;
48
+
49
+#if defined(USE_ROCM)
50
+ if ((dims > 0) && (dims <= 2)) {
51
+ auto divmod = sizes_[0].divmod(linear_idx);
52
+ #pragma unroll
53
+ for (int arg = 0; arg < NARGS; arg++)
54
+ offsets[arg] = divmod.mod * strides_[0][arg];
55
+ if (dims >= 2) {
56
+ divmod = sizes_[1].divmod(divmod.div);
57
58
59
+ offsets[arg] += divmod.mod * strides_[1][arg];
60
+ }
61
+ // [...]
62
+ return offsets;
63
64
+#endif
65
66
#pragma unroll
67
for (int arg = 0; arg < NARGS; arg++) {
68
offsets[arg] = 0;
0 commit comments