@@ -137,7 +137,7 @@ struct webgpu_context_struct {
137137 wgpu::ComputePipeline mul_ip_pipeline[2 ];
138138 wgpu::ComputePipeline rms_norm_pipeline;
139139 wgpu::ComputePipeline rms_norm_ip_pipeline;
140- wgpu::ComputePipeline rope_pipeline[2 ][2 ];
140+ wgpu::ComputePipeline rope_pipeline[2 ][2 ][ 2 ][ 2 ]; // type, mode, ff, inplace
141141
142142 size_t memset_bytes_per_thread;
143143
@@ -734,11 +734,17 @@ static void ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * src, ggml_t
734734 ggml_backend_webgpu_build_and_enqueue (ctx, pipeline, params, entries, wg_x, ggml_op_name (dst->op ));
735735}
736736
737- static void ggml_webgpu_rope (webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * src2, ggml_tensor * dst) {
738- bool in_place = ggml_webgpu_tensor_equal (src0, dst);
739- int has_freq_factor = (src2 != nullptr );
740-
741- float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
737+ static void ggml_webgpu_rope (webgpu_context & ctx,
738+ ggml_tensor * src0,
739+ ggml_tensor * src1,
740+ ggml_tensor * src2,
741+ ggml_tensor * dst) {
742+ const int inplace = ggml_webgpu_tensor_equal (src0, dst);
743+ const int has_freq_factor = (src2 != nullptr );
744+ const int mode = ((int32_t *) dst->op_params )[2 ];
745+ const int is_neox = mode & GGML_ROPE_TYPE_NEOX;
746+
747+ float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
742748 const int n_dims = ((int32_t *) dst->op_params )[1 ];
743749 const int n_ctx_orig = ((int32_t *) dst->op_params )[4 ];
744750
@@ -757,30 +763,25 @@ static void ggml_webgpu_rope(webgpu_context & ctx, ggml_tensor * src0, ggml_tens
757763 std::vector<uint32_t > params = {
758764 (uint32_t ) (ggml_webgpu_tensor_misalignment (ctx, src0) / ggml_type_size (src0->type )),
759765 (uint32_t ) (ggml_webgpu_tensor_misalignment (ctx, src1) / ggml_type_size (src1->type )),
766+ (uint32_t ) (ggml_webgpu_tensor_misalignment (ctx, dst) / ggml_type_size (dst->type )),
767+ (uint32_t ) (src0->nb [1 ] / ggml_type_size (src0->type )),
768+ (uint32_t ) (src0->nb [2 ] / ggml_type_size (src0->type )),
769+ (uint32_t ) (src0->nb [3 ] / ggml_type_size (src0->type )),
770+ (uint32_t ) (dst->nb [1 ] / ggml_type_size (dst->type )),
771+ (uint32_t ) (dst->nb [2 ] / ggml_type_size (dst->type )),
772+ (uint32_t ) (dst->nb [3 ] / ggml_type_size (dst->type )),
773+ (uint32_t ) ggml_nelements (src0) / 2 ,
774+ (uint32_t ) src0->ne [0 ],
775+ (uint32_t ) src0->ne [1 ],
776+ (uint32_t ) src0->ne [2 ],
777+ (uint32_t ) n_dims,
778+ *(uint32_t *) &theta_scale,
779+ *(uint32_t *) &attn_factor,
780+ *(uint32_t *) &freq_scale,
781+ *(uint32_t *) &ext_factor,
782+ *(uint32_t *) &corr_dims[0 ],
783+ *(uint32_t *) &corr_dims[1 ]
760784 };
761- if (!in_place) {
762- params.push_back ((uint32_t ) (ggml_webgpu_tensor_misalignment (ctx, dst) / ggml_type_size (dst->type )));
763- }
764- params.push_back ((uint32_t ) (src0->nb [1 ] / ggml_type_size (src0->type )));
765- params.push_back ((uint32_t ) (src0->nb [2 ] / ggml_type_size (src0->type )));
766- params.push_back ((uint32_t ) (src0->nb [3 ] / ggml_type_size (src0->type )));
767- if (!in_place) {
768- params.push_back ((uint32_t ) (dst->nb [1 ] / ggml_type_size (dst->type )));
769- params.push_back ((uint32_t ) (dst->nb [2 ] / ggml_type_size (dst->type )));
770- params.push_back ((uint32_t ) (dst->nb [3 ] / ggml_type_size (dst->type )));
771- }
772- params.push_back ((uint32_t ) ggml_nelements (src0) / 2 );
773- params.push_back ((uint32_t ) src0->ne [0 ]);
774- params.push_back ((uint32_t ) src0->ne [1 ]);
775- params.push_back ((uint32_t ) src0->ne [2 ]);
776-
777- params.push_back ((uint32_t ) n_dims);
778- params.push_back (*(uint32_t *) &theta_scale);
779- params.push_back (*(uint32_t *) &attn_factor);
780- params.push_back (*(uint32_t *) &freq_scale);
781- params.push_back (*(uint32_t *) &ext_factor);
782- params.push_back (*(uint32_t *) &corr_dims[0 ]);
783- params.push_back (*(uint32_t *) &corr_dims[1 ]);
784785
785786 std::vector<wgpu::BindGroupEntry> entries = {
786787 { .binding = 0 ,
@@ -800,21 +801,16 @@ static void ggml_webgpu_rope(webgpu_context & ctx, ggml_tensor * src0, ggml_tens
800801 .offset = ggml_webgpu_tensor_align_offset (ctx, src2),
801802 .size = ggml_webgpu_tensor_binding_size (ctx, src2) });
802803 }
803- if (!in_place ) {
804+ if (!inplace ) {
804805 entries.push_back ({ .binding = dst_binding,
805806 .buffer = ggml_webgpu_tensor_buf (dst),
806807 .offset = ggml_webgpu_tensor_align_offset (ctx, dst),
807808 .size = ggml_webgpu_tensor_binding_size (ctx, dst) });
808809 }
809810
810- wgpu::ComputePipeline pipeline;
811- if (in_place) {
812- pipeline = ctx->rope_pipeline [dst->type ][has_freq_factor];
813- } else {
814- pipeline = ctx->rope_pipeline [dst->type ][has_freq_factor];
815- }
816- size_t max_wg_size = ctx->max_wg_size_x ;
817- uint32_t wg_x = (ggml_nelements (src0) / 2 + max_wg_size - 1 ) / max_wg_size;
811+ wgpu::ComputePipeline pipeline = ctx->rope_pipeline [dst->type ][is_neox][has_freq_factor][inplace];
812+ size_t max_wg_size = ctx->max_wg_size_x ;
813+ uint32_t wg_x = (ggml_nelements (src0) / 2 + max_wg_size - 1 ) / max_wg_size;
818814 ggml_backend_webgpu_build_and_enqueue (ctx, pipeline, params, entries, wg_x, ggml_op_name (dst->op ));
819815}
820816
@@ -1290,10 +1286,22 @@ static void ggml_webgpu_init_rms_norm_pipeline(webgpu_context & webgpu_ctx) {
12901286
12911287static void ggml_webgpu_init_rope_pipeline (webgpu_context & webgpu_ctx) {
12921288 std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry (webgpu_ctx);
1293- ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->rope_pipeline [GGML_TYPE_F32][0 ], wgsl_rope_f32_norm, " rope_f32_norm" , constants);
1294- ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->rope_pipeline [GGML_TYPE_F32][1 ], wgsl_rope_f32_norm_ff, " rope_f32_norm_ff" , constants);
1295- ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->rope_pipeline [GGML_TYPE_F16][0 ], wgsl_rope_f16_norm, " rope_f16_norm" , constants);
1296- ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->rope_pipeline [GGML_TYPE_F16][1 ], wgsl_rope_f16_norm_ff, " rope_f16_norm_ff" , constants);
1289+ ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->rope_pipeline [GGML_TYPE_F32][0 ][0 ][0 ],
1290+ wgsl_rope_f32_norm, " rope_f32_norm" , constants);
1291+ ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->rope_pipeline [GGML_TYPE_F32][0 ][0 ][1 ],
1292+ wgsl_rope_f32_norm_inplace, " rope_f32_norm_inplace" , constants);
1293+ ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->rope_pipeline [GGML_TYPE_F32][0 ][1 ][0 ],
1294+ wgsl_rope_f32_norm_ff, " rope_f32_norm_ff" , constants);
1295+ ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->rope_pipeline [GGML_TYPE_F32][0 ][1 ][1 ],
1296+ wgsl_rope_f32_norm_ff_inplace, " rope_f32_norm_ff_inplace" , constants);
1297+ ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->rope_pipeline [GGML_TYPE_F16][0 ][0 ][0 ],
1298+ wgsl_rope_f16_norm, " rope_f16_norm" , constants);
1299+ ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->rope_pipeline [GGML_TYPE_F16][0 ][0 ][1 ],
1300+ wgsl_rope_f16_norm_inplace, " rope_f16_norm_inplace" , constants);
1301+ ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->rope_pipeline [GGML_TYPE_F16][0 ][1 ][0 ],
1302+ wgsl_rope_f16_norm_ff, " rope_f16_norm_ff" , constants);
1303+ ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->rope_pipeline [GGML_TYPE_F16][0 ][1 ][1 ],
1304+ wgsl_rope_f16_norm_ff_inplace, " rope_f16_norm_ff_inplace" , constants);
12971305}
12981306
12991307static ggml_backend_t ggml_backend_webgpu_device_init (ggml_backend_dev_t dev, const char * params) {
0 commit comments