Skip to content
Merged
4 changes: 2 additions & 2 deletions common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class DownSampleBlock : public GGMLBlock {
out_channels(out_channels),
vae_downsample(vae_downsample) {
if (vae_downsample) {
blocks["conv"] = std::shared_ptr<GGMLBlock>(new Conv2d(channels, out_channels, {3, 3}, {2, 2}, {0, 0}));
blocks["conv"] = std::shared_ptr<GGMLBlock>(new Conv2d(channels, out_channels, {3, 3}, {2, 2}, {0, 0}, {1, 1}, true, true));
} else {
blocks["op"] = std::shared_ptr<GGMLBlock>(new Conv2d(channels, out_channels, {3, 3}, {2, 2}, {1, 1}));
}
Expand Down Expand Up @@ -49,7 +49,7 @@ class UpSampleBlock : public GGMLBlock {
int out_channels)
: channels(channels),
out_channels(out_channels) {
blocks["conv"] = std::shared_ptr<GGMLBlock>(new Conv2d(channels, out_channels, {3, 3}, {1, 1}, {1, 1}));
blocks["conv"] = std::shared_ptr<GGMLBlock>(new Conv2d(channels, out_channels, {3, 3}, {1, 1}, {1, 1}, {1, 1}, true, true));
}

struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
Expand Down
36 changes: 33 additions & 3 deletions ggml_extend.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -706,6 +706,25 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_conv_2d(struct ggml_context* ctx,
return x;
}

__STATIC_INLINE__ struct ggml_tensor* ggml_nn_conv_2d_direct(struct ggml_context* ctx,
struct ggml_tensor* x,
struct ggml_tensor* w,
struct ggml_tensor* b,
int s0 = 1,
int s1 = 1,
int p0 = 0,
int p1 = 0,
int d0 = 1,
int d1 = 1) {
x = ggml_conv_2d_direct(ctx, w, x, s0, s1, p0, p1, d0, d1);
if (b != NULL) {
b = ggml_reshape_4d(ctx, b, 1, 1, b->ne[0], 1);
// b = ggml_repeat(ctx, b, x);
x = ggml_add(ctx, x, b);
}
return x;
}

// w: [OC,IC, KD, 1 * 1]
// x: [N, IC, IH, IW]
// b: [OC,]
Expand Down Expand Up @@ -1456,6 +1475,7 @@ class Conv2d : public UnaryBlock {
std::pair<int, int> padding;
std::pair<int, int> dilation;
bool bias;
bool direct;

void init_params(struct ggml_context* ctx, std::map<std::string, enum ggml_type>& tensor_types, const std::string prefix = "") {
enum ggml_type wtype = GGML_TYPE_F16; //(tensor_types.find(prefix + "weight") != tensor_types.end()) ? tensor_types[prefix + "weight"] : GGML_TYPE_F16;
Expand All @@ -1473,22 +1493,32 @@ class Conv2d : public UnaryBlock {
std::pair<int, int> stride = {1, 1},
std::pair<int, int> padding = {0, 0},
std::pair<int, int> dilation = {1, 1},
bool bias = true)
bool bias = true,
bool direct = false)
: in_channels(in_channels),
out_channels(out_channels),
kernel_size(kernel_size),
stride(stride),
padding(padding),
dilation(dilation),
bias(bias) {}
bias(bias),
direct(direct) {}

struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
struct ggml_tensor* w = params["weight"];
struct ggml_tensor* b = NULL;
if (bias) {
b = params["bias"];
}
return ggml_nn_conv_2d(ctx, x, w, b, stride.second, stride.first, padding.second, padding.first, dilation.second, dilation.first);
if (direct) {
#if defined(SD_USE_CUDA) || defined(SD_USE_SYCL) || defined(SD_USE_METAL) || defined(SD_USE_OPENCL)
return ggml_nn_conv_2d(ctx, x, w, b, stride.second, stride.first, padding.second, padding.first, dilation.second, dilation.first);
#else
return ggml_nn_conv_2d_direct(ctx, x, w, b, stride.second, stride.first, padding.second, padding.first, dilation.second, dilation.first);
#endif
} else {
return ggml_nn_conv_2d(ctx, x, w, b, stride.second, stride.first, padding.second, padding.first, dilation.second, dilation.first);
}
}
};

Expand Down
48 changes: 31 additions & 17 deletions vae.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,13 @@ class ResnetBlock : public UnaryBlock {
out_channels(out_channels) {
// temb_channels is always 0
blocks["norm1"] = std::shared_ptr<GGMLBlock>(new GroupNorm32(in_channels));
blocks["conv1"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, out_channels, {3, 3}, {1, 1}, {1, 1}));
blocks["conv1"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, out_channels, {3, 3}, {1, 1}, {1, 1}, {1, 1}, true, true));

blocks["norm2"] = std::shared_ptr<GGMLBlock>(new GroupNorm32(out_channels));
blocks["conv2"] = std::shared_ptr<GGMLBlock>(new Conv2d(out_channels, out_channels, {3, 3}, {1, 1}, {1, 1}));
blocks["conv2"] = std::shared_ptr<GGMLBlock>(new Conv2d(out_channels, out_channels, {3, 3}, {1, 1}, {1, 1}, {1, 1}, true, true));

if (out_channels != in_channels) {
blocks["nin_shortcut"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, out_channels, {1, 1}));
blocks["nin_shortcut"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, out_channels, {1, 1}, {1, 1}, {0, 0}, {1, 1}, true, true));
}
}

