@@ -2020,7 +2020,12 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
20202020 // output
20212021 output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
20222022 output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0);
2023- output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
2023+ output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
2024+
2025+ // if output is NULL, init from the input tok embed
2026+ if (output == NULL) {
2027+ output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
2028+ }
20242029
20252030 for (int i = 0; i < n_layer; ++i) {
20262031 auto & layer = layers[i];
@@ -2381,7 +2386,12 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
23812386 // output
23822387 output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
23832388 output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0);
2384- output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
2389+ output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
2390+
2391+ // if output is NULL, init from the input tok embed
2392+ if (output == NULL) {
2393+ output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
2394+ }
23852395
23862396 for (int i = 0; i < n_layer; ++i) {
23872397 auto & layer = layers[i];
@@ -2407,7 +2417,12 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
24072417 } break;
24082418 case LLM_ARCH_CODESHELL:
24092419 {
2410- tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
2420+ tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
2421+
2422+ // if tok embd is NULL, init from output
2423+ if (tok_embd == NULL) {
2424+ tok_embd = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
2425+ }
24112426
24122427 // output
24132428 output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
0 commit comments