@@ -27,7 +27,19 @@ typedef struct ggml_metal_pipeline * ggml_metal_pipeline_t;
2727ggml_metal_pipeline_t ggml_metal_pipeline_init (void );
2828void ggml_metal_pipeline_free (ggml_metal_pipeline_t pipeline );
2929
30- void * ggml_metal_pipeline_get_obj (ggml_metal_pipeline_t pipeline );
30+ void ggml_metal_pipeline_set_nsg (ggml_metal_pipeline_t pipeline , int nsg );
31+ int ggml_metal_pipeline_get_nsg (ggml_metal_pipeline_t pipeline );
32+
33+ void ggml_metal_pipeline_set_nr0 (ggml_metal_pipeline_t pipeline , int nr0 );
34+ int ggml_metal_pipeline_get_nr0 (ggml_metal_pipeline_t pipeline );
35+
36+ void ggml_metal_pipeline_set_nr1 (ggml_metal_pipeline_t pipeline , int nr1 );
37+ int ggml_metal_pipeline_get_nr1 (ggml_metal_pipeline_t pipeline );
38+
39+ void ggml_metal_pipeline_set_smem (ggml_metal_pipeline_t pipeline , size_t smem );
40+ size_t ggml_metal_pipeline_get_smem (ggml_metal_pipeline_t pipeline );
41+
42+ int ggml_metal_pipeline_max_theads_per_threadgroup (ggml_metal_pipeline_t pipeline );
3143
3244// a collection of pipelines
3345typedef struct ggml_metal_pipelines * ggml_metal_pipelines_t ;
@@ -38,6 +50,37 @@ void ggml_metal_pipelines_free(ggml_metal_pipelines_t ppls);
3850void ggml_metal_pipelines_add (ggml_metal_pipelines_t ppls , const char * name , ggml_metal_pipeline_t pipeline );
3951ggml_metal_pipeline_t ggml_metal_pipelines_get (ggml_metal_pipelines_t ppls , const char * name );
4052
53+ //
54+ // MTLCommandBuffer wrapper
55+ //
56+
57+ typedef void * ggml_metal_cmd_buf_t ;
58+
59+ //
60+ // MTLComputeCommandEncoder wrapper
61+ //
62+
63+ typedef struct ggml_metal_encoder * ggml_metal_encoder_t ;
64+
65+ ggml_metal_encoder_t ggml_metal_encoder_init (ggml_metal_cmd_buf_t cmd_buf_raw , bool concurrent );
66+ void ggml_metal_encoder_free (ggml_metal_encoder_t encoder );
67+
68+ void ggml_metal_encoder_debug_group_push (ggml_metal_encoder_t encoder , const char * name );
69+ void ggml_metal_encoder_debug_group_pop (ggml_metal_encoder_t encoder );
70+
71+ void ggml_metal_encoder_set_pipeline (ggml_metal_encoder_t encoder , ggml_metal_pipeline_t pipeline );
72+
73+ void ggml_metal_encoder_set_bytes (ggml_metal_encoder_t encoder , void * data , size_t size , int idx );
74+ void ggml_metal_encoder_set_buffer (ggml_metal_encoder_t encoder , struct ggml_metal_buffer_id buffer , int idx );
75+
76+ void ggml_metal_encoder_set_threadgroup_memory_size (ggml_metal_encoder_t encoder , size_t size , int idx );
77+
78+ void ggml_metal_encoder_dispatch_threadgroups (ggml_metal_encoder_t encoder , int tg0 , int tg1 , int tg2 , int tptg0 , int tptg1 , int tptg2 );
79+
80+ void ggml_metal_encoder_memory_barrier (ggml_metal_encoder_t encoder );
81+
82+ void ggml_metal_encoder_end_encoding (ggml_metal_encoder_t encoder );
83+
4184//
4285// backend
4386//
@@ -63,6 +106,39 @@ void ggml_metal_set_abort_callback (ggml_metal_t ctx, ggml_abort_callback abort
63106bool ggml_metal_supports_family (ggml_metal_t ctx , int family );
64107void ggml_metal_capture_next_compute (ggml_metal_t ctx );
65108
109+ //
110+ // graph encoder
111+ //
112+
113+ typedef struct ggml_metal_graph_encoder * ggml_metal_graph_encoder_t ;
114+
115+ // TODO: tmp
116+ #include "ggml-metal-common.h"
117+
118+ // TODO: tmp
119+ struct ggml_metal_graph_encoder {
120+ ggml_metal_t ctx ;
121+
122+ const struct ggml_metal_device_props * props_dev ;
123+
124+ ggml_metal_encoder_t encoder ;
125+
126+ ggml_mem_ranges_t mem_ranges ;
127+
128+ struct ggml_cgraph * gf ;
129+
130+ int idx_start ;
131+ int idx_end ;
132+
133+ bool use_fusion ;
134+
135+ int debug_fusion ;
136+ };
137+
138+ bool ggml_metal_graph_encoder_concurrency_reset (ggml_metal_graph_encoder_t ctx );
139+ bool ggml_metal_graph_encoder_concurrency_check (ggml_metal_graph_encoder_t ctx , const struct ggml_tensor * node );
140+ bool ggml_metal_graph_encoder_concurrency_add (ggml_metal_graph_encoder_t ctx , const struct ggml_tensor * node );
141+
66142#ifdef __cplusplus
67143}
68144#endif
0 commit comments