Skip to content

Commit 354a6b9

Browse files
authored
reposition private methods (#213)
1 parent e5df80c commit 354a6b9

File tree

1 file changed

+46
-46
lines changed

1 file changed

+46
-46
lines changed

finetrainers/trainer.py

Lines changed: 46 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -122,52 +122,6 @@ def prepare_dataset(self) -> None:
122122
pin_memory=self.args.pin_memory,
123123
)
124124

125-
def _get_load_components_kwargs(self) -> Dict[str, Any]:
126-
load_component_kwargs = {
127-
"text_encoder_dtype": self.args.text_encoder_dtype,
128-
"text_encoder_2_dtype": self.args.text_encoder_2_dtype,
129-
"text_encoder_3_dtype": self.args.text_encoder_3_dtype,
130-
"transformer_dtype": self.args.transformer_dtype,
131-
"vae_dtype": self.args.vae_dtype,
132-
"shift": self.args.flow_shift,
133-
"revision": self.args.revision,
134-
"cache_dir": self.args.cache_dir,
135-
}
136-
if self.args.pretrained_model_name_or_path is not None:
137-
load_component_kwargs["model_id"] = self.args.pretrained_model_name_or_path
138-
return load_component_kwargs
139-
140-
def _set_components(self, components: Dict[str, Any]) -> None:
141-
# Set models
142-
self.tokenizer = components.get("tokenizer", self.tokenizer)
143-
self.tokenizer_2 = components.get("tokenizer_2", self.tokenizer_2)
144-
self.tokenizer_3 = components.get("tokenizer_3", self.tokenizer_3)
145-
self.text_encoder = components.get("text_encoder", self.text_encoder)
146-
self.text_encoder_2 = components.get("text_encoder_2", self.text_encoder_2)
147-
self.text_encoder_3 = components.get("text_encoder_3", self.text_encoder_3)
148-
self.transformer = components.get("transformer", self.transformer)
149-
self.unet = components.get("unet", self.unet)
150-
self.vae = components.get("vae", self.vae)
151-
self.scheduler = components.get("scheduler", self.scheduler)
152-
153-
# Set configs
154-
self.transformer_config = self.transformer.config if self.transformer is not None else self.transformer_config
155-
self.vae_config = self.vae.config if self.vae is not None else self.vae_config
156-
157-
def _delete_components(self) -> None:
158-
self.tokenizer = None
159-
self.tokenizer_2 = None
160-
self.tokenizer_3 = None
161-
self.text_encoder = None
162-
self.text_encoder_2 = None
163-
self.text_encoder_3 = None
164-
self.transformer = None
165-
self.unet = None
166-
self.vae = None
167-
self.scheduler = None
168-
free_memory()
169-
torch.cuda.synchronize(self.state.accelerator.device)
170-
171125
def prepare_models(self) -> None:
172126
logger.info("Initializing models")
173127

@@ -1109,6 +1063,52 @@ def _move_components_to_device(self):
11091063
if self.vae is not None:
11101064
self.vae = self.vae.to(self.state.accelerator.device)
11111065

1066+
def _get_load_components_kwargs(self) -> Dict[str, Any]:
1067+
load_component_kwargs = {
1068+
"text_encoder_dtype": self.args.text_encoder_dtype,
1069+
"text_encoder_2_dtype": self.args.text_encoder_2_dtype,
1070+
"text_encoder_3_dtype": self.args.text_encoder_3_dtype,
1071+
"transformer_dtype": self.args.transformer_dtype,
1072+
"vae_dtype": self.args.vae_dtype,
1073+
"shift": self.args.flow_shift,
1074+
"revision": self.args.revision,
1075+
"cache_dir": self.args.cache_dir,
1076+
}
1077+
if self.args.pretrained_model_name_or_path is not None:
1078+
load_component_kwargs["model_id"] = self.args.pretrained_model_name_or_path
1079+
return load_component_kwargs
1080+
1081+
def _set_components(self, components: Dict[str, Any]) -> None:
1082+
# Set models
1083+
self.tokenizer = components.get("tokenizer", self.tokenizer)
1084+
self.tokenizer_2 = components.get("tokenizer_2", self.tokenizer_2)
1085+
self.tokenizer_3 = components.get("tokenizer_3", self.tokenizer_3)
1086+
self.text_encoder = components.get("text_encoder", self.text_encoder)
1087+
self.text_encoder_2 = components.get("text_encoder_2", self.text_encoder_2)
1088+
self.text_encoder_3 = components.get("text_encoder_3", self.text_encoder_3)
1089+
self.transformer = components.get("transformer", self.transformer)
1090+
self.unet = components.get("unet", self.unet)
1091+
self.vae = components.get("vae", self.vae)
1092+
self.scheduler = components.get("scheduler", self.scheduler)
1093+
1094+
# Set configs
1095+
self.transformer_config = self.transformer.config if self.transformer is not None else self.transformer_config
1096+
self.vae_config = self.vae.config if self.vae is not None else self.vae_config
1097+
1098+
def _delete_components(self) -> None:
1099+
self.tokenizer = None
1100+
self.tokenizer_2 = None
1101+
self.tokenizer_3 = None
1102+
self.text_encoder = None
1103+
self.text_encoder_2 = None
1104+
self.text_encoder_3 = None
1105+
self.transformer = None
1106+
self.unet = None
1107+
self.vae = None
1108+
self.scheduler = None
1109+
free_memory()
1110+
torch.cuda.synchronize(self.state.accelerator.device)
1111+
11121112
def _get_training_dtype(self, accelerator) -> torch.dtype:
11131113
weight_dtype = torch.float32
11141114
if accelerator.state.deepspeed_plugin:

0 commit comments

Comments
 (0)