Skip to content

Commit 9753b8d

Browse files
Merge branch 'feature/compress' into dkorzekwa/nas_search
Signed-off-by: Daniel Korzekwa <[email protected]>
2 parents a0cfd13 + 002b8b5 commit 9753b8d

File tree

8 files changed

+185
-10
lines changed

8 files changed

+185
-10
lines changed

modelopt/torch/_compress/compress.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,7 @@
2828
from omegaconf import DictConfig
2929
from puzzle_tools.runtime import IRuntime
3030

31-
# TODO Move initialize_hydra_config_for_dir from tests to main
32-
from tests.utils.test_utils import initialize_hydra_config_for_dir
31+
from modelopt.torch._compress.hydra import initialize_hydra_config_for_dir
3332

3433

3534
def compress(

modelopt/torch/_compress/hydra.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
from hydra import compose, initialize, initialize_config_dir
17+
from omegaconf import DictConfig, OmegaConf
18+
19+
"""
20+
Utilities for hydra config initialization.
21+
"""
22+
23+
24+
def initialize_hydra_config_for_dir(
25+
config_dir: str, config_name: str, overrides: list[str]
26+
) -> DictConfig:
27+
"""Initialize a hydra config from an absolute path for a config directory
28+
29+
Args:
30+
config_dir (str):
31+
config_name (str):
32+
overrides (List[str]):
33+
34+
Returns:
35+
DictConfig:
36+
"""
37+
38+
with initialize_config_dir(version_base=None, config_dir=config_dir):
39+
args = compose(config_name, overrides)
40+
args._set_flag("allow_objects", True)
41+
OmegaConf.resolve(args) # resolve object attributes
42+
OmegaConf.set_struct(args, False)
43+
44+
return args
45+
46+
47+
def initialize_hydra_config(config_path: str, config_name: str, overrides: list[str]) -> DictConfig:
48+
with initialize(version_base=None, config_path=config_path):
49+
args = compose(config_name, overrides)
50+
args._set_flag("allow_objects", True)
51+
OmegaConf.resolve(args) # resolve object attributes
52+
OmegaConf.set_struct(args, False)
53+
54+
return args

modelopt/torch/_compress/nas/plugins/compress_nas_plugin.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ def config_class(self) -> type[ModeloptBaseConfig]:
159159
@property
160160
def search_algorithm(self) -> type[BaseSearcher]:
161161
"""Return the associated searcher implementation."""
162+
162163
return CompressSearcher
163164

164165
@property

setup.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,11 @@
100100
"setuptools-scm>=8",
101101
],
102102
# Dependedencies for modelopt.torch._compress subpackage
103-
"compress": ["fire"],
103+
"compress": [
104+
"fire",
105+
"hydra-core==1.3.2",
106+
"omegaconf==2.3.0",
107+
],
104108
}
105109

