Skip to content

Commit 04ca926

Browse files
committed
Refactor: fix controlnet and tae
1 parent 371d81f commit 04ca926

File tree

3 files changed

+8
-9
lines changed

3 files changed

+8
-9
lines changed

control.hpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -317,12 +317,10 @@ struct ControlNet : public GGMLRunner {
317317
bool guided_hint_cached = false;
318318

319319
ControlNet(ggml_backend_t backend,
320-
SDVersion version = VERSION_SD1)
320+
std::map<std::string, enum ggml_type>& tensor_types,
321+
SDVersion version = VERSION_SD1)
321322
: GGMLRunner(backend), control_net(version) {
322-
}
323-
324-
void init_params(std::map<std::string, enum ggml_type>& tensor_types, const std::string prefix) {
325-
control_net.init(params_ctx, tensor_types, prefix);
323+
control_net.init(params_ctx, tensor_types, "");
326324
}
327325

328326
~ControlNet() {

stable-diffusion.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,7 @@ class StableDiffusionGGML {
358358
first_stage_model->alloc_params_buffer();
359359
first_stage_model->get_param_tensors(tensors, "first_stage_model");
360360
} else {
361-
tae_first_stage = std::make_shared<TinyAutoEncoder>(backend, vae_decode_only);
361+
tae_first_stage = std::make_shared<TinyAutoEncoder>(backend, model_loader.tensor_storages_types, "decoder.layers", vae_decode_only);
362362
}
363363
// first_stage_model->get_param_tensors(tensors, "first_stage_model.");
364364

@@ -370,7 +370,7 @@ class StableDiffusionGGML {
370370
} else {
371371
controlnet_backend = backend;
372372
}
373-
control_net = std::make_shared<ControlNet>(controlnet_backend, version);
373+
control_net = std::make_shared<ControlNet>(controlnet_backend, model_loader.tensor_storages_types, version);
374374
}
375375

376376
if (id_embeddings_path.find("v2") != std::string::npos) {

tae.hpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -188,14 +188,15 @@ struct TinyAutoEncoder : public GGMLRunner {
188188
bool decode_only = false;
189189

190190
TinyAutoEncoder(ggml_backend_t backend,
191+
std::map<std::string, enum ggml_type>& tensor_types,
192+
const std::string prefix,
191193
bool decoder_only = true)
192194
: decode_only(decoder_only),
193195
taesd(decode_only),
194196
GGMLRunner(backend) {
195-
}
196-
void init_params(std::map<std::string, enum ggml_type>& tensor_types, const std::string prefix) {
197197
taesd.init(params_ctx, tensor_types, prefix);
198198
}
199+
199200
std::string get_desc() {
200201
return "taesd";
201202
}

0 commit comments

Comments
 (0)