@@ -65,6 +65,7 @@ struct random_tensor {
6565 for (int64_t d : shape) {
6666 prod *= d;
6767 }
68+ GGML_ASSERT (prod != 0 );
6869 return ggml_row_size (type, prod);
6970 }
7071
@@ -266,8 +267,20 @@ struct model_variant {
266267 tensors (other.tensors),
267268 metadata (other.metadata) {}
268269
270+ void add_tensor (const std::string & name, const std::vector<int64_t > & shape, float gain = 1 .0f ) {
271+ // ref: https://github.com/pytorch/pytorch/blob/134179474539648ba7dee1317959529fbd0e7f89/torch/nn/init.py#L515-L516
272+ const auto init_kaiming_uniform = [gain](uint32_t fan_in) {
273+ const float std = gain * std::sqrt (fan_in);
274+ const float bound = std::sqrt (3 .0f ) * std;
275+
276+ return std::uniform_real_distribution<float >(-bound, bound);
277+ };
278+
279+ tensors.push_back (random_tensor (name, shape, init_kaiming_uniform (shape[0 ])));
280+ }
281+
269282 void add_tensor (const std::string & name, const std::vector<int64_t > & shape,
270- const std::function<float (std::mt19937 &)> & distribution = std::normal_distribution<float>() ) {
283+ const std::function<float (std::mt19937 &)> & distribution) {
271284 tensors.push_back (random_tensor (name, shape, distribution));
272285 }
273286
@@ -299,7 +312,7 @@ struct model_variant {
299312
300313 size_t total_size = 0 ;
301314 for (const auto & t : tensors) {
302- total_size += t.n_bytes () + ggml_tensor_overhead ();
315+ total_size += GGML_PAD ( t.n_bytes () + ggml_tensor_overhead (), GGML_MEM_ALIGN );
303316 }
304317
305318 ggml_init_params init_params = {
@@ -356,6 +369,11 @@ struct model_variant {
356369 m.add_kv (LLM_KV_TOKENIZER_TOKEN_TYPE, vocab_types);
357370 };
358371
372+ // don't actually use bias
373+ const auto init_bias = []() {
374+ return 0 .0f ;
375+ };
376+
359377 // TODO: fill the variants
360378 // TODO: how to make the variants more modular?
361379 switch (arch) {
@@ -591,12 +609,12 @@ struct model_variant {
591609 cur.add_tensor (tn (LLM_TENSOR_SSM_IN, " weight" , i), { n_embd, 2 * d_inner });
592610
593611 cur.add_tensor (tn (LLM_TENSOR_SSM_CONV1D, " weight" , i), { d_conv, d_inner });
594- cur.add_tensor (tn (LLM_TENSOR_SSM_CONV1D, " bias" , i), { d_inner });
612+ cur.add_tensor (tn (LLM_TENSOR_SSM_CONV1D, " bias" , i), { d_inner }, init_bias );
595613
596614 cur.add_tensor (tn (LLM_TENSOR_SSM_X, " weight" , i), { d_inner, dt_rank + 2 * d_state });
597615
598616 cur.add_tensor (tn (LLM_TENSOR_SSM_DT, " weight" , i), { dt_rank, d_inner });
599- cur.add_tensor (tn (LLM_TENSOR_SSM_DT, " bias" , i), { d_inner });
617+ cur.add_tensor (tn (LLM_TENSOR_SSM_DT, " bias" , i), { d_inner }, init_bias );
600618
601619 // no "weight" suffix for these
602620 cur.add_tensor (tn (LLM_TENSOR_SSM_A, i), { d_state, d_inner }, init_A_S4D);
@@ -674,19 +692,19 @@ struct model_variant {
674692
675693 // Block 0, LN0
676694 cur.add_tensor (tn (LLM_TENSOR_TOKEN_EMBD_NORM, " weight" ), {n_embd});
677- cur.add_tensor (tn (LLM_TENSOR_TOKEN_EMBD_NORM, " bias" ), {n_embd});
695+ cur.add_tensor (tn (LLM_TENSOR_TOKEN_EMBD_NORM, " bias" ), {n_embd}, init_bias );
678696
679697 // output
680698 cur.add_tensor (tn (LLM_TENSOR_OUTPUT_NORM, " weight" ), {n_embd});
681- cur.add_tensor (tn (LLM_TENSOR_OUTPUT_NORM, " bias" ), {n_embd});
699+ cur.add_tensor (tn (LLM_TENSOR_OUTPUT_NORM, " bias" ), {n_embd}, init_bias );
682700 cur.add_tensor (tn (LLM_TENSOR_OUTPUT, " weight" ), {n_embd, n_vocab});
683701
684702 for (uint32_t i = 0 ; i < n_layer; ++i) {
685703 cur.add_tensor (tn (LLM_TENSOR_ATTN_NORM, " weight" , i), {n_embd});
686- cur.add_tensor (tn (LLM_TENSOR_ATTN_NORM, " bias" , i), {n_embd});
704+ cur.add_tensor (tn (LLM_TENSOR_ATTN_NORM, " bias" , i), {n_embd}, init_bias );
687705
688706 cur.add_tensor (tn (LLM_TENSOR_ATTN_NORM_2, " weight" , i), {n_embd});
689- cur.add_tensor (tn (LLM_TENSOR_ATTN_NORM_2, " bias" , i), {n_embd});
707+ cur.add_tensor (tn (LLM_TENSOR_ATTN_NORM_2, " bias" , i), {n_embd}, init_bias );
690708
691709 cur.add_tensor (tn (LLM_TENSOR_TIME_MIX_W0, " weight" , i), {n_embd});
692710 cur.add_tensor (tn (LLM_TENSOR_TIME_MIX_W1, " weight" , i), {n_embd, n_lora_decay});
@@ -721,7 +739,7 @@ struct model_variant {
721739 cur.add_tensor (tn (LLM_TENSOR_TIME_MIX_RECEPTANCE, " weight" , i), {attn_hidden_size, n_embd});
722740
723741 cur.add_tensor (tn (LLM_TENSOR_TIME_MIX_LN, " weight" , i), {n_embd});
724- cur.add_tensor (tn (LLM_TENSOR_TIME_MIX_LN, " bias" , i), {n_embd});
742+ cur.add_tensor (tn (LLM_TENSOR_TIME_MIX_LN, " bias" , i), {n_embd}, init_bias );
725743 cur.add_tensor (tn (LLM_TENSOR_TIME_MIX_OUTPUT, " weight" , i), {n_embd, attn_hidden_size});
726744
727745 cur.add_tensor (tn (LLM_TENSOR_CHANNEL_MIX_LERP_K, " weight" , i), {n_embd, 1 , 1 });
@@ -1036,7 +1054,7 @@ int main(int argc, char ** argv) {
10361054 for (llama_seq_id seq_id = 0 ; seq_id < n_seq_max; ++seq_id) {
10371055 float err = ref_outputs[seq_id].validate_batch (ctx, batch, seq_id);
10381056 if (!isfinite (err) || err > 1 .0f / 1024 .0f ) {
1039- fprintf (stderr, " Error for seq_id %i is %f\n " , seq_id, err);
1057+ fprintf (stderr, " Error for seq_id %i is %f at n_past=%i \n " , seq_id, err, seq_id_n_past[seq_id] );
10401058 valid[seq_id] = false ;
10411059 }
10421060 }
0 commit comments