@@ -8,18 +8,21 @@ class DownSampleBlock : public GGMLBlock {
88 int channels;
99 int out_channels;
1010 bool vae_downsample;
11+ bool direct = false ;
1112
1213public:
1314 DownSampleBlock (int channels,
1415 int out_channels,
15- bool vae_downsample = false )
16+ bool vae_downsample = false ,
17+ bool direct = false )
1618 : channels(channels),
1719 out_channels (out_channels),
18- vae_downsample(vae_downsample) {
20+ vae_downsample(vae_downsample),
21+ direct(direct) {
1922 if (vae_downsample) {
20- blocks[" conv" ] = std::shared_ptr<GGMLBlock>(new Conv2d (channels, out_channels, {3 , 3 }, {2 , 2 }, {0 , 0 }, {1 , 1 }, true , true ));
23+ blocks[" conv" ] = std::shared_ptr<GGMLBlock>(new Conv2d (channels, out_channels, {3 , 3 }, {2 , 2 }, {0 , 0 }, {1 , 1 }, true , direct ));
2124 } else {
22- blocks[" op" ] = std::shared_ptr<GGMLBlock>(new Conv2d (channels, out_channels, {3 , 3 }, {2 , 2 }, {1 , 1 }));
25+ blocks[" op" ] = std::shared_ptr<GGMLBlock>(new Conv2d (channels, out_channels, {3 , 3 }, {2 , 2 }, {1 , 1 }, { 1 , 1 }, true , direct ));
2326 }
2427 }
2528
@@ -43,13 +46,16 @@ class UpSampleBlock : public GGMLBlock {
4346protected:
4447 int channels;
4548 int out_channels;
49+ bool direct = false ;
4650
4751public:
4852 UpSampleBlock (int channels,
49- int out_channels)
53+ int out_channels,
54+ bool direct = false )
5055 : channels(channels),
51- out_channels (out_channels) {
52- blocks[" conv" ] = std::shared_ptr<GGMLBlock>(new Conv2d (channels, out_channels, {3 , 3 }, {1 , 1 }, {1 , 1 }, {1 , 1 }, true , true ));
56+ out_channels (out_channels),
57+ direct(direct) {
58+ blocks[" conv" ] = std::shared_ptr<GGMLBlock>(new Conv2d (channels, out_channels, {3 , 3 }, {1 , 1 }, {1 , 1 }, {1 , 1 }, true , direct));
5359 }
5460
5561 struct ggml_tensor * forward (struct ggml_context * ctx, struct ggml_tensor * x) {
@@ -381,7 +387,8 @@ class SpatialTransformer : public GGMLBlock {
381387 int64_t d_head,
382388 int64_t depth,
383389 int64_t context_dim,
384- bool flash_attn = false )
390+ bool flash_attn = false ,
391+ bool direct = false )
385392 : in_channels(in_channels),
386393 n_head (n_head),
387394 d_head(d_head),
@@ -391,14 +398,14 @@ class SpatialTransformer : public GGMLBlock {
391398 // disable_self_attn is always False
392399 int64_t inner_dim = n_head * d_head; // in_channels
393400 blocks[" norm" ] = std::shared_ptr<GGMLBlock>(new GroupNorm32 (in_channels));
394- blocks[" proj_in" ] = std::shared_ptr<GGMLBlock>(new Conv2d (in_channels, inner_dim, {1 , 1 }));
401+ blocks[" proj_in" ] = std::shared_ptr<GGMLBlock>(new Conv2d (in_channels, inner_dim, {1 , 1 }, { 1 , 1 }, { 0 , 0 }, { 1 , 1 }, true , direct ));
395402
396403 for (int i = 0 ; i < depth; i++) {
397404 std::string name = " transformer_blocks." + std::to_string (i);
398405 blocks[name] = std::shared_ptr<GGMLBlock>(new BasicTransformerBlock (inner_dim, n_head, d_head, context_dim, false , flash_attn));
399406 }
400407
401- blocks[" proj_out" ] = std::shared_ptr<GGMLBlock>(new Conv2d (inner_dim, in_channels, {1 , 1 }));
408+ blocks[" proj_out" ] = std::shared_ptr<GGMLBlock>(new Conv2d (inner_dim, in_channels, {1 , 1 }, { 1 , 1 }, { 0 , 0 }, { 1 , 1 }, true , direct ));
402409 }
403410
404411 virtual struct ggml_tensor * forward (struct ggml_context * ctx, struct ggml_tensor * x, struct ggml_tensor * context) {
0 commit comments