Skip to content

Commit df69100

Browse files
[Common] Fix long compile time in padding.cu on arch 75 (#2562)
* Fix long compile time in padding.cu Signed-off-by: Jeremy Berchtold <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Jeremy Berchtold <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent a976740 commit df69100

File tree

1 file changed

+5
-7
lines changed

1 file changed

+5
-7
lines changed

transformer_engine/common/util/padding.cu

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -94,15 +94,14 @@ __global__ void __launch_bounds__(threads_per_block) multi_padding_kernel(MultiP
9494
#pragma unroll
9595
for (int i2 = 0; i2 < nvec; ++i2) {
9696
const int row = tile_row + i1 * nvec + i2;
97-
size_t row_offset = static_cast<size_t>(row) * row_length;
9897
const int col = tile_col + j1 * nvec;
9998
Vec local_input;
10099
Vec local_output;
101100
local_input.clear();
102101
if (row < num_rows) {
103102
for (int j2 = 0; j2 < nvec; ++j2) {
104103
if (col + j2 < row_length) {
105-
local_input.data.elt[j2] = input[row_offset + col + j2];
104+
local_input.data.elt[j2] = input[static_cast<size_t>(row) * row_length + col + j2];
106105
}
107106
}
108107
}
@@ -113,14 +112,14 @@ __global__ void __launch_bounds__(threads_per_block) multi_padding_kernel(MultiP
113112
if (row < num_rows) {
114113
for (int j2 = 0; j2 < nvec; ++j2) {
115114
if (col + j2 < row_length) {
116-
output[row_offset + col + j2] = local_output.data.elt[j2];
115+
output[static_cast<size_t>(row) * row_length + col + j2] = local_output.data.elt[j2];
117116
}
118117
}
119118
} else if (row < padded_num_rows) {
120119
// padding
121120
for (int j2 = 0; j2 < nvec; ++j2) {
122121
if (col + j2 < row_length) {
123-
output[row_offset + col + j2] = local_zero;
122+
output[static_cast<size_t>(row) * row_length + col + j2] = local_zero;
124123
}
125124
}
126125
}
@@ -179,15 +178,14 @@ __global__ void __launch_bounds__(threads_per_block) multi_unpadding_kernel(Mult
179178
#pragma unroll
180179
for (int i2 = 0; i2 < nvec; ++i2) {
181180
const int row = tile_row + i1 * nvec + i2;
182-
size_t row_offset = static_cast<size_t>(row) * row_length;
183181
const int col = tile_col + j1 * nvec;
184182
Vec local_input;
185183
Vec local_output;
186184
local_input.clear();
187185
if (row < num_rows) {
188186
for (int j2 = 0; j2 < nvec; ++j2) {
189187
if (col + j2 < row_length) {
190-
local_input.data.elt[j2] = input[row_offset + col + j2];
188+
local_input.data.elt[j2] = input[static_cast<size_t>(row) * row_length + col + j2];
191189
}
192190
}
193191
}
@@ -198,7 +196,7 @@ __global__ void __launch_bounds__(threads_per_block) multi_unpadding_kernel(Mult
198196
if (row < num_rows) {
199197
for (int j2 = 0; j2 < nvec; ++j2) {
200198
if (col + j2 < row_length) {
201-
output[row_offset + col + j2] = local_output.data.elt[j2];
199+
output[static_cast<size_t>(row) * row_length + col + j2] = local_output.data.elt[j2];
202200
}
203201
}
204202
}

0 commit comments

Comments
 (0)