diff --git a/include/flexflow/batch_config.h b/include/flexflow/batch_config.h index 9b5230ffd..0b32c0249 100644 --- a/include/flexflow/batch_config.h +++ b/include/flexflow/batch_config.h @@ -88,11 +88,11 @@ class BatchConfig { // Maximum possible values for different parameters // These maximum values are used for copying BatchConfig // across workers - inline static int const MAX_NUM_REQUESTS = 64; + inline static int const MAX_NUM_REQUESTS = 96; inline static int const MAX_NUM_TOKENS = 1024; inline static int const MAX_SPECULATIVE_TREE_BRANCHES = 8; - inline static int const MAX_TREE_DEPTH = 8; - inline static int const MAX_TREE_WIDTH = 16; + inline static int const MAX_TREE_DEPTH = 10; + inline static int const MAX_TREE_WIDTH = 12; inline static int const MAX_SPEC_TREE_TOKEN_NUM = MAX_TREE_DEPTH * MAX_TREE_WIDTH; inline static int const MAX_K_LOGITS = 16; diff --git a/include/flexflow/flexflow_c.h b/include/flexflow/flexflow_c.h index 9423d7b4c..bd00965bc 100644 --- a/include/flexflow/flexflow_c.h +++ b/include/flexflow/flexflow_c.h @@ -285,8 +285,7 @@ flexflow_tensor_t *flexflow_model_add_add_bias_residual_layer_norm( flexflow_tensor_t flexflow_model_add_sigmoid_silu_multi(flexflow_model_t handle, - flexflow_tensor_t const input1, - flexflow_tensor_t const input2, + flexflow_tensor_t const input, int intermediate_size, char const *name); diff --git a/include/flexflow/model.h b/include/flexflow/model.h index 825c8e995..dc8ab8726 100644 --- a/include/flexflow/model.h +++ b/include/flexflow/model.h @@ -579,8 +579,7 @@ class FFModel { DataType data_type = DT_NONE, char const *name = NULL); // Add a sigmoid_silu_multi layer - Tensor sigmoid_silu_multi(Tensor const input1, - Tensor const input2, + Tensor sigmoid_silu_multi(Tensor const input, int intermediate_size, DataType data_type = DT_NONE, char const *name = NULL); @@ -711,7 +710,7 @@ class FFModel { Initializer *kernel_initializer = NULL, char const *name = NULL); Tensor inc_multihead_self_attention( - const Tensor input, + Tensor const input, int embed_dim, int num_heads, int kdim = 0, @@ -730,7 +729,7 @@ class FFModel { bool streaming_cache = false, char const *name = NULL); Tensor spec_inc_multihead_self_attention( - const Tensor input, + Tensor const input, int embed_dim, int num_heads, int kdim = 0, @@ -1211,10 +1210,8 @@ class FFModel { std::pair, AddBiasResidualLayerNormParams>, AddBiasResidualLayerNorm *>, - std::unordered_map< - std::pair, - SigmoidSiluMultiParams>, - SigmoidSiluMulti *>, + std::unordered_map, + SigmoidSiluMulti *>, std::unordered_map, Linear *>, std::unordered_map, diff --git a/include/flexflow/ops/inc_multihead_self_attention.h b/include/flexflow/ops/inc_multihead_self_attention.h index 8bc3b15a3..00694498a 100644 --- a/include/flexflow/ops/inc_multihead_self_attention.h +++ b/include/flexflow/ops/inc_multihead_self_attention.h @@ -178,7 +178,7 @@ class IncMultiHeadSelfAttentionMeta : public OpMeta { quantized_weightSize; int hidden_size, qk_dim, v_dim, o_dim; int global_num_q_heads, global_num_kv_heads, num_q_heads, num_kv_heads, - local_hidden_size; + local_hidden_size, total_heads_dim; bool *has_load_weights; RotaryEmbeddingMeta *rotary_embedding_meta; bool *qkv_bias; diff --git a/include/flexflow/ops/sigmoid_silu_multi.h b/include/flexflow/ops/sigmoid_silu_multi.h index bc07e253e..3dda1b23c 100644 --- a/include/flexflow/ops/sigmoid_silu_multi.h +++ b/include/flexflow/ops/sigmoid_silu_multi.h @@ -10,15 +10,14 @@ class SigmoidSiluMultiMeta; class SigmoidSiluMulti : public Op { public: using Params = SigmoidSiluMultiParams; - using Input = std::pair; + using Input = ParallelTensor; SigmoidSiluMulti(FFModel &model, Params const ¶ms, - Input const &inputs, + Input const &input, char const *name = nullptr); SigmoidSiluMulti(FFModel &model, LayerID const &_layer_guid, - const ParallelTensor _input1, - const ParallelTensor _input2, + ParallelTensor const _input, int _intermediate_size, int _tensor_parallelism_degree, char const *name = nullptr); @@ -63,13 +62,11 @@ class SigmoidSiluMulti : public Op { template static void inference_kernel(SigmoidSiluMultiMeta const *m, int num_elements, - T const *input1_ptr, - T const *input2_ptr, + T const *input_ptr, T *output_ptr, ffStream_t stream); static void inference_kernel_wrapper(SigmoidSiluMultiMeta const *m, - GenericTensorAccessorR const &input1, - GenericTensorAccessorR const &input2, + GenericTensorAccessorR const &input, GenericTensorAccessorW const &output, int token_size); diff --git a/include/flexflow/ops/sigmoid_silu_multi_params.h b/include/flexflow/ops/sigmoid_silu_multi_params.h index 0e92c0aa6..1d6c8b867 100644 --- a/include/flexflow/ops/sigmoid_silu_multi_params.h +++ b/include/flexflow/ops/sigmoid_silu_multi_params.h @@ -10,8 +10,7 @@ struct SigmoidSiluMultiParams { LayerID layer_guid; int intermediate_size, tensor_parallelism_degree; char name[MAX_OPNAME]; - bool is_valid( - std::pair const &) const; + bool is_valid(ParallelTensorShape const &) const; }; bool operator==(SigmoidSiluMultiParams const &, SigmoidSiluMultiParams const &); diff --git a/include/flexflow/request_manager.h b/include/flexflow/request_manager.h index b95cc37e4..864b445b5 100644 --- a/include/flexflow/request_manager.h +++ b/include/flexflow/request_manager.h @@ -269,9 +269,12 @@ struct ProfileInfo { std::vector tree_operation_step_times; // Number of generated tokens at each step std::vector generated_tokens_per_step; + // Number of proposed tokens at each step + std::vector tokens_in_verification_per_step; // To calculate the E2E time of serving long long server_start_time = 0; long long server_end_time = 0; + int prefilling_steps = 0; }; class RequestManager { @@ -444,6 +447,7 @@ class RequestManager { // configuration parameters int max_requests_per_batch; int max_tokens_per_batch; + int config_max_token_per_batch; int max_tokens_per_ssm_batch; int max_tokens_per_prefilling_batch; int max_spec_tree_token_num; @@ -586,6 +590,7 @@ class RequestManager { void prune_token_tree_greedy(); void add_tokens_toward_slo(RequestGuid guid, int &budget, + double num_tokens_to_decode, int num_req_with_slo); void add_tokens_toward_memory_occupancy(int budget); void add_tokens_toward_goodput(int budget); diff --git a/inference/incr_decoding/incr_decoding.cc b/inference/incr_decoding/incr_decoding.cc index f1b00617a..f7db6e7ef 100644 --- a/inference/incr_decoding/incr_decoding.cc +++ b/inference/incr_decoding/incr_decoding.cc @@ -306,10 +306,14 @@ void FlexFlow::top_level_task(Task const *task, /*ignore_comments */ true); ModelType model_type = ModelType::UNKNOWN; auto architectures = model_config["architectures"]; + bool qwen = false; for (auto const &str : architectures) { if (str == "LlamaForCausalLM" || str == "LLaMAForCausalLM" || - str == "MistralForCausalLM") { + str == "MistralForCausalLM" || str == "Qwen2ForCausalLM") { model_type = ModelType::LLAMA; + if (str == "Qwen2ForCausalLM") { + qwen = true; + } break; } else if (str == "OPTForCausalLM") { model_type = ModelType::OPT; @@ -361,8 +365,8 @@ void FlexFlow::top_level_task(Task const *task, rm->set_baseline_latency(baseline_latency_ms); rm->set_ssm_spec_latency(ssm_spec_latency_ms); rm->set_llm_verify_latency(llm_verify_latency_ms); - rm->set_max_tree_depth(8); - rm->set_max_tree_width(16); + rm->set_max_tree_depth(2); + rm->set_max_tree_width(2); rm->set_verbose(verbose); rm->set_streaming_cache(streaming_cache); rm->set_fcfs_slo(fcfs_slo); @@ -379,7 +383,8 @@ void FlexFlow::top_level_task(Task const *task, INC_DECODING_MODE, generationConfig, streaming_cache, - use_full_precision); + use_full_precision, + /*qkv_bias*/ qwen); } else if (model_type == ModelType::OPT) { OPT::create_opt_model(model, config_filepath, diff --git a/inference/models/llama.cc b/inference/models/llama.cc index 988f8f4b5..7ef3a5089 100644 --- a/inference/models/llama.cc +++ b/inference/models/llama.cc @@ -26,7 +26,8 @@ void LLAMA::create_llama_model(FFModel &ff, InferenceMode mode, GenerationConfig generation_config, bool streaming_cache, - bool use_full_precision) { + bool use_full_precision, + bool qkv_bias) { // do not apply cpu offload in beam search model. LLAMAConfig llama_config(model_config_file_path); llama_config.print(); @@ -104,7 +105,7 @@ void LLAMA::create_llama_model(FFModel &ff, llama_config.hidden_size / llama_config.num_attention_heads, llama_config.hidden_size / llama_config.num_attention_heads, 0.0f, /*dropout*/ - false, /*qkv_bias*/ + qkv_bias, /*qkv_bias*/ false, /*final_bias*/ false, /*add_zero_attn*/ DT_NONE, /*data_type*/ @@ -129,7 +130,7 @@ void LLAMA::create_llama_model(FFModel &ff, llama_config.hidden_size / llama_config.num_attention_heads, llama_config.hidden_size / llama_config.num_attention_heads, 0.0f, /*dropout*/ - false, /*qkv_bias*/ + qkv_bias, /*qkv_bias*/ false, /*final_bias*/ false, /*add_zero_attn*/ DT_NONE, /*data_type*/ @@ -153,7 +154,7 @@ void LLAMA::create_llama_model(FFModel &ff, llama_config.hidden_size / llama_config.num_attention_heads, llama_config.hidden_size / llama_config.num_attention_heads, 0.0f, /*dropout*/ - false, /*qkv_bias*/ + qkv_bias, /*qkv_bias*/ false, /*final_bias*/ false, /*add_zero_attn*/ DT_NONE, /*data_type*/ @@ -188,34 +189,22 @@ void LLAMA::create_llama_model(FFModel &ff, token = token_ff_norm[0]; Tensor ff_norm = token_ff_norm[1]; - Tensor w1 = ff.dense( - ff_norm, - llama_config.intermediate_size, - AC_MODE_NONE, - false, - DT_NONE, - nullptr, - nullptr, - nullptr, - REG_MODE_NONE, - 0.0f, - std::string("layers." + std::to_string(i) + ".mlp.gate_proj").c_str()); - - Tensor w3 = ff.dense( - ff_norm, - llama_config.intermediate_size, - AC_MODE_NONE, - false, - DT_NONE, - nullptr, - nullptr, - nullptr, - REG_MODE_NONE, - 0.0f, - std::string("layers." + std::to_string(i) + ".mlp.up_proj").c_str()); + Tensor hidden_gate_and_up = + ff.dense(ff_norm, + llama_config.intermediate_size * 2, + AC_MODE_NONE, + false, + DT_NONE, + nullptr, + nullptr, + nullptr, + REG_MODE_NONE, + 0.0f, + std::string("layers." + std::to_string(i) + ".mlp.gate_and_up") + .c_str()); - Tensor multi = - ff.sigmoid_silu_multi(w1, w3, llama_config.intermediate_size); + Tensor multi = ff.sigmoid_silu_multi(hidden_gate_and_up, + llama_config.intermediate_size); w2 = ff.dense( multi, diff --git a/inference/models/llama.h b/inference/models/llama.h index 3f11ca96d..74d675dd5 100644 --- a/inference/models/llama.h +++ b/inference/models/llama.h @@ -109,7 +109,8 @@ class LLAMA { InferenceMode mode, GenerationConfig generation_config, bool streaming_cache, - bool use_full_precision = false); + bool use_full_precision = false, + bool qkv_bias = false); }; }; // namespace FlexFlow diff --git a/inference/spec_infer/spec_infer.cc b/inference/spec_infer/spec_infer.cc index 2d1572ecf..2d6be63e8 100644 --- a/inference/spec_infer/spec_infer.cc +++ b/inference/spec_infer/spec_infer.cc @@ -55,6 +55,8 @@ struct ModelMeta { std::vector ssm_model_types; std::vector ssm_model_config_paths; std::vector ssm_model_weights_paths; + + bool qkv_bias = false; }; void parse_input_args(char **argv, @@ -288,8 +290,11 @@ void get_model_meta(FilePaths &file_paths, auto architectures = llm_model_config["architectures"]; for (auto const &str : architectures) { if (str == "LlamaForCausalLM" || str == "LLaMAForCausalLM" || - str == "MistralForCausalLM") { + str == "MistralForCausalLM" || str == "Qwen2ForCausalLM") { model_metadata.llm_model_type = ModelType::LLAMA; + if (str == "Qwen2ForCausalLM") { + model_metadata.qkv_bias = true; + } break; } else if (str == "OPTForCausalLM") { model_metadata.llm_model_type = ModelType::OPT; @@ -350,7 +355,7 @@ void get_model_meta(FilePaths &file_paths, auto architectures = ssm_model_config["architectures"]; for (auto const &str : architectures) { if (str == "LlamaForCausalLM" || str == "LLaMAForCausalLM" || - str == "MistralForCausalLM") { + str == "MistralForCausalLM" || str == "Qwen2ForCausalLM") { ssm_model_type = ModelType::LLAMA; break; } else if (str == "OPTForCausalLM") { @@ -525,7 +530,8 @@ void FlexFlow::top_level_task(Task const *task, TREE_VERIFY_MODE, generationConfig, false, - use_full_precision); + use_full_precision, + /*qkv_bias*/ model_metadata.qkv_bias); } else if (model_metadata.llm_model_type == ModelType::OPT) { OPT::create_opt_model(tree_model, model_metadata.llm_model_config_path, @@ -574,7 +580,8 @@ void FlexFlow::top_level_task(Task const *task, TREE_SEARCH_MODE, generationConfig, streaming_cache, - use_full_precision); + use_full_precision, + /*qkv_bias*/ model_metadata.qkv_bias); } else if (model_metadata.ssm_model_types[ssm_id] == ModelType::OPT) { OPT::create_opt_model(beam_model, model_metadata.ssm_model_config_paths[ssm_id], diff --git a/inference/trace_generator/trace_generator.cc b/inference/trace_generator/trace_generator.cc index 14abf5976..dae940d49 100644 --- a/inference/trace_generator/trace_generator.cc +++ b/inference/trace_generator/trace_generator.cc @@ -213,7 +213,7 @@ void get_model_meta(FilePaths &file_paths, auto architectures = llm_model_config["architectures"]; for (auto const &str : architectures) { if (str == "LlamaForCausalLM" || str == "LLaMAForCausalLM" || - str == "MistralForCausalLM") { + str == "MistralForCausalLM" || str == "Qwen2ForCausalLM") { model_metadata.llm_model_type = ModelType::LLAMA; break; } else if (str == "OPTForCausalLM") { @@ -275,7 +275,7 @@ void get_model_meta(FilePaths &file_paths, auto architectures = ssm_model_config["architectures"]; for (auto const &str : architectures) { if (str == "LlamaForCausalLM" || str == "LLaMAForCausalLM" || - str == "MistralForCausalLM") { + str == "MistralForCausalLM" || str == "Qwen2ForCausalLM") { ssm_model_type = ModelType::LLAMA; break; } else if (str == "OPTForCausalLM") { @@ -336,8 +336,6 @@ void FlexFlow::top_level_task(Task const *task, int max_tokens_per_ssm_batch = -1; int max_tokens_per_prefilling_batch = -1; int expansion_degree = 3; - int max_tree_depth = 8; - int max_tree_width = 16; RequestManager::DecodingMode decoding_mode = RequestManager::SPECULATIVE_DECODING; bool spec_sampling = false; @@ -405,8 +403,8 @@ void FlexFlow::top_level_task(Task const *task, rm->set_max_tokens_per_prefilling_batch(max_tokens_per_prefilling_batch); rm->set_max_sequence_length(max_sequence_length); rm->set_max_output_length(max_output_length); - rm->set_max_tree_depth(max_tree_depth); - rm->set_max_tree_width(max_tree_width); + rm->set_max_tree_depth(2); + rm->set_max_tree_width(2); rm->set_verbose(verbose); rm->set_streaming_cache(streaming_cache); rm->register_tokenizer(model_metadata.llm_model_type, diff --git a/src/c/flexflow_c.cc b/src/c/flexflow_c.cc index 1ac9a511d..aa3f64fce 100644 --- a/src/c/flexflow_c.cc +++ b/src/c/flexflow_c.cc @@ -36,7 +36,9 @@ class FFCObjectWrapper { t_.impl = const_cast(static_cast(t)); \ return t_; \ } \ - static T unwrap(T_ t_) { return static_cast(t_.impl); } \ + static T unwrap(T_ t_) { \ + return static_cast(t_.impl); \ + } \ static const T unwrap_const(const T_ t_) { \ return static_cast(t_.impl); \ } @@ -746,19 +748,16 @@ flexflow_tensor_t *flexflow_model_add_add_bias_residual_layer_norm( flexflow_tensor_t flexflow_model_add_sigmoid_silu_multi(flexflow_model_t handle_, - flexflow_tensor_t const input1_, - flexflow_tensor_t const input2_, + flexflow_tensor_t const input_, int intermediate_size, char const *name) { FFModel *handle = FFCObjectWrapper::unwrap(handle_); - Tensor const input1 = FFCObjectWrapper::unwrap(input1_); - Tensor const input2 = FFCObjectWrapper::unwrap(input2_); + Tensor const input = FFCObjectWrapper::unwrap(input_); Tensor tensor = handle->sigmoid_silu_multi( - input1, input2, intermediate_size, input1->data_type, name); - DEBUG_PRINT("[SigmoidSiluMulti] new Tensor %p, input1 %p, input2 %p, name %s", + input, intermediate_size, input->data_type, name); + DEBUG_PRINT("[SigmoidSiluMulti] new Tensor %p, input %p, name %s", tensor, - input1, - input2, + input, name); return FFCObjectWrapper::wrap(tensor); } diff --git a/src/ops/fused.cu b/src/ops/fused.cu index 78983d579..0e537554f 100644 --- a/src/ops/fused.cu +++ b/src/ops/fused.cu @@ -1081,14 +1081,13 @@ __host__ void break; } case OP_SIGMOID_SILU_MULTI: { - assert(fused->op_num_inputs[op] == 2); + assert(fused->op_num_inputs[op] == 1); assert(fused->op_num_outputs[op] == 1); SigmoidSiluMultiMeta const *m = (SigmoidSiluMultiMeta *)metas->meta[op]; // use active number of tokens SigmoidSiluMulti::inference_kernel_wrapper(m, my_input_accessor[0], - my_input_accessor[1], my_output_accessor[0], bc->num_active_tokens()); break; diff --git a/src/ops/inc_multihead_self_attention.cc b/src/ops/inc_multihead_self_attention.cc index bfcc7dc4c..c1324068a 100644 --- a/src/ops/inc_multihead_self_attention.cc +++ b/src/ops/inc_multihead_self_attention.cc @@ -55,7 +55,7 @@ bool IncMultiHeadSelfAttentionParams::is_valid( } Tensor FFModel::inc_multihead_self_attention( - const Tensor input, + Tensor const input, int embed_dim, int num_heads, int kdim, @@ -95,7 +95,7 @@ Tensor FFModel::inc_multihead_self_attention( } Tensor FFModel::groupquery_self_attention( - const Tensor input, + Tensor const input, int embed_dim, int num_q_heads, int num_kv_heads, @@ -159,9 +159,8 @@ Tensor FFModel::groupquery_self_attention( int vParas = v_dim * hidden_size; int oParas = o_dim * (v_dim > 0 ? v_dim : hidden_size); - // allocate num_q_heads for key, value for replication - int weight_size = qParas * num_q_heads + kParas * num_q_heads + - vParas * num_q_heads + oParas * num_q_heads; + int weight_size = qParas * num_q_heads + kParas * num_kv_heads + + vParas * num_kv_heads + oParas * num_q_heads; int one_head_size = qParas + kParas + vParas + oParas; { @@ -182,7 +181,7 @@ Tensor FFModel::groupquery_self_attention( } if (qkv_bias || final_bias) { // q, k, v, o - int qkv_bias_size = qk_dim * num_q_heads + (qk_dim + v_dim) * num_q_heads; + int qkv_bias_size = qk_dim * num_q_heads + (qk_dim + v_dim) * num_kv_heads; int dims[1] = {(qkv_bias ? qkv_bias_size : 0) + (final_bias ? o_dim : 0)}; li->weights[1] = create_weight_legion_ordering(1, dims, @@ -308,7 +307,7 @@ Op *IncMultiHeadSelfAttention::create_operator_from_layer( IncMultiHeadSelfAttention::IncMultiHeadSelfAttention( FFModel &model, LayerID const &_layer_guid, - const ParallelTensor _input, + ParallelTensor const _input, int _embed_dim, int _num_q_heads, int _num_kv_heads, @@ -376,7 +375,7 @@ IncMultiHeadSelfAttention::IncMultiHeadSelfAttention( dims[0].size = dims[0].degree; dims[1] = inputs[0]->dims[num_dims - 1]; dims[1].size = this->num_q_heads * (qParas + oParas) + - this->num_q_heads * (kParas + vParas); + this->num_kv_heads * (kParas + vParas); dims[1].is_replica_dim = false; if (quantization_type != DT_NONE) { @@ -394,7 +393,8 @@ IncMultiHeadSelfAttention::IncMultiHeadSelfAttention( CHOSEN_SYNC_TYPE); if (qkv_bias || final_bias) { ParallelTensorShape bias_shape = _input->get_shape(); - int qkv_bias_size = qk_dim * num_q_heads + (qk_dim + v_dim) * num_q_heads; + int qkv_bias_size = + qk_dim * num_q_heads + (qk_dim + v_dim) * num_kv_heads; bias_shape.dims[0].size = (qkv_bias ? qkv_bias_size : 0) + (final_bias ? o_dim : 0); bias_shape.dims[1].size = bias_shape.dims[2].size = 1; @@ -420,8 +420,8 @@ IncMultiHeadSelfAttention::IncMultiHeadSelfAttention( IncMultiHeadSelfAttention::IncMultiHeadSelfAttention( FFModel &model, - const ParallelTensor _input, - const ParallelTensor _weight, + ParallelTensor const _input, + ParallelTensor const _weight, int _embed_dim, int _num_q_heads, int _num_kv_heads, @@ -488,7 +488,7 @@ IncMultiHeadSelfAttention::IncMultiHeadSelfAttention( dims[0].size = dims[0].degree; dims[1] = inputs[0]->dims[num_dims - 1]; dims[1].size = this->num_q_heads * (qParas + oParas) + - this->num_q_heads * (kParas + vParas); + this->num_kv_heads * (kParas + vParas); dims[1].is_replica_dim = false; // dims[2].size = this->num_q_heads * (qParas + oParas) + this->num_kv_heads // * (kParas + vParas); @@ -507,7 +507,8 @@ IncMultiHeadSelfAttention::IncMultiHeadSelfAttention( CHOSEN_SYNC_TYPE); if (qkv_bias || final_bias) { ParallelTensorShape bias_shape = _input->get_shape(); - int qkv_bias_size = qk_dim * num_q_heads + (qk_dim + v_dim) * num_q_heads; + int qkv_bias_size = + qk_dim * num_q_heads + (qk_dim + v_dim) * num_kv_heads; bias_shape.dims[0].size = (qkv_bias ? qkv_bias_size : 0) + (final_bias ? o_dim : 0); bias_shape.dims[1].size = bias_shape.dims[2].size = 1; @@ -537,7 +538,7 @@ IncMultiHeadSelfAttention::IncMultiHeadSelfAttention( IncMultiHeadSelfAttention::IncMultiHeadSelfAttention( FFModel &model, IncMultiHeadSelfAttention const &other, - const ParallelTensor input, + ParallelTensor const input, bool allocate_weights) : IncMultiHeadSelfAttention(model, other.layer_guid, diff --git a/src/ops/inc_multihead_self_attention.cu b/src/ops/inc_multihead_self_attention.cu index e220e8285..1cf894434 100644 --- a/src/ops/inc_multihead_self_attention.cu +++ b/src/ops/inc_multihead_self_attention.cu @@ -64,6 +64,9 @@ void incr_attention(IncMultiHeadSelfAttentionMeta *m, uint32_t const num_kv_heads = m->num_kv_heads; uint32_t const head_dim = m->qk_dim; uint32_t const batch_size = bc->num_active_requests(); + if (batch_size == 0) { + return; + } float const sm_scale = (*m->qk_prod_scaling) ? 1.0f / sqrt(m->qk_dim) : 1.0f; // cudaEventCreate(&t_start); @@ -425,6 +428,8 @@ IncMultiHeadSelfAttentionMeta::IncMultiHeadSelfAttentionMeta( hidden_size = _hidden_size; qk_dim = _qk_dim; v_dim = _v_dim; + assert(v_dim == qk_dim && + "v_dim must be equal to qk_dim for current implementation"); o_dim = _o_dim; size_t size_of_dt = data_type_size(attn->data_type); quantization_type = _quantization_type; @@ -436,11 +441,13 @@ IncMultiHeadSelfAttentionMeta::IncMultiHeadSelfAttentionMeta( num_q_heads = _num_q_heads; num_kv_heads = _num_kv_heads; local_hidden_size = num_q_heads * qk_dim; + total_heads_dim = + (num_q_heads + num_kv_heads) * qk_dim + num_kv_heads * v_dim; weightSize = - ((hidden_size * qk_dim + o_dim * (v_dim > 0 ? v_dim : hidden_size)) * + ((hidden_size * qk_dim + (v_dim > 0 ? v_dim : hidden_size) * o_dim) * num_q_heads + - (hidden_size * qk_dim + hidden_size * v_dim) * num_q_heads) * + hidden_size * (qk_dim + v_dim) * num_kv_heads) * size_of_dt; if (quantization_type != DT_NONE) { quantized_weightSize = get_quantization_to_byte_size( @@ -448,7 +455,7 @@ IncMultiHeadSelfAttentionMeta::IncMultiHeadSelfAttentionMeta( } // biasSize = _bias ? o_dim * size_of_dt * 4 : 0; - int qkv_bias_size = qk_dim * num_q_heads + (qk_dim + v_dim) * num_q_heads; + int qkv_bias_size = qk_dim * num_q_heads + (qk_dim + v_dim) * num_kv_heads; int final_bias_size = o_dim; biasSize = (_qkv_bias ? qkv_bias_size : 0) + (final_bias ? final_bias_size : 0); @@ -484,7 +491,7 @@ IncMultiHeadSelfAttentionMeta::IncMultiHeadSelfAttentionMeta( BatchConfig::max_tokens_per_prefilling_batch()); size_t qkv_max_proj_size = max_tokens_per_batch * - (qk_dim * num_q_heads + qk_dim * num_q_heads + v_dim * num_q_heads); + (qk_dim * num_q_heads + qk_dim * num_kv_heads + v_dim * num_kv_heads); size_t query_tmp_size = 0, key_cache_size = 0, value_cache_size = 0; size_t streaming_pre_pos_enc_size = 0; // assert((BatchConfig::max_sequence_length() + diff --git a/src/ops/kernels/inc_multihead_self_attention_kernels.cu b/src/ops/kernels/inc_multihead_self_attention_kernels.cu index 2ed554e43..7be011947 100644 --- a/src/ops/kernels/inc_multihead_self_attention_kernels.cu +++ b/src/ops/kernels/inc_multihead_self_attention_kernels.cu @@ -87,34 +87,45 @@ __global__ void apply_proj_bias_qkv(DT *input_ptr, int v_dim, int global_num_q_heads, int num_q_heads, + int global_num_kv_heads, + int num_kv_heads, bool scaling_query, - float scaling_factor, - int hidden_size) { - CUDA_KERNEL_LOOP(i, num_tokens * hidden_size * QKV_WEIGHT_NUM) { - // for simplicity, assume q, k, v is in same shape - // 0->q, 1->k, 2->v - // int qkv_index = i / (num_tokens * qk_dim) % 3; + float scaling_factor) { + CUDA_KERNEL_LOOP( + i, + num_tokens * (qk_dim * num_q_heads + (qk_dim + v_dim) * num_kv_heads)) { - int token_idx = i / (hidden_size * QKV_WEIGHT_NUM); - size_t in_token_idx = i - token_idx * hidden_size * QKV_WEIGHT_NUM; + size_t proj_dim = (qk_dim * num_q_heads + (qk_dim + v_dim) * num_kv_heads); + size_t in_token_idx = i % proj_dim; - int qkv_index = in_token_idx / hidden_size; + size_t local_k_offset = qk_dim * num_q_heads; + size_t local_v_offset = local_k_offset + qk_dim * num_kv_heads; - int proj_size = qkv_index == 0 ? qk_dim : qk_dim; + size_t global_k_offset = qk_dim * global_num_q_heads; + size_t global_v_offset = global_k_offset + qk_dim * global_num_kv_heads; - int head_idx = - (in_token_idx - qkv_index * num_q_heads * proj_size) / proj_size; - int global_head_idx = head_idx + shard_id * num_q_heads; + size_t global_q_head_idx = shard_id * num_q_heads + in_token_idx / qk_dim; + size_t global_k_head_idx = + shard_id * num_kv_heads + (in_token_idx - local_k_offset) / qk_dim; + size_t global_v_head_idx = + shard_id * num_kv_heads + (in_token_idx - local_v_offset) / v_dim; - size_t pre_length = - qkv_index == 0 - ? 0 - : (qkv_index == 1 ? qk_dim * global_num_q_heads - : qk_dim * global_num_q_heads * KV_WEIGHT_NUM); + size_t q_head_offset = in_token_idx % qk_dim; + size_t k_head_offset = (in_token_idx - local_k_offset) % qk_dim; + size_t v_head_offset = (in_token_idx - local_v_offset) % v_dim; - size_t bias_idx = pre_length + global_head_idx * proj_size + i % proj_size; + size_t qkv_index = in_token_idx < local_k_offset + ? 0 + : (in_token_idx < local_v_offset ? 1 : 2); - input_ptr[i] += bias_ptr[bias_idx]; + input_ptr[i] += + qkv_index == 0 + ? bias_ptr[global_q_head_idx * qk_dim + q_head_offset] + : (qkv_index == 1 + ? bias_ptr[global_k_offset + global_k_head_idx * qk_dim + + k_head_offset] + : bias_ptr[global_v_offset + global_v_head_idx * v_dim + + v_head_offset]); if (scaling_query && qkv_index == 0) { input_ptr[i] *= scaling_factor; @@ -124,15 +135,13 @@ __global__ void apply_proj_bias_qkv(DT *input_ptr, template __global__ void scaling_query_kernel(DT *input_ptr, - int qk_dim, int num_tokens, - int num_q_heads, float scaling_factor, - int hidden_size) { - CUDA_KERNEL_LOOP(i, num_tokens * hidden_size) { - int token_idx = i / hidden_size; - input_ptr[i % hidden_size + token_idx * hidden_size * QKV_WEIGHT_NUM] *= - scaling_factor; + int q_heads_dim, + int total_heads_dim) { + CUDA_KERNEL_LOOP(i, num_tokens * q_heads_dim) { + int token_idx = i / q_heads_dim; + input_ptr[i % q_heads_dim + token_idx * total_heads_dim] *= scaling_factor; } } @@ -160,21 +169,17 @@ void compute_qkv(IncMultiHeadSelfAttentionMeta const *m, { DT alpha = 1.0f, beta = 0.0f; // after transpositions - int m_q = m->qk_dim * m->num_q_heads; - int m_k = m->qk_dim * m->num_q_heads; - int m_v = m->v_dim * m->num_q_heads; - assert(m_q == m_k && m_k == m_v); // keep things simple for now int n = bc->num_active_tokens(); int k = m->hidden_size; - int m_ = m_q * QKV_WEIGHT_NUM; + int m_ = m->total_heads_dim; // before transpositions int lda = k, ldb = k, ldc = m_; // matrix A: QKV weights - // matrix A's layout: [hidden_size (hidden_dim), qk_dim, num_heads, 3] + // matrix A's layout: [hidden_size (hidden_dim), num_heads_dim(q+k+v)] // matrix B: input // matrix B's layout: [hidden_size (hidden_dim), num_new_tokens] // matrix C: devQKVProjArray - // matrix B's layout: [qk_dim, num_heads, 3, num_new_tokens] + // matrix B's layout: [num_heads_dim(q+k+v), num_new_tokens] m->handle.gemm_engine->gemm_internal(CUBLAS_OP_T, CUBLAS_OP_N, m_, @@ -209,6 +214,8 @@ void compute_qkv(IncMultiHeadSelfAttentionMeta const *m, // Step 2: apply bias for QKV, or scale the query if (*m->qkv_bias) { + parallelism = num_tokens * (m->qk_dim * m->num_q_heads + + (m->qk_dim + m->v_dim) * m->num_kv_heads); apply_proj_bias_qkv<<v_dim, m->global_num_q_heads, m->num_q_heads, + m->global_num_kv_heads, + m->num_kv_heads, *m->scaling_query, - m->scaling_factor, - m->local_hidden_size); - } else if (m->scaling_query) { + m->scaling_factor); + } else if (*m->scaling_query) { scaling_query_kernel<<>>(output_ptr, num_tokens, - m->num_q_heads, - m->qk_dim, m->scaling_factor, - m->local_hidden_size); + m->qk_dim * m->num_q_heads, + m->total_heads_dim); } } @@ -248,22 +255,24 @@ __global__ void apply_pos_encoding_to_tokens_in_batch_kernel( int original_max_position_embeddings, int qk_dim, int num_tokens, - size_t q_array_size, - int hidden_size) { - CUDA_KERNEL_LOOP(i, num_tokens * hidden_size) { + int q_heads_dim, + int k_heads_dim, + int total_heads_dim) { + int per_token_size = (q_heads_dim + k_heads_dim) / 2; + CUDA_KERNEL_LOOP(i, num_tokens * per_token_size) { // create complex number - bool q_tensor = i < (q_array_size / 2); - int proj_size = q_tensor ? qk_dim : qk_dim; - int real_i = q_tensor ? i : i - q_array_size / 2; - - int token_idx = real_i / (hidden_size / 2); - int idx = real_i % (proj_size / 2); - int head_idx = (real_i - (token_idx * (hidden_size / 2))) / (proj_size / 2); - - int real_part_index = idx + head_idx * proj_size + - token_idx * hidden_size * QKV_WEIGHT_NUM + - hidden_size * (q_tensor ? 0 : 1); - int complex_part_index = real_part_index + (proj_size / 2); + int token_idx = i / per_token_size; + int idx = i % per_token_size; + bool q_tensor = idx < q_heads_dim / 2; + if (!q_tensor) { + idx -= q_heads_dim / 2; + } + int head_idx = idx / (qk_dim / 2); + idx = idx % (qk_dim / 2); + int real_part_index = token_idx * total_heads_dim + + (q_tensor ? 0 : q_heads_dim) + head_idx * qk_dim + + idx; + int complex_part_index = real_part_index + qk_dim / 2; cuFloatComplex cii = {input_ptr[real_part_index], input_ptr[complex_part_index]}; @@ -276,7 +285,7 @@ __global__ void apply_pos_encoding_to_tokens_in_batch_kernel( size_t pos = tokenInfos[token_idx].abs_depth_in_request; - float freq = pos * (1.0 / pow(rope_theta, (float)2 * idx / proj_size)); + float freq = pos * (1.0 / pow(rope_theta, (float)2 * idx / qk_dim)); if (llama3_rope) { float pi = CUDART_PI_F; @@ -319,8 +328,9 @@ void apply_pos_encoding_to_tokens_in_batch( if (num_tokens == 0) { return; } - int parallelism = num_tokens * m->local_hidden_size; - size_t q_array_size = m->qk_dim * num_tokens * m->num_q_heads; + int q_heads_dim = m->qk_dim * m->num_q_heads; + int k_heads_dim = m->qk_dim * m->num_kv_heads; + int parallelism = num_tokens * (q_heads_dim + k_heads_dim) / 2; bool llama3_rope = (m->rotary_embedding_meta->rope_type == "llama3"); apply_pos_encoding_to_tokens_in_batch_kernel<<rotary_embedding_meta->original_max_position_embeddings, m->qk_dim, num_tokens, - q_array_size, - m->local_hidden_size); + q_heads_dim, + k_heads_dim, + m->total_heads_dim); } __global__ void apply_pos_encoding_to_streaming_proj_kernel( @@ -472,12 +483,11 @@ __global__ void int num_kv_heads, int head_dim, int num_new_tokens) { - int const q_hidden_size = num_q_heads * head_dim; - int const temp_kv_hidden_size = num_q_heads * head_dim; // temporary hard code - int const kv_hidden_size = num_kv_heads * head_dim; + int const q_heads_dim = num_q_heads * head_dim; + int const kv_heads_dim = num_kv_heads * head_dim; int const thread_idx = blockIdx.x * blockDim.x + threadIdx.x; - int const token_idx = thread_idx / q_hidden_size; - int const offset = thread_idx % q_hidden_size; + int const token_idx = thread_idx / q_heads_dim; + int const offset = thread_idx % q_heads_dim; if (token_idx >= num_new_tokens) { return; } @@ -485,24 +495,20 @@ __global__ void int const req_idx = tokenInfos[token_idx].request_index; int token_abs_idx = tokenInfos[token_idx].abs_index_in_request; - size_t from_idx = token_idx * (q_hidden_size + temp_kv_hidden_size * 2); - qTmp_ptr[token_idx * q_hidden_size + offset] = + size_t from_idx = token_idx * (q_heads_dim + kv_heads_dim * 2); + qTmp_ptr[token_idx * q_heads_dim + offset] = static_cast(qkv_proj_array[from_idx + offset]); - if (offset < kv_hidden_size) { + if (offset < kv_heads_dim) { size_t to_k_idx = get_k_entry_offset( req_idx, token_abs_idx, max_num_pages, num_kv_heads, head_dim), to_v_idx = get_v_entry_offset( req_idx, token_abs_idx, max_num_pages, num_kv_heads, head_dim); // key and value cache should be stored interleaved - int const stride = num_q_heads / num_kv_heads; - int const kv_offset = - offset / head_dim * stride * head_dim + offset % head_dim; kvCache_ptr[to_k_idx + offset] = - static_cast(qkv_proj_array[from_idx + q_hidden_size + kv_offset]); - kvCache_ptr[to_v_idx + offset] = - static_cast(qkv_proj_array[from_idx + q_hidden_size + - temp_kv_hidden_size + kv_offset]); + static_cast(qkv_proj_array[from_idx + q_heads_dim + offset]); + kvCache_ptr[to_v_idx + offset] = static_cast( + qkv_proj_array[from_idx + q_heads_dim + kv_heads_dim + offset]); } } @@ -514,7 +520,7 @@ void update_qkv_in_batch(IncMultiHeadSelfAttentionMeta const *m, if (num_new_tokens == 0) { return; } - int parallelism = m->local_hidden_size * num_new_tokens; + int parallelism = m->num_q_heads * m->qk_dim * num_new_tokens; int const max_num_pages = round_up_pages(BatchConfig::max_sequence_length() + BatchConfig::max_spec_tree_token_num()); @@ -649,12 +655,11 @@ __global__ void int head_dim, StreamingCacheInfo const *streaming_cache_infos, int num_new_tokens) { - int const q_hidden_size = num_q_heads * head_dim; - int const temp_kv_hidden_size = num_q_heads * head_dim; // temporary hard code - int const kv_hidden_size = num_kv_heads * head_dim; + int const q_heads_dim = num_q_heads * head_dim; + int const kv_heads_dim = num_kv_heads * head_dim; int const thread_idx = blockIdx.x * blockDim.x + threadIdx.x; - int const token_idx = thread_idx / kv_hidden_size; - int const offset = thread_idx % kv_hidden_size; + int const token_idx = thread_idx / kv_heads_dim; + int const offset = thread_idx % kv_heads_dim; if (token_idx >= num_new_tokens) { return; } @@ -684,21 +689,16 @@ __global__ void // for more than once. In this case, we should only count the last tokens in // the same window position. - size_t from_idx = token_idx * (q_hidden_size + temp_kv_hidden_size * 2); + size_t from_idx = token_idx * (q_heads_dim + kv_heads_dim * 2); size_t to_k_idx = get_k_entry_offset( request_idx, to_idx, max_num_pages, num_kv_heads, head_dim), to_v_idx = get_v_entry_offset( request_idx, to_idx, max_num_pages, num_kv_heads, head_dim); - int const stride = num_q_heads / num_kv_heads; - int const kv_offset = - offset / head_dim * stride * head_dim + offset % head_dim; - pre_pos_enc_buf[to_k_idx + offset] = - static_cast(qkv_proj_array[from_idx + q_hidden_size + kv_offset]); - pre_pos_enc_buf[to_v_idx + offset] = - static_cast(qkv_proj_array[from_idx + q_hidden_size + - temp_kv_hidden_size + kv_offset]); + static_cast(qkv_proj_array[from_idx + q_heads_dim + offset]); + pre_pos_enc_buf[to_v_idx + offset] = static_cast( + qkv_proj_array[from_idx + q_heads_dim + kv_heads_dim + offset]); } template @@ -778,9 +778,7 @@ void compute_o_prod_bias(IncMultiHeadSelfAttentionMeta const *m, int lda = k, ldb = k, ldc = m_; // matrix A: output projection weight // matrix A's layout: [v_dim * num_heads, o_dim] - DT const *A = weight_ptr + m->hidden_size * (m->qk_dim * m->num_q_heads + - m->qk_dim * m->num_q_heads + - m->v_dim * m->num_q_heads); + DT const *A = weight_ptr + m->hidden_size * m->total_heads_dim; // matrix B: attn heads // matrix B's layout: [v_dim * num_heads, num_new_tokens] DT const *B = static_cast
(m->attn_heads); @@ -805,6 +803,8 @@ void compute_o_prod_bias(IncMultiHeadSelfAttentionMeta const *m, } // Add final output bias if (*m->final_bias && shard_id == 0) { + assert(false && "TODO not resolved"); + // TODO: bias, change QKV_WEIGHT_NUM to q_heads+k_heads+v_heads int parallelism = m->o_dim * num_tokens; int qkv_weight_size = m->qk_dim * m->global_num_q_heads + m->qk_dim * m->global_num_q_heads + diff --git a/src/ops/sigmoid_silu_multi.cc b/src/ops/sigmoid_silu_multi.cc index 9d1261123..19b664ce5 100644 --- a/src/ops/sigmoid_silu_multi.cc +++ b/src/ops/sigmoid_silu_multi.cc @@ -46,9 +46,8 @@ bool operator==(SigmoidSiluMultiParams const &lhs, lhs.tensor_parallelism_degree == rhs.tensor_parallelism_degree; } -bool SigmoidSiluMultiParams::is_valid( - std::pair const &input) const { - return input.first.is_valid() && input.second.is_valid(); +bool SigmoidSiluMultiParams::is_valid(ParallelTensorShape const &input) const { + return input.is_valid(); } SigmoidSiluMultiParams SigmoidSiluMulti::get_params() const { @@ -62,43 +61,39 @@ SigmoidSiluMultiParams SigmoidSiluMulti::get_params() const { return params; } -Tensor FFModel::sigmoid_silu_multi(const Tensor input1, - const Tensor input2, +Tensor FFModel::sigmoid_silu_multi(Tensor const input, int intermediate_size, DataType data_type, char const *name) { // Check dims - assert(input1->num_dims == input2->num_dims); - for (int i = 0; i < input1->num_dims; i++) { - assert(input1->dims[i] == input2->dims[i]); - } + assert(input->dims[0] == intermediate_size * 2); + // Tensor Data type if (data_type == DT_NONE) { - data_type = input1->data_type; - assert(input2->data_type == input1->data_type); + data_type = input->data_type; } - Tensor casted_input1 = - (data_type != input1->data_type) - ? cast(input1, data_type, "type cast for sigmoid_silu_multi") - : input1; - Tensor casted_input2 = - (data_type != input2->data_type) - ? cast(input2, data_type, "type cast for sigmoid_silu_multi") - : input2; + Tensor casted_input = + (data_type != input->data_type) + ? cast(input, data_type, "type cast for sigmoid_silu_multi") + : input; // Create layer Layer *ssm = new Layer(this, OP_SIGMOID_SILU_MULTI, data_type, name, - 2 /*inputs*/, + 1 /*inputs*/, 0 /*weights*/, 1 /*outputs*/, - casted_input1, - casted_input2); + casted_input); + int dims[MAX_TENSOR_DIM] = {0}; + for (int i = 0; i < input->num_dims; i++) { + dims[i] = input->dims[i]; + } + dims[0] = intermediate_size; ssm->outputs[0] = create_tensor_legion_ordering( - input1->num_dims, input1->dims, data_type, ssm, 0, false /*create_grad*/); + input->num_dims, dims, data_type, ssm, 0, false /*create_grad*/); ssm->add_int_property("intermediate_size", intermediate_size); ssm->add_int_property("tensor_parallelism_degree", config.tensor_parallelism_degree); @@ -118,50 +113,47 @@ Op *SigmoidSiluMulti::create_operator_from_layer( return new SigmoidSiluMulti(model, layer->layer_guid, inputs[0], - inputs[1], intermediate_size, tensor_parallelism_degree, layer->name); } -SigmoidSiluMulti::SigmoidSiluMulti( - FFModel &model, - SigmoidSiluMultiParams const ¶ms, - std::pair const &inputs, - char const *name) +SigmoidSiluMulti::SigmoidSiluMulti(FFModel &model, + SigmoidSiluMultiParams const ¶ms, + ParallelTensor const &input, + char const *name) : SigmoidSiluMulti(model, params.layer_guid, - inputs.first, - inputs.second, + input, params.intermediate_size, params.tensor_parallelism_degree, params.name) {} SigmoidSiluMulti::SigmoidSiluMulti(FFModel &model, LayerID const &_layer_guid, - const ParallelTensor _input1, - const ParallelTensor _input2, + ParallelTensor const _input, int _intermediate_size, int _tensor_parallelism_degree, char const *name) : Op(model, OP_SIGMOID_SILU_MULTI, - _input1->data_type, + _input->data_type, name, - 2 /*inputs*/, + 1 /*inputs*/, 0 /*weights*/, 1 /*outputs*/, - _input1, - _input2), + _input), intermediate_size(_intermediate_size), tensor_parallelism_degree(_tensor_parallelism_degree) { // overwrite layer_guid layer_guid = _layer_guid; - outputs[0] = model.create_parallel_tensor_legion_ordering(_input1->num_dims, - _input1->dims, - _input1->data_type, - this, - 0 /*owner_idx*/); + ParallelDim dims[MAX_TENSOR_DIM]; + for (int i = 0; i < _input->num_dims; i++) { + dims[i] = _input->dims[i]; + } + dims[0].size = _intermediate_size; + outputs[0] = model.create_parallel_tensor_legion_ordering( + _input->num_dims, dims, _input->data_type, this, 0 /*owner_idx*/); } void SigmoidSiluMulti::init_inference( @@ -185,27 +177,20 @@ void SigmoidSiluMulti::init_inference( false /*must*/, 0 /*mapper_id*/, machine_view_hash); - // input 1 + // input launcher.add_region_requirement(RegionRequirement(batch_inputs[0]->part, 0 /*projection id*/, READ_ONLY, EXCLUSIVE, batch_inputs[0]->region)); launcher.add_field(0, FID_DATA); - // input 2 - launcher.add_region_requirement(RegionRequirement(batch_inputs[1]->part, - 0 /*projection id*/, - READ_ONLY, - EXCLUSIVE, - batch_inputs[1]->region)); - launcher.add_field(1, FID_DATA); // output launcher.add_region_requirement(RegionRequirement(batch_outputs[0]->part, 0 /*projection id*/, WRITE_ONLY, EXCLUSIVE, batch_outputs[0]->region)); - launcher.add_field(2, FID_DATA); + launcher.add_field(1, FID_DATA); FutureMap fm = runtime->execute_index_space(ctx, launcher); fm.wait_all_results(); set_opmeta_from_futuremap_inference(ff, fm, batch_outputs[0]); @@ -226,36 +211,28 @@ void SigmoidSiluMulti::init(FFModel const &ff) { false /*must*/, 0 /*mapper_id*/, outputs[0]->machine_view.hash()); - // input 1 + // input launcher.add_region_requirement(RegionRequirement(inputs[0]->part, 0 /*projection id*/, READ_ONLY, EXCLUSIVE, inputs[0]->region)); launcher.add_field(0, FID_DATA); - // input 2 - launcher.add_region_requirement(RegionRequirement(inputs[1]->part, - 0 /*projection id*/, - READ_ONLY, - EXCLUSIVE, - inputs[1]->region)); - launcher.add_field(1, FID_DATA); // output launcher.add_region_requirement(RegionRequirement(outputs[0]->part, 0 /*projection id*/, WRITE_ONLY, EXCLUSIVE, outputs[0]->region)); - launcher.add_field(2, FID_DATA); + launcher.add_field(1, FID_DATA); FutureMap fm = runtime->execute_index_space(ctx, launcher); fm.wait_all_results(); set_opmeta_from_futuremap(ff, fm); } /* - regions[0](I): input 1 - regions[1](I): input 2 - regions[2](O): output + regions[0](I): input + regions[1](O): output */ OpMeta *SigmoidSiluMulti::init_task(Task const *task, std::vector const ®ions, @@ -276,7 +253,6 @@ OpMeta *SigmoidSiluMulti::init_task(Task const *task, ssm->intermediate_size, intermediate_size); meta->input_type[0] = ssm->inputs[0]->data_type; - meta->input_type[1] = ssm->inputs[1]->data_type; meta->output_type[0] = ssm->outputs[0]->data_type; std::strcpy(meta->op_name, ssm->name); meta->layer_guid = ssm->layer_guid; @@ -316,34 +292,26 @@ FutureMap SigmoidSiluMulti::inference( 0 /*mapper_id*/, machine_view_hash); launcher.add_future(bc); - // input 1 + // input launcher.add_region_requirement(RegionRequirement(batch_inputs[0]->part, 0 /*projection id*/, READ_ONLY, EXCLUSIVE, batch_inputs[0]->region)); launcher.add_field(0, FID_DATA); - // input 2 - launcher.add_region_requirement(RegionRequirement(batch_inputs[1]->part, - 0 /*projection id*/, - READ_ONLY, - EXCLUSIVE, - batch_inputs[1]->region)); - launcher.add_field(1, FID_DATA); // output launcher.add_region_requirement(RegionRequirement(batch_outputs[0]->part, 0 /*projection id*/, WRITE_ONLY, EXCLUSIVE, batch_outputs[0]->region)); - launcher.add_field(2, FID_DATA); + launcher.add_field(1, FID_DATA); return runtime->execute_index_space(ctx, launcher); } /* - regions[0](I): input 1 - regions[1](I): input 2 - regions[2](O): output + regions[0](I): input + regions[1](O): output */ void SigmoidSiluMulti::inference_task( Task const *task, @@ -352,7 +320,7 @@ void SigmoidSiluMulti::inference_task( Runtime *runtime) { assert(task->regions.size() == regions.size()); - assert(regions.size() == 3); + assert(regions.size() == 2); BatchConfig const *bc = BatchConfig::from_future(task->futures[0]); if (bc->num_tokens == 0) { @@ -361,34 +329,26 @@ void SigmoidSiluMulti::inference_task( SigmoidSiluMultiMeta *m = *((SigmoidSiluMultiMeta **)task->local_args); - GenericTensorAccessorR input1 = helperGetGenericTensorAccessorRO( + GenericTensorAccessorR input = helperGetGenericTensorAccessorRO( m->input_type[0], regions[0], task->regions[0], FID_DATA, ctx, runtime); - GenericTensorAccessorR input2 = helperGetGenericTensorAccessorRO( - m->input_type[1], regions[1], task->regions[1], FID_DATA, ctx, runtime); GenericTensorAccessorW output = helperGetGenericTensorAccessorWO( - m->output_type[0], regions[2], task->regions[2], FID_DATA, ctx, runtime); + m->output_type[0], regions[1], task->regions[1], FID_DATA, ctx, runtime); - Domain input1_domain = runtime->get_index_space_domain( + Domain input_domain = runtime->get_index_space_domain( ctx, task->regions[0].region.get_index_space()); - Domain input2_domain = runtime->get_index_space_domain( - ctx, task->regions[1].region.get_index_space()); Domain output_domain = runtime->get_index_space_domain( - ctx, task->regions[2].region.get_index_space()); - - assert(input1_domain.get_volume() == input2_domain.get_volume()); - assert(input1_domain.get_volume() == output_domain.get_volume()); + ctx, task->regions[1].region.get_index_space()); - assert(input1_domain == input2_domain); - assert(input1_domain == output_domain); + assert(input_domain.get_volume() == output_domain.get_volume() * 2); // use active number of tokens SigmoidSiluMulti::inference_kernel_wrapper( - m, input1, input2, output, bc->num_active_tokens()); + m, input, output, bc->num_active_tokens()); if (m->inference_debugging) { assert(task->index_point.get_dim() == 1); int shard_id = task->index_point.point_data[0]; SigmoidSiluMulti::save_inference_tensors_to_file( - m, shard_id, bc, {input1, input2}, {}, {output}); + m, shard_id, bc, {input}, {}, {output}); } } @@ -414,7 +374,7 @@ Node SigmoidSiluMulti::deserialize(FFModel &ff, Legion::Deserializer &dez, ParallelTensor inputs[], int num_inputs) { - assert(num_inputs == 2); + assert(num_inputs == 1); size_t id, transformer_layer_id, deserialized_model_id; int intermediate_size, tensor_parallelism_degree; dez.deserialize(id); @@ -433,8 +393,7 @@ Node SigmoidSiluMulti::deserialize(FFModel &ff, params.intermediate_size = intermediate_size; params.tensor_parallelism_degree = tensor_parallelism_degree; strcpy(params.name, name); - return ff.get_or_create_node({inputs[0], inputs[1]}, - params); + return ff.get_or_create_node(inputs[0], params); } }; // namespace FlexFlow diff --git a/src/ops/sigmoid_silu_multi.cu b/src/ops/sigmoid_silu_multi.cu index 962777ff3..33c1bba5f 100644 --- a/src/ops/sigmoid_silu_multi.cu +++ b/src/ops/sigmoid_silu_multi.cu @@ -39,21 +39,26 @@ SigmoidSiluMultiMeta::~SigmoidSiluMultiMeta(void) { template __global__ void SigmoidSiluMultiKernel(int num_elements, - T const *input1_ptr, - T const *input2_ptr, + int intermediate_size, + T const *input, T *output_ptr) { CUDA_KERNEL_LOOP(i, num_elements) { - float sigmoid_val = static_cast(input1_ptr[i]); - sigmoid_val = 1.0f / (1.0f + exp(-sigmoid_val)); - output_ptr[i] = input1_ptr[i] * T(sigmoid_val) * input2_ptr[i]; + int row = i / intermediate_size; + int col = i % intermediate_size; + int gate_idx = row * intermediate_size * 2 + col; + int up_idx = row * intermediate_size * 2 + col + intermediate_size; + T gate = input[gate_idx]; + T up = input[up_idx]; + float sigmoid_val = static_cast(gate); + sigmoid_val = 1.0f / (1.0f + __expf(-sigmoid_val)); + output_ptr[i] = gate * static_cast(sigmoid_val) * up; } } /*static*/ void SigmoidSiluMulti::inference_kernel_wrapper( SigmoidSiluMultiMeta const *m, - GenericTensorAccessorR const &input1, - GenericTensorAccessorR const &input2, + GenericTensorAccessorR const &input, GenericTensorAccessorW const &output, int token_size) { if (token_size == 0) { @@ -62,8 +67,7 @@ void SigmoidSiluMulti::inference_kernel_wrapper( cudaStream_t stream; checkCUDA(get_legion_stream(&stream)); - assert(input2.domain.get_volume() == input1.domain.get_volume()); - assert(output.domain.get_volume() == input1.domain.get_volume()); + assert(output.domain.get_volume() * 2 == input.domain.get_volume()); int num_elements = token_size * m->intermediate_size; @@ -78,16 +82,16 @@ void SigmoidSiluMulti::inference_kernel_wrapper( min(CUDA_NUM_THREADS, num_elements), 0, stream>>>(num_elements, - input1.get_float_ptr(), - input2.get_float_ptr(), + m->intermediate_size, + input.get_float_ptr(), output.get_float_ptr()); } else if (m->input_type[0] == DT_HALF) { SigmoidSiluMultiKernel<<>>(num_elements, - input1.get_half_ptr(), - input2.get_half_ptr(), + m->intermediate_size, + input.get_half_ptr(), output.get_half_ptr()); } else { assert(false && "unsupport datatype in SigmoidSiluMulti"); diff --git a/src/ops/spec_inc_multihead_self_attention.cc b/src/ops/spec_inc_multihead_self_attention.cc index 2e5cc9fa7..bef3f5dcc 100644 --- a/src/ops/spec_inc_multihead_self_attention.cc +++ b/src/ops/spec_inc_multihead_self_attention.cc @@ -154,8 +154,8 @@ Tensor FFModel::spec_inc_multiquery_self_attention( int kParas = qk_dim * hidden_size; int vParas = v_dim * hidden_size; int oParas = o_dim * (v_dim > 0 ? v_dim : hidden_size); - int weight_size = qParas * num_q_heads + kParas * num_q_heads + - vParas * num_q_heads + oParas * num_q_heads; + int weight_size = qParas * num_q_heads + kParas * num_kv_heads + + vParas * num_kv_heads + oParas * num_q_heads; { int dims[1] = {weight_size}; li->weights[0] = create_weight_legion_ordering(1, @@ -168,7 +168,7 @@ Tensor FFModel::spec_inc_multiquery_self_attention( } if (qkv_bias || final_bias) { // q, k, v, o - int qkv_bias_size = qk_dim * num_q_heads + (qk_dim + v_dim) * num_q_heads; + int qkv_bias_size = qk_dim * num_q_heads + (qk_dim + v_dim) * num_kv_heads; int dims[1] = {(qkv_bias ? qkv_bias_size : 0) + (final_bias ? o_dim : 0)}; li->weights[1] = create_weight_legion_ordering(1, dims, @@ -350,7 +350,7 @@ SpecIncMultiHeadSelfAttention::SpecIncMultiHeadSelfAttention( dims[0].size = dims[0].degree; dims[1] = inputs[0]->dims[num_dims - 1]; dims[1].size = this->num_q_heads * (qParas + oParas) + - this->num_q_heads * (kParas + vParas); + this->num_kv_heads * (kParas + vParas); dims[1].is_replica_dim = false; int seed = std::rand(); Initializer *initializer = new GlorotUniform(seed); @@ -362,7 +362,8 @@ SpecIncMultiHeadSelfAttention::SpecIncMultiHeadSelfAttention( CHOSEN_SYNC_TYPE); if (qkv_bias || final_bias) { ParallelTensorShape bias_shape = _input->get_shape(); - int qkv_bias_size = qk_dim * num_q_heads + (qk_dim + v_dim) * num_q_heads; + int qkv_bias_size = + qk_dim * num_q_heads + (qk_dim + v_dim) * num_kv_heads; bias_shape.dims[0].size = (qkv_bias ? qkv_bias_size : 0) + (final_bias ? o_dim : 0); bias_shape.dims[1].size = bias_shape.dims[2].size = 1; @@ -453,7 +454,7 @@ SpecIncMultiHeadSelfAttention::SpecIncMultiHeadSelfAttention( dims[0].size = dims[0].degree; dims[1] = inputs[0]->dims[num_dims - 1]; dims[1].size = this->num_q_heads * (qParas + oParas) + - this->num_q_heads * (kParas + vParas); + this->num_kv_heads * (kParas + vParas); dims[1].is_replica_dim = false; // dims[2].size = qParas + kParas + vParas + oParas; int seed = std::rand(); @@ -466,7 +467,8 @@ SpecIncMultiHeadSelfAttention::SpecIncMultiHeadSelfAttention( CHOSEN_SYNC_TYPE); if (qkv_bias || final_bias) { ParallelTensorShape bias_shape = _input->get_shape(); - int qkv_bias_size = qk_dim * num_q_heads + (qk_dim + v_dim) * num_q_heads; + int qkv_bias_size = + qk_dim * num_q_heads + (qk_dim + v_dim) * num_kv_heads; bias_shape.dims[0].size = (qkv_bias ? qkv_bias_size : 0) + (final_bias ? o_dim : 0); bias_shape.dims[1].size = bias_shape.dims[2].size = 1; diff --git a/src/ops/spec_inc_multihead_self_attention.cu b/src/ops/spec_inc_multihead_self_attention.cu index 6d7bf1364..9f2ec008c 100644 --- a/src/ops/spec_inc_multihead_self_attention.cu +++ b/src/ops/spec_inc_multihead_self_attention.cu @@ -61,6 +61,9 @@ void tree_search_attention(SpecIncMultiHeadSelfAttentionMeta *m, uint32_t const num_kv_heads = m->num_kv_heads; uint32_t const head_dim = m->qk_dim; uint32_t const batch_size = bc->num_active_requests(); + if (batch_size == 0) { + return; + } float const sm_scale = (*m->qk_prod_scaling) ? 1.0f / sqrt(m->qk_dim) : 1.0f; // cudaEventCreate(&t_start); diff --git a/src/ops/tree_inc_multihead_self_attention.cc b/src/ops/tree_inc_multihead_self_attention.cc index 0e1c83b6e..57ceaa53a 100644 --- a/src/ops/tree_inc_multihead_self_attention.cc +++ b/src/ops/tree_inc_multihead_self_attention.cc @@ -156,8 +156,8 @@ Tensor FFModel::inc_multiquery_self_attention_verify( int vParas = v_dim * hidden_size; int oParas = o_dim * (v_dim > 0 ? v_dim : hidden_size); int one_head_size = qParas + kParas + vParas + oParas; - int weight_size = qParas * num_q_heads + kParas * num_q_heads + - vParas * num_q_heads + oParas * num_q_heads; + int weight_size = qParas * num_q_heads + kParas * num_kv_heads + + vParas * num_kv_heads + oParas * num_q_heads; { // compress the weight size if quantization. if (quantization_type != DT_NONE) { @@ -177,7 +177,7 @@ Tensor FFModel::inc_multiquery_self_attention_verify( } if (qkv_bias || final_bias) { // q, k, v, o - int qkv_bias_size = qk_dim * num_q_heads + (qk_dim + v_dim) * num_q_heads; + int qkv_bias_size = qk_dim * num_q_heads + (qk_dim + v_dim) * num_kv_heads; int dims[1] = {(qkv_bias ? qkv_bias_size : 0) + (final_bias ? o_dim : 0)}; li->weights[1] = create_weight_legion_ordering(1, dims, @@ -360,7 +360,7 @@ TreeIncMultiHeadSelfAttention::TreeIncMultiHeadSelfAttention( dims[0].size = dims[0].degree; dims[1] = inputs[0]->dims[num_dims - 1]; dims[1].size = this->num_q_heads * (qParas + oParas) + - this->num_q_heads * (kParas + vParas); + this->num_kv_heads * (kParas + vParas); dims[1].is_replica_dim = false; // dims[2].size = qParas + kParas + vParas + oParas; if (quantization_type != DT_NONE) { @@ -380,7 +380,8 @@ TreeIncMultiHeadSelfAttention::TreeIncMultiHeadSelfAttention( CHOSEN_SYNC_TYPE); if (qkv_bias || final_bias) { ParallelTensorShape bias_shape = _input->get_shape(); - int qkv_bias_size = qk_dim * num_q_heads + (qk_dim + v_dim) * num_q_heads; + int qkv_bias_size = + qk_dim * num_q_heads + (qk_dim + v_dim) * num_kv_heads; bias_shape.dims[0].size = (qkv_bias ? qkv_bias_size : 0) + (final_bias ? o_dim : 0); bias_shape.dims[1].size = bias_shape.dims[2].size = 1; @@ -472,7 +473,7 @@ TreeIncMultiHeadSelfAttention::TreeIncMultiHeadSelfAttention( dims[0].size = dims[0].degree; dims[1] = inputs[0]->dims[num_dims - 1]; dims[1].size = this->num_q_heads * (qParas + oParas) + - this->num_q_heads * (kParas + vParas); + this->num_kv_heads * (kParas + vParas); dims[1].is_replica_dim = false; // dims[2].size = qParas + kParas + vParas + oParas; if (quantization_type != DT_NONE) { @@ -490,7 +491,8 @@ TreeIncMultiHeadSelfAttention::TreeIncMultiHeadSelfAttention( CHOSEN_SYNC_TYPE); if (qkv_bias || final_bias) { ParallelTensorShape bias_shape = _input->get_shape(); - int qkv_bias_size = qk_dim * num_q_heads + (qk_dim + v_dim) * num_q_heads; + int qkv_bias_size = + qk_dim * num_q_heads + (qk_dim + v_dim) * num_kv_heads; bias_shape.dims[0].size = (qkv_bias ? qkv_bias_size : 0) + (final_bias ? o_dim : 0); bias_shape.dims[1].size = bias_shape.dims[2].size = 1; diff --git a/src/ops/tree_inc_multihead_self_attention.cu b/src/ops/tree_inc_multihead_self_attention.cu index 7266eb78c..07eba1733 100644 --- a/src/ops/tree_inc_multihead_self_attention.cu +++ b/src/ops/tree_inc_multihead_self_attention.cu @@ -152,6 +152,9 @@ void tree_verify_attention(TreeIncMultiHeadSelfAttentionMeta *m, uint32_t const num_kv_heads = m->num_kv_heads; uint32_t const head_dim = m->qk_dim; uint32_t const batch_size = bc->num_active_requests(); + if (batch_size == 0) { + return; + } float const sm_scale = (*m->qk_prod_scaling) ? 1.0f / sqrt(m->qk_dim) : 1.0f; // cudaEventCreate(&t_start); diff --git a/src/parallel_ops/kernels/allreduce_kernels.cu b/src/parallel_ops/kernels/allreduce_kernels.cu index 879be72b8..10486c6b7 100644 --- a/src/parallel_ops/kernels/allreduce_kernels.cu +++ b/src/parallel_ops/kernels/allreduce_kernels.cu @@ -150,8 +150,9 @@ void inference_kernel_wrapper(Context ctx, tensorrt_llm::SelectImplementation( num_elements * ((get_bits(dtype) + 7) / 8), num_devices); - if (strategy == tensorrt_llm::AllReduceStrategyType::RING || - !CanApplyCustomAllReduce(num_elements, dtype)) { + // if (strategy == tensorrt_llm::AllReduceStrategyType::RING || + // !CanApplyCustomAllReduce(num_elements, dtype)) { + if (true) { // Dispatch to nccl AllReduce if the customized all-reduce cannot apply. ncclDataType_t nccl_data_type = ff_to_nccl_datatype(dtype); runtime->concurrent_task_barrier(ctx); diff --git a/src/runtime/file_loader.cc b/src/runtime/file_loader.cc index 14e806d49..d6f14ef77 100644 --- a/src/runtime/file_loader.cc +++ b/src/runtime/file_loader.cc @@ -80,55 +80,6 @@ std::string removeGuidOperatorName(std::string const &input) { } } -template -void load_attention_weights_multi_query(DT *ptr, - std::string layer_name, - std::string weights_folder, - size_t hidden_dim, - int num_heads) { - - std::string qkv_file = layer_name.substr(0, layer_name.find("attention")) + - "attention_query_key_value_weight"; - std::string o_file = layer_name.substr(0, layer_name.find("attention")) + - "attention_dense_weight"; - - // q has n_heads heads, k and v only have one head, o have n_head heads - std::vector weight_filenames = {qkv_file, o_file}; - int file_index = 0; - int data_index = 0; - for (auto filename : weight_filenames) { - std::cout << "Loading weight file " << filename << std::endl; - std::string weight_filepath = join_path({weights_folder, filename}); - size_t partial_size = - file_index == 0 ? (hidden_dim + 2 * hidden_dim / num_heads) * hidden_dim - : hidden_dim * hidden_dim; - - std::ifstream in(weight_filepath, std::ios::in | std::ios::binary); - // std::cout << "Loading filename: " << weight_filepath << std::endl; - if (!in.good()) { - std::cout << "Could not open file: " << weight_filepath << std::endl; - } - assert(in.good() && "incorrect weight file path"); - std::vector
host_array(partial_size); - size_t loaded_data_size = sizeof(DT) * partial_size; - in.seekg(0, in.end); - in.seekg(0, in.beg); - in.read((char *)host_array.data(), loaded_data_size); - size_t in_get_size = in.gcount(); - - if (in_get_size != loaded_data_size) { - std::cout << "load data error " << in_get_size << ", " - << loaded_data_size; - assert(false && "data size mismatch"); - } - for (int i = 0; i < partial_size; i++) { - ptr[data_index++] = host_array.at(i); - } - file_index++; - in.close(); - } -} - template void load_attention_bias_v2(DT *ptr, int num_heads, @@ -149,69 +100,60 @@ void load_attention_bias_v2(DT *ptr, int file_index = 0; - // now only opt use this. - // assert(num_heads == num_kv_heads); int idx = 0; + size_t q_dim = num_heads * head_dim; + size_t kv_dim = num_kv_heads * head_dim; + size_t o_dim = hidden_dim; for (auto filename : bias_files) { std::cout << "Loading weight file " << filename << std::endl; std::string weight_filepath = join_path({weights_folder, filename}); - int n_heads = file_index == 0 ? num_heads : num_kv_heads; - - int replicate_num = num_heads / num_kv_heads; + size_t bias_size = 0; + switch (file_index) { + case 0: + bias_size = q_dim; + break; + case 1: + case 2: + bias_size = kv_dim; + break; + case 3: + bias_size = o_dim; + break; + default: + std::cout << "file index is " << file_index << std::endl; + assert(false && "file index out of range"); + } - size_t qkv_partial_size = head_dim * n_heads; - size_t qkv_replicate_size = head_dim * num_heads; - size_t out_partial_size = hidden_dim; - size_t partial_size = - (file_index < 3) ? qkv_partial_size : out_partial_size; std::ifstream in(weight_filepath, std::ios::in | std::ios::binary); assert(in.good() && "incorrect bias file path"); - std::vector
host_array(partial_size); - size_t loaded_data_size = sizeof(DT) * partial_size; + std::vector
host_array(bias_size); + size_t loaded_data_size = sizeof(DT) * bias_size; in.seekg(0, in.end); in.seekg(0, in.beg); in.read((char *)host_array.data(), loaded_data_size); size_t in_get_size = in.gcount(); if (in_get_size != loaded_data_size) { - printf( - "load bias data error: in_get_size (%lu) != loaded_data_size (%lu)\n", - in_get_size, - loaded_data_size); + std::cout << "load bias data error" << std::endl; + std::cout << "in_get_size: " << in_get_size << std::endl; + std::cout << "loaded_data_size: " << loaded_data_size << std::endl; assert(false); } - assert(partial_size == host_array.size()); - - size_t data_index = 0; - - // q, o - if (file_index == 0 || file_index == 3) { - for (int i = 0; i < partial_size; i++) { - ptr[idx + i] = host_array.at(data_index); - data_index++; - } - } else { - // k, v - for (int i = 0; i < partial_size; i++) { - for (int j = 0; j < replicate_num; j++) { - ptr[idx + j * partial_size + i] = host_array.at(data_index); - } - data_index++; - } - } + assert(bias_size == host_array.size()); + in.close(); + // Copy the data to the pointer with memcpy + memcpy(ptr + idx, host_array.data(), sizeof(DT) * bias_size); + idx += bias_size; file_index++; - idx += qkv_replicate_size; - - in.close(); } } template void load_attention_weights_v2(DT *ptr, - int num_heads, + int num_q_heads, int num_kv_heads, size_t hidden_dim, size_t head_dim, @@ -230,101 +172,91 @@ void load_attention_weights_v2(DT *ptr, size_t single_proj_size = hidden_dim * head_dim; // size of each of Q,K,V,O weights for a single head - size_t one_weight_file_size = - num_heads * single_proj_size; // size of each of Q/K/V/O for all heads - size_t q_size = one_weight_file_size, o_size = one_weight_file_size; - size_t k_size = single_proj_size * num_kv_heads, - v_size = single_proj_size * num_kv_heads; + size_t qo_weight_file_size = num_q_heads * single_proj_size; + size_t kv_weight_file_size = num_kv_heads * single_proj_size; - size_t k_replicate_size = one_weight_file_size; - size_t v_replicate_size = one_weight_file_size; + size_t q_size = qo_weight_file_size, o_size = qo_weight_file_size; + size_t k_size = kv_weight_file_size, v_size = kv_weight_file_size; - int replicate_num = num_heads / num_kv_heads; + int replicate_num = num_q_heads / num_kv_heads; // stride for q, k, v, o - size_t stride_size = (q_size + v_replicate_size + k_replicate_size + o_size) / - tensor_parallelism_degree; + size_t stride_size = + (q_size + k_size + v_size + o_size) / tensor_parallelism_degree; for (auto filename : weight_filenames) { std::cout << "Loading weight file " << filename << std::endl; std::string weight_filepath = join_path({weights_folder, filename}); - int data_index = 0; size_t partial_size = (file_index == 0 || file_index == 3) - ? one_weight_file_size - : single_proj_size * num_kv_heads; - size_t one_partition_size = - one_weight_file_size / tensor_parallelism_degree; - - std::ifstream in(weight_filepath, std::ios::in | std::ios::binary); - if (!in.good()) { - std::cout << "Could not open file: " << weight_filepath << std::endl; - } - assert(in.good() && "incorrect weight file path"); + ? qo_weight_file_size + : kv_weight_file_size; std::vector
host_array(partial_size); size_t loaded_data_size = sizeof(DT) * partial_size; - in.seekg(0, in.end); - in.seekg(0, in.beg); - in.read((char *)host_array.data(), loaded_data_size); - size_t in_get_size = in.gcount(); - - if (in_get_size != loaded_data_size) { - std::cout << "load attention data error " << in_get_size << ", " - << loaded_data_size << ", " << file_index << ", " - << weight_filepath << "\n"; - assert(false && "data size mismatch"); + { + std::ifstream in(weight_filepath, std::ios::in | std::ios::binary); + if (!in.good()) { + std::cout << "Could not open file: " << weight_filepath << std::endl; + } + assert(in.good() && "incorrect weight file path"); + in.seekg(0, in.end); + in.seekg(0, in.beg); + in.read((char *)host_array.data(), loaded_data_size); + size_t in_get_size = in.gcount(); + in.close(); + if (in_get_size != loaded_data_size) { + std::cout << "load attention data error " << in_get_size << ", " + << loaded_data_size << ", " << file_index << ", " + << weight_filepath << "\n"; + assert(false && "data size mismatch"); + } } - // wq, wk, wo + + size_t one_partition_size = 0; if (file_index == 0) { - for (int i = 0; i < tensor_parallelism_degree; i++) { - for (int j = 0; j < one_partition_size; j++) { - ptr[base_index + i * stride_size + j] = host_array.at(data_index++); - } - } + one_partition_size = qo_weight_file_size / tensor_parallelism_degree; } else { - for (int i = 0; i < num_heads; i++) { - int kv_idx = i / (num_heads / num_kv_heads); - int head_idx = i % (num_heads / tensor_parallelism_degree); - int tp_idx = (i / (num_heads / tensor_parallelism_degree)); - for (int j = 0; j < single_proj_size; j++) { - ptr[base_index + tp_idx * stride_size + single_proj_size * head_idx + - j] = host_array.at(kv_idx * single_proj_size + j); - } + one_partition_size = kv_weight_file_size / tensor_parallelism_degree; + } + int data_index = 0; + for (int i = 0; i < tensor_parallelism_degree; i++) { + for (int j = 0; j < one_partition_size; j++) { + ptr[i * stride_size + base_index + j] = host_array.at(data_index++); } } - - // assert(data_index == partial_size); base_index += one_partition_size; + file_index++; } - assert(base_index == (q_size + k_replicate_size + v_replicate_size) / - tensor_parallelism_degree); + assert(base_index == (q_size + k_size + v_size) / tensor_parallelism_degree); { std::cout << "Loading weight file " << o_file << std::endl; std::string weight_filepath = join_path({weights_folder, o_file}); - std::ifstream in(weight_filepath, std::ios::in | std::ios::binary); - if (!in.good()) { - std::cout << "Could not open file: " << weight_filepath << std::endl; + std::vector
host_array(qo_weight_file_size); + size_t loaded_data_size = sizeof(DT) * qo_weight_file_size; + { + std::ifstream in(weight_filepath, std::ios::in | std::ios::binary); + if (!in.good()) { + std::cout << "Could not open file: " << weight_filepath << std::endl; + } + assert(in.good() && "incorrect weight file path"); + in.seekg(0, in.end); + in.seekg(0, in.beg); + in.read((char *)host_array.data(), loaded_data_size); + size_t in_get_size = in.gcount(); + in.close(); + if (in_get_size != loaded_data_size) { + std::cout << "load data error" << std::endl; + assert(false); + } } - assert(in.good() && "incorrect weight file path"); - std::vector
host_array(one_weight_file_size); - size_t loaded_data_size = sizeof(DT) * one_weight_file_size; - in.seekg(0, in.end); - in.seekg(0, in.beg); - in.read((char *)host_array.data(), loaded_data_size); - size_t in_get_size = in.gcount(); - if (in_get_size != loaded_data_size) { - std::cout << "load data error" << std::endl; - assert(false); - } - assert(one_weight_file_size == host_array.size()); int data_index = 0; - - int one_partition_size = head_dim * (num_heads / tensor_parallelism_degree); - for (int i = 0; i < one_weight_file_size; i++) { + int one_partition_size = + head_dim * (num_q_heads / tensor_parallelism_degree); + for (int i = 0; i < qo_weight_file_size; i++) { int part_idx = (i / one_partition_size) % tensor_parallelism_degree; int block_num = (i / one_partition_size); int offset = block_num / tensor_parallelism_degree * one_partition_size + @@ -333,10 +265,77 @@ void load_attention_weights_v2(DT *ptr, host_array.at(data_index++); } - in.close(); + assert(data_index == qo_weight_file_size); + } +} + +template +void load_gate_and_up(DT *ptr, + std::string layer_name, + std::string weights_folder, + size_t volume, + int tensor_parallelism_degree) { + // Replace the "gate_and_up" with the actual prefix of the file names + std::string prefix = layer_name.substr(0, layer_name.find("gate_and_up")); + std::string gate_file = prefix + "gate_proj.weight"; + std::string up_file = prefix + "up_proj.weight"; + std::string up_file_path = join_path({weights_folder, up_file}); + std::string gate_file_path = join_path({weights_folder, gate_file}); + + size_t single_weight_size = volume / 2; + size_t stride_size = volume / tensor_parallelism_degree; + size_t partition_size = single_weight_size / tensor_parallelism_degree; + // std::vector weight_filenames = {q_file, k_file, v_file}; + // int file_index = 0; + std::ifstream gate_stream(gate_file_path, std::ios::in | std::ios::binary); + if (!gate_stream.good()) { + std::cout << "Could not open file: " << gate_file_path << std::endl; + assert(false && "incorrect weight file path"); + } + std::cout << "Loading weight file " << prefix + "gate_proj.weight" + << std::endl; + std::ifstream up_stream(up_file_path, std::ios::in | std::ios::binary); + if (!up_stream.good()) { + std::cout << "Could not open file: " << up_file_path << std::endl; + assert(false && "incorrect weight file path"); + } + std::cout << "Loading weight file " << prefix + "up_proj.weight" << std::endl; + + std::vector
gate_array(single_weight_size); + std::vector
up_array(single_weight_size); - assert(data_index == one_weight_file_size); + size_t loaded_data_size = sizeof(DT) * single_weight_size; + gate_stream.seekg(0, gate_stream.end); + gate_stream.seekg(0, gate_stream.beg); + gate_stream.read((char *)gate_array.data(), loaded_data_size); + size_t in_get_size = gate_stream.gcount(); + if (in_get_size != loaded_data_size) { + std::cout << "load gate data error " << in_get_size << ", " + << loaded_data_size; + assert(false && "data size mismatch"); + } + up_stream.seekg(0, up_stream.end); + up_stream.seekg(0, up_stream.beg); + up_stream.read((char *)up_array.data(), loaded_data_size); + in_get_size = up_stream.gcount(); + if (in_get_size != loaded_data_size) { + std::cout << "load up data error " << in_get_size << ", " + << loaded_data_size; + assert(false && "data size mismatch"); } + assert(single_weight_size == gate_array.size()); + assert(single_weight_size == up_array.size()); + + for (int i = 0; i < tensor_parallelism_degree; i++) { + for (int j = 0; j < partition_size; j++) { + ptr[i * stride_size + j] = gate_array.at(j + i * partition_size); + ptr[i * stride_size + j + partition_size] = + up_array.at(j + i * partition_size); + } + } + + gate_stream.close(); + up_stream.close(); } template @@ -736,7 +735,6 @@ void FileDataLoader::load_single_weight_tensor(FFModel *ff, } assert(volume_ == volume * num_replicas); // assert(data_type_size(weight->data_type) == sizeof(DT)); - DT *data = (DT *)malloc(sizeof(DT) * volume); std::string weight_filename = removeGuidOperatorName(std::string(l->name)); @@ -750,7 +748,7 @@ void FileDataLoader::load_single_weight_tensor(FFModel *ff, l->op_type == OP_SPEC_INC_MULTIHEAD_SELF_ATTENTION || l->op_type == OP_TREE_INC_MULTIHEAD_SELF_ATTENTION) { if (weight_idx == 0) { - load_attention_weights_v2(data, + load_attention_weights_v2(weight, num_heads, num_kv_heads, hidden_dim, @@ -763,7 +761,7 @@ void FileDataLoader::load_single_weight_tensor(FFModel *ff, long long value; l->get_int_property("final_bias", value); bool final_bias = (bool)value; - load_attention_bias_v2(data, + load_attention_bias_v2(weight, num_heads, num_kv_heads, hidden_dim, @@ -780,7 +778,16 @@ void FileDataLoader::load_single_weight_tensor(FFModel *ff, std::cout << "Loading weight file " << weight_filename << std::endl; std::string weight_filepath = join_path({weights_folder, weight_filename}); - load_from_file(data, volume, weight_filepath); + load_from_file(weight, volume, weight_filepath); + } else if (l->name != nullptr && + std::string(l->name).find("gate_and_up") != std::string::npos) { + assert(weight_idx == 0); + assert(l->numWeights == 1); // We do not support bias in SwiGLU now + load_gate_and_up(weight, + weight_filename, + weights_folder, + volume, + tensor_parallelism_degree); } else { // default op assert(weight_idx == 0 || weight_idx == 1); @@ -791,19 +798,14 @@ void FileDataLoader::load_single_weight_tensor(FFModel *ff, std::cout << "Loading weight file " << weight_filename << std::endl; std::string weight_filepath = join_path({weights_folder, weight_filename}); - load_from_file(data, volume, weight_filepath); + load_from_file(weight, volume, weight_filepath); } } - // Copy the weight data from the buffer to the weight - DT *ptr = weight; - for (size_t i = 0; i < num_replicas; i++) { - memcpy(ptr, data, volume * sizeof(DT)); - ptr += volume; + // Copy the weight data from the first replica to other replicas + for (size_t i = 1; i < num_replicas; i++) { + memcpy(weight + i * volume, weight, volume * sizeof(DT)); } - - // Free buffer memory - free(data); } void FileDataLoader::load_weight_task( diff --git a/src/runtime/model.cc b/src/runtime/model.cc index 2a72029c5..96365ca1f 100644 --- a/src/runtime/model.cc +++ b/src/runtime/model.cc @@ -3408,10 +3408,13 @@ void FFModel::create_operators_from_layers() { layers[layer_idx - 4]->op_type == OP_LINEAR && layers[layer_idx - 5]->op_type == OP_LINEAR) || // LLAMA with element-wise operator fusion - (l->op_type == OP_LINEAR && layer_idx >= 3 && + // (l->op_type == OP_LINEAR && layer_idx >= 3 && + // layers[layer_idx - 1]->op_type == OP_SIGMOID_SILU_MULTI && + // layers[layer_idx - 2]->op_type == OP_LINEAR && + // layers[layer_idx - 3]->op_type == OP_LINEAR) + (l->op_type == OP_LINEAR && layer_idx >= 2 && layers[layer_idx - 1]->op_type == OP_SIGMOID_SILU_MULTI && - layers[layer_idx - 2]->op_type == OP_LINEAR && - layers[layer_idx - 3]->op_type == OP_LINEAR))) { + layers[layer_idx - 2]->op_type == OP_LINEAR))) { assert(op->numOutputs == 1); AllReduce *allreduce = new AllReduce(*this, op->outputs[0], op->outputs[0]->num_dims - 1); diff --git a/src/runtime/request_manager.cc b/src/runtime/request_manager.cc index 9da88d8ed..8a5ef7bd5 100644 --- a/src/runtime/request_manager.cc +++ b/src/runtime/request_manager.cc @@ -25,6 +25,7 @@ #include #include #include +#include #include #include #include @@ -879,46 +880,47 @@ void RequestManager::request_complete_clean_up(int batch_index) { std::cout << ""; } std::cout << std::endl << std::endl; - { - RequestProfileInfo profile_info = profiling_requests[guid]; - - std::ostream *os = &std::cout; - std::ofstream output_file; - if (!output_filepath.empty()) { - output_file.open(output_filepath, std::ios::app); - if (output_file.is_open()) { - os = &output_file; - } else { - std::cout << "Unable to open the output file: " << output_filepath - << std::endl; - assert(false); - } - } - *os << "Request " << guid << " profiling: " << std::endl; - if (profile_info.start_decoding_time != 0) { - *os << "Decoding time: " - << (profile_info.finish_time - profile_info.start_decoding_time) * - 1e-3 - << " ms" << std::endl; - } else { - *os << "Decoding time: 0 ms" << std::endl; - } - *os << "Total time: " - << (profile_info.finish_time - profile_info.start_time) * 1e-3 << " ms" - << std::endl; - *os << "LLM decoding steps: " << profile_info.llm_decoding_steps - << std::endl; - if (decoding_mode == SPECULATIVE_DECODING) { - *os << "SSM decoding steps: " << profile_info.ssm_decoding_steps - << std::endl; - } - *os << std::endl; - // *os << output << std::endl << std::endl; - - if (!output_filepath.empty()) { - output_file.close(); - } - } + // { + // RequestProfileInfo profile_info = profiling_requests[guid]; + + // std::ostream *os = &std::cout; + // std::ofstream output_file; + // if (!output_filepath.empty()) { + // output_file.open(output_filepath, std::ios::app); + // if (output_file.is_open()) { + // os = &output_file; + // } else { + // std::cout << "Unable to open the output file: " << output_filepath + // << std::endl; + // assert(false); + // } + // } + // *os << "Request " << guid << " profiling: " << std::endl; + // if (profile_info.start_decoding_time != 0) { + // *os << "Decoding time: " + // << (profile_info.finish_time - profile_info.start_decoding_time) * + // 1e-3 + // << " ms" << std::endl; + // } else { + // *os << "Decoding time: 0 ms" << std::endl; + // } + // *os << "Total time: " + // << (profile_info.finish_time - profile_info.start_time) * 1e-3 << " + // ms" + // << std::endl; + // *os << "LLM decoding steps: " << profile_info.llm_decoding_steps + // << std::endl; + // if (decoding_mode == SPECULATIVE_DECODING) { + // *os << "SSM decoding steps: " << profile_info.ssm_decoding_steps + // << std::endl; + // } + // *os << std::endl; + // // *os << output << std::endl << std::endl; + + // if (!output_filepath.empty()) { + // output_file.close(); + // } + // } // RequestProfileInfo profile_info = profiling_requests[guid]; // std::string str = // "[" + std::to_string(guid) + @@ -1216,7 +1218,9 @@ void RequestManager::update_ssm_prefill_results( BatchConfig RequestManager::prepare_next_batch() { if (is_background_server_terminated()) { - return BatchConfig(); + return BatchConfig(decoding_mode == SPECULATIVE_DECODING + ? InferenceMode::TREE_SEARCH_MODE + : InferenceMode::INC_DECODING_MODE); } switch (request_manager_status) { case PREFILLING: @@ -1228,7 +1232,9 @@ BatchConfig RequestManager::prepare_next_batch() { return prepare_ssm_prefilling_batch(); } else { // Return an empty batch config - return BatchConfig(); + return BatchConfig(decoding_mode == SPECULATIVE_DECODING + ? InferenceMode::TREE_SEARCH_MODE + : InferenceMode::INC_DECODING_MODE); } } else if (prefill_model == LLM) { return prepare_llm_prefilling_batch(); @@ -1254,7 +1260,9 @@ BatchConfig RequestManager::prepare_next_batch() { return prepare_next_spec_batch_config(); } else { // Return an empty batch config - return BatchConfig(); + return BatchConfig(decoding_mode == SPECULATIVE_DECODING + ? InferenceMode::TREE_SEARCH_MODE + : InferenceMode::INC_DECODING_MODE); } case LLM_VERIFY: return prepare_verify_batch_config(); @@ -1339,6 +1347,11 @@ BatchConfig RequestManager::prepare_llm_prefilling_batch() { } bc.num_tokens = num_tokens; + // Debug + std::cout << "[Debug] RequestManager::prepare_llm_prefilling_batch: " + << "num_tokens: " << num_tokens << std::endl; + profiling.prefilling_steps++; + if (verbose) { std::cout << "prepare_llm_prefilling_batch NEW batchconfig:" << std::endl; bc.print(); @@ -2013,6 +2026,7 @@ BatchConfig RequestManager::prepare_verify_batch_config() { new_bc.print(); } profiling.llm_step_start = Realm::Clock::current_time_in_microseconds(); + profiling.tokens_in_verification_per_step.push_back(new_bc.num_tokens); return new_bc; } @@ -2906,6 +2920,8 @@ void RequestManager::serve_spec_infer(FFModel *llm) { im->init_operators_inference(ssm); } + std::cout << "Finish loading model weights." << std::endl; + InferenceResultFuture irf_0; { // Initialize futures for incr decoding @@ -2923,11 +2939,13 @@ void RequestManager::serve_spec_infer(FFModel *llm) { // reset_profiling_statistics(); background_server_status = SERVING; while (!is_background_server_terminated()) { + // std::cout << "[Debug] Begin blocking." << std::endl; if (infer_result_future_pipeline.size() >= 4) { // Block here to avoid launching too many batches auto const &ir = infer_result_future_pipeline.front(); ir.get_void_result(); } + // std::cout << "[Debug] End blocking." << std::endl; // deque finished batches while (infer_result_future_pipeline.size() > 1) { auto const &ir = infer_result_future_pipeline.front(); @@ -2939,6 +2957,7 @@ void RequestManager::serve_spec_infer(FFModel *llm) { } runtime->begin_trace(ctx, 12345 /*trace_id*/); + // std::cout << "[Debug] Trace 12345 starts." << std::endl; for (int ssm_step_i = 0; ssm_step_i < get_max_tree_depth(); ssm_step_i++) { InferenceResultFuture irf = infer_result_future_pipeline.back(); BatchConfigFuture bcf = get_next_batch_config(irf, ctx, runtime); @@ -2950,6 +2969,7 @@ void RequestManager::serve_spec_infer(FFModel *llm) { FutureMap fm = im->inference(llm, 0, bcf); infer_result_future_pipeline.push(fm.get_future(0)); runtime->end_trace(ctx, 12345 /*trace_id*/); + // std::cout << "[Debug] Trace 12345 ends." << std::endl; } } @@ -3025,6 +3045,13 @@ void RequestManager::trigger_request_completion_future( void RequestManager::terminate_background_server_at_exit() { RequestManager *rm = RequestManager::get_request_manager(); rm->terminate_background_server(); + std::cout << "Background server terminated." << std::endl; +} + +std::string format_float(double value, int precision = 2) { + std::ostringstream oss; + oss << std::fixed << std::setprecision(precision) << value; + return oss.str(); } void RequestManager::terminate_background_server() { @@ -3066,17 +3093,22 @@ void RequestManager::terminate_background_server() { str += std::to_string(request.decode_length()); float speedup = (float)request.decode_length() / profiling_info.second.llm_decoding_steps; - str += " " + std::to_string(speedup) + "\n"; + str += "\t" + format_float(speedup) + "\t"; + str += "(SLO: " + format_float(request.slo_ratio) + ")\t"; + if (request.attained == false) { + str += "(Not Attained)\t"; + } + str += "\n"; } - str += "\n total_time_ms(" + std::to_string(total_time / 1000.0) + ")"; + str += "\n total_time_ms(" + format_float(total_time / 1000.0) + ")"; str += "\n total_requests(" + std::to_string(total_requests) + "/" + std::to_string(all_requests.size()) + ")"; str += "\n total_tokens(" + std::to_string(total_tokens) + ")"; // throughput str += "\n throughput_requests_per_sec(" + - std::to_string(total_requests / (total_time / 1e6)) + ")"; + format_float(total_requests / (total_time / 1e6)) + ")"; str += "\n throughput_tokens_per_sec(" + - std::to_string(total_tokens / (total_time / 1e6)) + ")"; + format_float(total_tokens / (total_time / 1e6)) + ")"; double average_latency_per_request = 0; std::string latency_per_request_ms = "\n latency_per_request_ms( "; @@ -3087,7 +3119,7 @@ void RequestManager::terminate_background_server() { // latency_per_request_ms += "[" + std::to_string(profiling_info.first) // + // ","; latency_per_request_ms += std::to_string(latency_ms) + "] "; - latency_per_request_ms += std::to_string(latency_ms) + " "; + latency_per_request_ms += format_float(latency_ms) + " "; average_latency_per_request += latency_ms; } latency_per_request_ms += ")"; @@ -3095,7 +3127,7 @@ void RequestManager::terminate_background_server() { average_latency_per_request /= total_requests; str += "\n average_latency_per_request_ms(" + - std::to_string(average_latency_per_request) + ")"; + format_float(average_latency_per_request) + ")"; std::string ttft_per_request_ms = "\n ttft_per_request_ms( "; for (auto const &profiling_info : profiling_requests) { @@ -3108,7 +3140,7 @@ void RequestManager::terminate_background_server() { prefilling_time_ms = (profiling.finish_time - profiling.start_time) / 1000.0; } - ttft_per_request_ms += std::to_string(prefilling_time_ms) + " "; + ttft_per_request_ms += format_float(prefilling_time_ms) + " "; } ttft_per_request_ms += ")"; str += ttft_per_request_ms; @@ -3124,7 +3156,7 @@ void RequestManager::terminate_background_server() { (profiling.finish_time - profiling.start_decoding_time) / 1000.0 / request.decode_length(); } - tpot_per_request_ms += std::to_string(per_token_time_ms) + " "; + tpot_per_request_ms += format_float(per_token_time_ms, 3) + "\t"; auto &tpot = tpots[request.slo_ratio]; tpot.first++; tpot.second += per_token_time_ms; @@ -3136,12 +3168,45 @@ void RequestManager::terminate_background_server() { for (auto const &kv : tpots) { double average_tpot = kv.second.second / kv.second.first; average_tpot_per_slo_ms += - std::to_string(kv.first) + ":" + std::to_string(average_tpot) + " "; + format_float(kv.first) + ":" + format_float(average_tpot, 3) + "\t"; } average_tpot_per_slo_ms += ")"; str += average_tpot_per_slo_ms; - std::string req_per_step = "\n requests_per_step( "; + // Create a map to store all individual tpot values per SLO ratio + std::unordered_map> tpots_by_slo; + for (auto const &profiling_info : profiling_requests) { + double per_token_time_ms = 0; + auto const &request = all_requests[profiling_info.first]; + auto const &profiling = profiling_info.second; + if (profiling.start_decoding_time != 0) { + per_token_time_ms = + (profiling.finish_time - profiling.start_decoding_time) / 1000.0 / + request.decode_length(); + // Store the individual tpot value in the appropriate SLO group + tpots_by_slo[request.slo_ratio].push_back(per_token_time_ms); + } + } + + // Format and output all individual tpot values by SLO group + std::string all_tpots_by_slo = "\n all_tpots_by_slo_ms( "; + for (auto const &kv : tpots_by_slo) { + double slo_ratio = kv.first; + std::vector const &tpot_values = kv.second; + + all_tpots_by_slo += format_float(slo_ratio) + ":["; + for (size_t i = 0; i < tpot_values.size(); ++i) { + all_tpots_by_slo += format_float(tpot_values[i], 3); + if (i < tpot_values.size() - 1) { + all_tpots_by_slo += "\t"; + } + } + all_tpots_by_slo += "] "; + } + all_tpots_by_slo += ")"; + str += all_tpots_by_slo; + + std::string req_per_step = "\nrequests_per_step( "; for (int nb : profiling.requests_per_step) { req_per_step += std::to_string(nb) + " "; } @@ -3151,16 +3216,16 @@ void RequestManager::terminate_background_server() { if (profiling.ssm_step_times.size() > 0) { // assert(profiling.ssm_step_times.size() == // profiling.llm_step_times.size()); - std::string ssm_step_times_ms = "\n ssm_step_times_ms( "; + std::string ssm_step_times_ms = "\nssm_step_times_ms( "; for (double time : profiling.ssm_step_times) { - ssm_step_times_ms += std::to_string(time) + " "; + ssm_step_times_ms += format_float(time, 3) + "\t"; } ssm_step_times_ms += ")"; str += ssm_step_times_ms; } if (profiling.ssm_steps.size() > 0) { - std::string ssm_steps = "\n ssm_steps( "; + std::string ssm_steps = "\nssm_steps( "; for (int nb : profiling.ssm_steps) { ssm_steps += std::to_string(nb) + " "; } @@ -3168,22 +3233,178 @@ void RequestManager::terminate_background_server() { str += ssm_steps; } - std::string llm_step_times_ms = "\n llm_step_times_ms( "; + std::string llm_step_times_ms = "\nllm_step_times_ms( "; for (double time : profiling.llm_step_times) { - llm_step_times_ms += std::to_string(time) + " "; + llm_step_times_ms += format_float(time, 3) + " "; } llm_step_times_ms += ")"; str += llm_step_times_ms; - std::string generated_tokens_per_step = "\n generated_tokens_per_step( "; + std::string generated_tokens_per_step = "\ngenerated_tokens_per_step( "; for (int nb : profiling.generated_tokens_per_step) { generated_tokens_per_step += std::to_string(nb) + " "; } generated_tokens_per_step += ")"; str += generated_tokens_per_step; + std::string tokens_in_verification_per_step = + "\ntokens_in_verification_per_step( "; + for (int nb : profiling.tokens_in_verification_per_step) { + tokens_in_verification_per_step += std::to_string(nb) + " "; + } + tokens_in_verification_per_step += ")"; + str += tokens_in_verification_per_step; + + // Add this after the llm_step_times_ms section but before the + // generated_tokens_per_step section + + // Group llm_step_times_ms by tokens_in_verification_per_step + std::unordered_map> + llm_times_by_verification_tokens; + for (size_t i = 0; i < profiling.llm_step_times.size() && + i < profiling.tokens_in_verification_per_step.size(); + i++) { + int tokens = profiling.tokens_in_verification_per_step[i]; + llm_times_by_verification_tokens[tokens].push_back( + profiling.llm_step_times[i]); + } + + // Calculate and output average llm_step_times for each group + std::string avg_llm_time_by_verification_tokens = + "\navg_llm_time_by_verification_tokens("; + for (auto const &group : llm_times_by_verification_tokens) { + int tokens = group.first; + std::vector const × = group.second; + + // Calculate average time for this group + double avg_time = 0.0; + if (!times.empty()) { + avg_time = + std::accumulate(times.begin(), times.end(), 0.0) / times.size(); + } + + // Add to output string + avg_llm_time_by_verification_tokens += + std::to_string(tokens) + ":" + format_float(avg_time) + " "; + } + avg_llm_time_by_verification_tokens += ")"; + str += avg_llm_time_by_verification_tokens; + + // Calculate verification throughput (tokens/s) for each step + std::unordered_map> + verification_throughput_by_tokens; + for (size_t i = 0; i < profiling.llm_step_times.size() && + i < profiling.tokens_in_verification_per_step.size(); + i++) { + int tokens = profiling.tokens_in_verification_per_step[i]; + double step_time = profiling.llm_step_times[i]; + + // Avoid division by zero + if (step_time > 0) { + double throughput = tokens / step_time * 1000; // tokens per s + verification_throughput_by_tokens[tokens].push_back(throughput); + } + } + + // Calculate and output average verification throughput for each group + std::string avg_verification_throughput = + "\navg_verification_throughput_tokens_per_s("; + for (auto const &group : verification_throughput_by_tokens) { + int tokens = group.first; + std::vector const &throughputs = group.second; + + // Calculate average throughput for this group + double avg_throughput = 0.0; + if (!throughputs.empty()) { + avg_throughput = + std::accumulate(throughputs.begin(), throughputs.end(), 0.0) / + throughputs.size(); + } + + // Add to output string + avg_verification_throughput += + std::to_string(tokens) + ":" + format_float(avg_throughput) + " "; + } + avg_verification_throughput += ")"; + str += avg_verification_throughput; + + // Group generated_tokens_per_step by tokens_in_verification_per_step + std::unordered_map> + generated_tokens_by_verification_tokens; + for (size_t i = 0; i < profiling.generated_tokens_per_step.size() && + i < profiling.tokens_in_verification_per_step.size(); + i++) { + int verification_tokens = profiling.tokens_in_verification_per_step[i]; + int generated_tokens = profiling.generated_tokens_per_step[i]; + generated_tokens_by_verification_tokens[verification_tokens].push_back( + generated_tokens); + } + + // Calculate and output average generated tokens for each verification token + // group + std::string avg_generated_tokens_by_verification = + "\navg_generated_tokens_by_verification( "; + for (auto const &group : generated_tokens_by_verification_tokens) { + int verification_tokens = group.first; + std::vector const &generated_tokens = group.second; + + // Calculate average generated tokens for this group + double avg_tokens = 0.0; + if (!generated_tokens.empty()) { + avg_tokens = std::accumulate(generated_tokens.begin(), + generated_tokens.end(), + 0.0) / + generated_tokens.size(); + } + + // Add to output string + avg_generated_tokens_by_verification += + std::to_string(verification_tokens) + ":" + format_float(avg_tokens) + + " "; + } + avg_generated_tokens_by_verification += ")"; + str += avg_generated_tokens_by_verification; + + // Calculate generation throughput (average generated tokens / average llm + // step time) for each verification token group + std::string generation_throughput_by_verification = + "\ngeneration_throughput_tokens_per_s("; + for (auto const &group : generated_tokens_by_verification_tokens) { + int verification_tokens = group.first; + std::vector const &generated_tokens = group.second; + + // Get corresponding LLM step times for this verification token count + auto const &llm_times_it = + llm_times_by_verification_tokens.find(verification_tokens); + if (llm_times_it != llm_times_by_verification_tokens.end() && + !llm_times_it->second.empty() && !generated_tokens.empty()) { + // Calculate average generated tokens + double avg_tokens = std::accumulate(generated_tokens.begin(), + generated_tokens.end(), + 0.0) / + generated_tokens.size(); + + // Calculate average LLM step time + std::vector const &llm_times = llm_times_it->second; + double avg_llm_time = + std::accumulate(llm_times.begin(), llm_times.end(), 0.0) / + llm_times.size(); + + // Calculate generation throughput (avoid division by zero) + double throughput = + (avg_llm_time > 0) ? (avg_tokens / avg_llm_time) * 1000 : 0.0; + + // Add to output string + generation_throughput_by_verification += + std::to_string(verification_tokens) + ":" + + format_float(throughput) + " "; + } + } + generation_throughput_by_verification += ")"; + str += generation_throughput_by_verification; + std::string mean_generated_tokens_per_step = - "\n mean_generated_tokens_per_step( "; + "\nmean_generated_tokens_per_step("; double mean_generated_tokens = (double)std::accumulate(profiling.generated_tokens_per_step.begin(), profiling.generated_tokens_per_step.end(), @@ -3193,10 +3414,62 @@ void RequestManager::terminate_background_server() { profiling.requests_per_step.end(), 0); mean_generated_tokens /= total_request_steps; - mean_generated_tokens_per_step += std::to_string(mean_generated_tokens); + mean_generated_tokens_per_step += format_float(mean_generated_tokens); mean_generated_tokens_per_step += ")"; str += mean_generated_tokens_per_step; + str += "\nPrefilling steps: "; + str += std::to_string(profiling.prefilling_steps); + str += "\nVerifying steps: "; + str += std::to_string(profiling.llm_step_times.size()); + + // Compute SLO attainment by SLO scale + std::unordered_map> + attainment_by_slo; // pair + + // Count attained vs total requests for each SLO ratio + for (auto const &request_pair : all_requests) { + int request_id = request_pair.first; + Request const &request = request_pair.second; + + // Skip requests that aren't completed + if (request.status == Request::COMPLETED) { + // Initialize the entry if needed + if (attainment_by_slo.find(request.slo_ratio) == + attainment_by_slo.end()) { + attainment_by_slo[request.slo_ratio] = std::make_pair(0, 0); + } + + // Increment total count for this SLO ratio + attainment_by_slo[request.slo_ratio].second++; + + // Increment attained count if SLO was met + if (request.attained) { + attainment_by_slo[request.slo_ratio].first++; + } + } + } + + // Format and output attainment percentages by SLO ratio + std::string slo_attainment_by_scale = "\nslo_attainment_by_scale("; + for (auto const &kv : attainment_by_slo) { + double slo_ratio = kv.first; + int attained_count = kv.second.first; + int total_count = kv.second.second; + + // Calculate attainment percentage + double attainment_pct = + (total_count > 0) ? (double)attained_count / total_count : 0.0; + + // Add to output string with format: slo_ratio:attainment(attained/total) + slo_attainment_by_scale += format_float(slo_ratio) + " : " + + format_float(attainment_pct * 100) + "% (" + + std::to_string(attained_count) + "/" + + std::to_string(total_count) + ") "; + } + slo_attainment_by_scale += ")"; + str += slo_attainment_by_scale; + double attainment = 0, goodput = 0; for (auto request_pair : all_requests) { Request &request = request_pair.second; @@ -3208,20 +3481,20 @@ void RequestManager::terminate_background_server() { attainment /= total_requests; goodput /= total_time / 1e6; - std::string slo_attainment = "\n slo_attainment( "; - slo_attainment += std::to_string(attainment); - slo_attainment += ")"; + std::string slo_attainment = "\nslo_attainment("; + slo_attainment += format_float(attainment * 100); + slo_attainment += "%)"; str += slo_attainment; - std::string goodput_str = "\n goodput( "; - goodput_str += std::to_string(goodput); + std::string goodput_str = "\ngoodput("; + goodput_str += format_float(goodput); goodput_str += ")"; str += goodput_str; if (get_eval_overhead_breakdown()) { eval_process_latency_us -= eval_schedule_latency_us + eval_other_latency_us; - std::string eval_overhead_breakdown_str = "\n eval_overhead_breakdown( "; + std::string eval_overhead_breakdown_str = "\neval_overhead_breakdown( "; eval_overhead_breakdown_str += "\n ssm_prefill_us: " + std::to_string(eval_ssm_prefill_latency_us); eval_overhead_breakdown_str += @@ -3301,10 +3574,8 @@ void RequestManager::add_tokens_to_spec_token_tree( // TODO: parameterize MAX_SPECULATIVE_TREE_BRANCHES // TODO: support gumbel sampling - int tree_width = - min(get_max_tokens_per_ssm_batch() / get_num_active_requests(), - get_max_tree_width()); - assert(tree_width >= 1); + int remaining_budget = get_max_tokens_per_ssm_batch(); + int remaining_requests = get_num_active_requests(); for (int request_index = 0; request_index < get_max_requests_per_batch(); ++request_index) { @@ -3344,7 +3615,7 @@ void RequestManager::add_tokens_to_spec_token_tree( result_idx++) { double log_prob = log((double)ssm_inference_result.probs[result_idx]); if (log_prob == -std::numeric_limits::infinity()) { - continue; + log_prob = -1e10; } if (log_prob == 0.0) { // Slightly perturb the log prob to make it strictly less than 0 @@ -3359,7 +3630,13 @@ void RequestManager::add_tokens_to_spec_token_tree( } spec_token_tree.add_layer(); - int actual_width = min(tree_width, (int)child_probs_v.size()); + + int tree_width_limit = + min(remaining_budget / remaining_requests, get_max_tree_width()); + int actual_width = min(tree_width_limit, (int)child_probs_v.size()); + remaining_budget -= actual_width; + remaining_requests--; + if (actual_width == 0) { continue; } @@ -3380,6 +3657,9 @@ void RequestManager::add_tokens_to_spec_token_tree( std::make_pair(node_ptr, accumulated_log_prob)); } } + std::cout << "[Debug] RequestManager::add_tokens_to_spec_token_tree() added " + << get_max_tokens_per_ssm_batch() - remaining_budget << " tokens " + << std::endl; } void RequestManager::add_tokens_to_spec_token_tree_old_version( @@ -3476,6 +3756,8 @@ void RequestManager::prune_token_tree() { std::vector> num_tokens_to_decode_2_request_index; num_tokens_to_decode_2_request_index.reserve(get_max_requests_per_batch()); + double ssm_spec_latency_estimated = + ssm_spec_latency_ms / get_max_tree_depth() * ssm_tree_depth; for (int request_index = 0; request_index < get_max_requests_per_batch(); ++request_index) { if (!request_available[request_index]) { @@ -3487,9 +3769,12 @@ void RequestManager::prune_token_tree() { if (request.get_slo_ratio() > 999) { // infinity continue; } + // double num_tokens_to_decode_per_step = + // (ssm_spec_latency_ms + llm_verify_latency_ms) * correction_factor / + // get_slo_constraint(request); double num_tokens_to_decode_per_step = - (ssm_spec_latency_ms + llm_verify_latency_ms) * correction_factor / - get_slo_constraint(request); + (ssm_spec_latency_estimated + llm_verify_latency_ms) * + correction_factor / get_slo_constraint(request); double expected_num_tokens_decoded = request.decode_latency_ms / get_slo_constraint(request); double num_tokens_to_decode = @@ -3502,20 +3787,27 @@ void RequestManager::prune_token_tree() { std::make_pair(num_tokens_to_decode, request_index)); } - // Sort the requests by spare latency in ascending order + // Sort the requests by number of tokens to decode in descending order std::sort(num_tokens_to_decode_2_request_index.begin(), num_tokens_to_decode_2_request_index.end(), - std::less>()); + std::greater>()); - for (auto const &spare_latency_request_index_pair : + // Debug + std::cout + << "[Debug] RequestManager::prune_token_tree() num of active requests: " + << num_tokens_to_decode_2_request_index.size() << std::endl; + + for (auto const &num_tokens_to_decode_request_index_pair : num_tokens_to_decode_2_request_index) { - int request_index = spare_latency_request_index_pair.second; + int request_index = num_tokens_to_decode_request_index_pair.second; RequestGuid guid = guid_of_requests[request_index]; - if (all_requests[guid].get_slo_ratio() < 0) { - continue; - } - add_tokens_toward_slo( - guid, budget, num_tokens_to_decode_2_request_index.size()); + // if (all_requests[guid].get_slo_ratio() < 0) { + // continue; + // } + add_tokens_toward_slo(guid, + budget, + num_tokens_to_decode_request_index_pair.first, + num_tokens_to_decode_2_request_index.size()); } assert(budget >= 0); @@ -3526,6 +3818,24 @@ void RequestManager::prune_token_tree() { add_tokens_toward_goodput(budget); } } + // Clear the priority queue in each requests + for (int request_index = 0; request_index < get_max_requests_per_batch(); + ++request_index) { + if (!request_available[request_index]) { + continue; + } + RequestGuid guid = guid_of_requests[request_index]; + Request &request = all_requests[guid]; + assert(request.status == Request::RUNNING); + std::vector, double>> + _prealloc_vector; + _prealloc_vector.reserve(BatchConfig::MAX_SPEC_TREE_TOKEN_NUM); + request.token_tree_nodes_acc_prob_pair_pq = std::priority_queue< + std::pair, double>, + std::vector, double>>, + SharedTokenTreeNodePtrDoubleLess>(SharedTokenTreeNodePtrDoubleLess(), + std::move(_prealloc_vector)); + } } void RequestManager::prune_token_tree_equal() { @@ -3547,6 +3857,14 @@ void RequestManager::prune_token_tree_equal() { if (budget > 0) { add_tokens_toward_goodput_per_request(budget, request_index); } + std::vector, double>> + _prealloc_vector; + _prealloc_vector.reserve(BatchConfig::MAX_SPEC_TREE_TOKEN_NUM); + request.token_tree_nodes_acc_prob_pair_pq = std::priority_queue< + std::pair, double>, + std::vector, double>>, + SharedTokenTreeNodePtrDoubleLess>(SharedTokenTreeNodePtrDoubleLess(), + std::move(_prealloc_vector)); } } @@ -3568,29 +3886,53 @@ void RequestManager::prune_token_tree_greedy() { if (budget > 0) { add_tokens_toward_goodput(budget); } + for (int request_index = 0; request_index < get_max_requests_per_batch(); + ++request_index) { + if (!request_available[request_index]) { + continue; + } + RequestGuid guid = guid_of_requests[request_index]; + Request &request = all_requests[guid]; + assert(request.status == Request::RUNNING); + std::vector, double>> + _prealloc_vector; + _prealloc_vector.reserve(BatchConfig::MAX_SPEC_TREE_TOKEN_NUM); + request.token_tree_nodes_acc_prob_pair_pq = std::priority_queue< + std::pair, double>, + std::vector, double>>, + SharedTokenTreeNodePtrDoubleLess>(SharedTokenTreeNodePtrDoubleLess(), + std::move(_prealloc_vector)); + } } void RequestManager::add_tokens_toward_slo(RequestGuid guid, int &budget, + double num_tokens_to_decode, int num_req_with_slo) { Request &request = all_requests[guid]; - double num_tokens_to_decode_per_step = - (ssm_spec_latency_ms + llm_verify_latency_ms) * correction_factor / - get_slo_constraint(request); - double expected_num_tokens_decoded = - request.decode_latency_ms / get_slo_constraint(request); - - double num_tokens_to_decode = - max(1.0, - num_tokens_to_decode_per_step + expected_num_tokens_decoded - - request.decode_length()); - num_tokens_to_decode = min(num_tokens_to_decode, (double)ssm_tree_depth + 1); + // double num_tokens_to_decode_per_step = + // (ssm_spec_latency_ms + llm_verify_latency_ms) * correction_factor / + // get_slo_constraint(request); + // double expected_num_tokens_decoded = + // request.decode_latency_ms / get_slo_constraint(request); + + // double num_tokens_to_decode = + // max(1.0, + // num_tokens_to_decode_per_step + expected_num_tokens_decoded - + // request.decode_length()); + // num_tokens_to_decode = min(num_tokens_to_decode, (double)ssm_tree_depth + + // 1); // The root is already included // In function add_root_to_spec_token_tree double current_added = 1.0; - // The max token that can be added to the token tree when fulfilling the SLO + std::cout << "[Debug] Request " << request.guid + << " SLO constraint: " << get_slo_constraint(request) << " " + << "num_tokens_to_decode: " << num_tokens_to_decode << std::endl; + + // The max token that can be added to the token tree when fulfilling the + // SLO int max_token_toward_slo = int(get_max_tokens_per_batch() * 1.2 / num_available_requests); @@ -3602,8 +3944,14 @@ void RequestManager::add_tokens_toward_slo(RequestGuid guid, auto [node_ptr, log_acc_prob] = request.token_tree_nodes_acc_prob_pair_pq.top(); request.token_tree_nodes_acc_prob_pair_pq.pop(); + double prob = exp(log_acc_prob); + if (prob < 8e-2) { + break; + } node_ptr->included = true; - current_added += exp(log_acc_prob); + current_added += prob; + std::cout << "[Debug] added token with prob: " << prob + << " current_added: " << current_added << std::endl; budget--; max_token_toward_slo--; } @@ -3658,25 +4006,6 @@ void RequestManager::add_tokens_toward_memory_occupancy(int budget) { } budget--; } - - // Clear the priority queue in each requests - for (int request_index = 0; request_index < get_max_requests_per_batch(); - ++request_index) { - if (!request_available[request_index]) { - continue; - } - RequestGuid guid = guid_of_requests[request_index]; - Request &request = all_requests[guid]; - assert(request.status == Request::RUNNING); - std::vector, double>> - _prealloc_vector; - _prealloc_vector.reserve(BatchConfig::MAX_SPEC_TREE_TOKEN_NUM); - request.token_tree_nodes_acc_prob_pair_pq = std::priority_queue< - std::pair, double>, - std::vector, double>>, - SharedTokenTreeNodePtrDoubleLess>(SharedTokenTreeNodePtrDoubleLess(), - std::move(_prealloc_vector)); - } } void RequestManager::add_tokens_toward_goodput(int budget) { @@ -3719,6 +4048,15 @@ void RequestManager::add_tokens_toward_goodput(int budget) { while (budget > 0 and !global_token_tree_node_pq.empty()) { auto [node_ptr, acc_log_prob, guid] = global_token_tree_node_pq.top(); global_token_tree_node_pq.pop(); + double prob = exp(acc_log_prob); + if (prob < 2e-2 && budget % 32 == 0) { + std::cout << "[Debug] prob: " << prob << " is too small, break" + << std::endl; + break; + } + // Debug + std::cout << "[Debug] added token with prob: " << prob << " toward goodput " + << std::endl; node_ptr->included = true; if (!get_request_with_guid(guid) .token_tree_nodes_acc_prob_pair_pq.empty()) { @@ -3734,25 +4072,6 @@ void RequestManager::add_tokens_toward_goodput(int budget) { } budget--; } - - // Clear the priority queue in each requests - for (int request_index = 0; request_index < get_max_requests_per_batch(); - ++request_index) { - if (!request_available[request_index]) { - continue; - } - RequestGuid guid = guid_of_requests[request_index]; - Request &request = all_requests[guid]; - assert(request.status == Request::RUNNING); - std::vector, double>> - _prealloc_vector; - _prealloc_vector.reserve(BatchConfig::MAX_SPEC_TREE_TOKEN_NUM); - request.token_tree_nodes_acc_prob_pair_pq = std::priority_queue< - std::pair, double>, - std::vector, double>>, - SharedTokenTreeNodePtrDoubleLess>(SharedTokenTreeNodePtrDoubleLess(), - std::move(_prealloc_vector)); - } } void RequestManager::add_tokens_toward_goodput_per_request(int budget, @@ -3773,16 +4092,6 @@ void RequestManager::add_tokens_toward_goodput_per_request(int budget, node_ptr->included = true; budget--; } - - // Clear the priority queue in the request - std::vector, double>> - _prealloc_vector; - _prealloc_vector.reserve(BatchConfig::MAX_SPEC_TREE_TOKEN_NUM); - request.token_tree_nodes_acc_prob_pair_pq = std::priority_queue< - std::pair, double>, - std::vector, double>>, - SharedTokenTreeNodePtrDoubleLess>(SharedTokenTreeNodePtrDoubleLess(), - std::move(_prealloc_vector)); } std::ostream &operator<<(std::ostream &os, TokenTree const &token_tree) { diff --git a/src/runtime/request_manager.cu b/src/runtime/request_manager.cu index 56437bd0c..09d1f2da0 100644 --- a/src/runtime/request_manager.cu +++ b/src/runtime/request_manager.cu @@ -389,8 +389,8 @@ void RequestManager::load_batch_config_task( } // prepare attention forward handler - { - int batch_size = batch_config->num_active_requests(); + int batch_size = batch_config->num_active_requests(); + if (batch_size != 0) { static int32_t q_indptr_h[BatchConfig::MAX_NUM_REQUESTS + 1], kv_indptr_h[BatchConfig::MAX_NUM_REQUESTS + 1], kv_last_page_len_h[BatchConfig::MAX_NUM_REQUESTS]; @@ -538,8 +538,8 @@ void RequestManager::load_batch_config_task( } // prepare attention forward handler - { - int batch_size = batch_config->num_active_requests(); + int batch_size = batch_config->num_active_requests(); + if (batch_size != 0) { BatchPrefillHandler *handler = nullptr; if (!batch_config->prompt_phase) { @@ -660,8 +660,8 @@ void RequestManager::load_batch_config_task( } // prepare attention forward handler - { - int batch_size = batch_config->num_active_requests(); + int batch_size = batch_config->num_active_requests(); + if (batch_size != 0) { BatchPrefillHandler *handler = nullptr; if (!batch_config->prompt_phase) { diff --git a/src/runtime/substitution.cc b/src/runtime/substitution.cc index 176133c49..aede906f2 100644 --- a/src/runtime/substitution.cc +++ b/src/runtime/substitution.cc @@ -57,7 +57,7 @@ using namespace Legion; Legion::Logger log_xfers("xfers"); Legion::Logger log_xfer_matches("xfer_matches"); -const TensorX TensorX::NO_TX = TensorX(); +TensorX const TensorX::NO_TX = TensorX(); bool TensorX::operator==(TensorX const &other) const { return this->op == other.op && this->idx == other.idx; @@ -155,7 +155,7 @@ tl::optional TensorX::to_tensor(GraphXfer const *xfer) const { } } -OpX::OpX(const OperatorType _type, +OpX::OpX(OperatorType const _type, int num_inputs, int num_outputs, TensorX const &input0, @@ -177,7 +177,7 @@ OpX::OpX(const OperatorType _type, } } -OpX::OpX(const OperatorType _type, +OpX::OpX(OperatorType const _type, int num_inputs, int num_outputs, TensorX const *input_array) @@ -1301,16 +1301,10 @@ void Graph::export_strategy_computation_graph( weight_mem /= (*node.ptr->weights)->get_total_num_parts(); } - runtime_code << "fwd" - << "bwd" - << "sync" - << "secs"; + runtime_code << "fwd" << "bwd" << "sync" << "secs"; runtime_cost_row << op_cost.forward_time << op_cost.backward_time << op_cost.sync_time; - memory_code << "in" - << "out" - << "weight" - << "bytes"; + memory_code << "in" << "out" << "weight" << "bytes"; memory_cost_row << input_mem << output_mem << weight_mem; rf << runtime_code << runtime_cost_row << memory_code << memory_cost_row; @@ -1402,8 +1396,7 @@ void create_mapping_xfers( op_type_name.begin(), [](unsigned char c) { return std::tolower(c); }); oss << "mapping::" << pre_name << "_" << op_type_name << "_" << post_name - << "[" - << "input_dim=" << input_dim << ",degree=" << degree << "]"; + << "[" << "input_dim=" << input_dim << ",degree=" << degree << "]"; subst->name = oss.str(); xfers.push_back(subst); @@ -3082,9 +3075,9 @@ GraphXfer *create_partition_linear_combine(FFModel *model, subst->dstOps.push_back(combine); std::ostringstream oss; - oss << "partition_linear_combine[" - << "num_dims=" << num_dims << ",num_parts=" << num_parts - << ",activation=" << activation << ",use_bias=" << use_bias << "]"; + oss << "partition_linear_combine[" << "num_dims=" << num_dims + << ",num_parts=" << num_parts << ",activation=" << activation + << ",use_bias=" << use_bias << "]"; subst->name = oss.str(); return subst; @@ -3109,8 +3102,8 @@ GraphXfer *create_partition_conv2d_combine(FFModel *model, subst->dstOps.push_back(combine); std::ostringstream oss; - oss << "partition_conv2d_combine[" - << "num_dims=" << num_dims << ",num_parts=" << num_parts << "]"; + oss << "partition_conv2d_combine[" << "num_dims=" << num_dims + << ",num_parts=" << num_parts << "]"; subst->name = oss.str(); return subst; @@ -3207,8 +3200,8 @@ GraphXfer *create_partition_attention_combine(FFModel *model, subst->dstOps.push_back(combine); std::ostringstream oss; - oss << "partition_attention_combine[" - << "num_heads=" << num_heads << ",num_parts=" << num_parts << "]"; + oss << "partition_attention_combine[" << "num_heads=" << num_heads + << ",num_parts=" << num_parts << "]"; subst->name = oss.str(); return subst; @@ -3236,8 +3229,8 @@ GraphXfer *create_replicate_attention_reduce(FFModel *model, subst->dstOps.push_back(reduce); std::ostringstream oss; - oss << "replicate_attention_reduce[" - << "num_heads=" << num_heads << ",num_parts=" << num_parts << "]"; + oss << "replicate_attention_reduce[" << "num_heads=" << num_heads + << ",num_parts=" << num_parts << "]"; subst->name = oss.str(); return subst; @@ -3266,9 +3259,9 @@ GraphXfer *create_replicate_linear_combine(FFModel *model, subst->dstOps.push_back(combine); std::ostringstream oss; - oss << "replicate_linear_combine[" - << "num_dims=" << num_dims << ",num_parts=" << num_parts - << ",activation=" << activation << ",use_bias=" << use_bias << "]"; + oss << "replicate_linear_combine[" << "num_dims=" << num_dims + << ",num_parts=" << num_parts << ",activation=" << activation + << ",use_bias=" << use_bias << "]"; subst->name = oss.str(); return subst; @@ -3297,8 +3290,8 @@ GraphXfer *create_partition_add_combine(FFModel *model, subst->dstOps.push_back(combine); std::ostringstream oss; - oss << "partition_add_combine[" - << "parallel_dim=" << parallel_dim << ",num_parts=" << num_parts << "]"; + oss << "partition_add_combine[" << "parallel_dim=" << parallel_dim + << ",num_parts=" << num_parts << "]"; subst->name = oss.str(); return subst; @@ -3326,8 +3319,8 @@ GraphXfer *create_combine_add_partition(FFModel *model, subst->dstOps.push_back(repartition); std::ostringstream oss; - oss << "combine_add_partition[" - << "parallel_dim=" << parallel_dim << ",num_parts=" << num_parts << "]"; + oss << "combine_add_partition[" << "parallel_dim=" << parallel_dim + << ",num_parts=" << num_parts << "]"; subst->name = oss.str(); return subst; @@ -3354,8 +3347,8 @@ GraphXfer *create_partition_relu_combine(FFModel *model, subst->dstOps.push_back(combine); std::ostringstream oss; - oss << "partition_relu_combine[" - << "parallel_dim=" << parallel_dim << ",num_parts=" << num_parts << "]"; + oss << "partition_relu_combine[" << "parallel_dim=" << parallel_dim + << ",num_parts=" << num_parts << "]"; subst->name = oss.str(); return subst; @@ -3382,8 +3375,8 @@ GraphXfer *create_combine_relu_partition(FFModel *model, subst->dstOps.push_back(partition); std::ostringstream oss; - oss << "combine_relu_partition[" - << "parallel_dim=" << parallel_dim << ",num_parts=" << num_parts << "]"; + oss << "combine_relu_partition[" << "parallel_dim=" << parallel_dim + << ",num_parts=" << num_parts << "]"; subst->name = oss.str(); return subst; @@ -3419,9 +3412,9 @@ GraphXfer *create_partition_concat_combine(FFModel *model, subst->map_output(concat->outputs[0], combine->outputs[0]); std::ostringstream oss; - oss << "partition_concat_combine[" - << "num_inputs=" << num_inputs << ",concat_dim=" << concat_dim - << ",parallel_dim=" << parallel_dim << ",num_parts=" << num_parts << "]"; + oss << "partition_concat_combine[" << "num_inputs=" << num_inputs + << ",concat_dim=" << concat_dim << ",parallel_dim=" << parallel_dim + << ",num_parts=" << num_parts << "]"; subst->name = oss.str(); return subst; @@ -3446,9 +3439,8 @@ GraphXfer *create_partition_softmax_combine(FFModel *model, subst->dstOps.push_back(combine); std::ostringstream oss; - oss << "partition_softmax_combine[" - << "softmax_dim=" << softmax_dim << ",parallel_dim=" << parallel_dim - << ",num_parts=" << num_parts << "]"; + oss << "partition_softmax_combine[" << "softmax_dim=" << softmax_dim + << ",parallel_dim=" << parallel_dim << ",num_parts=" << num_parts << "]"; subst->name = oss.str(); return subst; @@ -3473,9 +3465,8 @@ GraphXfer *create_combine_softmax_partition(FFModel *model, subst->dstOps.push_back(repartition); std::ostringstream oss; - oss << "combine_softmax_partition[" - << "softmax_dim=" << softmax_dim << ",parallel_dim=" << parallel_dim - << ",num_parts=" << num_parts << "]"; + oss << "combine_softmax_partition[" << "softmax_dim=" << softmax_dim + << ",parallel_dim=" << parallel_dim << ",num_parts=" << num_parts << "]"; subst->name = oss.str(); return subst; @@ -3514,9 +3505,8 @@ GraphXfer *leading_relu_branch_combine(FFModel *model, subst->dstOps.insert(subst->dstOps.end(), new_noops.begin(), new_noops.end()); std::ostringstream oss; - oss << "leading_relu_branch_combine[" - << "parallel_dim=" << parallel_dim << ",num_parts=" << num_parts - << ",num_combines=" << num_combines << "]"; + oss << "leading_relu_branch_combine[" << "parallel_dim=" << parallel_dim + << ",num_parts=" << num_parts << ",num_combines=" << num_combines << "]"; subst->name = oss.str(); return subst; @@ -3553,9 +3543,9 @@ GraphXfer *leading_relu_branch_partition(FFModel *model, subst->dstOps.insert(subst->dstOps.end(), new_noops.begin(), new_noops.end()); std::ostringstream oss; - oss << "leading_relu_branch_partition[" - << "parallel_dim=" << parallel_dim << ",num_parts=" << num_parts - << ",num_partitions=" << num_partitions << "]"; + oss << "leading_relu_branch_partition[" << "parallel_dim=" << parallel_dim + << ",num_parts=" << num_parts << ",num_partitions=" << num_partitions + << "]"; subst->name = oss.str(); return subst; @@ -3578,8 +3568,8 @@ GraphXfer * subst->dstOps.push_back(new_linear); std::ostringstream oss; - oss << "linear_relu_merge[" - << "num_dims=" << num_dims << ",use_bias=" << use_bias << "]"; + oss << "linear_relu_merge[" << "num_dims=" << num_dims + << ",use_bias=" << use_bias << "]"; subst->name = oss.str(); return subst; @@ -3824,12 +3814,11 @@ bool FFModel::convert_graph_to_operators( break; } case OP_SIGMOID_SILU_MULTI: { - assert(inList.size() == 2); + assert(inList.size() == 1); SigmoidSiluMulti *ssm = (SigmoidSiluMulti *)node.ptr; new_op = new SigmoidSiluMulti(*this, ssm->layer_guid, inputs[0], - inputs[1], ssm->intermediate_size, ssm->tensor_parallelism_degree, NULL);