33
44import torch
55from diffusers import AutoencoderKL , CogView4Transformer2DModel , FlowMatchEulerDiscreteScheduler
6- from transformers import AutoTokenizer , GlmConfig , GlmModel
6+ from transformers import AutoTokenizer , GlmModel
77
88
99project_root = pathlib .Path (__file__ ).resolve ().parents [2 ]
@@ -17,39 +17,26 @@ def __init__(self, **kwargs):
1717 super ().__init__ (** kwargs )
1818
1919 def load_condition_models (self ):
20- text_encoder_config = GlmConfig (
21- hidden_size = 32 , intermediate_size = 8 , num_hidden_layers = 2 , num_attention_heads = 4 , head_dim = 8
20+ text_encoder = GlmModel .from_pretrained (
21+ "hf-internal-testing/tiny-random-cogview4" , subfolder = "text_encoder" , torch_dtype = self .text_encoder_dtype
22+ )
23+ tokenizer = AutoTokenizer .from_pretrained (
24+ "hf-internal-testing/tiny-random-cogview4" , subfolder = "tokenizer" , trust_remote_code = True
2225 )
23- text_encoder = GlmModel (text_encoder_config )
24- # TODO(aryan): try to not rely on trust_remote_code by creating dummy tokenizer
25- tokenizer = AutoTokenizer .from_pretrained ("THUDM/glm-4-9b-chat" , trust_remote_code = True )
2626 return {"text_encoder" : text_encoder , "tokenizer" : tokenizer }
2727
2828 def load_latent_models (self ):
2929 torch .manual_seed (0 )
30- vae = AutoencoderKL (
31- block_out_channels = [32 , 64 ],
32- in_channels = 3 ,
33- out_channels = 3 ,
34- down_block_types = ["DownEncoderBlock2D" , "DownEncoderBlock2D" ],
35- up_block_types = ["UpDecoderBlock2D" , "UpDecoderBlock2D" ],
36- latent_channels = 4 ,
37- sample_size = 128 ,
30+ vae = AutoencoderKL .from_pretrained (
31+ "hf-internal-testing/tiny-random-cogview4" , subfolder = "vae" , torch_dtype = self .vae_dtype
3832 )
33+ self .vae_config = vae .config
3934 return {"vae" : vae }
4035
4136 def load_diffusion_models (self ):
4237 torch .manual_seed (0 )
43- transformer = CogView4Transformer2DModel (
44- patch_size = 2 ,
45- in_channels = 4 ,
46- num_layers = 2 ,
47- attention_head_dim = 4 ,
48- num_attention_heads = 4 ,
49- out_channels = 4 ,
50- text_embed_dim = 32 ,
51- time_embed_dim = 8 ,
52- condition_dim = 4 ,
38+ transformer = CogView4Transformer2DModel .from_pretrained (
39+ "hf-internal-testing/tiny-random-cogview4" , subfolder = "transformer" , torch_dtype = self .transformer_dtype
5340 )
5441 scheduler = FlowMatchEulerDiscreteScheduler ()
5542 return {"transformer" : transformer , "scheduler" : scheduler }
0 commit comments