11
11
)
12
12
from llmcompressor .utils .fsdp .context import summon_full_params_context
13
13
from llmcompressor .utils .fsdp .helpers import maybe_get_wrapped , set_wrapped_model
14
- from llmcompressor .utils .pytorch .module import get_layers , set_layer
14
+ from llmcompressor .utils .pytorch .module import get_layers
15
15
16
16
__all__ = ["OutputDistillationModifier" ]
17
17
@@ -85,8 +85,8 @@ def on_initialize(self, state: State, **kwargs) -> bool:
85
85
86
86
with summon_full_params_context (state .teacher_model , offload_to_cpu = True ):
87
87
for key , (student_wrapper , teacher_wrapper ) in self .wrappers_ .items ():
88
- set_layer (key , student_wrapper , state .model )
89
- set_layer (key , teacher_wrapper , state .teacher_model )
88
+ Module . set_submodule (key , student_wrapper , state .model )
89
+ Module . set_submodule (key , teacher_wrapper , state .teacher_model )
90
90
91
91
self .wrapped_kd_model_ = self ._create_model_wrapper (
92
92
student_model = maybe_get_wrapped (state .model ),
@@ -109,8 +109,8 @@ def on_finalize(self, state: State, **kwargs) -> bool:
109
109
110
110
with summon_full_params_context (state .teacher_model , offload_to_cpu = True ):
111
111
for key , (student_wrapper , teacher_wrapper ) in self .wrappers_ .items ():
112
- set_layer (key , student_wrapper .layer , state .model )
113
- set_layer (key , teacher_wrapper .layer , state .teacher_model )
112
+ Module . set_submodule (key , student_wrapper .layer , state .model )
113
+ Module . set_submodule (key , teacher_wrapper .layer , state .teacher_model )
114
114
del student_wrapper
115
115
del teacher_wrapper
116
116
0 commit comments