diff --git a/src/models/kv_cache.cpp b/src/models/kv_cache.cpp index 581d12caa..3894f72ec 100644 --- a/src/models/kv_cache.cpp +++ b/src/models/kv_cache.cpp @@ -175,8 +175,14 @@ DefaultKeyValueCache::DefaultKeyValueCache(State& state) } // Set the size after empty_past_ has been created with 0 for this field - if (past_present_share_buffer_) + if (state.model_.p_device_->GetType() == DeviceType::NvTensorRtRtx && + model_.config_->model.decoder.sliding_window.has_value() && + model_.config_->model.decoder.sliding_window->window_size > 0) { + shape_[2] = std::min(state_.params_->search.max_length, + model_.config_->model.decoder.sliding_window->window_size); + } else if (past_present_share_buffer_) { shape_[2] = state_.params_->search.max_length; + } try { for (int i = 0; i < layer_count_ * 2; ++i) { @@ -422,7 +428,8 @@ std::unique_ptr CreateKeyValueCache(State& state) { return nullptr; } - if (state.model_.config_->model.decoder.sliding_window && + if (state.model_.p_device_->GetType() != DeviceType::NvTensorRtRtx && + state.model_.config_->model.decoder.sliding_window && state.model_.config_->model.decoder.sliding_window->slide_key_value_cache) { return std::make_unique(state); } @@ -430,4 +437,4 @@ std::unique_ptr CreateKeyValueCache(State& state) { return std::make_unique(state); } -} // namespace Generators +} // namespace Generators \ No newline at end of file diff --git a/src/python/py/models/builder.py b/src/python/py/models/builder.py index 1303269e5..79303f9d2 100644 --- a/src/python/py/models/builder.py +++ b/src/python/py/models/builder.py @@ -321,6 +321,11 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): self.quant_attrs["config"] = config.quantization_config self.quant_attrs["use_g_idx"] = config.quantization_config["desc_act"] if "desc_act" in config.quantization_config else False + # KV cache quantization attributes (add at the end) + self.kv_cache_attrs = { + "quantize_kv_cache": extra_options.get("quantize_kv_cache", False), + "kv_scale_factor": 127.0, # Scale factor for INT8 quantization + } def make_outputs_init(self): # Always use float32 logits to improve accuracy in the case of bf16 models. @@ -334,17 +339,20 @@ def make_outputs_init(self): self.output_names = [name.replace("logits", "hidden_states") for name in self.output_names] elif self.include_hidden_states: self.output_names = ["hidden_states"] + self.output_names + + # Update output types for quantized KV cache + if hasattr(self, 'kv_cache_attrs') and self.kv_cache_attrs["quantize_kv_cache"]: + self.output_types["present.key"] = TensorProto.INT8 + self.output_types["present.value"] = TensorProto.INT8 def make_attention_init(self): valid_gqa_configurations = [ ("cpu", TensorProto.FLOAT), - ("cuda", TensorProto.FLOAT16), ("cuda", TensorProto.BFLOAT16), ("rocm", TensorProto.FLOAT16), ("dml", TensorProto.FLOAT16), ("webgpu", TensorProto.FLOAT16), ("webgpu", TensorProto.FLOAT), - ("NvTensorRtRtx", TensorProto.FLOAT16), ] if (self.ep, self.io_dtype) in valid_gqa_configurations: # Change model settings for GroupQueryAttention @@ -644,17 +652,21 @@ def make_inputs_and_outputs(self): # Add KV cache to inputs and outputs for i in range(self.num_layers): + # Determine input/output types based on quantization setting + kv_input_dtype = TensorProto.INT8 if (hasattr(self, 'kv_cache_attrs') and self.kv_cache_attrs["quantize_kv_cache"]) else self.input_types["past_key_values.key"] + kv_output_dtype = TensorProto.INT8 if (hasattr(self, 'kv_cache_attrs') and self.kv_cache_attrs["quantize_kv_cache"]) else self.output_types["present.key"] + # Add KV cache to inputs key_name = f"past_key_values.{i}.key" - inputs.append(helper.make_tensor_value_info(key_name, self.input_types["past_key_values.key"], shape=self.input_shapes["past_key_values.key"])) + inputs.append(helper.make_tensor_value_info(key_name, kv_input_dtype, shape=self.input_shapes["past_key_values.key"])) value_name = f"past_key_values.{i}.value" - inputs.append(helper.make_tensor_value_info(value_name, self.input_types["past_key_values.value"], shape=self.input_shapes["past_key_values.value"])) + inputs.append(helper.make_tensor_value_info(value_name, kv_input_dtype, shape=self.input_shapes["past_key_values.value"])) # Add KV cache to outputs key_name = f"present.{i}.key" - outputs.append(helper.make_tensor_value_info(key_name, self.output_types["present.key"], shape=self.output_shapes["present.key"])) + outputs.append(helper.make_tensor_value_info(key_name, kv_output_dtype, shape=self.output_shapes["present.key"])) value_name = f"present.{i}.value" - outputs.append(helper.make_tensor_value_info(value_name, self.output_types["present.value"], shape=self.output_shapes["present.value"])) + outputs.append(helper.make_tensor_value_info(value_name, kv_output_dtype, shape=self.output_shapes["present.value"])) self.inputs = inputs self.outputs = outputs @@ -1545,7 +1557,6 @@ def _make_simplified_layer_norm(self, basename, root_input, weight_name, output_ self.make_node("Mul", inputs=make_mul_1_inputs, outputs=[output_0], name=make_mul_1_name) self.make_value_info(output_0, dtype=io_dtype, shape=shape) - def make_qk_norm(self, layer_id, attention): # Make subgraph to compute SimplifiedLayerNorm after Q and K MatMuls in attention: # @@ -1692,7 +1703,9 @@ def make_repeat_kv(self, layer_id, root_input, past_kv, present_kv, **kwargs): # Transpose # | # Reshape - basename = f"/model/layers.{layer_id}/attn/{'k_proj' if past_kv.endswith('key') else 'v_proj'}/repeat_kv" + # Determine if this is for key or value based on past_kv (original case) or present_kv (quantized case) + is_key = past_kv.endswith('key') or (present_kv and 'temp_present_k' in present_kv) + basename = f"/model/layers.{layer_id}/attn/{'k_proj' if is_key else 'v_proj'}/repeat_kv" # Make the initial subgraph # @@ -1712,9 +1725,16 @@ def make_repeat_kv(self, layer_id, root_input, past_kv, present_kv, **kwargs): concat_1_name = f"{basename}/Concat_1" concat_1_inputs = [past_kv, f"{transpose_1_name}/output_0"] self.make_node("Concat", inputs=concat_1_inputs, outputs=[present_kv], name=concat_1_name, axis=2) + self.make_value_info(present_kv, self.io_dtype, shape=['batch_size', self.num_kv_heads, 'total_sequence_length', self.head_size]) + # Use Identity node to ensure proper shape flow for quantized KV cache scenarios + # This prevents shape inference issues when present_kv is a temporary tensor name + identity_present_name = f"{basename}/Identity_present" + self.make_node("Identity", inputs=[present_kv], outputs=[f"{identity_present_name}/output_0"], name=identity_present_name) + self.make_value_info(f"{identity_present_name}/output_0", self.io_dtype, shape=['batch_size', self.num_kv_heads, 'total_sequence_length', self.head_size]) + shape_1_name = f"{basename}/Shape_1" - self.make_shape(shape_1_name, present_kv, shape=[4]) + self.make_shape(shape_1_name, f"{identity_present_name}/output_0", shape=[4]) gather_1_name = f"{basename}/Gather_1" gather_1_inputs = [f"{shape_1_name}/output_0", "/model/constants/TensorProto.INT64/0D/0"] self.make_gather(gather_1_name, gather_1_inputs, axis=0) @@ -1781,7 +1801,7 @@ def make_repeat_kv(self, layer_id, root_input, past_kv, present_kv, **kwargs): # \ \ # Unsqueeze --> Expand --> Reshape --> Transpose --> Reshape unsqueeze_5_name = f"{basename}/Unsqueeze_5" - unsqueeze_5_inputs = [present_kv, "/model/constants/TensorProto.INT64/1D/2"] + unsqueeze_5_inputs = [f"{identity_present_name}/output_0", "/model/constants/TensorProto.INT64/1D/2"] self.make_unsqueeze(unsqueeze_5_name, unsqueeze_5_inputs, dtype=self.io_dtype, shape=['batch_size', self.num_kv_heads, 1, 'sequence_length', self.head_size]) expand_name = f"{basename}/Expand" expand_inputs = [f"{unsqueeze_5_name}/output_0", f"{where_name}/output_0"] @@ -1802,6 +1822,111 @@ def make_repeat_kv(self, layer_id, root_input, past_kv, present_kv, **kwargs): def make_attention_op(self, name, **kwargs): op_type = self.attention_attrs["op_type"] + # Handle KV cache quantization if enabled + original_past_k = kwargs.get("past_k", "") + original_past_v = kwargs.get("past_v", "") + original_present_k = kwargs.get("present_k", "") + original_present_v = kwargs.get("present_v", "") + + # Check if we need to handle quantized KV cache with MHA + GQA configuration + has_quantized_kv = hasattr(self, 'kv_cache_attrs') and self.kv_cache_attrs["quantize_kv_cache"] + needs_kv_repeat = (self.num_attn_heads != self.num_kv_heads and op_type == "MultiHeadAttention") + + if has_quantized_kv and needs_kv_repeat: + # Special handling for quantized KV cache + MHA + different head counts + # We need to: dequantize -> repeat_kv -> attention -> quantize + + # Extract layer_id from the attention name + import re + layer_match = re.search(r'/model/layers\.(\d+)/', name) + if layer_match: + layer_id = int(layer_match.group(1)) + else: + raise ValueError(f"Could not extract layer_id from attention name: {name}") + + if original_past_k and original_past_v: + # 1. Dequantize past KV cache + dequant_past_k = self.make_dequantize_kv_cache( + f"{name}/dequant_past_k", + original_past_k, + ['batch_size', self.num_kv_heads, 'past_sequence_length', self.head_size] + ) + dequant_past_v = self.make_dequantize_kv_cache( + f"{name}/dequant_past_v", + original_past_v, + ['batch_size', self.num_kv_heads, 'past_sequence_length', self.head_size] + ) + + # 2. Apply repeat_kv with direct dequantized inputs (remove barriers that cause TensorRT issues) + # Create temporary present outputs (unquantized) + temp_present_k_unquant = f"{name}/temp_present_k_unquant" + temp_present_v_unquant = f"{name}/temp_present_v_unquant" + + # Call repeat_kv to handle head expansion and KV cache update + repeated_k = self.make_repeat_kv(layer_id, root_input=kwargs["k_path"], past_kv=dequant_past_k, present_kv=temp_present_k_unquant) + repeated_v = self.make_repeat_kv(layer_id, root_input=kwargs["v_path"], past_kv=dequant_past_v, present_kv=temp_present_v_unquant) + + # 3. Update kwargs for attention operation + kwargs["k_path"] = repeated_k + kwargs["v_path"] = repeated_v + kwargs["past_k"] = "" # Already handled by repeat_kv + kwargs["past_v"] = "" # Already handled by repeat_kv + kwargs["present_k"] = "" # Will be handled by quantization + kwargs["present_v"] = "" # Will be handled by quantization + + # 4. Run MultiHeadAttention + self.make_multi_head_attention(name, add_qk=f"{self.mask_attrs['mask_name']}/output_0", **kwargs) + + # 5. Quantize the present KV cache from repeat_kv outputs + if original_present_k: + quant_present_k = self.make_quantize_kv_cache( + f"{name}/quant_present_k", + temp_present_k_unquant, + ['batch_size', self.num_kv_heads, 'total_sequence_length', self.head_size] + ) + self.make_node("Identity", inputs=[quant_present_k], outputs=[original_present_k], + name=f"{name}/present_k_identity") + self.make_value_info(original_present_k, TensorProto.INT8, + shape=['batch_size', self.num_kv_heads, 'total_sequence_length', self.head_size]) + + if original_present_v: + quant_present_v = self.make_quantize_kv_cache( + f"{name}/quant_present_v", + temp_present_v_unquant, + ['batch_size', self.num_kv_heads, 'total_sequence_length', self.head_size] + ) + self.make_node("Identity", inputs=[quant_present_v], outputs=[original_present_v], + name=f"{name}/present_v_identity") + self.make_value_info(original_present_v, TensorProto.INT8, + shape=['batch_size', self.num_kv_heads, 'total_sequence_length', self.head_size]) + return + + elif has_quantized_kv: + # Standard quantized KV cache handling (for GQA or MHA without head mismatch) + if original_past_k: + dequant_past_k = self.make_dequantize_kv_cache( + f"{name}/dequant_past_k", + original_past_k, + ['batch_size', self.num_kv_heads, 'past_sequence_length', self.head_size] + ) + kwargs["past_k"] = dequant_past_k + + if original_past_v: + dequant_past_v = self.make_dequantize_kv_cache( + f"{name}/dequant_past_v", + original_past_v, + ['batch_size', self.num_kv_heads, 'past_sequence_length', self.head_size] + ) + kwargs["past_v"] = dequant_past_v + + # Use temporary names for present outputs + temp_present_k = f"{name}/temp_present_k" if original_present_k else "" + temp_present_v = f"{name}/temp_present_v" if original_present_v else "" + if original_present_k: + kwargs["present_k"] = temp_present_k + if original_present_v: + kwargs["present_v"] = temp_present_v + if op_type == "MultiHeadAttention": self.make_multi_head_attention(name, add_qk=f"{self.mask_attrs['mask_name']}/output_0", **kwargs) elif op_type == "GroupQueryAttention": @@ -1811,6 +1936,31 @@ def make_attention_op(self, name, **kwargs): else: raise NotImplementedError(f"The {op_type} op is not currently supported.") + # Quantize present KV cache for storage if quantization is enabled + # (Skip if already handled in the special repeat_kv case above) + if has_quantized_kv and not needs_kv_repeat: + if original_present_k and temp_present_k: + quant_present_k = self.make_quantize_kv_cache( + f"{name}/quant_present_k", + temp_present_k, + ['batch_size', self.num_kv_heads, 'total_sequence_length', self.head_size] + ) + self.make_node("Identity", inputs=[quant_present_k], outputs=[original_present_k], + name=f"{name}/present_k_identity") + self.make_value_info(original_present_k, TensorProto.INT8, + shape=['batch_size', self.num_kv_heads, 'total_sequence_length', self.head_size]) + + if original_present_v and temp_present_v: + quant_present_v = self.make_quantize_kv_cache( + f"{name}/quant_present_v", + temp_present_v, + ['batch_size', self.num_kv_heads, 'total_sequence_length', self.head_size] + ) + self.make_node("Identity", inputs=[quant_present_v], outputs=[original_present_v], + name=f"{name}/present_v_identity") + self.make_value_info(original_present_v, TensorProto.INT8, + shape=['batch_size', self.num_kv_heads, 'total_sequence_length', self.head_size]) + def make_multi_head_attention(self, name, **kwargs): inputs = [ kwargs["q_path"], kwargs["k_path"], kwargs["v_path"], kwargs.get("bias", ""), @@ -1958,10 +2108,20 @@ def make_attention(self, layer_id, attention, root_input, **kwargs): past_v = f"past_key_values.{layer_id}.value" present_k = f"present.{layer_id}.key" present_v = f"present.{layer_id}.value" + + # Check if we have quantized KV cache - if so, handle it differently + has_quantized_kv = hasattr(self, 'kv_cache_attrs') and self.kv_cache_attrs["quantize_kv_cache"] + if self.num_attn_heads != self.num_kv_heads and self.attention_attrs["op_type"] == "MultiHeadAttention": - self.attention_attrs["k_path"] = self.make_repeat_kv(layer_id, root_input=self.attention_attrs["k_path"], past_kv=past_k, present_kv=present_k) - self.attention_attrs["v_path"] = self.make_repeat_kv(layer_id, root_input=self.attention_attrs["v_path"], past_kv=past_v, present_kv=present_v) - past_k, past_v, present_k, present_v = "", "", "", "" + if has_quantized_kv: + # For quantized KV cache, we'll do repeat_kv in make_attention_op after dequantization + # So we keep the KV cache parameters for the attention_op to handle + pass # KV cache will be handled in make_attention_op with proper quantization + else: + # Original logic for non-quantized case + self.attention_attrs["k_path"] = self.make_repeat_kv(layer_id, root_input=self.attention_attrs["k_path"], past_kv=past_k, present_kv=present_k) + self.attention_attrs["v_path"] = self.make_repeat_kv(layer_id, root_input=self.attention_attrs["v_path"], past_kv=past_v, present_kv=present_v) + past_k, past_v, present_k, present_v = "", "", "", "" # Make attention node (e.g. MultiHeadAttention, GroupQueryAttention, etc.) attn_name = f"/model/layers.{layer_id}/attn/{self.attention_attrs['op_type']}" @@ -3078,6 +3238,66 @@ def make_position_ids_reformatting(self): return reshape_name + def make_quantize_kv_cache(self, name, input_tensor, shape): + """Create standard ONNX quantization for KV cache using QuantizeLinear""" + # Use a more conservative scale factor for KV cache values + # KV cache values typically have much smaller range than full FP16 + # Using scale = 0.1 maps [-12.7, 12.7] to [-127, 127] which is more appropriate + scale_tensor_name = name.replace("/", ".") + "_scale" + scale_value = 0.1 # Conservative scale factor for better precision + self.make_external_tensor( + torch.tensor([scale_value], dtype=torch.float32).contiguous(), + scale_tensor_name + ) + + # Create zero_point tensor (0 for symmetric quantization) + zero_point_tensor_name = name.replace("/", ".") + "_zero_point" + self.make_external_tensor( + torch.tensor([0], dtype=torch.int8).contiguous(), + zero_point_tensor_name + ) + + # Cast input to FP32 for QuantizeLinear (must match scale tensor type) + cast_name = f"{name}/Cast_to_fp32" + self.make_cast(cast_name, input_tensor, dtype=TensorProto.FLOAT, shape=shape) + + # Use standard QuantizeLinear operator + quantize_name = f"{name}/QuantizeLinear" + quantize_inputs = [f"{cast_name}/output_0", scale_tensor_name, zero_point_tensor_name] + self.make_node("QuantizeLinear", inputs=quantize_inputs, outputs=[f"{quantize_name}/output_0"], name=quantize_name) + self.make_value_info(f"{quantize_name}/output_0", TensorProto.INT8, shape=shape) + + return f"{quantize_name}/output_0" + + def make_dequantize_kv_cache(self, name, input_tensor, shape): + """Create standard ONNX dequantization for KV cache using DequantizeLinear""" + # Use the same scale as quantization for proper round-trip + scale_tensor_name = name.replace("/", ".") + "_scale" + scale_value = 0.1 # Same scale as quantization + self.make_external_tensor( + torch.tensor([scale_value], dtype=torch.float32).contiguous(), + scale_tensor_name + ) + + # Create zero_point tensor (0 for symmetric quantization) + zero_point_tensor_name = name.replace("/", ".") + "_zero_point" + self.make_external_tensor( + torch.tensor([0], dtype=torch.int8).contiguous(), + zero_point_tensor_name + ) + + # Use standard DequantizeLinear operator (output is FP32) + dequantize_name = f"{name}/DequantizeLinear" + dequantize_inputs = [input_tensor, scale_tensor_name, zero_point_tensor_name] + self.make_node("DequantizeLinear", inputs=dequantize_inputs, outputs=[f"{dequantize_name}/output_0"], name=dequantize_name) + self.make_value_info(f"{dequantize_name}/output_0", TensorProto.FLOAT, shape=shape) + + # Cast from FP32 back to model's io_dtype (FP16) for compatibility + cast_name = f"{name}/Cast_to_fp16" + self.make_cast(cast_name, f"{dequantize_name}/output_0", dtype=self.io_dtype, shape=shape) + + return f"{cast_name}/output_0" + class LlamaModel(Model): def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): @@ -3646,7 +3866,9 @@ def check_extra_options(kv_pairs): """ Check key-value pairs and set values correctly """ - bools = ["int4_is_symmetric", "exclude_embeds", "exclude_lm_head", "include_hidden_states", "enable_cuda_graph", "use_8bits_moe", "use_qdq", "use_webgpu_fp32"] + bools = ["int4_is_symmetric", "exclude_embeds", "exclude_lm_head", "include_hidden_states", + "enable_cuda_graph", "use_8bits_moe", "use_qdq", "use_webgpu_fp32", + "quantize_kv_cache"] # Add quantize_kv_cache to bools list for key in bools: if key in kv_pairs: if kv_pairs[key] in {"false", "False", "0"}: