Skip to content

Commit 48cd384

Browse files
authored
Linear
1 parent a9a661e commit 48cd384

File tree

1 file changed

+25
-9
lines changed

1 file changed

+25
-9
lines changed

ggml_extend.hpp

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1273,11 +1273,17 @@ class Linear : public UnaryBlock {
12731273
bool force_f32;
12741274

12751275
void init_params(struct ggml_context* ctx, std::map<std::string, enum ggml_type>& tensor_types, const std::string prefix = "") {
1276-
enum ggml_type wtype = (tensor_types.find(prefix + "weight") != tensor_types.end()) ? tensor_types[prefix + "weight"] : GGML_TYPE_F32;
1277-
if (in_features % ggml_blck_size(wtype) != 0 || force_f32) {
1278-
wtype = GGML_TYPE_F32;
1276+
if (tensor_types.find(prefix + "A") != tensor_types.end()) {
1277+
params["A"] = ggml_new_tensor_2d(ctx, wtype, in_features, 64);
1278+
params["B"] = ggml_new_tensor_2d(ctx, wtype, 64, out_features);
1279+
} else {
1280+
enum ggml_type wtype = (tensor_types.find(prefix + "weight") != tensor_types.end()) ? tensor_types[prefix + "weight"] : GGML_TYPE_F32;
1281+
if (in_features % ggml_blck_size(wtype) != 0 || force_f32) {
1282+
wtype = GGML_TYPE_F32;
1283+
}
1284+
params["weight"] = ggml_new_tensor_2d(ctx, wtype, in_features, out_features);
12791285
}
1280-
params["weight"] = ggml_new_tensor_2d(ctx, wtype, in_features, out_features);
1286+
12811287
if (bias) {
12821288
enum ggml_type wtype = GGML_TYPE_F32; //(tensor_types.ypes.find(prefix + "bias") != tensor_types.end()) ? tensor_types[prefix + "bias"] : GGML_TYPE_F32;
12831289
params["bias"] = ggml_new_tensor_1d(ctx, wtype, out_features);
@@ -1295,12 +1301,22 @@ class Linear : public UnaryBlock {
12951301
force_f32(force_f32) {}
12961302

12971303
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
1298-
struct ggml_tensor* w = params["weight"];
1299-
struct ggml_tensor* b = NULL;
1300-
if (bias) {
1301-
b = params["bias"];
1304+
if (params.find("A") != tensor_types.end()) {
1305+
struct ggml_tensor* down = params["A"];
1306+
struct ggml_tensor* up = params["B"];
1307+
struct ggml_tensor* b = NULL;
1308+
if (bias) {
1309+
b = params["bias"];
1310+
}
1311+
return ggml_nn_linear(ctx, ggml_nn_linear(ctx, x, down, NULL), up, b);
1312+
} else {
1313+
struct ggml_tensor* w = params["weight"];
1314+
struct ggml_tensor* b = NULL;
1315+
if (bias) {
1316+
b = params["bias"];
1317+
}
1318+
return ggml_nn_linear(ctx, x, w, b);
13021319
}
1303-
return ggml_nn_linear(ctx, x, w, b);
13041320
}
13051321
};
13061322

0 commit comments

Comments
 (0)