|
44 | 44 | // note: assumes single GPU device - the default one |
45 | 45 | // TODO: support multiple GPU devices |
46 | 46 | static struct ggml_backend_metal_device_context { |
47 | | - id<MTLDevice> mtl_device; |
48 | | - int mtl_device_ref_count; |
| 47 | + id<MTLDevice> mtl_device; |
| 48 | + int mtl_device_ref_count; |
49 | 49 | id<MTLLibrary> mtl_library; |
50 | 50 |
|
51 | 51 | bool has_simdgroup_reduction; |
@@ -470,6 +470,7 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte |
470 | 470 |
|
471 | 471 | struct ggml_backend_metal_context { |
472 | 472 | id<MTLCommandQueue> queue; |
| 473 | + id<MTLHeap> heap; |
473 | 474 |
|
474 | 475 | dispatch_queue_t d_queue; |
475 | 476 |
|
@@ -693,6 +694,19 @@ @implementation GGMLMetalClass |
693 | 694 |
|
694 | 695 | ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT); |
695 | 696 |
|
| 697 | + // allocate tmp heap with fixed size for testing |
| 698 | + // TODO: figure out how to dynamically resize it |
| 699 | + { |
| 700 | + MTLHeapDescriptor *heapDescriptor = [[MTLHeapDescriptor alloc] init]; |
| 701 | + heapDescriptor.storageMode = MTLStorageModePrivate; |
| 702 | + heapDescriptor.cpuCacheMode = MTLCPUCacheModeDefaultCache; |
| 703 | + heapDescriptor.size = 32*1024*1024; |
| 704 | + |
| 705 | + ctx->heap = [device newHeapWithDescriptor:heapDescriptor]; |
| 706 | + |
| 707 | + [heapDescriptor release]; |
| 708 | + } |
| 709 | + |
696 | 710 | // load library |
697 | 711 | if (ctx_dev->mtl_library == nil) { |
698 | 712 | ctx_dev->mtl_library = ggml_metal_load_library(device, ctx_dev->use_bfloat); |
@@ -1136,6 +1150,7 @@ static void ggml_metal_free(struct ggml_backend_metal_context * ctx) { |
1136 | 1150 | Block_release(ctx->encode_async); |
1137 | 1151 |
|
1138 | 1152 | [ctx->queue release]; |
| 1153 | + [ctx->heap release]; |
1139 | 1154 |
|
1140 | 1155 | dispatch_release(ctx->d_queue); |
1141 | 1156 |
|
@@ -1438,7 +1453,8 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex |
1438 | 1453 | static void ggml_metal_encode_node( |
1439 | 1454 | ggml_backend_t backend, |
1440 | 1455 | int idx, |
1441 | | - id<MTLComputeCommandEncoder> encoder) { |
| 1456 | + id<MTLComputeCommandEncoder> encoder, |
| 1457 | + id<MTLHeap> heap) { |
1442 | 1458 | struct ggml_backend_metal_context * ctx = backend->context; |
1443 | 1459 | struct ggml_backend_metal_device_context * ctx_dev = backend->device->context; |
1444 | 1460 |
|
@@ -2110,26 +2126,65 @@ static void ggml_metal_encode_node( |
2110 | 2126 | const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); |
2111 | 2127 | const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); |
2112 | 2128 |
|
2113 | | - ggml_metal_kargs_soft_max args = { |
| 2129 | + // cpy to tmp buffer in MTLHeap |
| 2130 | + |
| 2131 | + ggml_metal_kargs_cpy args_cpy = { |
2114 | 2132 | /*.ne00 =*/ ne00, |
2115 | 2133 | /*.ne01 =*/ ne01, |
2116 | 2134 | /*.ne02 =*/ ne02, |
2117 | | - /*.scale =*/ scale, |
2118 | | - /*.max_bias =*/ max_bias, |
2119 | | - /*.m0 =*/ m0, |
2120 | | - /*.m1 =*/ m1, |
| 2135 | + /*.ne03 =*/ ne03, |
| 2136 | + /*.nb00 =*/ nb00, |
| 2137 | + /*.nb01 =*/ nb01, |
| 2138 | + /*.nb02 =*/ nb02, |
| 2139 | + /*.nb03 =*/ nb03, |
| 2140 | + /*.ne0 =*/ ne00, |
| 2141 | + /*.ne1 =*/ ne01, |
| 2142 | + /*.ne2 =*/ ne02, |
| 2143 | + /*.ne3 =*/ ne03, |
| 2144 | + /*.nb0 =*/ nb00, |
| 2145 | + /*.nb1 =*/ nb01, |
| 2146 | + /*.nb2 =*/ nb02, |
| 2147 | + /*.nb3 =*/ nb03, |
| 2148 | + }; |
| 2149 | + |
| 2150 | + id<MTLBuffer> id_src0h = [heap newBufferWithLength:ggml_nbytes(src0) options:MTLResourceStorageModePrivate]; |
| 2151 | + |
| 2152 | + if (src0->type == GGML_TYPE_F16) { |
| 2153 | + [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline]; |
| 2154 | + } else { |
| 2155 | + [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline]; |
| 2156 | + } |
| 2157 | + [encoder setBytes:&args_cpy length:sizeof(args_cpy) atIndex:0]; |
| 2158 | + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; |
| 2159 | + [encoder setBuffer:id_src0h offset:0 atIndex:2]; |
| 2160 | + |
| 2161 | + GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0); |
| 2162 | + int nth_cpy = MIN(1024, ne00 / ggml_blck_size(src0->type)); |
| 2163 | + |
| 2164 | + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth_cpy, 1, 1)]; |
| 2165 | + |
| 2166 | + // softmax |
| 2167 | + |
| 2168 | + ggml_metal_kargs_soft_max args = { |
| 2169 | + /*.ne00 =*/ ne00, |
| 2170 | + /*.ne01 =*/ ne01, |
| 2171 | + /*.ne02 =*/ ne02, |
| 2172 | + /*.scale =*/ scale, |
| 2173 | + /*.max_bias =*/ max_bias, |
| 2174 | + /*.m0 =*/ m0, |
| 2175 | + /*.m1 =*/ m1, |
2121 | 2176 | /*.n_head_log2 =*/ n_head_log2, |
2122 | 2177 | }; |
2123 | 2178 |
|
2124 | 2179 | [encoder setComputePipelineState:pipeline]; |
2125 | | - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; |
| 2180 | + [encoder setBuffer:id_src0h offset:0 atIndex:0]; |
2126 | 2181 | if (id_src1) { |
2127 | | - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; |
| 2182 | + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; |
2128 | 2183 | } else { |
2129 | | - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; |
| 2184 | + [encoder setBuffer:id_src0h offset:0 atIndex:1]; |
2130 | 2185 | } |
2131 | | - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; |
2132 | | - [encoder setBytes:&args length:sizeof(args) atIndex:3]; |
| 2186 | + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; |
| 2187 | + [encoder setBytes:&args length:sizeof(args) atIndex:3]; |
2133 | 2188 |
|
2134 | 2189 | [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; |
2135 | 2190 |
|
@@ -4991,7 +5046,7 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) { |
4991 | 5046 | [encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]]; |
4992 | 5047 | } |
4993 | 5048 |
|
4994 | | - ggml_metal_encode_node(backend, idx, encoder); |
| 5049 | + ggml_metal_encode_node(backend, idx, encoder, ctx->heap); |
4995 | 5050 |
|
4996 | 5051 | if (should_capture) { |
4997 | 5052 | [encoder popDebugGroup]; |
|
0 commit comments