@@ -1273,11 +1273,17 @@ class Linear : public UnaryBlock {
12731273 bool force_f32;
12741274
12751275 void init_params (struct ggml_context * ctx, std::map<std::string, enum ggml_type>& tensor_types, const std::string prefix = " " ) {
1276- enum ggml_type wtype = (tensor_types.find (prefix + " weight" ) != tensor_types.end ()) ? tensor_types[prefix + " weight" ] : GGML_TYPE_F32;
1277- if (in_features % ggml_blck_size (wtype) != 0 || force_f32) {
1278- wtype = GGML_TYPE_F32;
1276+ if (tensor_types.find (prefix + " A" ) != tensor_types.end ()) {
1277+ params[" A" ] = ggml_new_tensor_2d (ctx, wtype, in_features, 64 );
1278+ params[" B" ] = ggml_new_tensor_2d (ctx, wtype, 64 , out_features);
1279+ } else {
1280+ enum ggml_type wtype = (tensor_types.find (prefix + " weight" ) != tensor_types.end ()) ? tensor_types[prefix + " weight" ] : GGML_TYPE_F32;
1281+ if (in_features % ggml_blck_size (wtype) != 0 || force_f32) {
1282+ wtype = GGML_TYPE_F32;
1283+ }
1284+ params[" weight" ] = ggml_new_tensor_2d (ctx, wtype, in_features, out_features);
12791285 }
1280- params[ " weight " ] = ggml_new_tensor_2d (ctx, wtype, in_features, out_features);
1286+
12811287 if (bias) {
12821288 enum ggml_type wtype = GGML_TYPE_F32; // (tensor_types.ypes.find(prefix + "bias") != tensor_types.end()) ? tensor_types[prefix + "bias"] : GGML_TYPE_F32;
12831289 params[" bias" ] = ggml_new_tensor_1d (ctx, wtype, out_features);
@@ -1295,12 +1301,22 @@ class Linear : public UnaryBlock {
12951301 force_f32(force_f32) {}
12961302
12971303 struct ggml_tensor * forward (struct ggml_context * ctx, struct ggml_tensor * x) {
1298- struct ggml_tensor * w = params[" weight" ];
1299- struct ggml_tensor * b = NULL ;
1300- if (bias) {
1301- b = params[" bias" ];
1304+ if (params.find (" A" ) != tensor_types.end ()) {
1305+ struct ggml_tensor * down = params[" A" ];
1306+ struct ggml_tensor * up = params[" B" ];
1307+ struct ggml_tensor * b = NULL ;
1308+ if (bias) {
1309+ b = params[" bias" ];
1310+ }
1311+ return ggml_nn_linear (ctx, ggml_nn_linear (ctx, x, down, NULL ), up, b);
1312+ } else {
1313+ struct ggml_tensor * w = params[" weight" ];
1314+ struct ggml_tensor * b = NULL ;
1315+ if (bias) {
1316+ b = params[" bias" ];
1317+ }
1318+ return ggml_nn_linear (ctx, x, w, b);
13021319 }
1303- return ggml_nn_linear (ctx, x, w, b);
13041320 }
13051321};
13061322
0 commit comments