Expand Down Expand Up @@ -69,11 +69,11 @@ class AttnBlock : public UnaryBlock {
AttnBlock(int64_t in_channels)
: in_channels(in_channels) {
blocks["norm"] = std::shared_ptr<GGMLBlock>(new GroupNorm32(in_channels));
blocks["q"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, in_channels, {1, 1}));
blocks["k"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, in_channels, {1, 1}));
blocks["v"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, in_channels, {1, 1}));
blocks["q"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, in_channels, {1, 1}, {1, 1}, {0, 0}, {1, 1}, true, true));
blocks["k"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, in_channels, {1, 1}, {1, 1}, {0, 0}, {1, 1}, true, true));
blocks["v"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, in_channels, {1, 1}, {1, 1}, {0, 0}, {1, 1}, true, true));

blocks["proj_out"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, in_channels, {1, 1}));
blocks["proj_out"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, in_channels, {1, 1}, {1, 1}, {0, 0}, {1, 1}, true, true));
}

struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
Expand Down Expand Up @@ -123,8 +123,9 @@ class AE3DConv : public Conv2d {
std::pair<int, int> stride = {1, 1},
std::pair<int, int> padding = {0, 0},
std::pair<int, int> dilation = {1, 1},
bool bias = true)
: Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias) {
bool bias = true,
bool direct = false)
: Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias, direct) {
int64_t kernel_padding = video_kernel_size / 2;
blocks["time_mix_conv"] = std::shared_ptr<GGMLBlock>(new Conv3dnx1x1(out_channels,
out_channels,
Expand Down Expand Up @@ -240,7 +241,7 @@ class Encoder : public GGMLBlock {
in_channels(in_channels),
z_channels(z_channels),
double_z(double_z) {
blocks["conv_in"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, ch, {3, 3}, {1, 1}, {1, 1}));
blocks["conv_in"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, ch, {3, 3}, {1, 1}, {1, 1}, {1, 1}, true, true));

size_t num_resolutions = ch_mult.size();

Expand Down Expand Up @@ -268,7 +269,7 @@ class Encoder : public GGMLBlock {
blocks["mid.block_2"] = std::shared_ptr<GGMLBlock>(new ResnetBlock(block_in, block_in));

blocks["norm_out"] = std::shared_ptr<GGMLBlock>(new GroupNorm32(block_in));
blocks["conv_out"] = std::shared_ptr<GGMLBlock>(new Conv2d(block_in, double_z ? z_channels * 2 : z_channels, {3, 3}, {1, 1}, {1, 1}));
blocks["conv_out"] = std::shared_ptr<GGMLBlock>(new Conv2d(block_in, double_z ? z_channels * 2 : z_channels, {3, 3}, {1, 1}, {1, 1}, {1, 1}, true, true));
}

virtual struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
Expand Down Expand Up @@ -328,11 +329,14 @@ class Decoder : public GGMLBlock {
int64_t out_channels,
std::pair<int, int> kernel_size,
std::pair<int, int> stride = {1, 1},
std::pair<int, int> padding = {0, 0}) {
std::pair<int, int> padding = {0, 0},
std::pair<int, int> dilation = {1, 1},
bool bias = true,
bool direct = false){
if (video_decoder) {
return std::shared_ptr<GGMLBlock>(new AE3DConv(in_channels, out_channels, kernel_size, video_kernel_size, stride, padding));
} else {
return std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, out_channels, kernel_size, stride, padding));
return std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias, direct));
}
}

Expand Down Expand Up @@ -363,7 +367,7 @@ class Decoder : public GGMLBlock {
size_t num_resolutions = ch_mult.size();
int block_in = ch * ch_mult[num_resolutions - 1];

blocks["conv_in"] = std::shared_ptr<GGMLBlock>(new Conv2d(z_channels, block_in, {3, 3}, {1, 1}, {1, 1}));
blocks["conv_in"] = std::shared_ptr<GGMLBlock>(new Conv2d(z_channels, block_in, {3, 3}, {1, 1}, {1, 1}, {1, 1}, true, true));

blocks["mid.block_1"] = get_resnet_block(block_in, block_in);
blocks["mid.attn_1"] = std::shared_ptr<GGMLBlock>(new AttnBlock(block_in));
Expand All @@ -385,7 +389,7 @@ class Decoder : public GGMLBlock {
}

blocks["norm_out"] = std::shared_ptr<GGMLBlock>(new GroupNorm32(block_in));
blocks["conv_out"] = get_conv_out(block_in, out_ch, {3, 3}, {1, 1}, {1, 1});
blocks["conv_out"] = get_conv_out(block_in, out_ch, {3, 3}, {1, 1}, {1, 1}, {1, 1}, true, true);
}

virtual struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* z) {
Expand Down Expand Up @@ -474,7 +478,12 @@ class AutoencodingEngine : public GGMLBlock {
if (use_quant) {
blocks["post_quant_conv"] = std::shared_ptr<GGMLBlock>(new Conv2d(dd_config.z_channels,
embed_dim,
{1, 1}));
{1, 1},
{1, 1},
{0, 0},
{1, 1},
true,
true));
}
if (!decode_only) {
blocks["encoder"] = std::shared_ptr<GGMLBlock>(new Encoder(dd_config.ch,
Expand All @@ -488,7 +497,12 @@ class AutoencodingEngine : public GGMLBlock {

blocks["quant_conv"] = std::shared_ptr<GGMLBlock>(new Conv2d(embed_dim * factor,
dd_config.z_channels * factor,
{1, 1}));
{1, 1},
{1, 1},
{0, 0},
{1, 1},
true,
true));
}
}
}
Expand Down
Loading