diff --git a/extensions_built_in/diffusion_models/chroma/chroma_model.py b/extensions_built_in/diffusion_models/chroma/chroma_model.py index 236d9508b..83ec1a669 100644 --- a/extensions_built_in/diffusion_models/chroma/chroma_model.py +++ b/extensions_built_in/diffusion_models/chroma/chroma_model.py @@ -167,8 +167,13 @@ def load_model(self): chroma_params.depth = double_blocks chroma_params.depth_single_blocks = single_blocks + + # load Chroma into RAM in bfloat16, go back to fp32 afterwards + def_dtype = torch.get_default_dtype() + torch.set_default_dtype(torch.bfloat16) transformer = Chroma(chroma_params) - + torch.set_default_dtype(def_dtype) + # add dtype, not sure why it doesnt have it transformer.dtype = dtype # load the state dict into the model @@ -420,9 +425,13 @@ def get_model_has_grad(self): return self.model.final_layer.linear.weight.requires_grad def get_te_has_grad(self): - # return from a weight if it has grad - return self.text_encoder[1].encoder.block[0].layer[0].SelfAttention.q.weight.requires_grad - + from toolkit.unloader import FakeTextEncoder + + te = self.text_encoder[1] + if isinstance(te, FakeTextEncoder): + return False + return te.encoder.block[0].layer[0].SelfAttention.q.weight.requires_grad + def save_model(self, output_path, meta, save_dtype): if not output_path.endswith(".safetensors"): output_path = output_path + ".safetensors"