Skip to content

Commit 932daf5

Browse files
committed
set_layer -> Module.set_submodule
1 parent 06a1e71 commit 932daf5

File tree

2 files changed

+5
-22
lines changed

2 files changed

+5
-22
lines changed

src/llmcompressor/modifiers/distillation/output/base.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
)
1212
from llmcompressor.utils.fsdp.context import summon_full_params_context
1313
from llmcompressor.utils.fsdp.helpers import maybe_get_wrapped, set_wrapped_model
14-
from llmcompressor.utils.pytorch.module import get_layers, set_layer
14+
from llmcompressor.utils.pytorch.module import get_layers
1515

1616
__all__ = ["OutputDistillationModifier"]
1717

@@ -85,8 +85,8 @@ def on_initialize(self, state: State, **kwargs) -> bool:
8585

8686
with summon_full_params_context(state.teacher_model, offload_to_cpu=True):
8787
for key, (student_wrapper, teacher_wrapper) in self.wrappers_.items():
88-
set_layer(key, student_wrapper, state.model)
89-
set_layer(key, teacher_wrapper, state.teacher_model)
88+
Module.set_submodule(key, student_wrapper, state.model)
89+
Module.set_submodule(key, teacher_wrapper, state.teacher_model)
9090

9191
self.wrapped_kd_model_ = self._create_model_wrapper(
9292
student_model=maybe_get_wrapped(state.model),
@@ -109,8 +109,8 @@ def on_finalize(self, state: State, **kwargs) -> bool:
109109

110110
with summon_full_params_context(state.teacher_model, offload_to_cpu=True):
111111
for key, (student_wrapper, teacher_wrapper) in self.wrappers_.items():
112-
set_layer(key, student_wrapper.layer, state.model)
113-
set_layer(key, teacher_wrapper.layer, state.teacher_model)
112+
Module.set_submodule(key, student_wrapper.layer, state.model)
113+
Module.set_submodule(key, teacher_wrapper.layer, state.teacher_model)
114114
del student_wrapper
115115
del teacher_wrapper
116116

src/llmcompressor/utils/pytorch/module.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@
5151
"match_layers_params",
5252
"get_layers",
5353
"get_layer",
54-
"set_layer",
5554
"get_params",
5655
"get_param",
5756
"get_terminal_layers",
@@ -197,22 +196,6 @@ def get_layer(target: str, module: Module) -> Tuple[str, Module]:
197196
return name, layer
198197

199198

200-
def set_layer(target: str, layer: Module, module: Module) -> Module:
201-
with summon_full_params_context(module):
202-
# importing here to avoid circular import
203-
from llmcompressor.utils.fsdp.helpers import maybe_get_wrapped
204-
205-
parent_target = ".".join(target.split(".")[:-1])
206-
if parent_target != "":
207-
parent_layer = get_layer(parent_target, module)[1]
208-
else:
209-
parent_layer = maybe_get_wrapped(module)
210-
old_layer = getattr(parent_layer, target.split(".")[-1])
211-
setattr(parent_layer, target.split(".")[-1], layer)
212-
213-
return old_layer
214-
215-
216199
def get_params(targets: Union[str, List[str]], module: Module) -> Dict[str, Parameter]:
217200
return match_layers_params(targets, module, params=True)
218201

0 commit comments

Comments
 (0)