Skip to content

Commit 1e18ba0

Browse files
committed
is_expendable argument reduces memory usage for command line script
1 parent 3837710 commit 1e18ba0

File tree

4 files changed

+81
-31
lines changed

4 files changed

+81
-31
lines changed

image_from_text.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
parser.add_argument('--text', type=str, default='alien life')
1616
parser.add_argument('--seed', type=int, default=7)
1717
parser.add_argument('--image_path', type=str, default='generated')
18-
parser.add_argument('--sample_token_count', type=int, default=256) # for debugging
18+
parser.add_argument('--token_count', type=int, default=256) # for debugging
1919

2020

2121
def ascii_from_image(image: Image.Image, size: int) -> str:
@@ -42,20 +42,21 @@ def generate_image(
4242
text: str,
4343
seed: int,
4444
image_path: str,
45-
sample_token_count: int
45+
token_count: int
4646
):
47+
is_expendable = True
4748
if is_torch:
48-
image_generator = MinDalleTorch(is_mega, sample_token_count)
49-
image_tokens = image_generator.generate_image_tokens(text, seed)
49+
image_generator = MinDalleTorch(is_mega, is_expendable, token_count)
5050

51-
if sample_token_count < image_generator.config['image_length']:
51+
if token_count < image_generator.config['image_length']:
52+
image_tokens = image_generator.generate_image_tokens(text, seed)
5253
print('image tokens', list(image_tokens.to('cpu').detach().numpy()))
5354
return
5455
else:
5556
image = image_generator.generate_image(text, seed)
5657

5758
else:
58-
image_generator = MinDalleFlax(is_mega)
59+
image_generator = MinDalleFlax(is_mega, is_expendable=True)
5960
image = image_generator.generate_image(text, seed)
6061

6162
save_image(image, image_path)
@@ -71,5 +72,5 @@ def generate_image(
7172
text=args.text,
7273
seed=args.seed,
7374
image_path=args.image_path,
74-
sample_token_count=args.sample_token_count
75+
token_count=args.token_count
7576
)
Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from .load_params import load_vqgan_torch_params, load_dalle_bart_flax_params
77
from .models.vqgan_detokenizer import VQGanDetokenizer
88

9-
class MinDalle:
9+
class MinDalleBase:
1010
def __init__(self, is_mega: bool):
1111
self.is_mega = is_mega
1212
model_name = 'dalle_bart_{}'.format('mega' if is_mega else 'mini')
@@ -25,11 +25,15 @@ def __init__(self, is_mega: bool):
2525
merges = f.read().split("\n")[1:-1]
2626

2727
self.model_params = load_dalle_bart_flax_params(model_path)
28-
2928
self.tokenizer = TextTokenizer(vocab, merges)
29+
30+
31+
def init_detokenizer(self):
32+
print("initializing VQGanDetokenizer")
33+
params = load_vqgan_torch_params('./pretrained/vqgan')
3034
self.detokenizer = VQGanDetokenizer()
31-
vqgan_params = load_vqgan_torch_params('./pretrained/vqgan')
32-
self.detokenizer.load_state_dict(vqgan_params)
35+
self.detokenizer.load_state_dict(params)
36+
del params
3337

3438

3539
def tokenize_text(self, text: str) -> numpy.ndarray:

min_dalle/min_dalle_flax.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,25 @@
33
from PIL import Image
44
import torch
55

6-
from .min_dalle import MinDalle
6+
from .min_dalle_base import MinDalleBase
77
from .models.dalle_bart_encoder_flax import DalleBartEncoderFlax
88
from .models.dalle_bart_decoder_flax import DalleBartDecoderFlax
99

1010

11-
class MinDalleFlax(MinDalle):
12-
def __init__(self, is_mega: bool):
11+
class MinDalleFlax(MinDalleBase):
12+
def __init__(self, is_mega: bool, is_expendable: bool = False):
1313
super().__init__(is_mega)
14+
self.is_expendable = is_expendable
1415
print("initializing MinDalleFlax")
16+
if not is_expendable:
17+
self.init_encoder()
18+
self.init_decoder()
19+
self.init_detokenizer()
1520

