11#include  " ggml-cuda/common.cuh" 
22#include  " set.cuh" 
33
4- static  __global__  void  set_f32_cuda_copy (const  float  * __restrict__  src1 ,
4+ static  __global__  void  set_f32_cuda_copy (const  float  * __restrict__  src0 ,
55                                         float  * __restrict__  dst,
66                                         const  size_t  ne0,
77                                         const  size_t  ne1,
88                                         const  size_t  ne2,
99                                         const  size_t  ne3,
10-                                          const  int     offset,  //  element‐offset
11-                                          const  int     nb1,     //  stride in elements along dim1
12-                                          const  int     nb2,     //  stride in elements along dim2
13-                                          const  int     nb3      //  stride in elements along dim3
14- ) {
10+                                          const  size_t  nb0,
11+                                          const  size_t  nb1,
12+                                          const  size_t  nb2,
13+                                          const  size_t  nb3) {
1514    const  size_t  total = ne0 * ne1 * ne2 * ne3;
1615    const  size_t  gid   = blockIdx .x  * blockDim .x  + threadIdx .x ;
1716    if  (gid >= total) {
1817        return ;
1918    }
2019
21-     //  unravel into 4D indices (i0 fastest, then i1, i2, i3): 
22-      size_t        tmp = gid; 
23-     const  size_t  i0   = tmp % ne0;
20+     size_t  tmp = gid; 
21+ 
22+     const  size_t  i0 = tmp % ne0;
2423    tmp /= ne0;
2524    const  size_t  i1 = tmp % ne1;
2625    tmp /= ne1;
2726    const  size_t  i2 = tmp % ne2;
2827    tmp /= ne2;
29-     const  size_t  i3 = tmp;   //  < ne3 
28+     const  size_t  i3 = tmp;
3029
31-     //  compute flat positions with strides + offset
32-     const  size_t  pos = offset + i0 + i1 * (size_t ) nb1 + i2 * (size_t ) nb2 + i3 * (size_t ) nb3;
30+     const  size_t  pos = (i0 * nb0 + i1 * nb1 + i2 * nb2 + i3 * nb3);
3331
34-     dst[ pos]  = src1[ pos] ;
32+     *(( float  *) (( char  *)  dst +  pos))  = *(( const   float  *) (( const   char  *) src0 +  pos)) ;
3533}
3634
37- static  __global__  void  set_f32_cuda (const  float  * __restrict__  src0 ,
35+ static  __global__  void  set_f32_cuda (const  float  * __restrict__  src1 ,
3836                                    float  * __restrict__  dst,
39-                                     const  size_t  ne0,
40-                                     const  size_t  ne1,
41-                                     const  size_t  ne2,
42-                                     const  size_t  ne3,
43-                                     const  int     offset,  //  element‐offset into	dst
44-                                     const  int     nb1,     //  stride in elements along dim1
45-                                     const  int     nb2,     //  stride in elements along dim2
46-                                     const  int     nb3      //  stride in elements along dim3
37+                                     const  size_t  ne10,
38+                                     const  size_t  ne11,
39+                                     const  size_t  ne12,
40+                                     const  size_t  ne13,
41+                                     const  size_t  nb10,
42+                                     const  size_t  nb11,
43+                                     const  size_t  nb12,
44+                                     const  size_t  nb13,
45+                                     const  size_t  nb0,
46+                                     const  size_t  nb1,
47+                                     const  size_t  nb2,
48+                                     const  size_t  nb3,
49+                                     const  size_t  offset
50+ 
4751) {
48-     //  src0 is contiguous over ne0*ne1*ne2*ne3 elements
49-     const  size_t  total = ne0 * ne1 * ne2 * ne3;
52+     const  size_t  total = ne10 * ne11 * ne12 * ne13;
5053    const  size_t  gid   = blockIdx .x  * blockDim .x  + threadIdx .x ;
5154    if  (gid >= total) {
5255        return ;
5356    }
5457
55-     //  unravel gid to 4D (same as copy) 
56-      size_t        tmp = gid; 
57-     const  size_t  i0   = tmp % ne0 ;
58-     tmp /= ne0 ;
59-     const  size_t  i1 = tmp % ne1 ;
60-     tmp /= ne1 ;
61-     const  size_t  i2 = tmp % ne2 ;
62-     tmp /= ne2 ;
58+     size_t  tmp = gid; 
59+ 
60+     const  size_t  i0 = tmp % ne10 ;
61+     tmp /= ne10 ;
62+     const  size_t  i1 = tmp % ne11 ;
63+     tmp /= ne11 ;
64+     const  size_t  i2 = tmp % ne12 ;
65+     tmp /= ne12 ;
6366    const  size_t  i3 = tmp;
6467
65-     //  dst position has the same formula: 
66-     const   size_t  pos  = offset + i0  + i1 * ( size_t ) nb1  + i2 * ( size_t ) nb2  + i3 * ( size_t ) nb3 ;
68+     size_t  dst_offset  = offset + i0 * nb0 + i1 * nb1 + i2 * nb2 + i3 * nb3; 
69+     size_t  src1_offset  = i0 * nb10  + i1 * nb11  + i2 * nb12  + i3 * nb13 ;
6770
68-     //  src0 is contiguous: flat index = gid
69-     dst[pos] = src0[gid];
71+     *((float  *) ((char  *) dst + dst_offset)) = *((const  float  *) ((const  char  *) src1 + src1_offset));
7072}
7173
7274void  ggml_cuda_op_set (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
75+     //  nb0 is implicitly element_size because src0 and dst are contiguous
7376    const  int32_t  nb1     = dst->op_params [0 ];
7477    const  int32_t  nb2     = dst->op_params [1 ];
7578    const  int32_t  nb3     = dst->op_params [2 ];
@@ -80,31 +83,37 @@ void ggml_cuda_op_set(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
8083    const  ggml_tensor * src1 = dst->src [1 ];
8184
8285    GGML_ASSERT (ggml_are_same_shape (src0, dst));
86+ 
87+     //  TODO: support more dtypes.
8388    GGML_ASSERT (src0->type  == GGML_TYPE_F32);
8489    GGML_ASSERT (src1->type  == GGML_TYPE_F32);
8590    GGML_ASSERT (dst->type  == GGML_TYPE_F32);
8691
87-     //  dims
88-     const  size_t  ne0 = dst->ne [0 ];
89-     const  size_t  ne1 = dst->ne [1 ];
90-     const  size_t  ne2 = dst->ne [2 ];
91-     const  size_t  ne3 = dst->ne [3 ];
92+     GGML_TENSOR_BINARY_OP_LOCALS01;
93+     const  int  nb0 = ggml_element_size (dst);
9294
9395    const  float  * src0_d = (const  float  *) src0->data ;
9496    const  float  * src1_d = (const  float  *) src1->data ;
9597    float  *       dst_d  = (float  *) dst->data ;
9698
9799    cudaStream_t stream = ctx.stream ();
98100
99-     const  size_t  total   = ne0 * ne1 * ne2 * ne3;
100-     const  int     threads = 256 ;
101-     const  int     blocks  = (total + threads - 1 ) / threads;
102- 
103101    if  (!inplace) {
104-         //  copy whole src1→dst
105-         set_f32_cuda_copy<<<blocks, threads, 0 , stream>>> (src1_d, dst_d, ne0, ne1, ne2, ne3, offset, nb1, nb2, nb3);
102+         //  copy whole src0 -> dst.
103+         const  size_t  total = ne00 * ne01 * ne02 * ne03;
104+ 
105+         const  int  num_blocks = (total + CUDA_SET_BLOCK_SIZE - 1 ) / CUDA_SET_BLOCK_SIZE;
106+ 
107+         set_f32_cuda_copy<<<num_blocks, CUDA_SET_BLOCK_SIZE, 0 , stream>>> (
108+             src0_d, dst_d, ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03);
106109    }
107110
108-     //  then overwrite from src0→dst at same offsets/strides
109-     set_f32_cuda<<<blocks, threads, 0 , stream>>> (src0_d, dst_d, ne0, ne1, ne2, ne3, offset, nb1, nb2, nb3);
111+     //  set: src1 -> dst
112+     //  set_f32_cuda
113+ 
114+     const  size_t  total      = ne10 * ne11 * ne12 * ne13;
115+     const  size_t  num_blocks = (total + CUDA_SET_BLOCK_SIZE - 1 ) / CUDA_SET_BLOCK_SIZE;
116+ 
117+     set_f32_cuda<<<num_blocks, CUDA_SET_BLOCK_SIZE, 0 , stream>>> (
118+         src1_d, dst_d, ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13, nb0, nb1, nb2, nb3, offset);
110119}
0 commit comments