@@ -2374,6 +2374,8 @@ static bool ggml_metal_encode_node(
23742374                const  float  m0 = powf (2 .0f , -(max_bias       ) / n_head_log2);
23752375                const  float  m1 = powf (2 .0f , -(max_bias / 2 .0f ) / n_head_log2);
23762376
2377+ //  use this branch to test the ggml_metal_mem_pool functionality
2378+ #if  0 
23772379                // cpy to tmp buffer in MTLHeap
23782380
23792381                id<MTLBuffer> h_src0 = h_src0 = ggml_metal_mem_pool_alloc(mem_pool, ggml_nbytes(src0));
@@ -2382,6 +2384,8 @@ static bool ggml_metal_encode_node(
23822384                    return false;
23832385                }
23842386
2387+                 offs_src0 = 0;
2388+ 
23852389                ggml_metal_kargs_cpy args_cpy = {
23862390                    /*.ne00 =*/ ne00,
23872391                    /*.ne01 =*/ ne01,
@@ -2415,6 +2419,9 @@ static bool ggml_metal_encode_node(
24152419
24162420                [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth_cpy, 1, 1)];
24172421
2422+ #else 
2423+                 id <MTLBuffer > h_src0 = id_src0;
2424+ #endif 
24182425                //  softmax
24192426
24202427                ggml_metal_kargs_soft_max args = {
@@ -2429,11 +2436,11 @@ static bool ggml_metal_encode_node(
24292436                };
24302437
24312438                [encoder setComputePipelineState: pipeline];
2432-                 [encoder setBuffer: h_src0 offset: 0                atIndex: 0 ];
2439+                 [encoder setBuffer: h_src0 offset: offs_src0       atIndex: 0 ];
24332440                if  (id_src1) {
24342441                    [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 1 ];
24352442                } else  {
2436-                     [encoder setBuffer: h_src0 offset: 0            atIndex: 1 ];
2443+                     [encoder setBuffer: h_src0 offset: offs_src0   atIndex: 1 ];
24372444                }
24382445                [encoder setBuffer: id_dst offset: offs_dst       atIndex: 2 ];
24392446                [encoder setBytes: &args   length: sizeof (args)   atIndex: 3 ];
0 commit comments