16-
print("loading encoder")
17-
self.encoder = DalleBartEncoderFlax(
21+
22+
def init_encoder(self):
23+
print("initializing DalleBartEncoderFlax")
24+
self.encoder: DalleBartEncoderFlax = DalleBartEncoderFlax(
1825
attention_head_count = self.config['encoder_attention_heads'],
1926
embed_count = self.config['d_model'],
2027
glu_embed_count = self.config['encoder_ffn_dim'],
@@ -23,7 +30,9 @@ def __init__(self, is_mega: bool):
2330
layer_count = self.config['encoder_layers']
2431
).bind({'params': self.model_params.pop('encoder')})
2532

26-
print("loading decoder")
33+
34+
def init_decoder(self):
35+
print("initializing DalleBartDecoderFlax")
2736
self.decoder = DalleBartDecoderFlax(
2837
image_token_count = self.config['image_length'],
2938
text_token_count = self.config['max_text_length'],
@@ -39,20 +48,30 @@ def __init__(self, is_mega: bool):
3948
def generate_image(self, text: str, seed: int) -> Image.Image:
4049
text_tokens = self.tokenize_text(text)
4150

51+
if self.is_expendable: self.init_encoder()
4252
print("encoding text tokens")
4353
encoder_state = self.encoder(text_tokens)
54+
if self.is_expendable: del self.encoder
4455

56+
if self.is_expendable:
57+
self.init_decoder()
58+
params = self.model_params.pop('decoder')
59+
else:
60+
params = self.model_params['decoder']
4561
print("sampling image tokens")
4662
image_tokens = self.decoder.sample_image_tokens(
4763
text_tokens,
4864
encoder_state,
4965
jax.random.PRNGKey(seed),
50-
self.model_params['decoder']
66+
params
5167
)
68+
if self.is_expendable: del self.decoder
5269

5370
image_tokens = torch.tensor(numpy.array(image_tokens))
5471

72+
if self.is_expendable: self.init_detokenizer()
5573
print("detokenizing image")
5674
image = self.detokenizer.forward(image_tokens).to(torch.uint8)
75+
if self.is_expendable: del self.detokenizer
5776
image = Image.fromarray(image.to('cpu').detach().numpy())
5877
return image

min_dalle/min_dalle_torch.py

Lines changed: 39 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,30 @@
99
torch.set_num_threads(os.cpu_count())
1010

1111
from .load_params import convert_dalle_bart_torch_from_flax_params
12-
from .min_dalle import MinDalle
12+
from .min_dalle_base import MinDalleBase
1313
from .models.dalle_bart_encoder_torch import DalleBartEncoderTorch
1414
from .models.dalle_bart_decoder_torch import DalleBartDecoderTorch
1515

1616

17-
class MinDalleTorch(MinDalle):
18-
def __init__(self, is_mega: bool, sample_token_count: int = 256):
17+
class MinDalleTorch(MinDalleBase):
18+
def __init__(
19+
self,
20+
is_mega: bool,
21+
is_expendable: bool = False,
22+
token_count: int = 256
23+
):
1924
super().__init__(is_mega)
25+
self.is_expendable = is_expendable
26+
self.token_count = token_count
2027
print("initializing MinDalleTorch")
28+
if not is_expendable:
29+
self.init_encoder()
30+
self.init_decoder()
31+
self.init_detokenizer()
2132

22-
print("loading encoder")
33+
34+
def init_encoder(self):
35+
print("initializing DalleBartEncoderTorch")
2336
self.encoder = DalleBartEncoderTorch(
2437
layer_count = self.config['encoder_layers'],
2538
embed_count = self.config['d_model'],
@@ -28,18 +41,22 @@ def __init__(self, is_mega: bool, sample_token_count: int = 256):
2841
text_token_count = self.config['max_text_length'],
2942
glu_embed_count = self.config['encoder_ffn_dim']
3043
)
31-
encoder_params = convert_dalle_bart_torch_from_flax_params(
44+
params = convert_dalle_bart_torch_from_flax_params(
3245
self.model_params.pop('encoder'),
3346
layer_count=self.config['encoder_layers'],
3447
is_encoder=True
3548
)
36-
self.encoder.load_state_dict(encoder_params, strict=False)
49+
self.encoder.load_state_dict(params, strict=False)
50+
if torch.cuda.is_available(): self.encoder = self.encoder.cuda()
51+
del params
52+
3753

38-
print("loading decoder")
54+
def init_decoder(self):
55+
print("initializing DalleBartDecoderTorch")
3956
self.decoder = DalleBartDecoderTorch(
4057
image_vocab_size = self.config['image_vocab_size'],
4158
image_token_count = self.config['image_length'],
42-
sample_token_count = sample_token_count,
59+
sample_token_count = self.token_count,
4360
embed_count = self.config['d_model'],
4461
attention_head_count = self.config['decoder_attention_heads'],
4562
glu_embed_count = self.config['decoder_ffn_dim'],
@@ -48,36 +65,45 @@ def __init__(self, is_mega: bool, sample_token_count: int = 256):
4865
start_token = self.config['decoder_start_token_id'],
4966
is_verbose = True
5067
)
51-
decoder_params = convert_dalle_bart_torch_from_flax_params(
68+
params = convert_dalle_bart_torch_from_flax_params(
5269
self.model_params.pop('decoder'),
5370
layer_count=self.config['decoder_layers'],
5471
is_encoder=False
5572
)
56-
self.decoder.load_state_dict(decoder_params, strict=False)
73+
self.decoder.load_state_dict(params, strict=False)
74+
if torch.cuda.is_available(): self.decoder = self.decoder.cuda()
75+
del params
5776

77+
78+
def init_detokenizer(self):
79+
super().init_detokenizer()
5880
if torch.cuda.is_available():
59-
self.encoder = self.encoder.cuda()
60-
self.decoder = self.decoder.cuda()
6181
self.detokenizer = self.detokenizer.cuda()
62-
82+
6383

6484
def generate_image_tokens(self, text: str, seed: int) -> LongTensor:
6585
text_tokens = self.tokenize_text(text)
6686
text_tokens = torch.tensor(text_tokens).to(torch.long)
6787
if torch.cuda.is_available(): text_tokens = text_tokens.cuda()
6888

89+
if self.is_expendable: self.init_encoder()
6990
print("encoding text tokens")
7091
encoder_state = self.encoder.forward(text_tokens)
92+
if self.is_expendable: del self.encoder
7193

94+
if self.is_expendable: self.init_decoder()
7295
print("sampling image tokens")
7396
torch.manual_seed(seed)
7497
image_tokens = self.decoder.forward(text_tokens, encoder_state)
98+
if self.is_expendable: del self.decoder
7599
return image_tokens
76100

77101

78102
def generate_image(self, text: str, seed: int) -> Image.Image:
79103
image_tokens = self.generate_image_tokens(text, seed)
104+
if self.is_expendable: self.init_detokenizer()
80105
print("detokenizing image")
81106
image = self.detokenizer.forward(image_tokens).to(torch.uint8)
107+
if self.is_expendable: del self.detokenizer
82108
image = Image.fromarray(image.to('cpu').detach().numpy())
83109
return image

0 commit comments

Comments
 (0)