File tree Expand file tree Collapse file tree 1 file changed +6
-4
lines changed Expand file tree Collapse file tree 1 file changed +6
-4
lines changed Original file line number Diff line number Diff line change 55
55
print ("---- Model loading" , flush = True )
56
56
model_id = args .model_id
57
57
if args .dtype == "float16" :
58
- model = FlaxGPTJForCausalLM .from_pretrained (model_id , dtype = jax .numpy .float16 )
58
+ model = FlaxGPTJForCausalLM .from_pretrained (
59
+ pretrained_model_name_or_path = model_id , n_layer = args .num_layer , dtype = jax .numpy .float16 )
59
60
model .params = model .to_fp16 (model .params )
60
61
elif args .dtype == "bfloat16" :
61
- model = FlaxGPTJForCausalLM .from_pretrained (model_id , dtype = jax .numpy .bfloat16 )
62
+ model = FlaxGPTJForCausalLM .from_pretrained (
63
+ pretrained_model_name_or_path = model_id , n_layer = args .num_layer , dtype = jax .numpy .bfloat16 )
62
64
model .params = model .to_bf16 (model .params )
63
65
else :
64
- model = FlaxGPTJForCausalLM .from_pretrained (model_id , dtype = jax . numpy . float32 )
65
- model . config . n_layer = args .num_layer
66
+ model = FlaxGPTJForCausalLM .from_pretrained (
67
+ pretrained_model_name_or_path = model_id , n_layer = args .num_layer , dtype = jax . numpy . float32 )
66
68
print (model .config )
67
69
print ("---- Model loading done" , flush = True )
68
70
You can’t perform that action at this time.
0 commit comments