Skip to content

Commit 6504c44

Browse files
Utilitities for hydra initialization
Signed-off-by: Daniel Korzekwa <[email protected]>
1 parent 9bfcc21 commit 6504c44

File tree

3 files changed

+56
-5
lines changed

3 files changed

+56
-5
lines changed

modelopt/torch/_compress/compress.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,7 @@
2828
from omegaconf import DictConfig
2929
from puzzle_tools.runtime import IRuntime
3030

31-
# TODO Move initialize_hydra_config_for_dir from tests to main
32-
from tests.utils.test_utils import initialize_hydra_config_for_dir
31+
from modelopt.torch._compress.hydra import initialize_hydra_config_for_dir
3332

3433

3534
def compress(

modelopt/torch/_compress/hydra.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
from hydra import compose, initialize, initialize_config_dir
17+
from omegaconf import DictConfig, OmegaConf
18+
19+
"""
20+
Utilities for hydra config initialization.
21+
"""
22+
23+
24+
def initialize_hydra_config_for_dir(
25+
config_dir: str, config_name: str, overrides: list[str]
26+
) -> DictConfig:
27+
"""Initialize a hydra config from an absolute path for a config directory
28+
29+
Args:
30+
config_dir (str):
31+
config_name (str):
32+
overrides (List[str]):
33+
34+
Returns:
35+
DictConfig:
36+
"""
37+
38+
with initialize_config_dir(version_base=None, config_dir=config_dir):
39+
args = compose(config_name, overrides)
40+
args._set_flag("allow_objects", True)
41+
OmegaConf.resolve(args) # resolve object attributes
42+
OmegaConf.set_struct(args, False)
43+
44+
return args
45+
46+
47+
def initialize_hydra_config(config_path: str, config_name: str, overrides: list[str]) -> DictConfig:
48+
with initialize(version_base=None, config_path=config_path):
49+
args = compose(config_name, overrides)
50+
args._set_flag("allow_objects", True)
51+
OmegaConf.resolve(args) # resolve object attributes
52+
OmegaConf.set_struct(args, False)
53+
54+
return args

modelopt/torch/_compress/nas/plugins/compress_nas_plugin.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from scripts.convert_llama3_to_decilm import convert_llama3_to_decilm
2727
from torch import nn
2828

29+
from modelopt.torch._compress.hydra import initialize_hydra_config_for_dir
2930
from modelopt.torch._compress.runtime import NativeDdpRuntime
3031
from modelopt.torch.nas.conversion import NASModeRegistry
3132
from modelopt.torch.opt.config import ModeloptBaseConfig, ModeloptField
@@ -38,9 +39,6 @@
3839
)
3940
from modelopt.torch.opt.searcher import BaseSearcher
4041

41-
# TODO Move initialize_hydra_config_for_dir from tests to main
42-
from tests.utils.test_utils import initialize_hydra_config_for_dir
43-
4442

4543
class CompressModel(nn.Module):
4644
pass # No model implementation is needed for the compress mode

0 commit comments

Comments
 (0)