@@ -975,38 +975,28 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_conv_2d(struct ggml_context* ctx,
975975 struct ggml_tensor * x,
976976 struct ggml_tensor * w,
977977 struct ggml_tensor * b,
978- int s0 = 1 ,
979- int s1 = 1 ,
980- int p0 = 0 ,
981- int p1 = 0 ,
982- int d0 = 1 ,
983- int d1 = 1 ) {
984- x = ggml_conv_2d (ctx, w, x, s0, s1, p0, p1, d0, d1);
985- if (b != NULL ) {
986- b = ggml_reshape_4d (ctx, b, 1 , 1 , b->ne [0 ], 1 );
987- // b = ggml_repeat(ctx, b, x);
988- x = ggml_add_inplace (ctx, x, b);
978+ int s0 = 1 ,
979+ int s1 = 1 ,
980+ int p0 = 0 ,
981+ int p1 = 0 ,
982+ int d0 = 1 ,
983+ int d1 = 1 ,
984+ bool direct = false ,
985+ float scale = 1 .f) {
986+ if (scale != 1 .f ) {
987+ x = ggml_scale (ctx, x, scale);
988+ }
989+ if (direct) {
990+ x = ggml_conv_2d_direct (ctx, w, x, s0, s1, p0, p1, d0, d1);
991+ } else {
992+ x = ggml_conv_2d (ctx, w, x, s0, s1, p0, p1, d0, d1);
993+ }
994+ if (scale != 1 .f ) {
995+ x = ggml_scale (ctx, x, 1 .f / scale);
989996 }
990- return x;
991- }
992-
993- // w: [OC*IC, KD, KH, KW]
994- // x: [N*IC, ID, IH, IW]
995- __STATIC_INLINE__ struct ggml_tensor * ggml_nn_conv_2d_direct (struct ggml_context * ctx,
996- struct ggml_tensor * x,
997- struct ggml_tensor * w,
998- struct ggml_tensor * b,
999- int s0 = 1 ,
1000- int s1 = 1 ,
1001- int p0 = 0 ,
1002- int p1 = 0 ,
1003- int d0 = 1 ,
1004- int d1 = 1 ) {
1005- x = ggml_conv_2d_direct (ctx, w, x, s0, s1, p0, p1, d0, d1);
1006997 if (b != NULL ) {
1007998 b = ggml_reshape_4d (ctx, b, 1 , 1 , b->ne [0 ], 1 );
1008- // b = ggml_repeat(ctx, b, x);
1009- x = ggml_add (ctx, x, b);
999+ x = ggml_add_inplace (ctx, x, b);
10101000 }
10111001 return x;
10121002}
@@ -2067,6 +2057,7 @@ class Conv2d : public UnaryBlock {
20672057 std::pair<int , int > dilation;
20682058 bool bias;
20692059 bool direct = false ;
2060+ float scale = 1 .f;
20702061
20712062 void init_params (struct ggml_context * ctx, const String2GGMLType& tensor_types, const std::string prefix = " " ) {
20722063 enum ggml_type wtype = GGML_TYPE_F16;
@@ -2097,6 +2088,10 @@ class Conv2d : public UnaryBlock {
20972088 direct = true ;
20982089 }
20992090
2091+ void set_scale (float scale_value) {
2092+ scale = scale_value;
2093+ }
2094+
21002095 std::string get_desc () {
21012096 return " Conv2d" ;
21022097 }
@@ -2107,11 +2102,18 @@ class Conv2d : public UnaryBlock {
21072102 if (bias) {
21082103 b = params[" bias" ];
21092104 }
2110- if (direct) {
2111- return ggml_nn_conv_2d_direct (ctx, x, w, b, stride.second , stride.first , padding.second , padding.first , dilation.second , dilation.first );
2112- } else {
2113- return ggml_nn_conv_2d (ctx, x, w, b, stride.second , stride.first , padding.second , padding.first , dilation.second , dilation.first );
2114- }
2105+ return ggml_nn_conv_2d (ctx,
2106+ x,
2107+ w,
2108+ b,
2109+ stride.second ,
2110+ stride.first ,
2111+ padding.second ,
2112+ padding.first ,
2113+ dilation.second ,
2114+ dilation.first ,
2115+ direct,
2116+ scale);
21152117 }
21162118};
21172119
0 commit comments