106110
# create "compound" optional dependencies
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import os
17+
import shutil
18+
from pathlib import Path
19+
20+
import torch
21+
from datasets import Dataset, DatasetDict
22+
from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM, PreTrainedTokenizerBase
23+
24+
25+
def create_and_save_small_llama_model(
26+
output_path: str, vocab_size: int, tokenizer: PreTrainedTokenizerBase
27+
):
28+
"""
29+
Create and save a small Llama model for testing the conversion pipeline.
30+
This mimics having a real Llama checkpoint that needs to be converted.
31+
"""
32+
os.makedirs(output_path, exist_ok=True)
33+
34+
# Create a minimal Llama config (small for testing)
35+
# Note: intermediate_size must be divisible by 256 per DeciLM config requirements
36+
# Note: hidden_size must give head_dim >= 8 for Flash Attention 2 compatibility
37+
llama_config = LlamaConfig(
38+
vocab_size=vocab_size,
39+
hidden_size=256, # 32 heads times 8 head_dim = 256 (matches bypass config expectations)
40+
intermediate_size=512, # Must be divisible by 256
41+
num_hidden_layers=2,
42+
num_attention_heads=32, # Matches original test
43+
num_key_value_heads=8, # GQA: 32÷4=8 (matches original n_heads_in_group=4)
44+
max_position_embeddings=512,
45+
rms_norm_eps=1e-5,
46+
rope_theta=10000.0,
47+
attention_bias=False,
48+
hidden_act="silu",
49+
tie_word_embeddings=False,
50+
)
51+
52+
# Create and save the Llama model
53+
model = LlamaForCausalLM(llama_config)
54+
model.to(dtype=torch.bfloat16).save_pretrained(output_path)
55+
56+
# Save tokenizer
57+
tokenizer.save_pretrained(output_path)
58+
59+
# Save config
60+
llama_config.save_pretrained(output_path)
61+
62+
63+
def create_tokenizer(project_root_path: Path) -> PreTrainedTokenizerBase:
64+
"""
65+
Create a tokenizer for the Llama model.
66+
"""
67+
tokenizer_path = project_root_path / "tests/experimental/torch/_compress/resources/tokenizer"
68+
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
69+
return tokenizer
70+
71+
72+
def setup_puzzle_dir(puzzle_dir: str):
73+
"""
74+
Setup puzzle directory by removing existing directory and creating a new one.
75+
"""
76+
if Path(puzzle_dir).exists():
77+
shutil.rmtree(puzzle_dir)
78+
Path(puzzle_dir).mkdir(parents=True, exist_ok=True)
79+
80+
81+
def save_dummy_dataset(dataset_path: str):
82+
"""
83+
Save a dummy dataset for testing purposes.
84+
"""
85+
# dummy sample
86+
sample = [
87+
{"role": "user", "content": "please cite Lorem Ipsum?"},
88+
{
89+
"role": "assistant",
90+
"content": (
91+
"Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed in blandit ante. "
92+
"Sed tempus erat urna, ac elementum nisl facilisis quis. Aliquam consectetur mollis massa, "
93+
"in elementum sem venenatis posuere. Fusce lorem arcu, egestas vel massa sollicitudin, "
94+
"dictum mollis purus. Proin in ullamcorper elit. Nam tellus nisi, volutpat a mattis vel, "
95+
"pretium in purus. Nunc at lectus facilisis risus scelerisque rhoncus eu nec ex. "
96+
"Maecenas semper, tellus non placerat vulputate, urna felis facilisis diam, "
97+
"sit amet vestibulum erat sapien nec libero. Praesent non massa velit. Donec faucibus mi eros. "
98+
"Nam turpis nulla, congue sit amet mi at, porttitor scelerisque elit. Nunc id sodales lorem, "
99+
"nec tincidunt leo. Quisque a neque nec ligula porttitor auctor. "
100+
"Nunc accumsan nunc ac tellus congue vehicula. Praesent tellus eros, luctus non gravida dapibus, "
101+
"faucibus eu ex. Quisque bibendum leo pharetra, tristique est vitae, hendrerit nunc. "
102+
"Duis nec congue dolor. Donec commodo ipsum non efficitur volutpat. "
103+
"Nulla risus nulla, efficitur et urna at, imperdiet sodales lorem. "
104+
"Suspendisse erat est, sollicitudin at nisl tincidunt, vehicula hendrerit lectus. "
105+
"Nam quis nisi ullamcorper, rhoncus massa vel, tempus purus. "
106+
"Duis pulvinar eros vel nulla pellentesque, at dapibus justo laoreet. "
107+
"Praesent tortor orci, vulputate fermentum dapibus nec, feugiat vitae tortor. "
108+
"Donec mollis convallis massa quis iaculis."
109+
),
110+
},
111+
]
112+
113+
# Prepare train and val splits with sample repeated, 2500 samples are for
114+
# 128 samples with block-size 8192 and LLama3 tokenizer
115+
data = [{"conversation": sample}] * 2500
116+
117+
# For train-val splits
118+
data_dict = DatasetDict({"train": Dataset.from_list(data), "valid": Dataset.from_list(data)})
119+
data_dict.save_to_disk(dataset_path)

tests/experimental/torch/_compress/decilm/converters/test_convert_llama3_config_to_decilm_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import json
1717
from pathlib import Path
1818

19-
from experimental.torch._compress.test_utils import (
19+
from experimental.torch._compress.compress_test_utils import (
2020
create_and_save_small_llama_model,
2121
create_tokenizer,
2222
)

tests/experimental/torch/_compress/nas/plugins/test_nas_convert.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
import torch
2222
from _test_utils.torch.distributed.utils import spawn_multiprocess_job
23-
from experimental.torch._compress.test_utils import (
23+
from experimental.torch._compress.compress_test_utils import (
2424
create_and_save_small_llama_model,
2525
create_tokenizer,
2626
save_dummy_dataset,

tests/experimental/torch/_compress/test_compress.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
import torch
2222
from _test_utils.torch.distributed.utils import spawn_multiprocess_job
23-
from experimental.torch._compress.test_utils import (
23+
from experimental.torch._compress.compress_test_utils import (
2424
create_and_save_small_llama_model,
2525
create_tokenizer,
2626
save_dummy_dataset,
@@ -76,11 +76,9 @@ def _test_compress_multiprocess_job(project_root_path: Path, tmp_path: Path, ran
7676
hydra_config_dir = project_root_path / "tests/experimental/torch/_compress/resources/configs"
7777
hydra_config_name = "Llama-3_1-8B"
7878

79-
runtime = NativeDdpRuntime(
79+
with NativeDdpRuntime(
8080
dtype=torch.bfloat16, torch_distributed_timeout=datetime.timedelta(10)
81-
)
82-
83-
with runtime as runtime:
81+
) as runtime:
8482
#
8583
# Test setup
8684
#

0 commit comments

Comments
 (0)