@@ -20,11 +20,11 @@ static __global__ void k_set_rows(
2020 const src_t * __restrict__ src0, const int64_t * __restrict__ src1, dst_t * __restrict__ dst,
2121 const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
2222 const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
23- const size_t nb01 , const size_t nb02 , const size_t nb03 ,
24- const size_t nb10 , const size_t nb11 , const size_t nb12 ,
25- const size_t nb1 , const size_t nb2 , const size_t nb3 ) {
23+ const int64_t s01 , const int64_t s02 , const int64_t s03 ,
24+ const int64_t s10 , const int64_t s11 , const int64_t s12 ,
25+ const int64_t s1 , const int64_t s2 , const int64_t s3 ) {
2626
27- const int64_t i = blockDim .x * blockIdx .x + threadIdx .x ;
27+ const int64_t i = int64_t ( blockDim .x ) * blockIdx .x + threadIdx .x ;
2828 const int64_t ne_total = ne00 * ne01 * ne02 * ne03;
2929
3030 if (i >= ne_total) {
@@ -40,10 +40,10 @@ static __global__ void k_set_rows(
4040 const int64_t i11 = i02 % ne11;
4141 const int64_t i10 = i01;
4242
43- const int64_t dst_row = *(src1 + i10*nb10 + i11*nb11 + i12*nb12 );
43+ const int64_t dst_row = *(src1 + i10*s10 + i11*s11 + i12*s12 );
4444
45- const src_t * src0_row = src0 + i01*nb01 + i02*nb02 + i03*nb03 ;
46- dst_t * dst_row_ptr = dst + dst_row*nb1 + i02*nb2 + i03*nb3 ;
45+ const src_t * src0_row = src0 + i01*s01 + i02*s02 + i03*s03 ;
46+ dst_t * dst_row_ptr = dst + dst_row*s1 + i02*s2 + i03*s3 ;
4747
4848 const src_t * src_elem = src0_row + i00;
4949 dst_t * dst_elem = dst_row_ptr + i00;
0 commit comments