Skip to content

Commit 371d81f

Browse files
committed
unet: refactor the refactoring
1 parent cb46146 commit 371d81f

File tree

2 files changed

+3
-5
lines changed

2 files changed

+3
-5
lines changed

diffusion_model.hpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,7 @@ struct UNetModel : public DiffusionModel {
3434
std::map<std::string, enum ggml_type>& tensor_types,
3535
SDVersion version = VERSION_SD1,
3636
bool flash_attn = false)
37-
: unet(backend, version, flash_attn) {
38-
unet.init_params(tensor_types, "model.diffusion_model");
37+
: unet(backend, tensor_types, "model.diffusion_model", version, flash_attn) {
3938
}
4039

4140
void alloc_params_buffer() {

unet.hpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -532,12 +532,11 @@ struct UNetModelRunner : public GGMLRunner {
532532
UnetModelBlock unet;
533533

534534
UNetModelRunner(ggml_backend_t backend,
535+
std::map<std::string, enum ggml_type>& tensor_types,
536+
const std::string prefix,
535537
SDVersion version = VERSION_SD1,
536538
bool flash_attn = false)
537539
: GGMLRunner(backend), unet(version, flash_attn) {
538-
}
539-
540-
void init_params(std::map<std::string, enum ggml_type>& tensor_types, const std::string prefix) {
541540
unet.init(params_ctx, tensor_types, prefix);
542541
}
543542

0 commit comments

Comments
 (0)