@@ -137,6 +137,7 @@ struct webgpu_context_struct {
137137 wgpu::ComputePipeline mul_ip_pipeline[2 ];
138138 wgpu::ComputePipeline rms_norm_pipeline;
139139 wgpu::ComputePipeline rms_norm_ip_pipeline;
140+ wgpu::ComputePipeline rope_pipeline[2 ][2 ];
140141
141142 size_t memset_bytes_per_thread;
142143
@@ -693,9 +694,6 @@ static void ggml_webgpu_binary_op(webgpu_context & ctx,
693694static void ggml_webgpu_rms_norm (webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
694695 bool in_place = ggml_webgpu_tensor_equal (src, dst);
695696
696- uint32_t eps;
697- memcpy (&eps, dst->op_params , sizeof (float ));
698-
699697 std::vector<uint32_t > params = {
700698 (uint32_t ) (ggml_webgpu_tensor_misalignment (ctx, src) / ggml_type_size (src->type )),
701699 };
@@ -714,7 +712,7 @@ static void ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * src, ggml_t
714712 params.push_back ((uint32_t ) src->ne [1 ]);
715713 params.push_back ((uint32_t ) src->ne [2 ]);
716714 params.push_back ((uint32_t ) src->ne [3 ]);
717- params.push_back (eps) ; // epsilon, will be bitcast to float in shader
715+ params.push_back (*( uint32_t *) dst-> op_params ) ; // epsilon, treated as f32 in the shader
718716
719717 std::vector<wgpu::BindGroupEntry> entries = {
720718 { .binding = 0 ,
@@ -740,6 +738,90 @@ static void ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * src, ggml_t
740738 ggml_backend_webgpu_build_and_enqueue (ctx, pipeline, params, entries, wg_x, ggml_op_name (dst->op ));
741739}
742740
741+ static void ggml_webgpu_rope (webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * src2, ggml_tensor * dst) {
742+ bool in_place = ggml_webgpu_tensor_equal (src0, dst);
743+ int has_freq_factor = (src2 != nullptr );
744+
745+ float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
746+ const int n_dims = ((int32_t *) dst->op_params )[1 ];
747+ const int n_ctx_orig = ((int32_t *) dst->op_params )[4 ];
748+
749+ memcpy (&freq_base, (int32_t *) dst->op_params + 5 , sizeof (float ));
750+ memcpy (&freq_scale, (int32_t *) dst->op_params + 6 , sizeof (float ));
751+ memcpy (&ext_factor, (int32_t *) dst->op_params + 7 , sizeof (float ));
752+ memcpy (&attn_factor, (int32_t *) dst->op_params + 8 , sizeof (float ));
753+ memcpy (&beta_fast, (int32_t *) dst->op_params + 9 , sizeof (float ));
754+ memcpy (&beta_slow, (int32_t *) dst->op_params + 10 , sizeof (float ));
755+
756+ float theta_scale = powf (freq_base, -2 .0f / n_dims);
757+
758+ float corr_dims[2 ];
759+ ggml_rope_yarn_corr_dims (n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
760+
761+ std::vector<uint32_t > params = {
762+ (uint32_t ) (ggml_webgpu_tensor_misalignment (ctx, src0) / ggml_type_size (src0->type )),
763+ (uint32_t ) (ggml_webgpu_tensor_misalignment (ctx, src1) / ggml_type_size (src1->type )),
764+ };
765+ if (!in_place) {
766+ params.push_back ((uint32_t ) (ggml_webgpu_tensor_misalignment (ctx, dst) / ggml_type_size (dst->type )));
767+ }
768+ params.push_back ((uint32_t ) (src0->nb [1 ] / ggml_type_size (src0->type )));
769+ params.push_back ((uint32_t ) (src0->nb [2 ] / ggml_type_size (src0->type )));
770+ params.push_back ((uint32_t ) (src0->nb [3 ] / ggml_type_size (src0->type )));
771+ if (!in_place) {
772+ params.push_back ((uint32_t ) (dst->nb [1 ] / ggml_type_size (dst->type )));
773+ params.push_back ((uint32_t ) (dst->nb [2 ] / ggml_type_size (dst->type )));
774+ params.push_back ((uint32_t ) (dst->nb [3 ] / ggml_type_size (dst->type )));
775+ }
776+ params.push_back ((uint32_t ) ggml_nelements (src0) / 2 );
777+ params.push_back ((uint32_t ) src0->ne [0 ]);
778+ params.push_back ((uint32_t ) src0->ne [1 ]);
779+ params.push_back ((uint32_t ) src0->ne [2 ]);
780+
781+ params.push_back ((uint32_t ) n_dims);
782+ params.push_back (*(uint32_t *) &theta_scale);
783+ params.push_back (*(uint32_t *) &attn_factor);
784+ params.push_back (*(uint32_t *) &freq_scale);
785+ params.push_back (*(uint32_t *) &ext_factor);
786+ params.push_back (*(uint32_t *) &corr_dims[0 ]);
787+ params.push_back (*(uint32_t *) &corr_dims[1 ]);
788+
789+ std::vector<wgpu::BindGroupEntry> entries = {
790+ { .binding = 0 ,
791+ .buffer = ggml_webgpu_tensor_buf (src0),
792+ .offset = ggml_webgpu_tensor_align_offset (ctx, src0),
793+ .size = ggml_webgpu_tensor_binding_size (ctx, src0) },
794+ { .binding = 1 ,
795+ .buffer = ggml_webgpu_tensor_buf (src1),
796+ .offset = ggml_webgpu_tensor_align_offset (ctx, src1),
797+ .size = ggml_webgpu_tensor_binding_size (ctx, src1) }
798+ };
799+ uint32_t dst_binding = 2 ;
800+ if (has_freq_factor) {
801+ dst_binding = 3 ;
802+ entries.push_back ({ .binding = 2 ,
803+ .buffer = ggml_webgpu_tensor_buf (src2),
804+ .offset = ggml_webgpu_tensor_align_offset (ctx, src2),
805+ .size = ggml_webgpu_tensor_binding_size (ctx, src2) });
806+ }
807+ if (!in_place) {
808+ entries.push_back ({ .binding = dst_binding,
809+ .buffer = ggml_webgpu_tensor_buf (dst),
810+ .offset = ggml_webgpu_tensor_align_offset (ctx, dst),
811+ .size = ggml_webgpu_tensor_binding_size (ctx, dst) });
812+ }
813+
814+ wgpu::ComputePipeline pipeline;
815+ if (in_place) {
816+ pipeline = ctx->rope_pipeline [dst->type ][has_freq_factor];
817+ } else {
818+ pipeline = ctx->rope_pipeline [dst->type ][has_freq_factor];
819+ }
820+ size_t max_wg_size = ctx->max_wg_size_x ;
821+ uint32_t wg_x = (ggml_nelements (src0) / 2 + max_wg_size - 1 ) / max_wg_size;
822+ ggml_backend_webgpu_build_and_enqueue (ctx, pipeline, params, entries, wg_x, ggml_op_name (dst->op ));
823+ }
824+
743825// Returns true if node has enqueued work into the queue, false otherwise
744826static bool ggml_webgpu_encode_node (webgpu_context ctx, ggml_tensor * node) {
745827 if (ggml_is_empty (node)) {
@@ -749,6 +831,7 @@ static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) {
749831
750832 ggml_tensor * src0 = node->src [0 ];
751833 ggml_tensor * src1 = node->src [1 ];
834+ ggml_tensor * src2 = node->src [2 ];
752835
753836 switch (node->op ) {
754837 // no-ops
@@ -787,6 +870,9 @@ static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) {
787870 case GGML_OP_RMS_NORM:
788871 ggml_webgpu_rms_norm (ctx, src0, node);
789872 break ;
873+ case GGML_OP_ROPE:
874+ ggml_webgpu_rope (ctx, src0, src1, src2, node);
875+ break ;
790876 default :
791877 return false ;
792878 }
@@ -1206,6 +1292,14 @@ static void ggml_webgpu_init_rms_norm_pipeline(webgpu_context & webgpu_ctx) {
12061292 " rms_norm_in_place" , constants);
12071293}
12081294
1295+ static void ggml_webgpu_init_rope_pipeline (webgpu_context & webgpu_ctx) {
1296+ std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry (webgpu_ctx);
1297+ ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->rope_pipeline [GGML_TYPE_F32][0 ], wgsl_rope_f32_norm, " rope_f32_norm" , constants);
1298+ ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->rope_pipeline [GGML_TYPE_F32][1 ], wgsl_rope_f32_norm_ff, " rope_f32_norm_ff" , constants);
1299+ ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->rope_pipeline [GGML_TYPE_F16][0 ], wgsl_rope_f16_norm, " rope_f16_norm" , constants);
1300+ ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->rope_pipeline [GGML_TYPE_F16][1 ], wgsl_rope_f16_norm_ff, " rope_f16_norm_ff" , constants);
1301+ }
1302+
12091303static ggml_backend_t ggml_backend_webgpu_device_init (ggml_backend_dev_t dev, const char * params) {
12101304 GGML_UNUSED (params);
12111305
@@ -1287,6 +1381,8 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
12871381
12881382 ggml_tensor * src0 = op->src [0 ];
12891383 ggml_tensor * src1 = op->src [1 ];
1384+ ggml_tensor * src2 = op->src [1 ];
1385+
12901386 // on smaller devices (or CI), tensors may be larger than the max storage buffer size
12911387 if (ggml_nbytes (op) > webgpu_ctx->limits .maxStorageBufferBindingSize ||
12921388 (src0 != nullptr && ggml_nbytes (src0) > webgpu_ctx->limits .maxStorageBufferBindingSize ) ||
@@ -1360,6 +1456,23 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
13601456 case GGML_OP_RMS_NORM:
13611457 supports_op = op->type == GGML_TYPE_F32 && op->src [0 ]->type == GGML_TYPE_F32;
13621458 break ;
1459+ case GGML_OP_ROPE:
1460+ {
1461+ // std::cout << "ROPE op types: dst: " << ggml_type_name(op->type)
1462+ // << ", src0: " << (op->src[0] ? ggml_type_name(op->src[0]->type) : "null")
1463+ // << ", src1: " << (op->src[1] ? ggml_type_name(op->src[1]->type) : "null")
1464+ // << ", src2: " << (op->src[2] ? ggml_type_name(op->src[2]->type) : "null") << std::endl;
1465+ // std::cout << "ROPE op shapes: dst: op->ne[0]=" << op->ne[0] << ", ne[1]=" << op->ne[1] << ", ne[2]=" << op->ne[2]
1466+ // << ", ne[3]=" << op->ne[3] << std::endl;
1467+ // std::cout << "ROPE op shapes: src0: src0->ne[0]=" << op->src[0]->ne[0] << ", ne[1]=" << op->src[0]->ne[1]
1468+ // << ", ne[2]=" << op->src[0]->ne[2] << ", ne[3]=" << op->src[0]->ne[3] << std::endl;
1469+ // std::cout << "ROPE op shapes: src1: src1->ne[0]=" << op->src[1]->ne[0] << ", ne[1]=" << op->src[1]->ne[1]
1470+ // << ", ne[2]=" << op->src[1]->ne[2] << ", ne[3]=" << op->src[1]->ne[3] << std::endl;
1471+
1472+ const int mode = ((int32_t *) op->op_params )[2 ];
1473+ supports_op = mode == 0 && (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16);
1474+ break ;
1475+ }
13631476 default :
13641477 break ;
13651478 }
@@ -1486,6 +1599,7 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
14861599 ggml_webgpu_init_add_pipeline (ctx);
14871600 ggml_webgpu_init_mul_pipeline (ctx);
14881601 ggml_webgpu_init_rms_norm_pipeline (ctx);
1602+ ggml_webgpu_init_rope_pipeline (ctx);
14891603
14901604#ifdef GGML_WEBGPU_DEBUG
14911605 // Initialize debug buffers
0 commit comments