99torch .set_num_threads (os .cpu_count ())
1010
1111from .load_params import convert_dalle_bart_torch_from_flax_params
12- from .min_dalle import MinDalle
12+ from .min_dalle_base import MinDalleBase
1313from .models .dalle_bart_encoder_torch import DalleBartEncoderTorch
1414from .models .dalle_bart_decoder_torch import DalleBartDecoderTorch
1515
1616
17- class MinDalleTorch (MinDalle ):
18- def __init__ (self , is_mega : bool , sample_token_count : int = 256 ):
17+ class MinDalleTorch (MinDalleBase ):
18+ def __init__ (
19+ self ,
20+ is_mega : bool ,
21+ is_expendable : bool = False ,
22+ token_count : int = 256
23+ ):
1924 super ().__init__ (is_mega )
25+ self .is_expendable = is_expendable
26+ self .token_count = token_count
2027 print ("initializing MinDalleTorch" )
28+ if not is_expendable :
29+ self .init_encoder ()
30+ self .init_decoder ()
31+ self .init_detokenizer ()
2132
22- print ("loading encoder" )
33+
34+ def init_encoder (self ):
35+ print ("initializing DalleBartEncoderTorch" )
2336 self .encoder = DalleBartEncoderTorch (
2437 layer_count = self .config ['encoder_layers' ],
2538 embed_count = self .config ['d_model' ],
@@ -28,18 +41,22 @@ def __init__(self, is_mega: bool, sample_token_count: int = 256):
2841 text_token_count = self .config ['max_text_length' ],
2942 glu_embed_count = self .config ['encoder_ffn_dim' ]
3043 )
31- encoder_params = convert_dalle_bart_torch_from_flax_params (
44+ params = convert_dalle_bart_torch_from_flax_params (
3245 self .model_params .pop ('encoder' ),
3346 layer_count = self .config ['encoder_layers' ],
3447 is_encoder = True
3548 )
36- self .encoder .load_state_dict (encoder_params , strict = False )
49+ self .encoder .load_state_dict (params , strict = False )
50+ if torch .cuda .is_available (): self .encoder = self .encoder .cuda ()
51+ del params
52+
3753
38- print ("loading decoder" )
54+ def init_decoder (self ):
55+ print ("initializing DalleBartDecoderTorch" )
3956 self .decoder = DalleBartDecoderTorch (
4057 image_vocab_size = self .config ['image_vocab_size' ],
4158 image_token_count = self .config ['image_length' ],
42- sample_token_count = sample_token_count ,
59+ sample_token_count = self . token_count ,
4360 embed_count = self .config ['d_model' ],
4461 attention_head_count = self .config ['decoder_attention_heads' ],
4562 glu_embed_count = self .config ['decoder_ffn_dim' ],
@@ -48,36 +65,45 @@ def __init__(self, is_mega: bool, sample_token_count: int = 256):
4865 start_token = self .config ['decoder_start_token_id' ],
4966 is_verbose = True
5067 )
51- decoder_params = convert_dalle_bart_torch_from_flax_params (
68+ params = convert_dalle_bart_torch_from_flax_params (
5269 self .model_params .pop ('decoder' ),
5370 layer_count = self .config ['decoder_layers' ],
5471 is_encoder = False
5572 )
56- self .decoder .load_state_dict (decoder_params , strict = False )
73+ self .decoder .load_state_dict (params , strict = False )
74+ if torch .cuda .is_available (): self .decoder = self .decoder .cuda ()
75+ del params
5776
77+
78+ def init_detokenizer (self ):
79+ super ().init_detokenizer ()
5880 if torch .cuda .is_available ():
59- self .encoder = self .encoder .cuda ()
60- self .decoder = self .decoder .cuda ()
6181 self .detokenizer = self .detokenizer .cuda ()
62-
82+
6383
6484 def generate_image_tokens (self , text : str , seed : int ) -> LongTensor :
6585 text_tokens = self .tokenize_text (text )
6686 text_tokens = torch .tensor (text_tokens ).to (torch .long )
6787 if torch .cuda .is_available (): text_tokens = text_tokens .cuda ()
6888
89+ if self .is_expendable : self .init_encoder ()
6990 print ("encoding text tokens" )
7091 encoder_state = self .encoder .forward (text_tokens )
92+ if self .is_expendable : del self .encoder
7193
94+ if self .is_expendable : self .init_decoder ()
7295 print ("sampling image tokens" )
7396 torch .manual_seed (seed )
7497 image_tokens = self .decoder .forward (text_tokens , encoder_state )
98+ if self .is_expendable : del self .decoder
7599 return image_tokens
76100
77101
78102 def generate_image (self , text : str , seed : int ) -> Image .Image :
79103 image_tokens = self .generate_image_tokens (text , seed )
104+ if self .is_expendable : self .init_detokenizer ()
80105 print ("detokenizing image" )
81106 image = self .detokenizer .forward (image_tokens ).to (torch .uint8 )
107+ if self .is_expendable : del self .detokenizer
82108 image = Image .fromarray (image .to ('cpu' ).detach ().numpy ())
83109 return image
0 commit comments