diff --git a/examples/compress/Dockerfile b/examples/compress/Dockerfile new file mode 100644 index 000000000..5a65839de --- /dev/null +++ b/examples/compress/Dockerfile @@ -0,0 +1,26 @@ +# Docker file for compress example + +FROM nvcr.io/nvidia/tensorrt-llm/release:1.1.0rc5 + +# TODO: The MIP solver would not work with this torch version. +# Fix it, otherwise, mamba models will not be supported by the Compress algorithm. +# # Required for mamba_ssm to work (the default torch version in the 1.1.0rc5 does not work) +# RUN pip uninstall -y torch +# RUN pip uninstall -y torchvision +# RUN pip install torch==2.7.1 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128 + +# # Mamba SSM +# RUN pip install causal-conv1d --no-build-isolation +# RUN pip install mamba_ssm --no-build-isolation + +# Required for puzzletron calc_subblock_stats +RUN pip install hydra-core==1.3.2 +RUN pip install wandb~=0.17.5 +RUN pip install "frozendict>=2.4.4" +RUN pip install fire +RUN pip install mip +RUN pip install lru-dict + +WORKDIR /workspace/ + +COPY . . diff --git a/examples/compress/README.md b/examples/compress/README.md new file mode 100644 index 000000000..0b165f46b --- /dev/null +++ b/examples/compress/README.md @@ -0,0 +1,194 @@ +# Compress Algorithm Tutorial + +This tutorial demonstrates how to compress large language models using the compress algorithm based on the [Puzzle paper](https://arxiv.org/abs/2411.19146). +The goal of the algorithm it to find the most optimal modifications to MLP and attention layers of the model, resulting in a heterogeneous model architecture. +The supported modifications are: + +- `ffn_intermediate_size`: different FFN intermediate sizes +- `attention op/noop`: complete removal of attention layers + +To use the Puzzle algorithm effectively, we need to specify the target number of parameters and/or the memory. The final stage is based on Mixed-Integer Programming (MIP) algorithm to find the most optimal combination of layer modifications that satisfy the target requirements. + +In this example, we compress the [meta-llama/Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct) model reducing GPU memory usage from 113 GiB to 96 GiB (15% reduction) with less than 1% regression in the token_accuracy_top_10 metric. + +## Environment + +- [Dockerfile](./Dockerfile) to use. +- 2x NVIDIA H100 80GB HBM3 (1 card will be good as well). + +## Compress the Model + +1. Specify the `puzzle_dir`, `input_hf_model_path`, `dataset_path`, `intermediate_size_list`, and `target_memory` arguments in the [llama-3_1-8B_pruneffn_memory.yaml](./configs/llama-3_1-8B_pruneffn_memory/llama-3_1-8B_pruneffn_memory.yaml) configuration file. + + **_NOTE:_** + How to choose `intermediate_size_list`? + The list specifies the candidate FFN sizes that we wish to search over. It is recommended to choose several pruning sizes (e.g. 15%, 20%, 30% etc of the original). Note that the values must be hardware-friendly (divisible by a multiple of 2) to avoid issues with tensor operations in subsequent steps. + + Let's first shoot for 32% GPU memory reduction setting `target_memory = 78_000` GiB. This means that the algorithm will choose the candidates with highest accuracy that also meet the specified requirements. + +2. Download and prepare the [Nemotron-Post-Training-Dataset-v2](https://huggingface.co/datasets/nvidia/Nemotron-Post-Training-Dataset-v2). + + dataset split: "code", "math", "stem", "chat", excluding reasoning samples (2.62GB) + + ```bash + python -m modelopt.torch._compress.dataset.prepare_dataset --dataset_name nvidia/Nemotron-Post-Training-Dataset-v2 --output_dir path/to/Nemotron-Post-Training-Dataset-v2 + ``` + +3. Run the compression script. + + ```bash + torchrun --nproc_per_node 2 examples/compress/main.py --config path/to/llama-3_1-8B_pruneffn_memory.yaml 2>&1 | tee ./log.txt | grep "Compress Progress" + ``` + + This will save the full output to `log.txt` and display the following progress on screen: + + ```bash + [2025-11-02 12:06:34] Compress Progress 1/8: starting compression pipeline + [2025-11-02 12:06:45] Compress Progress 2/8: converting model from HF to DeciLM (single-gpu) + [2025-11-02 12:07:07] Compress Progress 3/8: scoring pruning activations (multi-gpu) + [2025-11-02 12:11:36] Compress Progress 4/8: pruning the model and saving pruned checkpoints (single-gpu) + [2025-11-02 12:12:20] Compress Progress 5/8: building replacement library and subblock statistics (single-gpu) + [2025-11-02 12:12:21] Compress Progress 6/8: calculating one block scores (multi-gpu) + [2025-11-02 12:50:41] Compress Progress 7/8: running MIP and realizing models (multi-gpu) + [2025-11-02 12:52:34] Compress Progress 8/8: compression pipeline completed (multi-gpu) + ``` + + Once the process is complete, the resulting network architecture will be recorded in `log.txt` for your review: + + ```bash + ... + block_0: attention gqa_4 ffn intermediate_14336 + block_1: attention gqa_4 ffn intermediate_14336 + block_2: attention gqa_4 ffn intermediate_14336 + block_3: attention gqa_4 ffn intermediate_14336 + block_4: attention gqa_4 ffn intermediate_14336 + block_5: attention gqa_4 ffn intermediate_14336 + block_6: attention gqa_4 ffn intermediate_14336 + block_7: attention gqa_4 ffn intermediate_14336 + block_8: attention gqa_4 ffn intermediate_14336 + block_9: attention gqa_4 ffn intermediate_14336 + block_10: attention gqa_4 ffn intermediate_14336 + block_11: attention gqa_4 ffn intermediate_14336 + block_12: attention gqa_4 ffn intermediate_14336 + block_13: attention gqa_4 ffn intermediate_14336 + block_14: attention gqa_4 ffn intermediate_14336 + block_15: attention gqa_4 ffn intermediate_14336 + block_16: attention gqa_4 ffn intermediate_14336 + block_17: attention no_op ffn intermediate_14336 + block_18: attention no_op ffn intermediate_14336 + block_19: attention no_op ffn intermediate_14336 + block_20: attention no_op ffn intermediate_14336 + block_21: attention no_op ffn intermediate_14336 + block_22: attention no_op ffn intermediate_14336 + block_23: attention no_op ffn intermediate_14336 + block_24: attention no_op ffn intermediate_14336 + block_25: attention no_op ffn intermediate_14336 + block_26: attention no_op ffn intermediate_14336 + block_27: attention no_op ffn intermediate_14336 + block_28: attention no_op ffn intermediate_14336 + block_29: attention gqa_4 ffn intermediate_14336 + block_30: attention gqa_4 ffn intermediate_14336 + block_31: attention gqa_4 ffn intermediate_14336 + + [2025-11-02 04:53:11,332]^[[92m[rank-0]^[[0m[run_puzzle.py:295] Total costs: {'stats.memory_mib': 75796.4140625, 'stats.ffn_num_params': 5637275648, 'stats.num_kv_heads': 160, 'stats.kv_cache_memory_mib': 61440.0, 'stats.ffn_memory_mib': 10752.25, 'stats.attention_memory_mib': 63040.15625, 'stats.attention_num_params': 838942720, 'stats.num_params': 7526895616, 'stats.has_attention': 20, 'stats.has_ffn': 32} + ... + ################################################################ + validate_model_and_extract_token_probs(model_name='teacher') + ################################################################ + ... + Average losses = {'lm_loss': 1.118250765837729, 'token_accuracy_top_1': 0.7331905364990234, 'token_accuracy_top_5': 0.9094219207763672, 'token_accuracy_top_10': 0.9423646926879883} + ... + ################################################################ + validate_model_with_kl_div(model_name='solution_0', is_calc_kl_div=True) + ################################################################ + .... + Average losses = {'lm_loss': 1.7577573340386152, 'token_accuracy_top_1': 0.6225490570068359, 'token_accuracy_top_5': 0.846257209777832, 'token_accuracy_top_10': 0.8987817764282227} + ``` + + 30% GPU memory reduction leads to nearly 5% regression in token_accuracy_top_10 metric (0.898 / 0.942). Let's rerun MIP search aiming for 15% memory reduction. + +## Re-run MIP Search with different constraints + +If you want to try different constraints without re-running the expensive pruning and scoring steps, use the `--mip-only` flag. +This assumes pruning, replacement library building, NAS scoring, and subblock stats calculation have already been completed. + +For example, let's set `target_memory: 96_000` in `llama-3_1-8B_pruneffn_memory.yaml`. + +```bash +torchrun --nproc_per_node 2 examples/compress/main.py --config path/to/llama-3_1-8B_pruneffn_memory.yaml --mip-only 2>&1 | tee ./log.txt | grep "Compress Progress" +``` + +This will generate the following network architecture (see `log.txt`): + +```bash +block_0: attention gqa_4 ffn intermediate_14336 +block_1: attention gqa_4 ffn intermediate_14336 +block_2: attention gqa_4 ffn intermediate_14336 +block_3: attention gqa_4 ffn intermediate_14336 +block_4: attention gqa_4 ffn intermediate_14336 +block_5: attention gqa_4 ffn intermediate_14336 +block_6: attention gqa_4 ffn intermediate_14336 +block_7: attention gqa_4 ffn intermediate_14336 +block_8: attention gqa_4 ffn intermediate_14336 +block_9: attention gqa_4 ffn intermediate_14336 +block_10: attention gqa_4 ffn intermediate_14336 +block_11: attention gqa_4 ffn intermediate_14336 +block_12: attention gqa_4 ffn intermediate_14336 +block_13: attention gqa_4 ffn intermediate_14336 +block_14: attention gqa_4 ffn intermediate_14336 +block_15: attention gqa_4 ffn intermediate_14336 +block_16: attention gqa_4 ffn intermediate_14336 +block_17: attention gqa_4 ffn intermediate_14336 +block_18: attention no_op ffn intermediate_14336 +block_19: attention no_op ffn intermediate_14336 +block_20: attention no_op ffn intermediate_14336 +block_21: attention gqa_4 ffn intermediate_14336 +block_22: attention no_op ffn intermediate_14336 +block_23: attention no_op ffn intermediate_14336 +block_24: attention no_op ffn intermediate_14336 +block_25: attention gqa_4 ffn intermediate_14336 +block_26: attention gqa_4 ffn intermediate_14336 +block_27: attention gqa_4 ffn intermediate_14336 +block_28: attention gqa_4 ffn intermediate_14336 +block_29: attention gqa_4 ffn intermediate_14336 +block_30: attention gqa_4 ffn intermediate_14336 +block_31: attention gqa_4 ffn intermediate_14336 + +[2025-11-02 12:50:42,024]^[[92m[rank-0]^[[0m[run_puzzle.py:295] Total costs: {'stats.memory_mib': 94708.4609375, 'stats.has_ffn': 32, 'stats.ffn_memory_mib': 10752.25, 'stats.kv_cache_memory_mib': 79872.0, 'stats.attention_num_params': 1090625536, 'stats.ffn_num_params': 5637275648, 'stats.has_attention': 26, 'stats.num_params': 7778578432, 'stats.attention_memory_mib': 81952.203125, 'stats.num_kv_heads': 208} +... +################################################################ +validate_model_with_kl_div(model_name='solution_0', is_calc_kl_div=True) +################################################################ +Average losses = {'lm_loss': 1.2425934937782586, 'token_accuracy_top_1': 0.703862190246582, 'token_accuracy_top_5': 0.8954982757568359, 'token_accuracy_top_10': 0.9336576461791992 +``` + +On the other hand, if you set `target_memory: 28_000`, you'll observe that the intermediate FFN sizes are significantly reduced in certain layers (see `log.txt` for details): + +```bash +block_5: attention no_op ffn intermediate_11520 +block_6: attention no_op ffn intermediate_14336 +block_7: attention no_op ffn intermediate_8704 +block_8: attention no_op ffn intermediate_14336 +block_9: attention no_op ffn intermediate_3072 +block_10: attention no_op ffn intermediate_11520 +block_11: attention no_op ffn intermediate_11520 +block_12: attention no_op ffn intermediate_11520 +block_13: attention no_op ffn intermediate_11520 +block_14: attention no_op ffn intermediate_3072 +``` + +## Evaluation + +Once the model is ready, you can evaluate it using [Language Model Evaluation Harness](https://pypi.org/project/lm-eval/). For example, run the following to evaluate the model on a subset of [MMLU](https://huggingface.co/datasets/cais/mmlu). + +```bash +lm_eval --model hf \ + --model_args pretrained=path/to/model,dtype=bfloat16,trust_remote_code=true,parallelize=True \ + --tasks mmlu_humanities \ + --num_fewshot 5 \ + --batch_size 4 +``` + +## Advanced usage + +Modify `path/to/Llama-3_1-8B yaml` file for advanced compression scenarios. diff --git a/examples/compress/configs/llama-3_1-8B_pruneffn_memory/Llama-3_1-8B.yaml b/examples/compress/configs/llama-3_1-8B_pruneffn_memory/Llama-3_1-8B.yaml new file mode 100644 index 000000000..70b5304c5 --- /dev/null +++ b/examples/compress/configs/llama-3_1-8B_pruneffn_memory/Llama-3_1-8B.yaml @@ -0,0 +1,110 @@ +defaults: + - pruning: ffn_pruning + - scoring: ../validate_solutions_defaults + - realize_model: ../validate_solutions_defaults + - bypass: + - override hydra/hydra_logging: disabled + - _self_ + +puzzle_dir: ??? +teacher_dir: ${puzzle_dir}/ckpts/teacher/ +replacement_library_path: ${puzzle_dir}/replacement_library.json +dataset_path: ??? # path to v0.4_mini + +skip_realize_model: false + +build_replacement_library: + add_ffn_no_ops: true + add_attention_no_ops: true + +calc_subblock_stats: + batch_sizes: [64, 96, 128] + prefill_seq_len: 4096 + generation_seq_len: 4096 + num_active_tokens_override: # Optional override for sequence lengths + prefill_queue_size: 0 + allocate_prefill_query: false + benchmark_iterations: # Set to a number (e.g., 1000) to enable runtime benchmarking + merge_with_existing_stats: false + subblock_stats_filename: "subblock_stats.json" + moe_stats_filename: "moe_stats.json" + runtime_stats: + backend: trt_torch + +scoring: + solutions_to_validate: + skip_existing_solutions: true + + replacement_library_path: ${replacement_library_path} + solutions_path: ${to_path:${puzzle_dir}/single_sequence_replacement_solutions.json} + teacher_dir: ${to_path:${teacher_dir}} + output_dir: ${puzzle_dir}/single_sequence_replacement_solutions--validation + + eval_samples: 10 # default is 128 + micro_batch_size: 1 + seed: 42 + shuffle_seed: 444 + dataset_path: ${dataset_path} + +mip: + single_block_replacement_validation_dir: ${to_path:${scoring.output_dir}} + subblock_stats_path: ${to_path:${puzzle_dir}/${calc_subblock_stats.subblock_stats_filename}} + output_path: ${to_path:${puzzle_dir}/mip/puzzle_solutions} + gathered_metrics_path: + puzzle_profile: + + # puzzle_profile: + objective: metrics.cosine_embedding_loss_hidden_states + bigger_is_better: false + num_solutions: 1 + minimal_diversity: 2 + + subblock_stats_args: + - batch_size: 96 + weights_dtype: torch.bfloat16 + activations_dtype: torch.bfloat16 + kv_cache_dtype: torch.bfloat16 + + report_additional_costs: + - stats.memory_mib + - stats.num_params + - stats.num_kv_heads + - stats.has_attention + - stats.has_ffn + - stats.kv_cache_memory_mib + - stats.attention_memory_mib + - stats.ffn_memory_mib + - stats.ffn_num_params + - stats.attention_num_params + + human_constraints: + target_memory: 78_000 + + mip_constraints: + use_greedy_search: false + is_multi_layer_puzzle: true + metric_overrides: + constrain_search_func: + max_seconds_per_solution: 60 + +realize_model: + teacher_dir: ${to_path:${teacher_dir}} + tokenizer_name: ${to_path:${teacher_dir}} + replacement_library_path: ${replacement_library_path} + save_models: true + solutions_path: # Filled dynamically + + # Validate params + skip_validation: false # To enable validation of the model solution set `skip_validation` as False + eval_samples: 128 + micro_batch_size: 1 + seed: 42 + shuffle_seed: 444 + dataset_path: ${dataset_path} + +nccl_timeout_minutes: ${timedelta_minutes:10} + +# This section redirects Hydra outputs +hydra: + run: + dir: ${puzzle_dir}/hydra_logs/${now:%Y-%m-%d}/${now:%H-%M-%S} diff --git a/examples/compress/configs/llama-3_1-8B_pruneffn_memory/llama-3_1-8B_pruneffn_memory.yaml b/examples/compress/configs/llama-3_1-8B_pruneffn_memory/llama-3_1-8B_pruneffn_memory.yaml new file mode 100644 index 000000000..cfd7f93e8 --- /dev/null +++ b/examples/compress/configs/llama-3_1-8B_pruneffn_memory/llama-3_1-8B_pruneffn_memory.yaml @@ -0,0 +1,21 @@ +defaults: + - Llama-3_1-8B + - _self_ + +# Input Hugging Face model to compress +input_hf_model_path: /workspace/hf_models/meta-llama/Llama-3.1-8B-Instruct + +# Dataset path for pruning and NAS scoring +dataset_path: /workspace/datasets/Nemotron-Post-Training-Dataset-v2 + +# Working directory for compression outputs +puzzle_dir: /workspace/puzzle_dir + +# MIP memory constraint (in MiB) +mip: + human_constraints: + target_memory: 96_000 # 96 GiB + +# FFN intermediate sizes to search over (heterogeneous architecture) +pruning: + intermediate_size_list: [3072, 5888, 8704, 11520] # teacher_intermediate_size is 14336 diff --git a/examples/compress/configs/llama-3_1-8B_pruneffn_memory/pruning/attn_pruning.yaml b/examples/compress/configs/llama-3_1-8B_pruneffn_memory/pruning/attn_pruning.yaml new file mode 100644 index 000000000..01886607e --- /dev/null +++ b/examples/compress/configs/llama-3_1-8B_pruneffn_memory/pruning/attn_pruning.yaml @@ -0,0 +1,16 @@ +defaults: + - pruning_defaults + +activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/attn_${pruning.activation_hooks_kwargs.method}/${pruning.experiment_id} + +activation_hooks_kwargs: + method: independent_kv_head_contribution + optimize_for: memory # IndependentKvHeadContributionHook implementation that consumes less memory + target_layer: "self_attn.o_proj" + layer_input_descriptors_path: + +# n_heads_in_group: 4 +# num_attention_heads: 32 # num query heads +# num_kv_heads: 32 / 4 = 8 # num_query_heads // n_heads_in_group +n_heads_in_group_list: [8, 16, 32] # num_kv_heads = [4, 2, 1] +gqa_init_mode: "PruneKVHeads" diff --git a/examples/compress/configs/llama-3_1-8B_pruneffn_memory/pruning/ffn_pruning.yaml b/examples/compress/configs/llama-3_1-8B_pruneffn_memory/pruning/ffn_pruning.yaml new file mode 100644 index 000000000..96a8ca72e --- /dev/null +++ b/examples/compress/configs/llama-3_1-8B_pruneffn_memory/pruning/ffn_pruning.yaml @@ -0,0 +1,12 @@ +defaults: + - pruning_defaults + +activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/ffn_${pruning.activation_hooks_kwargs.method}/${pruning.experiment_id} + +activation_hooks_kwargs: + method: iterative + target_layer: "mlp.down_proj" + layer_input_descriptors_path: + +intermediate_size_list: [3072, 5888, 8704, 11520] # teacher_intermediate_size is 14336 +mlp_init_mode: "PruneByActivationsLog" diff --git a/examples/compress/configs/llama-3_1-8B_pruneffn_memory/pruning/hidden_dim_pruning.yaml b/examples/compress/configs/llama-3_1-8B_pruneffn_memory/pruning/hidden_dim_pruning.yaml new file mode 100644 index 000000000..407c835d8 --- /dev/null +++ b/examples/compress/configs/llama-3_1-8B_pruneffn_memory/pruning/hidden_dim_pruning.yaml @@ -0,0 +1,15 @@ +defaults: + - pruning_defaults + +activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/hidden_dim_${pruning.activation_hooks_kwargs.method}/${pruning.experiment_id} + +activation_hooks_kwargs: + method: layer_norm_contribution + target_layer: "layernorm" + +# Hidden dimension pruning specific settings +hidden_size_list: [3072, 2048] # Target hidden sizes to prune to +hidden_size_init_mode: "PruneByChannelRanking" +mlp_init_mode: "Truncate" # TODO, make it work with CopyAsIs/FromTeacher +gqa_init_mode: "AverageKV" # TODO, make it work with CopyAsIs/FromTeacher +linear_init_mode: "FromTeacher" diff --git a/examples/compress/configs/llama-3_1-8B_pruneffn_memory/pruning/pruning_defaults.yaml b/examples/compress/configs/llama-3_1-8B_pruneffn_memory/pruning/pruning_defaults.yaml new file mode 100644 index 000000000..5d5307b9c --- /dev/null +++ b/examples/compress/configs/llama-3_1-8B_pruneffn_memory/pruning/pruning_defaults.yaml @@ -0,0 +1,32 @@ +defaults: + - /validate_model_defaults + +model_name_or_path: ${teacher_dir} +experiment_id: ${pruning.eval_samples}samples_diverse_mini +activations_log_dir: ??? +activation_hooks_kwargs: ??? + +# Data: +eval_samples: 1000 # default is 10000 +micro_batch_size: 4 +dataset_path: ${dataset_path} +val_dataset_name: train + +# Prune ckpts +pruned_ckpts_outpt_dir: ${puzzle_dir}/pruning/${pruning.experiment_id} + +## FFN pruning +ffn_list: +mlp_init_mode: "Truncate" # PruneByActivationsLog + +## KV-heads pruning +n_heads_in_group_list: +gqa_init_mode: "AverageKV" + +## Hidden dimension pruning +hidden_size_list: +hidden_size_init_mode: "PruneByChannelRanking" +linear_init_mode: "FromTeacher" + +mlp_init_config_yaml: + activations_log_dir: ${pruning.activations_log_dir} diff --git a/examples/compress/configs/llama-3_1-8B_pruneffn_memory/validate_model_defaults.yaml b/examples/compress/configs/llama-3_1-8B_pruneffn_memory/validate_model_defaults.yaml new file mode 100644 index 000000000..572331a84 --- /dev/null +++ b/examples/compress/configs/llama-3_1-8B_pruneffn_memory/validate_model_defaults.yaml @@ -0,0 +1,15 @@ +block_size: 8192 +bos_rate: 0.5 +data_column: messages +val_dataset_name: valid +shuffle_seed: 81436 +seed: 42 +fim_rate: 0 +fim_spm_rate: 0 +source_datasets_to_discard: +varlen: false +write_results: false +calc_losses_on_cpu: false +activations_log_dir: +model_name_or_path: +load_dataset_fn: ${get_object:utils.data.dataloaders.load_from_disk_fn} diff --git a/examples/compress/configs/llama-3_1-8B_pruneffn_memory/validate_solutions_defaults.yaml b/examples/compress/configs/llama-3_1-8B_pruneffn_memory/validate_solutions_defaults.yaml new file mode 100644 index 000000000..ec1390237 --- /dev/null +++ b/examples/compress/configs/llama-3_1-8B_pruneffn_memory/validate_solutions_defaults.yaml @@ -0,0 +1,10 @@ +defaults: + - /validate_model_defaults + - _self_ + +solutions_to_validate: +skip_validation: false +save_models: false +bigger_is_better: false +sort_solutions_by: +calculate_full_score_ablations: false diff --git a/examples/compress/main.py b/examples/compress/main.py new file mode 100644 index 000000000..93ea0b8ab --- /dev/null +++ b/examples/compress/main.py @@ -0,0 +1,165 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Main script for running the compress algorithm on large language models (based on Puzzle paper https://arxiv.org/abs/2411.19146). + +This script provides two modes: +1. Default mode: Runs the full compression pipeline +2. MIP-only mode: Runs only the MIP search and realize models phase + +Usage: + # Full compression pipeline + torchrun main.py --config ./configs/llama_3.2_1B_pruneffn_memory.yaml + + # Only MIP search and realize models phase + torchrun main.py --config ./configs/llama_3.2_1B_pruneffn_memory.yaml --mip-only +""" + +import argparse +import datetime +from pathlib import Path + +import mip_and_realize_models +import torch +from puzzle_tools.hydra_utils import register_hydra_resolvers + +import modelopt.torch.nas as mtn +from modelopt.torch._compress.dateutils import timestamped +from modelopt.torch._compress.nas.plugins.compress_nas_plugin import CompressModel +from modelopt.torch._compress.runtime import NativeDdpRuntime +from tests.utils.test_utils import initialize_hydra_config_for_dir + + +def parse_args(): + """Parse command line arguments.""" + parser = argparse.ArgumentParser( + description="Compress large language models using the Compress algorithm (based on Puzzle paper https://arxiv.org/abs/2411.19146)" + ) + parser.add_argument( + "--config", + type=str, + required=True, + help="Path to the main config YAML file (e.g., ./configs/llama_3.2_1B_pruneffn_memory.yaml)", + ) + parser.add_argument( + "--mip-only", + action="store_true", + help="Run only the MIP search and realize models phase (skip pruning and NAS scoring)", + ) + + return parser.parse_args() + + +def run_full_compress(hydra_config_path: str): + """Run the full compression pipeline. + + Args: + config_path: Path to the YAML configuration file + """ + + print(timestamped("Compress Progress 1/8: starting compression pipeline")) + with NativeDdpRuntime(dtype=torch.bfloat16, torch_distributed_timeout=datetime.timedelta(10)): + # Register Hydra custom resolvers (needed for config resolution) + register_hydra_resolvers() + + hydra_config_path = Path(hydra_config_path).resolve() + hydra_config_dir = str(hydra_config_path.parent) + hydra_config_name = hydra_config_path.stem + + # Load hydra config + hydra_cfg = initialize_hydra_config_for_dir( + config_dir=hydra_config_dir, + config_name=hydra_config_name, + overrides=[], + ) + + # Convert model (convert from HF to DeciLM, score pruning activations, + # prune the model and save pruned checkpoints) + input_model = CompressModel() + converted_model = mtn.convert( + input_model, + mode=[ + ( + "compress", + { + "puzzle_dir": str(hydra_cfg.puzzle_dir), + "input_model_path": hydra_cfg.input_hf_model_path, + "hydra_config_dir": hydra_config_dir, + "hydra_config_name": hydra_config_name, + "dataset_path": str(hydra_cfg.dataset_path), + }, + ) + ], + ) + + # Run NAS search (build replacement library and compute stats, + # compute one block scores, run MIP and realize models) + mtn.search( + converted_model, + constraints={}, # this is not used as the search space is defined in the hydra config + dummy_input=None, # Not used + config={}, # this is not used as the search space is defined in the hydra config + ) + + print(timestamped("Compress Progress 8/8: compression pipeline completed (multi-gpu)")) + + +def run_mip_only(hydra_config_path: str): + """Run only the MIP search and realize models phase. + + This assumes that pruning, replacement library building, NAS scoring, and subblock stats calculation + have already been completed. + + Args: + hydra_config_path: Path to the YAML configuration file + """ + + with NativeDdpRuntime( + dtype=torch.bfloat16, torch_distributed_timeout=datetime.timedelta(10) + ) as runtime: + # Register Hydra custom resolvers (needed for config resolution) + register_hydra_resolvers() + + hydra_config_path = Path(hydra_config_path).resolve() + hydra_config_dir = str(hydra_config_path.parent) + hydra_config_name = hydra_config_path.stem + + # Load hydra config + hydra_cfg = initialize_hydra_config_for_dir( + config_dir=hydra_config_dir, + config_name=hydra_config_name, + overrides=[], + ) + + # mip_and_realize_models (distributed processing) + # TODO: How to make it part of mnt.search() api, similarly to run_full_compress() API + print(timestamped("Compress Progress 7/8: running MIP and realizing models (multi-gpu)")) + mip_and_realize_models.launch_mip_and_realize_model(hydra_cfg, runtime) + + print(timestamped("Compress Progress 8/8: compression pipeline completed (multi-gpu)")) + + +def main(): + args = parse_args() + + if args.mip_only: + run_mip_only(hydra_config_path=args.config) + else: + run_full_compress(hydra_config_path=args.config) + + +if __name__ == "__main__": + main() diff --git a/examples/pruning/README.md b/examples/pruning/README.md index 3efa9eb79..54f7322b1 100644 --- a/examples/pruning/README.md +++ b/examples/pruning/README.md @@ -23,6 +23,8 @@ This section focuses on applying Model Optimizer's state-of-the-art complementar +For more advanced pruning strategies, such as the [Puzzle methodology](https://arxiv.org/pdf/2411.19146), please see [Puzzle pruning example](https://github.com/NVIDIA/TensorRT-Model-Optimizer/tree/feature/compress/examples/compress). + ## Pre-Requisites For Minitron pruning for Megatron-LM / NeMo models, use the NeMo container (e.g., `nvcr.io/nvidia/nemo:25.07`) which has all the dependencies installed. diff --git a/modelopt/torch/_compress/__init__.py b/modelopt/torch/_compress/__init__.py new file mode 100644 index 000000000..47f1c65a1 --- /dev/null +++ b/modelopt/torch/_compress/__init__.py @@ -0,0 +1,15 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/modelopt/torch/_compress/dataset/__init__.py b/modelopt/torch/_compress/dataset/__init__.py new file mode 100644 index 000000000..47f1c65a1 --- /dev/null +++ b/modelopt/torch/_compress/dataset/__init__.py @@ -0,0 +1,15 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/modelopt/torch/_compress/dataset/prepare_dataset.py b/modelopt/torch/_compress/dataset/prepare_dataset.py new file mode 100644 index 000000000..49d63d122 --- /dev/null +++ b/modelopt/torch/_compress/dataset/prepare_dataset.py @@ -0,0 +1,64 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import datasets +import fire +import numpy as np +from logger import mprint + + +def process_and_save_dataset( + dataset_name: str, + output_dir: str, + split: tuple = ("code", "math", "stem", "chat"), + overwrite: bool = False, +): + # Check if output_dir contains an existing dataset + dataset_dict_path = os.path.join(output_dir, "dataset_dict.json") + if os.path.exists(output_dir) and os.path.exists(dataset_dict_path): + if not overwrite: + mprint( + f"Output directory '{output_dir}' already contains a dataset. " + "Use '--overwrite True' to overwrite existing data." + ) + return + + ds = datasets.load_dataset(dataset_name, split=split) + ds = datasets.concatenate_datasets(ds) + # Filter out samples with reasoning = on + ds = ds.filter(lambda x: x["reasoning"] == "off") + # Hardcoded for dynamically create a deterministic train-val split + seed = 408 + generator = np.random.RandomState(seed=seed) + ds_split = ds.train_test_split(test_size=0.05, shuffle=True, generator=generator) + # Rename dataset names to follow previous conventions + ds_dict = datasets.DatasetDict( + { + "train": ds_split["train"], + "valid": ds_split["test"], + } + ) + # Save locally + os.makedirs(output_dir, exist_ok=True) + ds_dict.save_to_disk(output_dir) + + mprint(f"Dataset splits:\n{ds_dict}") + mprint(f"Saved processed datasets to {output_dir}") + + +if __name__ == "__main__": + fire.Fire(process_and_save_dataset) diff --git a/modelopt/torch/_compress/dateutils.py b/modelopt/torch/_compress/dateutils.py new file mode 100644 index 000000000..76a8aec2a --- /dev/null +++ b/modelopt/torch/_compress/dateutils.py @@ -0,0 +1,41 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Date and time utility functions for the compress module. +""" + +import datetime + + +def get_timestamp() -> str: + """Get a formatted timestamp string for logging. + + Returns: + A formatted timestamp string in the format 'YYYY-MM-DD HH:MM:SS'. + """ + return datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + + +def timestamped(message: str) -> str: + """Add a timestamp prefix to a message. + + Args: + message: The message to prefix with a timestamp. + + Returns: + The message with a timestamp prefix in the format '[YYYY-MM-DD HH:MM:SS] message'. + """ + return f"[{get_timestamp()}] {message}" diff --git a/modelopt/torch/_compress/nas/plugins/compress_nas_plugin.py b/modelopt/torch/_compress/nas/plugins/compress_nas_plugin.py index d821fbd02..03a7f9a98 100644 --- a/modelopt/torch/_compress/nas/plugins/compress_nas_plugin.py +++ b/modelopt/torch/_compress/nas/plugins/compress_nas_plugin.py @@ -28,6 +28,7 @@ import torch from torch import nn +from modelopt.torch._compress.dateutils import timestamped from modelopt.torch._compress.decilm.converters.convert_llama3_to_decilm import ( convert_llama3_to_decilm, ) @@ -119,17 +120,27 @@ def convert_compress_model(model: nn.Module, config: CompressConfig) -> ConvertR ) # Convert Llama3 model to DeciLM model - hf_ckpt_teacher_dir = "ckpts/teacher" # TODO: make it configurable - convert_llama3_to_decilm( - input_dir=config.input_model_path, - output_dir=Path(config.puzzle_dir) / hf_ckpt_teacher_dir, - ) + if runtime.global_rank == 0: + print(timestamped("Compress Progress 2/8: converting model from HF to DeciLM (single-gpu)")) + hf_ckpt_teacher_dir = "ckpts/teacher" # TODO: make it configurable + convert_llama3_to_decilm( + input_dir=config.input_model_path, + output_dir=Path(config.puzzle_dir) / hf_ckpt_teacher_dir, + ) + runtime.wait_for_everyone() # Score_pruning_activations (distributed processing) + print(timestamped("Compress Progress 3/8: scoring pruning activations (multi-gpu)")) score_pruning_activations.launch_score_activations(hydra_cfg, runtime) # Prune the model and save pruned checkpoints + if runtime.global_rank == 0: + print( + timestamped( + "Compress Progress 4/8: pruning the model and saving pruned checkpoints (single-gpu)" + ) + ) pruning_ckpts.launch_prune_ckpt(hydra_cfg) runtime.wait_for_everyone() @@ -209,11 +220,20 @@ def run_search(self) -> None: # Build_library_and_stats (single process) if runtime.global_rank == 0: - build_library_and_stats.launch_build_library_and_stats(hydra_cfg) + print( + timestamped( + "Compress Progress 5/8: building replacement library and subblock statistics (single-gpu)" + ) + ) + + build_library_and_stats.launch_build_library_and_stats(hydra_cfg) runtime.wait_for_everyone() # Calc_one_block_scores (distributed processing) + + print(timestamped("Compress Progress 6/8: calculating one block scores (multi-gpu)")) scoring.launch_scoring(hydra_cfg, runtime) # mip_and_realize_models (distributed processing) + print(timestamped("Compress Progress 7/8: running MIP and realizing models (multi-gpu)")) mip_and_realize_models.launch_mip_and_realize_model(hydra_cfg, runtime) diff --git a/tests/experimental/torch/_compress/test_compress.py b/tests/experimental/torch/_compress/test_compress.py index 3d5d6b666..f945e75ff 100644 --- a/tests/experimental/torch/_compress/test_compress.py +++ b/tests/experimental/torch/_compress/test_compress.py @@ -40,13 +40,10 @@ # /workspace/puzzletron # # submit_job --partition interactive --time 0 \ -# --image gitlab-master.nvidia.com/deci/puzzletron:trtllm_main \ +# --image gitlab-master.nvidia.com/deci/puzzletron:modelopt_main \ # --workdir $MODELOPT SRC DIRECTORY --interactive --gpu 1 # -# pip install mip -# pip install lru-dict -# -# export PYTHONPATH=$PYTHONPATH:/workspace/puzzletron/v1 +# export PYTHONPATH=$PYTHONPATH:.:/workspace/puzzletron/v1 # # pytest -s -v ./tests/experimental/torch/_compress/test_compress.py::test_compress -o addopts=""