@@ -23,23 +23,23 @@ struct Dropout {
23
23
}
24
24
25
25
template <bool encode_dropout_in_sign_bit=false , typename Engine, typename Layout>
26
- __forceinline__ __device__ void apply_dropout (Tensor<Engine, Layout> &tensor_,
26
+ __forceinline__ __device__ void apply_dropout (cute:: Tensor<Engine, Layout> &tensor_,
27
27
int block_row_start, int block_col_start, int block_row_stride) {
28
28
// convert shape from (4, MMA_M, MMA_N) to (8, MMA_M, MMA_N / 2)
29
- Tensor tensor = make_tensor (tensor_.data (), flash::convert_layout_acc_dropout (tensor_.layout ()));
29
+ cute:: Tensor tensor = make_tensor (tensor_.data (), flash::convert_layout_acc_dropout (tensor_.layout ()));
30
30
using T = typename Engine::value_type;
31
31
auto encode_dropout = [](bool keep, T val) {
32
32
return keep ? val : (encode_dropout_in_sign_bit ? -val : T (0 ));
33
33
};
34
- static_assert (decltype (size<2 >(tensor))::value % 2 == 0 );
34
+ static_assert (decltype (cute:: size<2 >(tensor))::value % 2 == 0 );
35
35
const uint16_t p_dropout_8bit_in_uint16_t = uint16_t (p_dropout_in_uint8_t );
36
36
const uint32_t p_dropout_8bit_in_uint32_t = (uint32_t (p_dropout_8bit_in_uint16_t ) << 16 ) | uint32_t (p_dropout_8bit_in_uint16_t );
37
37
// if (cute::thread0()) { printf("threshold2 = 0x%x\n", p_dropout_8bit_in_uint32_t); }
38
38
#pragma unroll
39
- for (int m = 0 ; m < size<1 >(tensor); ++m, block_row_start += block_row_stride) {
39
+ for (int m = 0 ; m < cute:: size<1 >(tensor); ++m, block_row_start += block_row_stride) {
40
40
uint2 rowcol = make_uint2 (block_row_start, block_col_start);
41
41
#pragma unroll
42
- for (int n = 0 ; n < size<2 >(tensor) / 2 ; ++n, ++rowcol.y ) {
42
+ for (int n = 0 ; n < cute:: size<2 >(tensor) / 2 ; ++n, ++rowcol.y ) {
43
43
// if (cute::thread(32, 0)) { printf("m = %d, n = %d, row = %d, col = %d\n", m, n, int(rowcol.x), int(rowcol.y));}
44
44
uint4 random_uint4 = flash::philox (seed, reinterpret_cast <unsigned long long &>(rowcol), offset);
45
45
// if (cute::thread0()) { printf("philox = %u, %d, %d, %d\n", random_uint4.x, random_uint4.y, random_uint4.z, random_uint4.w);}
@@ -60,7 +60,7 @@ struct Dropout {
60
60
uint32_t (&rnd_32)[8 ] = reinterpret_cast <uint32_t (&)[8 ]>(rnd_16);
61
61
#pragma unroll
62
62
for (int j = 0 ; j < 2 ; j++) {
63
- Tensor tensor_uint32 = recast<uint32_t >(tensor (_, m, n * 2 + j));
63
+ cute:: Tensor tensor_uint32 = cute:: recast<uint32_t >(tensor (cute:: _, m, n * 2 + j));
64
64
// if (cute::thread0()) { printf("random = 0x%x, 0x%x, 0x%x, 0x%x\n", rnd_32[j * 4 + 0], rnd_32[j * 4 + 1], rnd_32[j * 4 + 2], rnd_32[j * 4 + 3]); }
65
65
// if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); }
66
66
#pragma unroll
@@ -78,7 +78,7 @@ struct Dropout {
78
78
for (int i = 0 ; i < 8 ; i++) {
79
79
tensor (i, m, n * 2 + j) = encode_dropout (rnd_8[j * 8 + i] <= p_dropout_in_uint8_t , tensor (i, m, n * 2 + j));
80
80
}
81
- Tensor tensor_uint32 = recast<uint32_t >(tensor (_, m, n * 2 + j));
81
+ cute:: Tensor tensor_uint32 = cute:: recast<uint32_t >(tensor (cute:: _, m, n * 2 + j));
82
82
// if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); }
83
83
}
84
84
}
0 commit comments