diff --git a/csrc/models/llama/llama_decoder_layer.cpp b/csrc/models/llama/llama_decoder_layer.cpp index caebd62d..ae3f7132 100644 --- a/csrc/models/llama/llama_decoder_layer.cpp +++ b/csrc/models/llama/llama_decoder_layer.cpp @@ -1,6 +1,7 @@ #include "llama_decoder_layer.hpp" #include "infinicore/nn/rmsnorm.hpp" #include "infinicore/ops.hpp" +#include namespace infinilm::models::llama { @@ -21,34 +22,50 @@ LlamaDecoderLayer::LlamaDecoderLayer(const LlamaConfig &config, INFINICORE_NN_MODULE_INIT(mlp, config, device, rank_info_); } -infinicore::Tensor LlamaDecoderLayer::forward(const infinicore::Tensor &hidden_states, - const infinicore::Tensor &position_ids, - std::shared_ptr kv_cache, - const infinicore::Tensor &cache_positions) const { - // Save residual for attention - auto residual = hidden_states; - - // 1. Pre-attention layer normalization - auto normed_states = input_layernorm_->forward(hidden_states); +std::pair LlamaDecoderLayer::forward( + const infinicore::Tensor &hidden_states, + const infinicore::Tensor &position_ids, + std::shared_ptr kv_cache, + const infinicore::Tensor &cache_positions, + const std::optional &residual_in) const { + + infinicore::Tensor normed_states; + infinicore::Tensor residual; + + // 1. Pre-attention layer normalization with optional residual add from previous layer + if (residual_in.has_value()) { + // Fuse previous layer's MLP residual add with current layer's input normalization + // This avoids a separate add operation: residual_in + hidden_states + auto [normed_result, add_result] = infinicore::op::add_rms_norm( + residual_in.value(), hidden_states, + input_layernorm_->weight(), + static_cast(input_layernorm_->eps())); + normed_states = normed_result; + residual = add_result; // This is residual_in + hidden_states + } else { + // First layer: no residual to add, just normalize + normed_states = input_layernorm_->forward(hidden_states); + residual = hidden_states; + } // 2. Self-attention with residual connection auto attn_output = self_attn_->forward(normed_states, position_ids, kv_cache, cache_positions); - // Add residual: hidden_states = hidden_states + attn_output - auto output = infinicore::op::add(residual, attn_output); - // Save residual for MLP - residual = output; - - // 3. Post-attention layer normalization - normed_states = post_attention_layernorm_->forward(output); + // 3. Add attention residual and apply post-attention layer normalization (fused) + auto [normed_states_result, add_result] = infinicore::op::add_rms_norm( + residual, attn_output, + post_attention_layernorm_->weight(), + static_cast(post_attention_layernorm_->eps())); + + normed_states = normed_states_result; + residual = add_result; // Save for MLP residual connection - // 4. MLP with residual connection + // 4. MLP auto mlp_output = mlp_->forward(normed_states); - // Add residual: output = output + mlp_output - output = infinicore::op::add(residual, mlp_output); - - return output; + // Return (mlp_output, residual) WITHOUT doing the final add + // Next layer will fuse this add with its input_layernorm using add_rms_norm + return std::make_pair(mlp_output, residual); } } // namespace infinilm::models::llama diff --git a/csrc/models/llama/llama_decoder_layer.hpp b/csrc/models/llama/llama_decoder_layer.hpp index 86f70d33..04bdf181 100644 --- a/csrc/models/llama/llama_decoder_layer.hpp +++ b/csrc/models/llama/llama_decoder_layer.hpp @@ -9,6 +9,7 @@ #include "llama_mlp.hpp" #include "../../engine/distributed/distributed.hpp" +#include namespace infinilm::models::llama { @@ -44,12 +45,16 @@ class LlamaDecoderLayer : public infinicore::nn::Module { * @param hidden_states Input tensor of shape [batch, seq_len, hidden_size] * @param position_ids Position IDs tensor of shape [batch, seq_len] or [seq_len] * @param kv_cache Optional KV cache for incremental decoding - * @return Output tensor of shape [batch, seq_len, hidden_size] + * @param cache_positions Cache positions tensor + * @param residual Optional residual tensor from previous layer (for MLP residual connection) + * @return Pair of (output, residual) tensors, where residual can be reused by next layer */ - infinicore::Tensor forward(const infinicore::Tensor &hidden_states, - const infinicore::Tensor &position_ids, - std::shared_ptr kv_cache, - const infinicore::Tensor &cache_positions) const; + std::pair forward( + const infinicore::Tensor &hidden_states, + const infinicore::Tensor &position_ids, + std::shared_ptr kv_cache, + const infinicore::Tensor &cache_positions, + const std::optional &residual = std::nullopt) const; /** * @brief Get the layer index diff --git a/csrc/models/llama/llama_model.cpp b/csrc/models/llama/llama_model.cpp index 5b2c44e5..5d738e45 100644 --- a/csrc/models/llama/llama_model.cpp +++ b/csrc/models/llama/llama_model.cpp @@ -50,18 +50,36 @@ infinicore::Tensor LlamaModel::forward(const infinicore::Tensor &input_ids, auto hidden_states = embed_tokens_->forward(input_ids); // 2. Process through all decoder layers + // Reuse residual across layers to avoid redundant add operations size_t num_layers = layers_.size(); + std::optional residual = std::nullopt; for (size_t i = 0; i < num_layers; ++i) { - hidden_states = layers_.at(i)->forward(hidden_states, position_ids, kv_cache_, cache_positions); + auto [output, next_residual] = layers_.at(i)->forward(hidden_states, position_ids, kv_cache_, cache_positions, residual); + hidden_states = output; + residual = next_residual; } // 3. Apply final layer normalization to last token only (aligns with transformers) // Narrow to last token: [batch, seq_len, hidden_size] -> [batch, 1, hidden_size] auto shape = hidden_states->shape(); size_t seq_len = shape[1]; - auto last_token = hidden_states->narrow({{1, seq_len - 1, 1}}); - - auto normalized_last_token = norm_->forward(last_token); + + // Narrow both residual and mlp_output to last token before fusing add and norm + // Note: narrow() creates a view (no data copy), so this is equivalent to: + // narrow(add(residual, mlp_output)) == add(narrow(residual), narrow(mlp_output)) + // But doing narrow first allows us to: + // 1. Only compute add for the last token (not the entire sequence) - saves computation + // 2. Fuse add with norm in a single kernel using add_rms_norm - avoids separate add kernel + auto residual_last_token = residual.value()->narrow({{1, seq_len - 1, 1}}); + auto mlp_output_last_token = hidden_states->narrow({{1, seq_len - 1, 1}}); + + // Fuse final residual add with layer normalization using add_rms_norm + // This avoids a separate add operation - add and norm are computed in one fused kernel + // Result is mathematically equivalent to: norm(add(residual, mlp_output))[last_token] + auto [normalized_last_token, _] = infinicore::op::add_rms_norm( + residual_last_token, mlp_output_last_token, + norm_->weight(), + static_cast(norm_->eps())); return normalized_last_token; }